diff --git a/CHANGELOG.md b/CHANGELOG.md index bd7a668..35d0066 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,45 @@ -# UNRELEASED +# Changelog -# 0.1.2 (January 6th, 2022) +All notable changes to this workspace will be documented in this file. -FEATURES +## Unreleased +## 0.2.0 - 2026-04-08 +### `soundevents` + +- Added `predict_raw_scores_batch_flat` and `predict_raw_scores_batch_into` for lower-allocation batched raw-score access. +- Expanded batched inference coverage with regression tests that verify flat and buffer-reuse paths against sequential inference. +- Removed redundant input validation in `classify_batch` while preserving the existing error behavior for invalid batches. +- Tightened crate metadata and docs.rs configuration so feature-gated APIs, including `Classifier::tiny`, render correctly on published docs. +- Added packaged third-party notices for bundled CED model artifacts. + +### `soundevents-dataset` + +- Packaged the dual-license texts with the published crate and aligned crate metadata for docs.rs and crates.io discovery. +- Kept the crate on its Rust 1.59 / edition 2021 compatibility track while removing the in-source `deny(warnings)` footgun. +- Added packaged third-party notices for bundled AudioSet ontology and label metadata. + +### Workspace + +- Included license files in published package contents for both crates. +- Upgraded README examples from ignored snippets to compile-checked doctests across the workspace. + +## 0.1.0 - 2026-04-08 + +### `soundevents` + +- Initial public release of the ONNX Runtime wrapper for CED AudioSet classifiers. +- Added file, memory, and bundled-model loading paths plus configurable graph optimization. +- Added ranked top-k helpers, raw-score accessors, and chunked inference with mean/max aggregation. +- Added equal-length batch APIs for clip inference and chunked window batching for higher-throughput services. + +### `soundevents-dataset` + +- Initial public release of the typed AudioSet dataset companion crate. +- Included both the 527-class rated label set and the full 632-entry ontology as `&'static` generated data. +- Kept the crate `no_std`-friendly, allocation-free at runtime, and compatible with Rust 1.59. + +### `xtask` + +- Added code generation for the rated label set and ontology modules from upstream AudioSet source data. diff --git a/Cargo.toml b/Cargo.toml index 8e6bfa7..462c842 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,12 @@ resolver = "3" thiserror = { version = "2", default-features = false } serde = "1" -soundevents-dataset = { version = "0.1", path = "soundevents-dataset", default-features = false } +soundevents-dataset = { version = "0.2", path = "soundevents-dataset", default-features = false } + +[workspace.package] +license = "MIT OR Apache-2.0" +repository = "https://github.com/findit-ai/soundevents" +homepage = "https://github.com/findit-ai/soundevents" [profile.bench] opt-level = 3 diff --git a/LICENSE-APACHE b/LICENSE-APACHE index 16fe87b..096b193 100644 --- a/LICENSE-APACHE +++ b/LICENSE-APACHE @@ -186,7 +186,7 @@ APPENDIX: How to apply the Apache License to your work. same "printed page" as the copyright notice for easier identification within third-party archives. -Copyright [yyyy] [name of copyright owner] +Copyright (c) 2026 The FinDIT studio developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/LICENSE-MIT b/LICENSE-MIT index e69282e..2f6b69c 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright (c) 2015 The Rust Project Developers +Copyright (c) 2026 The FinDIT studio developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated diff --git a/README.md b/README.md index f262e0d..dd7777d 100644 --- a/README.md +++ b/README.md @@ -22,20 +22,25 @@ Production-oriented Rust inference for [CED](https://arxiv.org/abs/2308.11957) A - **Drop-in CED inference** — load any [CED](https://arxiv.org/abs/2308.11957) AudioSet ONNX model (or use the bundled `tiny` variant) and run it directly on `&[f32]` PCM samples. No Python, no preprocessing pipeline. - **Typed labels, not bare integers** — every prediction comes back as an [`EventPrediction`] carrying a `&'static RatedSoundEvent` from [`soundevents-dataset`](./soundevents-dataset), so you get the canonical AudioSet name, the `/m/...` id, the model class index, and the confidence in one struct. - **Compile-time class-count guarantee** — the `NUM_CLASSES = 527` constant comes from the rated dataset at codegen time. If a model returns the wrong number of classes you get a typed [`ClassifierError::UnexpectedClassCount`] instead of a silent mismatch. -- **Long-clip chunking built in** — `classify_chunked` / `classify_all_chunked` window the input at a configurable hop, run inference on each chunk, and aggregate the per-chunk confidences with either `Mean` or `Max`. Defaults match CED's 10 s training window (160 000 samples at 16 kHz). +- **Long-clip chunking built in** — `classify_chunked` / `classify_all_chunked` window the input at a configurable hop, run inference on each chunk, and aggregate the per-chunk confidences with either `Mean` or `Max`. Defaults match CED's 10 s training window (160 000 samples at 16 kHz), and fixed-size chunk batches can now be packed into one model call. - **Top-k via a tiny min-heap** — `classify(samples, k)` does not allocate a full 527-element scores vector to find the top results. +- **Batch-ready low-level API** — `predict_raw_scores_batch`, `predict_raw_scores_batch_flat`, `predict_raw_scores_batch_into`, `classify_all_batch`, and `classify_batch` accept equal-length clip batches for service-layer batching. - **Bring-your-own model or bundle one** — load from a path, from in-memory bytes, or enable the `bundled-tiny` feature to embed `models/tiny.onnx` directly into your binary. ## Quick start ```toml [dependencies] -soundevents = "0.1" +soundevents = "0.2" ``` -```rust,ignore +```rust,no_run use soundevents::{Classifier, Options}; +fn load_mono_16k_audio(_: &str) -> Result, Box> { + Ok(vec![0.0; 16_000]) +} + fn main() -> Result<(), Box> { let mut classifier = Classifier::from_file("soundevents/models/tiny.onnx")?; @@ -61,9 +66,13 @@ fn main() -> Result<(), Box> { `Classifier::classify_chunked` slides a window over the input and aggregates each chunk's per-class confidences. The defaults (10 s window, 10 s hop, mean aggregation) match CED's training setup; tune them for overlap or peak-pooling. -```rust,ignore +```rust,no_run use soundevents::{ChunkAggregation, ChunkingOptions, Classifier}; +fn load_long_clip() -> Result, Box> { + Ok(vec![0.0; 320_000]) +} + fn main() -> Result<(), Box> { let mut classifier = Classifier::from_file("soundevents/models/tiny.onnx")?; let samples: Vec = load_long_clip()?; @@ -71,6 +80,8 @@ fn main() -> Result<(), Box> { let opts = ChunkingOptions::default() // 5 s overlap (50%) between adjacent windows .with_hop_samples(80_000) + // Batch up to 4 equal-length windows per session.run() + .with_batch_size(4) // Keep the loudest detection in any window instead of averaging .with_aggregation(ChunkAggregation::Max); @@ -111,18 +122,29 @@ If upstream releases new weights, or you cloned without the model files, refetch The script downloads the `*.onnx` artifact from each `mispeech/ced-*` Hugging Face repo and writes it as `soundevents/models/.onnx`. +See [THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md) for upstream model +sources and attribution details. + ### Bundled tiny model Enable the `bundled-tiny` feature to embed `models/tiny.onnx` into your binary — useful for CLI tools and self-contained services where you don't want to ship a separate model file. ```toml -soundevents = { version = "0.1", features = ["bundled-tiny"] } +soundevents = { version = "0.2", features = ["bundled-tiny"] } ``` -```rust,ignore +```rust +# #[cfg(feature = "bundled-tiny")] use soundevents::{Classifier, Options}; +# fn main() -> Result<(), Box> { +# #[cfg(feature = "bundled-tiny")] +# { let mut classifier = Classifier::tiny(Options::default())?; +# let _ = &mut classifier; +# } +# Ok(()) +# } ``` ## Features @@ -139,6 +161,8 @@ The full input/output contract: | `DEFAULT_CHUNK_SAMPLES` | `160_000` | Default 10 s window/hop for chunked inference. | | `NUM_CLASSES` | `527` | Number of CED output classes — derived at compile time from `RatedSoundEvent::events().len()`. | +For low-level batching, every clip in `predict_raw_scores_batch*` / `classify_*_batch` must be non-empty and have the same sample count. `predict_raw_scores_batch_flat` returns one row-major `Vec`, and `predict_raw_scores_batch_into` lets callers reuse their own output buffer to avoid per-call result allocations. `classify_chunked` uses the same equal-length restriction internally when `ChunkingOptions::batch_size() > 1`, which is naturally satisfied for fixed-size windows and automatically falls back to smaller batches for the final short tail chunk. + ## Development Regenerate the dataset from upstream sources: @@ -162,6 +186,8 @@ cargo test Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. +Bundled third-party model attributions and source licenses are documented in +[THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md). Copyright (c) 2026 FinDIT studio authors. diff --git a/soundevents-dataset/Cargo.toml b/soundevents-dataset/Cargo.toml index 645a160..bceff26 100644 --- a/soundevents-dataset/Cargo.toml +++ b/soundevents-dataset/Cargo.toml @@ -1,12 +1,17 @@ [package] name = "soundevents-dataset" -version = "0.1.0" +version = "0.2.0" +# Intentionally kept on edition 2021 / MSRV 1.59 so this no_std static-data +# crate remains usable from older toolchains. edition = "2021" -repository = "https://github.com/findit-ai/soundevents" -homepage = "https://github.com/findit-ai/soundevents" -documentation = "https://docs.rs/soundevents" +documentation = "https://docs.rs/soundevents-dataset" description = "Audio Set Ontology aims to provide a comprehensive set of categories to describe sound events." -license = "MIT OR Apache-2.0" +license.workspace = true +repository.workspace = true +homepage.workspace = true +readme = "README.md" +keywords = ["audioset", "sound-events", "ontology", "dataset", "no-std"] +categories = ["data-structures", "multimedia::audio", "no-std", "no-std::no-alloc"] rust-version = "1.59.0" [features] diff --git a/soundevents-dataset/LICENSE-APACHE b/soundevents-dataset/LICENSE-APACHE new file mode 100644 index 0000000..4922fee --- /dev/null +++ b/soundevents-dataset/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf of + any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright (c) 2026 The FinDIT studio developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/soundevents-dataset/LICENSE-MIT b/soundevents-dataset/LICENSE-MIT new file mode 100644 index 0000000..2f6b69c --- /dev/null +++ b/soundevents-dataset/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2026 The FinDIT studio developers + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/soundevents-dataset/README.md b/soundevents-dataset/README.md index 59b468f..3a34f0d 100644 --- a/soundevents-dataset/README.md +++ b/soundevents-dataset/README.md @@ -10,9 +10,9 @@ Typed, zero-allocation Rust access to [Google's AudioSet](https://research.googl [Build][CI-url] [codecov][codecov-url] -[docs.rs][doc-url] -[crates.io][crates-url] -[crates.io][crates-url] +[docs.rs][doc-url] +[crates.io][crates-url] +[crates.io][crates-url] license @@ -21,17 +21,17 @@ Typed, zero-allocation Rust access to [Google's AudioSet](https://research.googl ```toml [dependencies] -soundevents-dataset = "0.1" +soundevents-dataset = "0.2" ``` By default this pulls in the [`rated`](#rated--audioset-rated-label-set-527-entries) module — the 527-class label set used by released AudioSet/YAMNet/VGGish models. To use the [`ontology`](#ontology--full-audioset-taxonomy-632-entries) view instead (or in addition), pick the features explicitly: ```toml # Just the full AudioSet ontology, no rated set. -soundevents-dataset = { version = "0.1", default-features = false, features = ["std", "ontology"] } +soundevents-dataset = { version = "0.2", default-features = false, features = ["std", "ontology"] } # Both views. -soundevents-dataset = { version = "0.1", features = ["ontology"] } +soundevents-dataset = { version = "0.2", features = ["ontology"] } ``` ## Two views, two modules @@ -45,76 +45,12 @@ The two are independent: each lives in its own module, has its own `&'static` co ### `rated` — AudioSet rated label set (527 entries) -```rust,ignore -use soundevents_dataset::rated::RatedSoundEvent; - -// Look up by name, alias, or AudioSet id (case-insensitive). -let speech = RatedSoundEvent::from_key("Speech"); -assert_eq!(speech.len(), 1); -assert_eq!(speech[0].name(), "Speech"); -assert_eq!(speech[0].index(), 0); // class index in the released model output - -// Decode a model's argmax: `scores: [f32; 527]`. -let predicted = scores.iter().enumerate() - .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) - .unwrap().0; -let label = RatedSoundEvent::from_index(predicted).unwrap(); -println!("predicted: {}", label.name()); - -// AudioSet id lookup uses the same case-insensitive map. -assert_eq!( - RatedSoundEvent::from_key("/m/09x0r"), - RatedSoundEvent::from_key("/M/09X0R"), -); - -// Iterate every rated class in CSV order. -for entry in RatedSoundEvent::events() { - println!("{:>3} {} {}", entry.index(), entry.id(), entry.name()); -} -``` - `RatedSoundEvent` exposes the same metadata accessors as `SoundEvent` (`id`, `name`, `description`, `aliases`, `citation_uri`, `children`, `restrictions`) plus a rated-only [`index()`](https://docs.rs/soundevents-dataset) — the integer 0..527 used as the position in released AudioSet models' output vectors. Walking `children()` stays inside the rated namespace: any ontology child that is *not* in the rated set is dropped, so the hierarchy remains self-consistent. -### `ontology` — full AudioSet taxonomy (632 entries) - -```rust,ignore -use soundevents_dataset::ontology::{SoundEvent}; -use soundevents_dataset::Restriction; - -// Walk the hierarchy from an abstract container down to leaves. -let voice = SoundEvent::from_key("Human voice")[0]; -assert!(voice.restrictions().contains(&Restriction::Abstract)); -for child in voice.children() { - println!("- {}", child.name()); -} - -// `from_key` returns a slice — most aliases are unique (one match), a few -// ambiguous ones like "Inside" map to several entries. -let inside = SoundEvent::from_key("Inside"); -assert!(inside.len() > 1); - -// Stable 64-bit code (SipHash of the canonical name) for compact storage. -let code = SoundEvent::from_key("Speech")[0].encode(); -let round_tripped = SoundEvent::from_code(code).unwrap(); -assert_eq!(round_tripped.name(), "Speech"); -``` - ### Case-insensitive, separator-distinct lookup `from_key` is keyed by [`UncasedStr`](https://docs.rs/uncased), so any case form of an alias resolves to the same entry without us having to enumerate every possibility: -```rust,ignore -use soundevents_dataset::rated::RatedSoundEvent; - -let queries = [ - "man speaking", "MAN SPEAKING", "Man Speaking", "mAn SpEaKiNg", - "man_speaking", "manSpeaking", "Man-Speaking", -]; -for q in queries { - assert_eq!(RatedSoundEvent::from_key(q)[0].id(), "/m/05zppz"); -} -``` - Separator styles are still indexed independently (`"man speaking"` ≠ `"man_speaking"` ≠ `"man-speaking"` ≠ `"manSpeaking"`), so you only pay for the four shapes the codegen actually emits — every case variant of each shape collapses into one phf bucket. ## Features @@ -139,16 +75,17 @@ cargo xtask codegen #### License -`soundevents` is under the terms of both the MIT license and the +`soundevents-dataset` is under the terms of both the MIT license and the Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE), [LICENSE-MIT](LICENSE-MIT) for details. +Bundled AudioSet metadata attribution and upstream license details are +documented in [THIRD_PARTY_NOTICES.md](THIRD_PARTY_NOTICES.md). Copyright (c) 2026 FinDIT studio authors. [Github-url]: https://github.com/Findit-AI/soundevents [CI-url]: https://github.com/Findit-AI/soundevents/actions/workflows/ci.yml -[doc-url]: https://docs.rs/soundevents -[crates-url]: https://crates.io/crates/soundevents +[doc-url]: https://docs.rs/soundevents-dataset +[crates-url]: https://crates.io/crates/soundevents-dataset [codecov-url]: https://app.codecov.io/gh/Findit-AI/soundevents/ - diff --git a/soundevents-dataset/THIRD_PARTY_NOTICES.md b/soundevents-dataset/THIRD_PARTY_NOTICES.md new file mode 100644 index 0000000..e7ba0d1 --- /dev/null +++ b/soundevents-dataset/THIRD_PARTY_NOTICES.md @@ -0,0 +1,32 @@ +# Third-Party Notices for `soundevents-dataset` + +This crate redistributes AudioSet metadata from Google and generates static +Rust lookup tables from those upstream data files. + +## AudioSet ontology + +The file `assets/ontology.json` is sourced from the official AudioSet ontology +repository and is used to generate `src/ontology/generated.rs`. + +- Upstream repository: +- Upstream data file: + +- The upstream repository states: + "The ontology is made available by Google Inc. under a Creative Commons + Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license." + +## AudioSet released label metadata + +The file `assets/class_labels_indices.csv` is sourced from the AudioSet release +metadata and is used to generate `src/rated/generated.rs`. + +- AudioSet download page: +- The AudioSet download page states: + "The dataset is made available by Google Inc. under a Creative Commons + Attribution 4.0 International (CC BY 4.0) license, while the ontology is + available under a Creative Commons Attribution-ShareAlike 4.0 International + (CC BY-SA 4.0) license." + +These upstream files are redistributed for interoperability with AudioSet-based +models. When further redistributing this crate or derived artifacts, ensure the +applicable upstream attribution and license terms continue to be satisfied. diff --git a/soundevents-dataset/src/lib.rs b/soundevents-dataset/src/lib.rs index ba36036..0180944 100644 --- a/soundevents-dataset/src/lib.rs +++ b/soundevents-dataset/src/lib.rs @@ -2,7 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] -#![deny(missing_docs, warnings)] +#![deny(missing_docs)] #![forbid(unsafe_code)] #[cfg(feature = "ontology")] diff --git a/soundevents/Cargo.toml b/soundevents/Cargo.toml index e47d456..4e35cbb 100644 --- a/soundevents/Cargo.toml +++ b/soundevents/Cargo.toml @@ -1,21 +1,30 @@ [package] name = "soundevents" -version = "0.1.0" +version = "0.2.0" edition = "2024" description = "Production-oriented Rust inference wrapper for CED AudioSet classifiers." -license = "MIT OR Apache-2.0" -repository = "https://github.com/findit-ai/soundevents" -homepage = "https://github.com/findit-ai/soundevents" +license.workspace = true +repository.workspace = true +homepage.workspace = true documentation = "https://docs.rs/soundevents" +readme = "README.md" +keywords = ["audioset", "audio-classification", "ced", "sound-events", "onnx"] +categories = ["multimedia::audio", "science"] rust-version = "1.85" include = [ "Cargo.toml", + "README.md", + "THIRD_PARTY_NOTICES.md", "build.rs", + "LICENSE-APACHE", + "LICENSE-MIT", "models/tiny.onnx", "src/**", ] [features] +# Tiny CED is ~6.4 MB, so bundling stays opt-in instead of inflating every +# binary by default. default = [] bundled-tiny = [] @@ -33,3 +42,7 @@ unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(all_tests)', 'cfg(tarpaulin)', ] } + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/soundevents/LICENSE-APACHE b/soundevents/LICENSE-APACHE new file mode 100644 index 0000000..4922fee --- /dev/null +++ b/soundevents/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf of + any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright (c) 2026 The FinDIT studio developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/soundevents/LICENSE-MIT b/soundevents/LICENSE-MIT new file mode 100644 index 0000000..2f6b69c --- /dev/null +++ b/soundevents/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2026 The FinDIT studio developers + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/soundevents/THIRD_PARTY_NOTICES.md b/soundevents/THIRD_PARTY_NOTICES.md new file mode 100644 index 0000000..cbc1e70 --- /dev/null +++ b/soundevents/THIRD_PARTY_NOTICES.md @@ -0,0 +1,21 @@ +# Third-Party Notices for `soundevents` + +This crate redistributes third-party model artifacts in addition to the +project's own Rust source code. + +## CED model artifacts + +The published crate bundles `models/tiny.onnx`, an ONNX export of the CED-Tiny +audio classification model. + +- Upstream model card: +- Upstream repository referenced by the model card: + +- Paper referenced by the model card: + +- License reported by the upstream model card at the time this crate was + packaged: Apache-2.0 + +The repository may also contain additional CED ONNX variants for development +or benchmarking. See the upstream model cards for the full provenance and +license terms of those artifacts. diff --git a/soundevents/src/lib.rs b/soundevents/src/lib.rs index f84adb4..39baa14 100644 --- a/soundevents/src/lib.rs +++ b/soundevents/src/lib.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] -#![deny(missing_docs, warnings)] +#![deny(missing_docs)] #![forbid(unsafe_code)] use ort::{ @@ -114,6 +114,7 @@ pub enum ChunkAggregation { pub struct ChunkingOptions { window_samples: usize, hop_samples: usize, + batch_size: usize, aggregation: ChunkAggregation, } @@ -122,6 +123,7 @@ impl Default for ChunkingOptions { Self { window_samples: DEFAULT_CHUNK_SAMPLES, hop_samples: DEFAULT_CHUNK_SAMPLES, + batch_size: 1, aggregation: ChunkAggregation::Mean, } } @@ -140,6 +142,12 @@ impl ChunkingOptions { self.hop_samples } + /// Returns the maximum number of equal-length chunks to batch into one model call. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn batch_size(&self) -> usize { + self.batch_size + } + /// Returns the aggregation strategy. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn aggregation(&self) -> ChunkAggregation { @@ -160,6 +168,13 @@ impl ChunkingOptions { self } + /// Sets the chunk batch size used by batched chunked inference. + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + /// Sets the aggregation strategy. #[cfg_attr(not(tarpaulin), inline(always))] pub const fn with_aggregation(mut self, aggregation: ChunkAggregation) -> Self { @@ -186,16 +201,21 @@ pub enum ClassifierError { /// Empty audio was passed to the classifier. #[error("input audio is empty; expected mono {SAMPLE_RATE_HZ} Hz samples")] EmptyInput, + /// An empty batch was passed to the classifier. + #[error("input batch is empty; expected at least one mono {SAMPLE_RATE_HZ} Hz clip")] + EmptyBatch, /// Model output is empty. #[error("model returned empty output")] EmptyOutput, /// The model returned an unexpected output shape. #[error( - "unexpected model output shape {shape:?}; expected batch-one scores for {expected} classes" + "unexpected model output shape {shape:?}; expected batch scores for {expected_batch} x {expected_classes}" )] UnexpectedOutputShape { + /// The expected batch size. + expected_batch: usize, /// The expected number of classes. - expected: usize, + expected_classes: usize, /// The actual output shape returned by the model. shape: Vec, }, @@ -207,6 +227,26 @@ pub enum ClassifierError { /// The actual number of classes returned by the model. actual: usize, }, + /// Batch members must have the same length to be packed into one tensor. + #[error( + "batched inference requires equal clip lengths; expected {expected} samples, got {actual}" + )] + MismatchedBatchLength { + /// The clip length established by the first batch member. + expected: usize, + /// The mismatched clip length encountered later in the batch. + actual: usize, + }, + /// The requested batch is too large to pack or buffer safely. + #[error( + "batched inference request is too large to allocate safely (batch={batch_size}, item_len={item_len})" + )] + BatchTooLarge { + /// Number of items in the batch. + batch_size: usize, + /// Length of each item in elements. + item_len: usize, + }, /// A model class index could not be resolved to a rated entry. #[error("no rated sound event exists for model class index {index}")] MissingRatedEventIndex { @@ -215,13 +255,15 @@ pub enum ClassifierError { }, /// Invalid chunking parameters were provided. #[error( - "chunking options require non-zero window and hop sizes (window={window_samples}, hop={hop_samples})" + "chunking options require non-zero window, hop, and batch sizes (window={window_samples}, hop={hop_samples}, batch={batch_size})" )] InvalidChunkingOptions { /// The chunk window size in samples. window_samples: usize, /// The chunk hop size in samples. hop_samples: usize, + /// The chunk batch size. + batch_size: usize, }, } @@ -306,6 +348,7 @@ pub struct Classifier { session: Session, input_name: SmolStr, output_name: SmolStr, + input_scratch: Vec, confidence_scratch: Vec, } @@ -353,22 +396,137 @@ impl Classifier { self.with_raw_scores(samples_16k, |raw_scores| Ok(raw_scores.to_vec())) } + /// Run the model on a batch of equal-length mono 16 kHz clips. + /// + /// Every clip in `batch_16k` must be non-empty and have the same number of + /// samples. This low-level API is intended for service-layer batching and + /// for chunked inference over fixed windows. If you need a single + /// row-major buffer instead of one allocation per clip, prefer + /// [`predict_raw_scores_batch_flat`](Self::predict_raw_scores_batch_flat) or + /// [`predict_raw_scores_batch_into`](Self::predict_raw_scores_batch_into). + pub fn predict_raw_scores_batch( + &mut self, + batch_16k: &[&[f32]], + ) -> Result>, ClassifierError> { + self.with_raw_scores_batch(batch_16k, |raw_scores, batch_size| { + Ok( + raw_scores + .chunks_exact(NUM_CLASSES) + .take(batch_size) + .map(|scores| scores.to_vec()) + .collect(), + ) + }) + } + + /// Run the model on a batch of equal-length mono 16 kHz clips and return + /// all raw scores in one row-major buffer. + /// + /// The returned vector contains `batch_16k.len() * NUM_CLASSES` elements in + /// `[batch, class]` order, so callers can iterate with + /// `.chunks_exact(NUM_CLASSES)`. + pub fn predict_raw_scores_batch_flat( + &mut self, + batch_16k: &[&[f32]], + ) -> Result, ClassifierError> { + self.with_raw_scores_batch(batch_16k, |raw_scores, _| Ok(raw_scores.to_vec())) + } + + /// Run the model on a batch of equal-length mono 16 kHz clips and write the + /// raw scores into a caller-provided row-major buffer. + /// + /// `out` is cleared before writing and then filled with + /// `batch_16k.len() * NUM_CLASSES` elements in `[batch, class]` order. + pub fn predict_raw_scores_batch_into( + &mut self, + batch_16k: &[&[f32]], + out: &mut Vec, + ) -> Result<(), ClassifierError> { + self.with_raw_scores_batch(batch_16k, |raw_scores, batch_size| { + out.clear(); + let total_scores = checked_batch_len(batch_size, NUM_CLASSES)?; + out + .try_reserve(total_scores) + .map_err(|_| ClassifierError::BatchTooLarge { + batch_size, + item_len: NUM_CLASSES, + })?; + out.extend_from_slice(raw_scores); + Ok(()) + }) + } + fn with_raw_scores( &mut self, samples_16k: &[f32], f: impl FnOnce(&[f32]) -> Result, ) -> Result { - ensure_non_empty(samples_16k)?; + self.with_raw_scores_batch(&[samples_16k], |raw_scores, _| f(raw_scores)) + } + + fn with_raw_scores_batch( + &mut self, + batch_16k: &[&[f32]], + f: impl FnOnce(&[f32], usize) -> Result, + ) -> Result { + let chunk_len = validate_batch_inputs(batch_16k)?; + self.with_validated_raw_scores_batch(batch_16k, batch_16k.len(), chunk_len, f) + } - let input_ref = TensorRef::from_array_view(([1usize, samples_16k.len()], samples_16k))?; + fn with_validated_raw_scores_batch( + &mut self, + batch_16k: &[&[f32]], + batch_size: usize, + chunk_len: usize, + f: impl FnOnce(&[f32], usize) -> Result, + ) -> Result { + let total_samples = checked_batch_len(batch_size, chunk_len)?; + + self.input_scratch.clear(); + self + .input_scratch + .try_reserve(total_samples) + .map_err(|_| ClassifierError::BatchTooLarge { + batch_size, + item_len: chunk_len, + })?; + for clip in batch_16k { + self.input_scratch.extend_from_slice(clip); + } + + let input_ref = + TensorRef::from_array_view(([batch_size, chunk_len], self.input_scratch.as_slice()))?; let outputs = self .session .run(ort::inputs![self.input_name.as_str() => input_ref])?; let (shape, raw_scores) = outputs[self.output_name.as_str()].try_extract_tensor::()?; - validate_output(shape, raw_scores)?; + validate_output(shape, raw_scores, batch_size)?; - f(raw_scores) + f(raw_scores, batch_size) + } + + /// Classify a batch of equal-length mono 16 kHz clips and return every class in model order. + pub fn classify_all_batch( + &mut self, + batch_16k: &[&[f32]], + ) -> Result>, ClassifierError> { + self.with_raw_scores_batch(batch_16k, |raw_scores, batch_size| { + raw_scores + .chunks_exact(NUM_CLASSES) + .take(batch_size) + .map(|row| { + row + .iter() + .copied() + .enumerate() + .map(|(class_index, raw_score)| { + EventPrediction::from_confidence(class_index, sigmoid(raw_score)) + }) + .collect() + }) + .collect() + }) } /// Classify a mono 16 kHz clip and return every class in model order. @@ -405,6 +563,33 @@ impl Classifier { }) } + /// Classify a batch of equal-length mono 16 kHz clips and return the top `k` classes for each clip. + pub fn classify_batch( + &mut self, + batch_16k: &[&[f32]], + top_k: usize, + ) -> Result>, ClassifierError> { + let chunk_len = validate_batch_inputs(batch_16k)?; + let batch_size = batch_16k.len(); + + if top_k == 0 { + return Ok((0..batch_size).map(|_| Vec::new()).collect()); + } + + self.with_validated_raw_scores_batch( + batch_16k, + batch_size, + chunk_len, + |raw_scores, batch_size| { + raw_scores + .chunks_exact(NUM_CLASSES) + .take(batch_size) + .map(|row| top_k_from_scores(row.iter().copied().enumerate(), top_k, sigmoid)) + .collect() + }, + ) + } + /// Classify a long clip by chunking it into windows and aggregating chunk confidences. pub fn classify_all_chunked( &mut self, @@ -475,6 +660,7 @@ impl Classifier { session, input_name, output_name, + input_scratch: Vec::new(), confidence_scratch: Vec::with_capacity(NUM_CLASSES), }) } @@ -503,46 +689,66 @@ fn fill_aggregated_confidences( ensure_non_empty(samples_16k)?; validate_chunking(options)?; - let mut chunks = chunk_slices(samples_16k, options.window_samples(), options.hop_samples()); - let Some(first_chunk) = chunks.next() else { - return Err(ClassifierError::EmptyInput); - }; - classifier.with_raw_scores(first_chunk, |raw_scores| { - aggregated.clear(); - aggregated.extend(raw_scores.iter().copied().map(sigmoid)); - Ok(()) - })?; - let mut chunk_count = 1usize; + let mut chunk_count = 0usize; + let mut initialized = false; + + for batch in chunk_batches(samples_16k, options) { + accumulate_chunk_batch( + classifier, + aggregated, + &batch, + options.aggregation(), + initialized, + )?; + chunk_count += batch.len(); + initialized = true; + } + + if matches!(options.aggregation(), ChunkAggregation::Mean) && chunk_count > 1 { + let denominator = chunk_count as f32; + for aggregate in aggregated.iter_mut() { + *aggregate /= denominator; + } + } + + Ok(()) +} + +fn accumulate_chunk_batch( + classifier: &mut Classifier, + aggregated: &mut Vec, + batch: &[&[f32]], + aggregation: ChunkAggregation, + initialized: bool, +) -> Result<(), ClassifierError> { + classifier.with_raw_scores_batch(batch, |raw_scores, batch_size| { + for (row_index, row) in raw_scores + .chunks_exact(NUM_CLASSES) + .take(batch_size) + .enumerate() + { + if !initialized && row_index == 0 { + aggregated.clear(); + aggregated.extend(row.iter().copied().map(sigmoid)); + continue; + } - for chunk in chunks { - classifier.with_raw_scores(chunk, |raw_scores| { - match options.aggregation() { + match aggregation { ChunkAggregation::Mean => { - for (aggregate, raw_score) in aggregated.iter_mut().zip(raw_scores.iter().copied()) { + for (aggregate, raw_score) in aggregated.iter_mut().zip(row.iter().copied()) { *aggregate += sigmoid(raw_score); } } ChunkAggregation::Max => { - for (aggregate, raw_score) in aggregated.iter_mut().zip(raw_scores.iter().copied()) { + for (aggregate, raw_score) in aggregated.iter_mut().zip(row.iter().copied()) { *aggregate = aggregate.max(sigmoid(raw_score)); } } } - - Ok(()) - })?; - - chunk_count += 1; - } - - if matches!(options.aggregation(), ChunkAggregation::Mean) && chunk_count > 1 { - let denominator = chunk_count as f32; - for aggregate in aggregated.iter_mut() { - *aggregate /= denominator; } - } - Ok(()) + Ok(()) + }) } #[cfg_attr(not(tarpaulin), inline(always))] @@ -594,39 +800,119 @@ fn ensure_non_empty(samples_16k: &[f32]) -> Result<(), ClassifierError> { } fn validate_chunking(options: ChunkingOptions) -> Result<(), ClassifierError> { - if options.window_samples() == 0 || options.hop_samples() == 0 { + if options.window_samples() == 0 || options.hop_samples() == 0 || options.batch_size() == 0 { return Err(ClassifierError::InvalidChunkingOptions { window_samples: options.window_samples(), hop_samples: options.hop_samples(), + batch_size: options.batch_size(), }); } Ok(()) } -fn validate_output(shape: &ort::value::Shape, raw_scores: &[f32]) -> Result<(), ClassifierError> { +#[cfg_attr(not(tarpaulin), inline(always))] +fn checked_batch_len(batch_size: usize, item_len: usize) -> Result { + batch_size + .checked_mul(item_len) + .ok_or(ClassifierError::BatchTooLarge { + batch_size, + item_len, + }) +} + +fn validate_output( + shape: &ort::value::Shape, + raw_scores: &[f32], + expected_batch_size: usize, +) -> Result<(), ClassifierError> { if raw_scores.is_empty() { return Err(ClassifierError::EmptyOutput); } - if raw_scores.len() != NUM_CLASSES { - return Err(ClassifierError::UnexpectedClassCount { - expected: NUM_CLASSES, - actual: raw_scores.len(), + let expected_values = checked_batch_len(expected_batch_size, NUM_CLASSES)?; + if raw_scores.len() != expected_values { + if raw_scores.len() % expected_batch_size.max(1) == 0 { + return Err(ClassifierError::UnexpectedClassCount { + expected: NUM_CLASSES, + actual: raw_scores.len() / expected_batch_size.max(1), + }); + } + + return Err(ClassifierError::UnexpectedOutputShape { + expected_batch: expected_batch_size, + expected_classes: NUM_CLASSES, + shape: shape.to_vec(), }); } + // Keep shape validation strict: released CED exports use `[batch, 527]` + // (and some runtimes collapse batch-one to `[527]`). Accepting arbitrary + // shapes with the same element count would make tensor-packing bugs much + // harder to catch. match &shape[..] { - [classes] if *classes as usize == NUM_CLASSES => Ok(()), - [batch, classes] if *batch == 1 && *classes as usize == NUM_CLASSES => Ok(()), - _ if shape.num_elements() == NUM_CLASSES => Ok(()), + [classes] if expected_batch_size == 1 && *classes as usize == NUM_CLASSES => Ok(()), + [batch, classes] + if *batch as usize == expected_batch_size && *classes as usize == NUM_CLASSES => + { + Ok(()) + } _ => Err(ClassifierError::UnexpectedOutputShape { - expected: NUM_CLASSES, + expected_batch: expected_batch_size, + expected_classes: NUM_CLASSES, shape: shape.to_vec(), }), } } +/// Validates a batch of clips and returns the common clip length in samples. +fn validate_batch_inputs(batch_16k: &[&[f32]]) -> Result { + let Some(first) = batch_16k.first() else { + return Err(ClassifierError::EmptyBatch); + }; + + ensure_non_empty(first)?; + let expected = first.len(); + + for clip in &batch_16k[1..] { + ensure_non_empty(clip)?; + if clip.len() != expected { + return Err(ClassifierError::MismatchedBatchLength { + expected, + actual: clip.len(), + }); + } + } + + Ok(expected) +} + +/// Groups consecutive equal-length chunks into batches so one tensor never +/// mixes the usual short tail chunk with full-size windows. +fn chunk_batches(samples: &[f32], options: ChunkingOptions) -> impl Iterator> { + let mut chunks = + chunk_slices(samples, options.window_samples(), options.hop_samples()).peekable(); + + std::iter::from_fn(move || { + let first = chunks.next()?; + let first_len = first.len(); + let mut batch = Vec::with_capacity(options.batch_size()); + batch.push(first); + + while batch.len() < options.batch_size() { + let Some(next_len) = chunks.peek().map(|chunk| chunk.len()) else { + break; + }; + if next_len != first_len { + break; + } + batch.push(chunks.next().expect("peeked chunk must exist")); + } + + Some(batch) + }) +} + fn chunk_slices( samples: &[f32], window_samples: usize, @@ -655,6 +941,17 @@ fn sigmoid(x: f32) -> f32 { mod tests { use super::*; + #[cfg(feature = "bundled-tiny")] + fn pseudo_audio(len: usize, mut seed: u64) -> Vec { + let mut samples = Vec::with_capacity(len); + for _ in 0..len { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let value = ((seed >> 40) as u32) as f32 / u32::MAX as f32; + samples.push(value * 2.0 - 1.0); + } + samples + } + #[test] fn rated_indices_round_trip() { for event in RatedSoundEvent::events() { @@ -681,9 +978,137 @@ mod tests { assert_eq!(options.window_samples(), DEFAULT_CHUNK_SAMPLES); assert_eq!(options.hop_samples(), DEFAULT_CHUNK_SAMPLES); + assert_eq!(options.batch_size(), 1); assert_eq!(options.aggregation(), ChunkAggregation::Mean); } + #[test] + fn chunking_options_can_configure_batch_size() { + let options = ChunkingOptions::default().with_batch_size(8); + assert_eq!(options.batch_size(), 8); + } + + #[test] + fn validate_batch_inputs_requires_equal_non_empty_clips() { + assert!(matches!( + validate_batch_inputs(&[]), + Err(ClassifierError::EmptyBatch) + )); + assert!(matches!( + validate_batch_inputs(&[&[]]), + Err(ClassifierError::EmptyInput) + )); + assert!(matches!( + validate_batch_inputs(&[&[0.0, 1.0], &[0.0]]), + Err(ClassifierError::MismatchedBatchLength { + expected: 2, + actual: 1, + }) + )); + } + + #[test] + fn checked_batch_len_reports_overflow() { + assert!(matches!( + checked_batch_len(usize::MAX, 2), + Err(ClassifierError::BatchTooLarge { + batch_size, + item_len: 2, + }) if batch_size == usize::MAX + )); + } + + #[cfg(feature = "bundled-tiny")] + #[test] + fn batched_predict_raw_scores_matches_sequential_inference() { + let clip_a = pseudo_audio(SAMPLE_RATE_HZ * 2, 0x1234_5678); + let clip_b = pseudo_audio(SAMPLE_RATE_HZ * 2, 0x9abc_def0); + + let mut sequential = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let seq_a = sequential + .predict_raw_scores(&clip_a) + .expect("sequential clip a"); + let seq_b = sequential + .predict_raw_scores(&clip_b) + .expect("sequential clip b"); + + let mut batched = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let batch = batched + .predict_raw_scores_batch(&[&clip_a, &clip_b]) + .expect("batched inference"); + + assert_eq!(batch.len(), 2); + assert_eq!(batch[0].len(), seq_a.len()); + assert_eq!(batch[1].len(), seq_b.len()); + + for (expected, actual) in seq_a.iter().zip(batch[0].iter()) { + assert!((expected - actual).abs() < 1e-6); + } + for (expected, actual) in seq_b.iter().zip(batch[1].iter()) { + assert!((expected - actual).abs() < 1e-6); + } + } + + #[cfg(feature = "bundled-tiny")] + #[test] + fn flat_and_into_batch_raw_scores_match_sequential_inference() { + let clip_a = pseudo_audio(SAMPLE_RATE_HZ * 2, 0x1357_9bdf); + let clip_b = pseudo_audio(SAMPLE_RATE_HZ * 2, 0x2468_ace0); + + let mut sequential = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let seq_a = sequential + .predict_raw_scores(&clip_a) + .expect("sequential clip a"); + let seq_b = sequential + .predict_raw_scores(&clip_b) + .expect("sequential clip b"); + + let mut batched = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let flat = batched + .predict_raw_scores_batch_flat(&[&clip_a, &clip_b]) + .expect("flat batched inference"); + + assert_eq!(flat.len(), 2 * NUM_CLASSES); + for (expected, actual) in seq_a.iter().zip(flat[..NUM_CLASSES].iter()) { + assert!((expected - actual).abs() < 1e-6); + } + for (expected, actual) in seq_b.iter().zip(flat[NUM_CLASSES..].iter()) { + assert!((expected - actual).abs() < 1e-6); + } + + let mut into = vec![1.0; 7]; + batched + .predict_raw_scores_batch_into(&[&clip_a, &clip_b], &mut into) + .expect("into batched inference"); + assert_eq!(into, flat); + } + + #[cfg(feature = "bundled-tiny")] + #[test] + fn chunked_batching_matches_batch_size_one() { + let clip = pseudo_audio(DEFAULT_CHUNK_SAMPLES * 2 + 40_000, 0x0ddc_0ffe); + let single_opts = ChunkingOptions::default() + .with_hop_samples(DEFAULT_CHUNK_SAMPLES / 2) + .with_batch_size(1); + let batched_opts = single_opts.with_batch_size(4); + + let mut single = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let single_predictions = single + .classify_all_chunked(&clip, single_opts) + .expect("chunked single-batch inference"); + + let mut batched = Classifier::tiny(Options::default()).expect("load bundled classifier"); + let batched_predictions = batched + .classify_all_chunked(&clip, batched_opts) + .expect("chunked batched inference"); + + assert_eq!(single_predictions.len(), batched_predictions.len()); + for (expected, actual) in single_predictions.iter().zip(batched_predictions.iter()) { + assert_eq!(expected.index(), actual.index()); + assert!((expected.confidence() - actual.confidence()).abs() < 1e-6); + } + } + #[test] fn top_k_selection_returns_descending_predictions() { let predictions = top_k_from_scores(