From 0fe24afe198b6b5555965fad3a6f42f4bf1f10c9 Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Sun, 22 Jun 2025 22:48:22 +0800 Subject: [PATCH 01/14] chore: lint --- apps/orpheus-tts/src/main.rs | 63 ++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/apps/orpheus-tts/src/main.rs b/apps/orpheus-tts/src/main.rs index 977ee31..c0ba359 100644 --- a/apps/orpheus-tts/src/main.rs +++ b/apps/orpheus-tts/src/main.rs @@ -16,7 +16,7 @@ use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43 -const STOP_TOKEN_ID: u32 = 128258; +const STOP_TOKEN_ID: u32 = 128_258; #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum Voice { @@ -39,21 +39,22 @@ enum Voice { } impl Voice { - fn as_str(&self) -> &'static str { + const fn as_str(self) -> &'static str { match self { - Voice::Tara => "tara", - Voice::Leah => "leah", - Voice::Jess => "jess", - Voice::Leo => "leo", - Voice::Dan => "dan", - Voice::Mia => "mia", - Voice::Zac => "zac", - Voice::Zoe => "zoe", + Self::Tara => "tara", + Self::Leah => "leah", + Self::Jess => "jess", + Self::Leo => "leo", + Self::Dan => "dan", + Self::Mia => "mia", + Self::Zac => "zac", + Self::Zoe => "zoe", } } } #[derive(Parser)] +#[allow(clippy::struct_excessive_bools)] struct Args { #[arg(long)] cpu: bool, @@ -154,12 +155,14 @@ pub trait Sample { } impl Sample for f32 { + #[allow(clippy::cast_possible_truncation)] fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } } impl Sample for f64 { + #[allow(clippy::cast_possible_truncation)] fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } @@ -171,6 +174,7 @@ impl Sample for i16 { } } +#[allow(clippy::missing_panics_doc)] pub fn write_pcm_as_wav( w: &mut W, samples: &[S], @@ -178,9 +182,9 @@ pub fn write_pcm_as_wav( ) -> std::io::Result<()> { let len = 12u32; // header let len = len + 24u32; // fmt - let len = len + samples.len() as u32 * 2 + 8; // data + let len = len + u32::try_from(samples.len()).unwrap() * 2 + 8; // data let n_channels = 1u16; - let bytes_per_second = sample_rate * 2 * n_channels as u32; + let bytes_per_second = sample_rate * 2 * u32::from(n_channels); w.write_all(b"RIFF")?; w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes w.write_all(b"WAVE")?; @@ -197,21 +201,21 @@ pub fn write_pcm_as_wav( // Data block w.write_all(b"data")?; - w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; - for sample in samples.iter() { - w.write_all(&sample.to_i16().to_le_bytes())? + w.write_all(&(u32::try_from(samples.len()).unwrap() * 2).to_le_bytes())?; + for sample in samples { + w.write_all(&sample.to_i16().to_le_bytes())?; } Ok(()) } struct Model { - model: LlamaModel, + llama: LlamaModel, tokenizer: Tokenizer, logits_processor: LogitsProcessor, cache: Cache, device: Device, verbose_prompt: bool, - snac_model: SnacModel, + snac: SnacModel, out_file: String, voice: Voice, } @@ -226,15 +230,14 @@ impl Model { .build()?; let model_id = match args.orpheus_model_id { - Some(model_id) => model_id.to_string(), + Some(model_id) => model_id, None => match args.which_orpheus_model { WhichOrpheusModel::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(), }, }; - let revision = match args.revision { - Some(r) => r, - None => "main".to_string(), - }; + let revision = args + .revision + .map_or_else(|| "main".to_string(), |r| r); let repo = api.repo(hf_hub::Repo::with_revision(model_id, hf_hub::RepoType::Model, revision)); let model_file = match args.orpheus_model_file { Some(m) => vec![m.into()], @@ -292,16 +295,16 @@ impl Model { println!("load the model in {:?}", start_time.elapsed()); let cache = Cache::new(true, dtype, &config, &device)?; - let snac_model = load_snac_model(hf_token.clone().as_str(), &device)?; + let snac_model = load_snac_model(hf_token.as_str(), &device)?; Ok(Self { - model, + llama: model, tokenizer, logits_processor, cache, device, verbose_prompt: args.verbose_prompt, - snac_model, + snac: snac_model, out_file: args.out_file, voice: args.voice, }) @@ -319,7 +322,7 @@ impl Model { .map_err(E::msg)?; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82 - let mut tokens = [&[128259], tokens.get_ids(), &[128009, 128260, 128261, 128257]].concat(); + let mut tokens = [&[128_259], tokens.get_ids(), &[128_009, 128_260, 128_261, 128_257]].concat(); if self.verbose_prompt { println!("prompt tokens: {tokens:?}"); } @@ -342,7 +345,7 @@ impl Model { let context = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(context, device)?.unsqueeze(0)?; let logits = self - .model + .llama .forward(&input, context_index, &mut cache)?; let logits = logits.squeeze(0)?; index_pos += context.len(); @@ -354,7 +357,7 @@ impl Model { Some(tok) => { let tok = tok.parse::()?; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63 - let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096); + let tok = tok - 10 - ((u32::try_from(audio_tokens.len()).unwrap() % 7) * 4096); audio_tokens.push(tok); }, None => { @@ -393,9 +396,7 @@ impl Model { let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?; let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?; let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?; - let pcm = self - .snac_model - .decode(&[&codes0, &codes1, &codes2])?; + let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?; println!("decoded to pcm {pcm:?}"); let mut output = std::fs::File::create(&self.out_file)?; From acbf711bb0c6e1aba56b910f4b10077e4893884f Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 17:45:56 +0800 Subject: [PATCH 02/14] feat: api --- Cargo.lock | 357 ++++++++++++++++++ Cargo.toml | 3 +- .../Cargo.toml | 39 ++ .../silero-vad-whisper-realtime-api/README.md | 225 +++++++++++ .../melfilters.bytes | Bin 0 -> 64320 bytes .../melfilters128.bytes | Bin 0 -> 102912 bytes .../src/api.rs | 67 ++++ .../src/asr.rs | 287 ++++++++++++++ .../src/audio_manager.rs | 229 +++++++++++ .../src/main.rs | 114 ++++++ .../src/vad.rs | 85 +++++ .../src/whisper.rs | 206 ++++++++++ 12 files changed, 1611 insertions(+), 1 deletion(-) create mode 100644 apps/silero-vad-whisper-realtime-api/Cargo.toml create mode 100644 apps/silero-vad-whisper-realtime-api/README.md create mode 100644 apps/silero-vad-whisper-realtime-api/melfilters.bytes create mode 100644 apps/silero-vad-whisper-realtime-api/melfilters128.bytes create mode 100644 apps/silero-vad-whisper-realtime-api/src/api.rs create mode 100644 apps/silero-vad-whisper-realtime-api/src/asr.rs create mode 100644 apps/silero-vad-whisper-realtime-api/src/audio_manager.rs create mode 100644 apps/silero-vad-whisper-realtime-api/src/main.rs create mode 100644 apps/silero-vad-whisper-realtime-api/src/vad.rs create mode 100644 apps/silero-vad-whisper-realtime-api/src/whisper.rs diff --git a/Cargo.lock b/Cargo.lock index 9ea27ca..51b1f5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,17 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -131,6 +142,134 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit 0.7.3", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core 0.5.2", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "multer", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04" +dependencies = [ + "axum 0.7.9", + "axum-core 0.4.5", + "bytes", + "fastrand", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "multer", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -220,6 +359,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.18.1" @@ -538,6 +686,15 @@ dependencies = [ "windows", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -587,6 +744,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "cudarc" version = "0.16.4" @@ -680,6 +847,16 @@ dependencies = [ "syn", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dirs" version = "5.0.1" @@ -1211,6 +1388,16 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -1285,6 +1472,30 @@ version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +[[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64 0.22.1", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.5.0" @@ -1355,12 +1566,24 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.6.0" @@ -1374,6 +1597,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1762,6 +1986,18 @@ dependencies = [ "libc", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.7.5" @@ -1814,6 +2050,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1861,6 +2107,23 @@ dependencies = [ "syn", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "multimap" version = "0.10.1" @@ -2827,6 +3090,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -2848,6 +3121,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2951,6 +3235,38 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "silero-vad-whisper-realtime-api" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "axum-extra", + "byteorder", + "candle-core", + "candle-nn", + "candle-onnx", + "candle-transformers", + "clap", + "cpal", + "crossbeam-channel", + "futures", + "hf-hub", + "rand 0.9.1", + "rubato", + "serde", + "serde_json", + "symphonia", + "tokenizers", + "tokio", + "tokio-stream", + "tower", + "tower-http", + "tracing", + "tracing-chrome", + "tracing-subscriber", +] + [[package]] name = "slab" version = "0.4.10" @@ -2984,6 +3300,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -3384,6 +3706,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.15" @@ -3427,6 +3760,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3437,14 +3771,24 @@ checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ "bitflags 2.9.1", "bytes", + "futures-core", "futures-util", "http", "http-body", + "http-body-util", + "http-range-header", + "httpdate", "iri-string", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3465,6 +3809,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -3543,6 +3888,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "ug" version = "0.4.0" @@ -3591,6 +3942,12 @@ dependencies = [ "ug", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.18" diff --git a/Cargo.toml b/Cargo.toml index 13b467d..d935d72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ version = "0.0.1" authors = ["Project AIRI"] edition = "2024" -rust-version = "1.80" +rust-version = "1.85" readme = "README.md" homepage = "https://github.com/proj-airi/candle-examples" repository = "https://github.com/proj-airi/candle-examples" @@ -17,6 +17,7 @@ members = [ "apps/silero-vad-realtime", "apps/silero-vad-realtime-minimum", "apps/silero-vad-whisper-realtime", + "apps/silero-vad-whisper-realtime-api", "apps/whisper-realtime", ] diff --git a/apps/silero-vad-whisper-realtime-api/Cargo.toml b/apps/silero-vad-whisper-realtime-api/Cargo.toml new file mode 100644 index 0000000..514a9c5 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "silero-vad-whisper-realtime-api" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = "1.0.98" +byteorder = "1.5.0" +candle-core = { version = "0.9.1" } +candle-nn = { version = "0.9.1" } +candle-transformers = { version = "0.9.1" } +candle-onnx = { version = "0.9.1" } +clap = { version = "4.5.38", features = ["derive"] } +cpal = "0.15.3" +hf-hub = "0.4.2" +rand = "0.9.1" +rubato = "0.16.2" +serde_json = "1.0.140" +symphonia = "0.5.4" +tokenizers = "0.21.1" +tracing-chrome = "0.7.2" +tracing-subscriber = "0.3.19" +tracing = "0.1.41" +tokio = "1.45.1" +crossbeam-channel = "0.5.15" +axum = { version = "0.8.4", features = ["multipart"] } +serde = { version = "1.0.219", features = ["derive"] } + +# Server-sent events and HTTP utilities +futures = "0.3.31" +tokio-stream = "0.1.17" +axum-extra = { version = "0.9.5", features = ["typed-header"] } +tower = "0.5.1" +tower-http = { version = "0.6.2", features = ["fs", "cors"] } + +[features] +default = [] +metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] +cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] diff --git a/apps/silero-vad-whisper-realtime-api/README.md b/apps/silero-vad-whisper-realtime-api/README.md new file mode 100644 index 0000000..8208fd7 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/README.md @@ -0,0 +1,225 @@ +# ASR API - OpenAI Compatible Audio Transcription Service + +🎤 一个兼容OpenAI格式的语音转录API服务,支持实时流式响应(SSE),集成了Silero VAD和Whisper模型。 + +## ✨ 功能特性 + +- 🔄 **兼容OpenAI API**: 完全兼容OpenAI `/v1/audio/transcriptions` 端点格式 +- 📡 **Server-Sent Events (SSE)**: 支持流式响应,实时获取转录结果 +- 🎯 **语音活动检测**: 集成Silero VAD,智能检测语音片段 +- 🧠 **Whisper转录**: 使用Candle框架实现的高效Whisper模型 +- 🚀 **高性能**: 支持GPU加速(CUDA/Metal) +- 🌐 **现代Web界面**: 包含完整的测试页面 + +## 🚀 快速开始 + +### 1. 启动服务器 + +```bash +# 进入项目目录 +cd apps/asr-api + +# 安装依赖并启动 +cargo run --release +``` + +服务器将在 `http://localhost:3000` 启动。 + +### 2. 测试API + +打开浏览器访问测试页面: +``` +http://localhost:3000/test.html +``` + +或者使用curl命令: + +```bash +# 基础转录 +curl -X POST http://localhost:3000/v1/audio/transcriptions \ + -F "file=@your_audio.wav" \ + -F "model=whisper-1" + +# 流式转录 +curl -X POST "http://localhost:3000/v1/audio/transcriptions?stream=true" \ + -F "file=@your_audio.wav" \ + -F "model=whisper-1" \ + --no-buffer +``` + +## 📋 API文档 + +### POST `/v1/audio/transcriptions` + +转录音频文件为文本。 + +#### 请求参数 + +| 参数 | 类型 | 必需 | 描述 | +|------|------|------|------| +| `file` | File | ✅ | 要转录的音频文件 | +| `model` | String | ❌ | 模型名称 (默认: "whisper-1") | +| `language` | String | ❌ | 音频语言 | +| `prompt` | String | ❌ | 提示文本 | +| `response_format` | String | ❌ | 响应格式 (默认: "json") | +| `temperature` | Float | ❌ | 采样温度 (默认: 0.0) | +| `stream` | Boolean | ❌ | 启用流式响应 (Query参数) | + +#### 支持的音频格式 + +- WAV +- MP3 +- FLAC +- M4A +- 以及Symphonia支持的其他格式 + +#### 响应格式 + +**标准响应 (JSON)**: +```json +{ + "text": "转录的文本内容" +} +``` + +**流式响应 (SSE)**: +``` +data: {"text": "Processing audio chunk 1 of 4...", "timestamp": 0.5} + +data: {"text": "Processing audio chunk 2 of 4...", "timestamp": 1.0} + +data: {"text": "转录完成的文本", "timestamp": 2.5} +``` + +**错误响应**: +```json +{ + "error": { + "message": "错误描述", + "type": "invalid_request_error", + "param": "file", + "code": null + } +} +``` + +## 🛠️ 开发指南 + +### 项目结构 + +``` +apps/asr-api/ +├── src/ +│ ├── main.rs # 主服务器文件 +│ ├── vad.rs # VAD处理器 +│ ├── whisper.rs # Whisper处理器 +│ └── audio_manager.rs # 音频缓冲管理 +├── melfilters.bytes # Mel滤波器数据 +├── melfilters128.bytes # 128维Mel滤波器数据 +├── test.html # 测试页面 +├── Cargo.toml # 依赖配置 +└── README.md # 文档 +``` + +### 核心组件 + +1. **VAD处理器**: 使用Silero VAD模型检测语音活动 +2. **Whisper处理器**: 使用Candle实现的Whisper模型进行转录 +3. **音频管理器**: 处理音频缓冲和格式转换 +4. **Web服务器**: 基于Axum的高性能HTTP服务器 + +### 自定义配置 + +可以通过修改 `AppState::new()` 方法来调整以下参数: + +- VAD阈值 (默认: 0.3) +- Whisper模型 (默认: Tiny) +- 设备选择 (自动选择GPU/CPU) + +### 添加新功能 + +1. **支持更多音频格式**: 修改 `convert_audio_to_pcm` 函数 +2. **自定义VAD参数**: 在 `VADProcessor::new` 中调整参数 +3. **更大的Whisper模型**: 在 `WhisperProcessor::new` 中选择不同模型 + +## 🔧 高级配置 + +### 环境变量 + +```bash +# 设置日志级别 +export RUST_LOG=debug + +# 强制使用CPU +export CANDLE_FORCE_CPU=1 +``` + +### GPU加速 + +#### CUDA支持 +```bash +cargo run --release --features cuda +``` + +#### Metal支持 (macOS) +```bash +cargo run --release --features metal +``` + +## 📊 性能优化 + +### 推荐配置 + +- **内存**: 最少8GB RAM +- **GPU**: NVIDIA GTX 1060 6GB+ 或 Apple M1+ +- **存储**: SSD推荐,用于模型加载 + +### 批处理优化 + +对于大量文件处理,建议: + +1. 使用更大的Whisper模型获得更好质量 +2. 启用GPU加速 +3. 调整VAD参数减少误检 + +## 🚨 常见问题 + +### Q: 转录准确率不高怎么办? +A: 尝试以下方法: +- 使用更大的Whisper模型 (medium/large) +- 确保音频质量良好 (16kHz采样率) +- 调整VAD阈值 +- 提供语言参数 + +### Q: 服务器启动慢? +A: 首次启动需要下载模型文件,这是正常现象。模型会缓存到本地。 + +### Q: 支持实时语音输入吗? +A: 目前只支持文件上传,实时语音输入可以参考 `silero-vad-whisper-realtime` 项目。 + +### Q: 如何批量处理文件? +A: 可以编写脚本调用API,或者扩展当前代码支持批处理端点。 + +## 🤝 贡献指南 + +欢迎提交Issue和Pull Request! + +1. Fork项目 +2. 创建功能分支 +3. 提交改动 +4. 发起Pull Request + +## 📄 许可证 + +本项目采用与父项目相同的许可证。 + +## 🙏 致谢 + +- [Candle](https://github.com/huggingface/candle) - 高性能ML框架 +- [Axum](https://github.com/tokio-rs/axum) - 现代Web框架 +- [OpenAI](https://openai.com/) - API设计参考 +- [Silero VAD](https://github.com/snakers4/silero-vad) - VAD模型 + +--- + +🎯 **提示**: 第一次运行时会自动下载模型文件,请确保网络连接正常。 diff --git a/apps/silero-vad-whisper-realtime-api/melfilters.bytes b/apps/silero-vad-whisper-realtime-api/melfilters.bytes new file mode 100644 index 0000000000000000000000000000000000000000..0874829e2088c94e3a4f001725e7040fae57fa0c GIT binary patch literal 64320 zcmeI*c~p;S9|rKmB$Oqxy%HkL2rZU=_jQ|^;_V${O=THd)?TtC+k1veYBZT?K@2rx zsYsbO6kb^(lq@qE(}~G4Lue#3rWtGWHuJaNUH{*ga~$V9-{+t2=UmshpXb-_#KgqJ zBz50KS@#eBP=EqP1VUDtB5}$w_QZ$}FH_(zfsmCESl2R>Jz*Rupb*%Xa|Zv(+s2+K zj5wk|-E)_mw>XA9VH7B!5QuAUCy#h8Wlt1F98q9jj+fk)I*vVI6eyq&IPWt{hMKvu zCki8uC{Xu#mkbEBVow+a3Md5D6mF28ZMqA^EgVrmdx5z2DRT6LUF?bWNg#kq;Op7h z($US6Jy9WXroBMj&l%;tH=#m%2m(+*A&}en2fo8Ad^n=OUjo}wud*kM0tFNTOM|1t z=YD6{6NM2+6mWDNhETI?_JmQOfI{HP4=Gr5WEXp)Fye>;=56jFAaDbF!YEKcA+XcF ztsK`Wf;~|faYTWWXI<;Qvoe!CVH7B!5cqM1uMB$X$(|^TIHJHalVx?^?drjvFbWh< z2*i9BBVBSEu_p>6jwrCTYf{~JAoKYhNJfDI3V|238S?Ylk?e`Wh$9O8(jof^V zziWj_pnytX=;1MPPqYVrexgF+OnZTwu^nX6q!9K*`y>!RC9w8UHmWK%$)-_md`D41 zai+CENY(|h>Rgm;*{hm8(K-oSzCM9mr@`{v!=Zfkst<~Dtp%><6(BNmoy@rx%QHpm zByf3^z^Ip>inSOkA2he+nWBQ?OnU*B$K?o_79!os<9McMp9TUL6F7hKs7P~_a@pu! zvQbH~!PtGgP66EoQbr%Zk_EvsW^)P86y3wX{gnd!cQut)dw(bwG>KqWlu{h(EZ`r| z3g6zam+!2%lV#Ve*%h6`z}=Ms6E?4gf3BOXIMiG^CuQ;fdr)d|sIx#^^;fX-cbA?% zFYt?3B)g(>7`VGq;Ml8lB(ImU`_NzETbsnLD784$S>TaJ3W{r-WtVsVK+&a#?267| z;O#xuIL;F?yeN5h_ZsY z&mE*kcbBURvhdN(nS2gYigBo`z^JXhV$#@joNM|I`6~D>zOx?0vqjf1aPN8rlJdu) zpl22SoMtIkElkG3ip4xz>Q&> zAEb;bKtNa<*)8KRYBSogFAb98O}Yt$+if<~dd)=P&(|^bnTZ^q{1y5;+~c#DZb9I_ z4HEF2zC?t4vl{aPZo_8pV{AB`fX;ze*_Q_C@h05_j?UgHir0i;ro(j%thx{L@^~zo zU&+4c76k6wAc5k7SYi7z2=o6{guJ{8Jh9pe?@LyEe`$~&Z_-U5|NIBy=7kS&F5v(= z%=#U}lG0J0Is`Q%9C*g)76$HHufUX-E{506o$%hajqp$U5sr;cB5Ps@EN7IkFZJs2 z1|0=z%#Mq;p9a7$@*q0CEJfq&FHu(Nf>ZWm*%=)J!JU-?V|JGqd_&vevE4f4n4iV7 z%ejag5RM*$THtP{{ruf8N<9vB6sVaRDNLWc!1JdBL=L@xs`KgS(QGzKOr8jbPu$rV z9YevLl>*=I*(j91Gai*E`O*%_rChdK)EwAm-B3I`$b z)OxhFIgaHX+i@r;0IMc6LHjl~e4o)V6x>-U(61mt#4YQArsk{QemMif3scd2atOlS zwZVg`FW4ES9)~/utqUOq5I&oVza%-w|i&@2o{T!-P~-SM&cZ4nXun$KrC27@~r z71-lC(Qxa0j@UJ@H%iTx!YpV9TxM)V`)h$Hs_csAZ}y667iadys2~SA3Pf+}Bg$Ks zi~K$YME4Gd)6og z(#K#(wKc>ihsB?^rukIIf+EGB9UQggJuza$T3-mW6?1%&7O)lUk40+ zbXRPjx==*f@8PqWF+^V1PM}~{uAz9u7V+e_S`j^b0KQBPgzt)2%wDn-&pLYIv%FRq zcqvz8yEyX9(JmcaSS1kZ)L+~R`=2r2f&GaC!Y6-_ zh)C7JgT3)n7SF0PsHa1y!wG*PNa}duT?UKTU-zFfvPZ?&+m?Soi$q-K$ tJrJ4RR+wDsgqCwhBlOlt#4mF~ibY$@^R5)}%l3uGe zLq~@uhZc<*G=`)YgBpjhhiS7$jfOE>dp_#DpT6IJUs%ue_;CH!|6135=T!!SVWr=x zA;Q346rg}ffwH<6oDq{cEKuNY0l!nloD+tD0uq5ENt-z%5+9Z*P%qG{d7HkJpa2C- z3k1~Wp?~Bpn0|v53Q*ukk}Y=$V?Y5Nft(dPcsJ28Vp(kgzZ(Ud5w#=0?i(R6yV}4r zS0kMGin;;;wLan(_c%>m0d`ISeSzgCx^b82Gw}_@1^P!$&WN0d zRSFmeq+aEmFb5Qn2pr9e;fzRpSfW6^K+XzB-rpGo3P=Qk<|M+&!oYnZF=B}Vp53gu zOBe$R=m=y_*~49;W5lxB0&hM)&lyoW0_?sK0{@w8&NEjdocM~m0)suey!k9w3;tg$ zbtAye8zJy2urqf_Bb@k(+5#eQ0q-ViM}Xby2wdy^kTarV#4-ime2)mu31dJ3i9l}o zD$a<+hb0Qs3wS*3!23I+KmmzB+Nr5%_d_f06NwQ^6gZk^$z8%2P(Vi@%RiI5M8}9_ zwFL&*W^+c=jsUxFgutvT&3NW&gcDy;S0HGPov0WzhM(n9Hv;Uu5dsm{ow-XI;lx+e z7I3my$-9Z#5n%T^0zZzZ;Ed=Pu}pz})2lcqi~$8C0t>3bI3p4tmMBm!kUgan@9&HP z1tbE)%pLK&YZvYli4jW_FkjY$yM!^IfR4cT_7}KIbc|S5TVQCHWdh@V;?KoWI|l4t zN5C`8o@Xu{BbF(UzH>LvT#Nw)Bm)0E93#Ha`tdW55+jx<5FXu>XD-Ho0y+ZiGtThL zrDMdh+5*|PJjJw=?wk{~W5Dip1ajM*;Ed=Pu}py@US_5nRCJ@P(UJ(oYr0I9(a~>A~9l#0#zAP;8dBxIbjqiAQ3o{n1%5jQ#mIR zBbF!-TvP_fNn1E4i~?}6w(2_p2-Y2HyJPApL%Jt;&S3U}q%i&3C}M4;p& zD{-w)n0OT$&9j%pi6ylK-uW~fbCzxpt7f0)*-PyxuzMW=%Y*$z`?+2`d+9i_thPYz zAEj`~jS>ZqVz^V(jsm-v2owY)YwMeg6q~!W=T4D0u|xrjy${jr!eS8;ypB7CaiD++ zfoZYdY10e*#qtnm?i3S7d`y9Qfnhc0urqj>7@eQXox(^^z=XiD)&)j~t{p|%yP7DO z-;}dr!ikS5V08Wx58^%$iLPcme=!aekO*uwyN!$1Q$@&!Iov4{CzjL}2#)vF)()_J z^RrPj;n0WQ&tB~`uzQ(6h?x=M&mLlPTw7u3^pKzZl3}r?v_S05jWBNa6*m&z5l%zg zI4er0fz8VVlB|xx=}EAto_UXV6&Vz3$_tFxRg7LwMvL^&OPm$u(?9@I0+G$ndJncP zM`F!jQIb)}Suq90=gJG1&)Q=w-BpIWfdL}azKFA;d?E;7O5l!Lvhmw(#kd_DAoeY| zg#E8t@osMli_eu7xbYI+zOC}$ch_5_?5TtZ%Hgaiod`BB6IgDWf`ItW!t2cMIG*po zbC?W^HKhfHxGqCVPaE;i5It;GQG7jh3x<2g)Mj4!Av z5L@SpB8v(<@7h+BTW6xVdn$K}s$pQ?Qh}{m0h-;pBsk8q7GEEFjBAU=a<@puSX5cS zzQ7%oZ+=ew$T4=J&FO4-1>fXuQ8^Io-K4;Q#r9f6={oeDYY=|(Y7x*c5*y6!7)|~K z7AT;wz`#9qT2^Tmnx?cANA~@IPWSurET(WC*t%RG^|$@cl-fvy4=J5n`?1(@L|r2qIqE3a)B=fp49$Py%u4X zcQB#Z&j`0#hv0taIWKZHRuvVXj@#vNJob#e+9@w^AAauq$ z?OgN>c(@m0q4#6-)OMiq^S1o@n4FDOMFm!DcGgOgeDTrjOc-2V;95mG#`X4t&F)^@ zF^Xn_ZJQSOD{6#~hnGEE6BFQLU4fUg3eauRBAh5Y$9XZ$#)`rM(_R*8(O)e<=Ae9B z8Fm9po+rTD%NdrAh5Ss6!nt7U4GXM%?X1N=b;IfGZSa|W3%O4YWBQh}@hWQCuL%E=GH{$_FX065*eB84i~F@qXxJ3>|u#J4W$r zu>FPwhTlJ_W!D8^hanjQnwG+H)D9fLFof?t%b97Ij87F7sD4@Fy*{Q~+mtdIPmzKt zb1qo_Py3BQ^8%+Fdd&NED-qN7R^1~0h-$^bR3%mU;mv5dax9>{_i8tFJE*1b0hCH u3g?5Zzb)|1j5_0>->+(6#h&Q5?Ms+@?tx|VP1tBV;mz+h^?&~VuK6Ejlbd+} literal 0 HcmV?d00001 diff --git a/apps/silero-vad-whisper-realtime-api/src/api.rs b/apps/silero-vad-whisper-realtime-api/src/api.rs new file mode 100644 index 0000000..3162330 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/api.rs @@ -0,0 +1,67 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +pub struct TranscriptionResponse { + pub text: String, +} + +#[derive(Debug, Serialize)] +pub struct StreamChunk { + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub timestamp: Option, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + pub param: Option, + pub code: Option, +} + +pub fn default_model() -> String { + "whisper-1".to_string() +} + +pub fn default_response_format() -> String { + "json".to_string() +} + +pub fn default_temperature() -> f32 { + 0.0 +} + +#[derive(Debug, Deserialize)] +pub struct TranscriptionRequest { + /// The audio file object to transcribe + /// In multipart form, this would be the file field + + /// ID of the model to use. Only whisper-1 is currently available. + #[serde(default = "default_model")] + pub model: String, + + /// The language of the input audio + pub language: Option, + + /// An optional text to guide the model's style or continue a previous audio segment + pub prompt: Option, + + /// The format of the transcript output + #[serde(default = "default_response_format")] + pub response_format: String, + + /// The sampling temperature, between 0 and 1 + #[serde(default = "default_temperature")] + pub temperature: f32, + + /// Enable streaming response + #[serde(default)] + pub stream: bool, +} diff --git a/apps/silero-vad-whisper-realtime-api/src/asr.rs b/apps/silero-vad-whisper-realtime-api/src/asr.rs new file mode 100644 index 0000000..df9a00b --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/asr.rs @@ -0,0 +1,287 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::AppState; +use crate::api::{default_model, default_response_format, default_temperature, ErrorDetail, ErrorResponse, StreamChunk, TranscriptionRequest, TranscriptionResponse}; + +use crate::audio_manager::AudioBuffer; +use anyhow::Result; +use axum::{ + Json, + extract::{Multipart, Query, State}, + http::StatusCode, + response::{ + IntoResponse, Response, + sse::{Event, KeepAlive, Sse}, + }, +}; +use futures::stream::{self, Stream}; +use symphonia::{ + core::{ + audio::{AudioBufferRef, Signal}, + codecs::DecoderOptions, + formats::FormatOptions, + io::MediaSourceStream, + meta::MetadataOptions, + probe::Hint, + }, + default::get_probe, +}; + +// Main transcription endpoint +pub async fn transcribe_audio( + State(state): State>, + Query(params): Query>, + mut multipart: Multipart, +) -> Result)> { + // Parse query parameters for streaming + let stream_enabled = params + .get("stream") + .map(|s| s.parse::().unwrap_or(false)) + .unwrap_or(false); + + // Extract audio file from multipart form + let audio_data = match extract_audio_from_multipart(&mut multipart).await { + Ok(data) => data, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to extract audio file: {}", e), + error_type: "invalid_request_error".to_string(), + param: Some("file".to_string()), + code: None, + }, + }), + )); + }, + }; + + // Parse request parameters + let request = TranscriptionRequest { + model: params + .get("model") + .cloned() + .unwrap_or_else(default_model), + language: params.get("language").cloned(), + prompt: params.get("prompt").cloned(), + response_format: params + .get("response_format") + .cloned() + .unwrap_or_else(default_response_format), + temperature: params + .get("temperature") + .and_then(|s| s.parse().ok()) + .unwrap_or_else(default_temperature), + stream: stream_enabled, + }; + + println!("Request: {:?}", request); + + // Convert audio to PCM format + let pcm_data = match convert_audio_to_pcm(&audio_data).await { + Ok(data) => data, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to process audio file: {}", e), + error_type: "invalid_request_error".to_string(), + param: Some("file".to_string()), + code: None, + }, + }), + )); + }, + }; + + println!("Audio data length: {:?}", pcm_data.len()); + + if request.stream { + // Return streaming response + let stream = create_transcription_stream(state, pcm_data).await; + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); + + Ok(sse.into_response()) + } else { + // Return single response + match transcribe_audio_complete(state, pcm_data).await { + Ok(transcript) => Ok(Json(TranscriptionResponse { text: transcript }).into_response()), + Err(e) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Transcription failed: {}", e), + error_type: "server_error".to_string(), + param: None, + code: None, + }, + }), + )), + } + } +} + +// Extract audio file from multipart form data +async fn extract_audio_from_multipart(multipart: &mut Multipart) -> Result> { + while let Some(field) = multipart.next_field().await? { + if let Some(name) = field.name() { + if name == "file" { + let data = field.bytes().await?; + return Ok(data.to_vec()); + } + } + } + anyhow::bail!("No file field found in multipart data") +} + +// Convert various audio formats to PCM +async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { + let cursor = std::io::Cursor::new(audio_data.to_vec()); + let media_source = MediaSourceStream::new(Box::new(cursor), Default::default()); + + let mut hint = Hint::new(); + hint.mime_type("audio/wav"); // You might want to detect this automatically + + let meta_opts: MetadataOptions = Default::default(); + let fmt_opts: FormatOptions = Default::default(); + + let probed = get_probe().format(&hint, media_source, &fmt_opts, &meta_opts)?; + + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; + + let dec_opts: DecoderOptions = Default::default(); + let mut decoder = symphonia::default::get_codecs().make(&track.codec_params, &dec_opts)?; + + let track_id = track.id; + let mut pcm_data = Vec::new(); + + // Decode the audio + while let Ok(packet) = format.next_packet() { + if packet.track_id() != track_id { + continue; + } + + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => { + for &sample in buf.chan(0) { + pcm_data.push(sample); + } + }, + AudioBufferRef::S16(buf) => { + for &sample in buf.chan(0) { + pcm_data.push(f32::from(sample) / f32::from(i16::MAX)); + } + }, + AudioBufferRef::S32(buf) => { + for &sample in buf.chan(0) { + pcm_data.push(sample as f32 / i32::MAX as f32); + } + }, + _ => { + anyhow::bail!("Unsupported audio format"); + }, + } + } + + Ok(pcm_data) +} + +// Process complete audio file and return full transcript +pub async fn transcribe_audio_complete( + state: Arc, + audio_data: Vec, +) -> Result { + let sample_rate = 16000; + + // Process audio through VAD and Whisper + let mut vad = state.vad.lock().await; + let mut whisper = state.whisper.lock().await; + let mut audio_buffer = AudioBuffer::new(10000, 100, 500, sample_rate); + + let mut transcripts = Vec::new(); + let mut frame_buffer = Vec::::new(); + + // Process in chunks + for chunk in audio_data.chunks(1024) { + frame_buffer.extend_from_slice(chunk); + + // Process 512-sample frames + while frame_buffer.len() >= 512 { + let frame: Vec = frame_buffer.drain(..512).collect(); + let speech_prob = vad.process_chunk(&frame)?; + let is_speech = vad.is_speech(speech_prob); + + if let Some(complete_audio) = audio_buffer.add_chunk(&frame, is_speech) { + let transcript = whisper.transcribe(&complete_audio)?; + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { + transcripts.push(transcript.trim().to_string()); + } + } + } + } + + Ok(transcripts.join(" ")) +} + +// Create streaming transcription response +pub async fn create_transcription_stream( + state: Arc, + audio_data: Vec, +) -> impl Stream> { + let sample_rate = 16000; + + stream::unfold((state, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate)), move |(state, audio_data, mut processed, mut audio_buffer)| async move { + if processed >= audio_data.len() { + return None; + } + + // Process audio in chunks suitable for VAD (512 samples at a time) + let chunk_size = 512.min(audio_data.len() - processed); + let chunk = &audio_data[processed..processed + chunk_size]; + processed += chunk_size; + + // Process through VAD and Whisper processors + let mut whisper_result = None; + + // Process through VAD + let mut vad = state.vad.lock().await; + if let Ok(speech_prob) = vad.process_chunk(chunk) { + let is_speech = vad.is_speech(speech_prob); + + // Add to audio buffer and check if we have complete audio + if let Some(complete_audio) = audio_buffer.add_chunk(chunk, is_speech) { + // Release VAD lock before acquiring Whisper lock + drop(vad); + + // Process complete audio through Whisper + let mut whisper = state.whisper.lock().await; + if let Ok(transcript) = whisper.transcribe(&complete_audio) { + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { + whisper_result = Some(transcript.trim().to_string()); + } + } + } + } + + // Create event with actual transcription or progress update + let event_data = if let Some(transcript) = whisper_result { + StreamChunk { text: transcript, timestamp: Some(processed as f64 / sample_rate as f64) } + } else { + StreamChunk { + text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), + timestamp: Some(processed as f64 / sample_rate as f64), + } + }; + + let event = Event::default().json_data(event_data).unwrap(); + + Some((Ok(event), (state.clone(), audio_data, processed, audio_buffer))) + }) +} diff --git a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs new file mode 100644 index 0000000..94fba1c --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs @@ -0,0 +1,229 @@ +use std::time::Instant; + +use anyhow::Result; +use cpal::{ + InputCallbackInfo, + traits::{DeviceTrait, HostTrait, StreamTrait}, +}; +use rubato::{FastFixedIn, PolynomialDegree, Resampler}; + +pub struct AudioManager { + _stream: cpal::Stream, + audio_rx: crossbeam_channel::Receiver>, + resampler: Option>, + buffered_pcm: Vec, +} + +impl AudioManager { + pub fn new( + device_name: Option, + target_sample_rate: u32, + ) -> Result { + let host = cpal::default_host(); + let device = match device_name { + None => host.default_input_device(), + Some(name) => host + .input_devices()? + .find(|d| d.name().map(|n| n == name).unwrap_or(false)), + } + .ok_or_else(|| anyhow::anyhow!("No input device found"))?; + + println!("Using audio input device: {}", device.name()?); + + let config = device.default_input_config()?; + let channel_count = config.channels() as usize; + let device_sample_rate = config.sample_rate().0; + + println!("Device sample rate: {device_sample_rate}Hz, Target: {target_sample_rate}Hz"); + + let (tx, rx) = crossbeam_channel::unbounded(); + + let stream = device.build_input_stream( + &config.into(), + move |data: &[f32], _: &InputCallbackInfo| { + // Extract mono audio (first channel only) + let mono_data = data + .iter() + .step_by(channel_count) + .copied() + .collect::>(); + + if !mono_data.is_empty() { + let _ = tx.send(mono_data); + } + }, + |err| eprintln!("Audio stream error: {err}"), + None, + )?; + + stream.play()?; + + let resampler = if device_sample_rate == target_sample_rate { + None + } else { + let resample_ratio = f64::from(target_sample_rate) / f64::from(device_sample_rate); + + Some(FastFixedIn::new( + resample_ratio, + 10.0, // max_resample_ratio_relative + PolynomialDegree::Septic, + 1024, // chunk_size + 1, // channels + )?) + }; + + Ok(Self { _stream: stream, audio_rx: rx, resampler, buffered_pcm: Vec::new() }) + } + + #[allow(clippy::future_not_send, clippy::unused_async)] + pub async fn receive_audio(&mut self) -> Result> { + let chunk = self.audio_rx.recv()?; + + if let Some(ref mut resampler) = self.resampler { + // Need to resample + self.buffered_pcm.extend_from_slice(&chunk); + + let mut resampled_audio = Vec::new(); + + // Process in chunks of 1024 samples + let full_chunks = self.buffered_pcm.len() / 1024; + let remainder = self.buffered_pcm.len() % 1024; + + for chunk_idx in 0..full_chunks { + let chunk_slice = &self.buffered_pcm[chunk_idx * 1024..(chunk_idx + 1) * 1024]; + let resampled = resampler.process(&[chunk_slice], None)?; + resampled_audio.extend_from_slice(&resampled[0]); + } + + // Handle remainder + if remainder == 0 { + self.buffered_pcm.clear(); + } else { + self + .buffered_pcm + .copy_within(full_chunks * 1024.., 0); + self.buffered_pcm.truncate(remainder); + } + + Ok(resampled_audio) + } else { + // No resampling needed, return the chunk directly + Ok(chunk) + } + } +} + +pub struct AudioBuffer { + buffer: Vec, + max_duration_samples: usize, + min_speech_duration_samples: usize, + min_silence_duration_samples: usize, + is_recording: bool, + silence_start: Option, + speech_start: Option, + samples_since_speech_start: usize, + samples_since_silence_start: usize, + sample_rate: usize, +} + +impl AudioBuffer { + pub const fn new( + max_duration_ms: u64, + min_speech_duration_ms: u64, + min_silence_duration_ms: u64, + sample_rate: u32, + ) -> Self { + let sample_rate = sample_rate as usize; + Self { + buffer: Vec::new(), + max_duration_samples: (max_duration_ms * sample_rate as u64 / 1000) as usize, + min_speech_duration_samples: (min_speech_duration_ms * sample_rate as u64 / 1000) as usize, + min_silence_duration_samples: (min_silence_duration_ms * sample_rate as u64 / 1000) as usize, + is_recording: false, + silence_start: None, + speech_start: None, + samples_since_speech_start: 0, + samples_since_silence_start: 0, + sample_rate, + } + } + + pub fn add_chunk( + &mut self, + chunk: &[f32], + is_speech: bool, + ) -> Option> { + if is_speech { + #[allow(clippy::if_not_else)] + if !self.is_recording { + if self.speech_start.is_none() { + self.speech_start = Some(Instant::now()); + self.samples_since_speech_start = 0; + } + + self.samples_since_speech_start += chunk.len(); + + if self.samples_since_speech_start >= self.min_speech_duration_samples { + self.is_recording = true; + self.silence_start = None; + self.samples_since_silence_start = 0; + println!("🚀 Started recording"); + } + } else { + // Reset silence tracking + self.silence_start = None; + self.samples_since_silence_start = 0; + } + } else { + // Reset speech tracking + self.speech_start = None; + self.samples_since_speech_start = 0; + + if self.is_recording { + if self.silence_start.is_none() { + self.silence_start = Some(Instant::now()); + self.samples_since_silence_start = 0; + } + + self.samples_since_silence_start += chunk.len(); + + if self.samples_since_silence_start >= self.min_silence_duration_samples { + // End of speech detected + if !self.buffer.is_empty() { + let result = self.buffer.clone(); + self.reset(); + #[allow(clippy::cast_precision_loss)] + let duration_secs = result.len() as f32 / self.sample_rate as f32; + println!("🔇 Stopped recording, {duration_secs:.2}s"); + return Some(result); + } + + self.reset(); + } + } + } + + if self.is_recording { + self.buffer.extend_from_slice(chunk); + + // Check if buffer exceeds max duration + if self.buffer.len() >= self.max_duration_samples { + let result = self.buffer.clone(); + self.reset(); + println!("⏰ Max duration reached, {} samples", result.len()); + return Some(result); + } + } + + None + } + + fn reset(&mut self) { + self.buffer.clear(); + self.is_recording = false; + self.silence_start = None; + self.speech_start = None; + self.samples_since_speech_start = 0; + self.samples_since_silence_start = 0; + } +} diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs new file mode 100644 index 0000000..545ac20 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -0,0 +1,114 @@ +use std::sync::Arc; + +use anyhow::Result; +use axum::{ + Json, Router, + response::IntoResponse, + routing::{get, post}, +}; +use candle_core::Device; +use tokio::sync::Mutex; +use tower::ServiceBuilder; +use tower_http::cors::CorsLayer; + +use crate::{ + asr::transcribe_audio, + vad::VADProcessor, + whisper::{WhichWhisperModel, WhisperProcessor}, +}; + +mod api; +mod asr; +mod audio_manager; +mod vad; +mod whisper; + +// Application state +struct AppState { + vad: Arc>, + whisper: Arc>, + device: Device, +} + +impl AppState { + async fn new() -> Result { + // Determine device to use - allow override via environment variable + let device = if std::env::var("CANDLE_FORCE_CPU").is_ok() { + candle_core::Device::Cpu + } else if candle_core::utils::cuda_is_available() { + candle_core::Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + candle_core::Device::new_metal(0)? + } else { + candle_core::Device::Cpu + }; + + println!("🚀 Using device: {device:?}"); + + // Get VAD threshold from environment or use default + let vad_threshold = std::env::var("VAD_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.3); + + // Get Whisper model from environment or use default + let whisper_model = match std::env::var("WHISPER_MODEL").as_deref() { + Ok("tiny") => WhichWhisperModel::Tiny, + Ok("base") => WhichWhisperModel::Base, + Ok("small") => WhichWhisperModel::Small, + Ok("medium") => WhichWhisperModel::Medium, + Ok("large") => WhichWhisperModel::Large, + Ok("large-v2") => WhichWhisperModel::LargeV2, + Ok("large-v3") => WhichWhisperModel::LargeV3, + _ => WhichWhisperModel::Tiny, + }; + + println!("🎯 VAD threshold: {vad_threshold}"); + println!("🧠 Whisper model: {whisper_model:?}"); + + // Initialize VAD and Whisper processors + let vad = VADProcessor::new(candle_core::Device::Cpu, vad_threshold)?; + let whisper = WhisperProcessor::new(whisper_model, device.clone())?; + + Ok(Self { vad: Arc::new(Mutex::new(vad)), whisper: Arc::new(Mutex::new(whisper)), device }) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::fmt::init(); + + // Initialize application state + let state = AppState::new().await?; + + // Build application routes + let app = Router::new() + .route("/", get(health_check)) + .route("/v1/audio/transcriptions", post(transcribe_audio)) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) + .into_inner(), + ) + .with_state(Arc::new(state)); + + // Start server + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + println!("🚀 ASR API server running on http://0.0.0.0:3000"); + println!("📝 Available endpoints:"); + println!(" GET / - Health check"); + println!(" POST /v1/audio/transcriptions - Audio transcription (OpenAI compatible)"); + + axum::serve(listener, app).await?; + Ok(()) +} + +// Health check endpoint +async fn health_check() -> impl IntoResponse { + Json(serde_json::json!({ + "status": "ok", + "service": "ASR API", + "version": "1.0.0" + })) +} diff --git a/apps/silero-vad-whisper-realtime-api/src/vad.rs b/apps/silero-vad-whisper-realtime-api/src/vad.rs new file mode 100644 index 0000000..0c6aa9c --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/vad.rs @@ -0,0 +1,85 @@ +use std::collections::HashMap; + +use anyhow::Result; +use candle_core::{DType, Device, Tensor}; +use candle_onnx::simple_eval; + +pub struct VADProcessor { + model: candle_onnx::onnx::ModelProto, + frame_size: usize, + context_size: usize, + sample_rate: Tensor, + state: Tensor, + context: Tensor, + device: Device, + threshold: f32, +} + +impl VADProcessor { + pub fn new( + device: Device, + threshold: f32, + ) -> Result { + let api = hf_hub::api::sync::Api::new()?; + let model_path = api + .model("onnx-community/silero-vad".into()) + .get("onnx/model.onnx")?; + + let model = candle_onnx::read_file(model_path)?; + + let sample_rate_value = 16000i64; + let (frame_size, context_size) = (512, 64); + + Ok(Self { + model, + frame_size, + context_size, + sample_rate: Tensor::new(sample_rate_value, &device)?, + state: Tensor::zeros((2, 1, 128), DType::F32, &device)?, + context: Tensor::zeros((1, context_size), DType::F32, &device)?, + device, + threshold, + }) + } + + pub fn process_chunk( + &mut self, + chunk: &[f32], + ) -> Result { + if chunk.len() != self.frame_size { + return Ok(0.0); + } + + let next_context = Tensor::from_slice(&chunk[self.frame_size - self.context_size..], (1, self.context_size), &self.device)?; + let chunk_tensor = Tensor::from_vec(chunk.to_vec(), (1, self.frame_size), &self.device)?; + + let input = Tensor::cat(&[&self.context, &chunk_tensor], 1)?; + let inputs: HashMap = HashMap::from_iter([("input".to_string(), input), ("sr".to_string(), self.sample_rate.clone()), ("state".to_string(), self.state.clone())]); + + let outputs = simple_eval(&self.model, inputs)?; + let graph = self.model.graph.as_ref().unwrap(); + let out_names = &graph.output; + + let output = outputs + .get(&out_names[0].name) + .ok_or_else(|| anyhow::anyhow!("Missing VAD output tensor: {}", &out_names[0].name))? + .clone(); + + self.state = outputs + .get(&out_names[1].name) + .ok_or_else(|| anyhow::anyhow!("Missing VAD state tensor: {}", &out_names[1].name))? + .clone(); + + self.context = next_context; + + let speech_prob = output.flatten_all()?.to_vec1::()?[0]; + Ok(speech_prob) + } + + pub fn is_speech( + &self, + prob: f32, + ) -> bool { + prob >= self.threshold + } +} diff --git a/apps/silero-vad-whisper-realtime-api/src/whisper.rs b/apps/silero-vad-whisper-realtime-api/src/whisper.rs new file mode 100644 index 0000000..4b3ba32 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/whisper.rs @@ -0,0 +1,206 @@ +use anyhow::Result; +use byteorder::{ByteOrder, LittleEndian}; +use candle_core::{Device, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::whisper::{self as whisper_model, Config, audio}; +use clap::ValueEnum; +use hf_hub::{Repo, RepoType, api::sync::Api}; +use tokenizers::Tokenizer; + +pub enum WhisperModel { + Normal(whisper_model::model::Whisper), +} + +impl WhisperModel { + pub fn encoder_forward( + &mut self, + x: &Tensor, + flush: bool, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.encoder.forward(x, flush), + } + } + + pub fn decoder_forward( + &mut self, + x: &Tensor, + encoder_out: &Tensor, + flush: bool, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.decoder.forward(x, encoder_out, flush), + } + } + + pub fn decoder_final_linear( + &self, + x: &Tensor, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.decoder.final_linear(x), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +pub enum WhichWhisperModel { + Tiny, + #[value(name = "tiny.en")] + TinyEn, + Base, + #[value(name = "base.en")] + BaseEn, + Small, + #[value(name = "small.en")] + SmallEn, + Medium, + #[value(name = "medium.en")] + MediumEn, + Large, + LargeV2, + LargeV3, + LargeV3Turbo, + #[value(name = "distil-medium.en")] + DistilMediumEn, + #[value(name = "distil-large-v2")] + DistilLargeV2, +} + +impl WhichWhisperModel { + pub const fn model_and_revision(self) -> (&'static str, &'static str) { + match self { + Self::Tiny => ("openai/whisper-tiny", "main"), + Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::Base => ("openai/whisper-base", "refs/pr/22"), + Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"), + Self::Small => ("openai/whisper-small", "main"), + Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), + Self::Medium => ("openai/whisper-medium", "main"), + Self::MediumEn => ("openai/whisper-medium.en", "main"), + Self::Large => ("openai/whisper-large", "refs/pr/36"), + Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), + Self::LargeV3 => ("openai/whisper-large-v3", "main"), + Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"), + Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), + Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), + } + } +} + +pub struct WhisperProcessor { + pub model: WhisperModel, + pub tokenizer: Tokenizer, + pub config: Config, + pub mel_filters: Vec, + pub device: Device, +} + +impl WhisperProcessor { + pub fn new( + model: WhichWhisperModel, + device: Device, + ) -> Result { + // Load the Whisper model based on the provided model type + let api = Api::new()?; + let (model_id, revision) = model.model_and_revision(); + let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string())); + + let config_filename = repo.get("config.json")?; + let tokenizer_filename = repo.get("tokenizer.json")?; + let model_filename = repo.get("model.safetensors")?; + + let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?; + + let name = model_filename.display(); + println!("Loading Whisper model from: {name:?}"); + + let var_builder = unsafe { VarBuilder::from_mmaped_safetensors(&[model_filename], whisper_model::DTYPE, &device)? }; + + let model = WhisperModel::Normal(whisper_model::model::Whisper::load(&var_builder, config.clone())?); + + let mel_bytes = match config.num_mel_bins { + 80 => include_bytes!("../melfilters.bytes").as_slice(), + 128 => include_bytes!("../melfilters128.bytes").as_slice(), + num_mel_bins => anyhow::bail!("Unsupported number of mel bins: {}", num_mel_bins), + }; + + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + ::read_f32_into(mel_bytes, &mut mel_filters); + + Ok(Self { model, tokenizer, config, mel_filters, device }) + } + + pub fn transcribe( + &mut self, + audio: &[f32], + ) -> Result { + // Convert PCM to mel spectrogram + let mel = audio::pcm_to_mel(&self.config, audio, &self.mel_filters); + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, self.config.num_mel_bins, mel_len / self.config.num_mel_bins), &self.device)?; + + // Run inference + let audio_features = self.model.encoder_forward(&mel, true)?; + // Simple greedy decoding + let tokens = self.decode_greedy(&audio_features)?; + + let text = self + .tokenizer + .decode(&tokens, true) + .map_err(anyhow::Error::msg)?; + + Ok(text) + } + + fn decode_greedy( + &mut self, + audio_features: &Tensor, + ) -> Result> { + let mut tokens = vec![self.token_id(whisper_model::SOT_TOKEN)?, self.token_id(whisper_model::TRANSCRIBE_TOKEN)?, self.token_id(whisper_model::NO_TIMESTAMPS_TOKEN)?]; + + let max_len = 50; // Short sequence for real-time processing + + for i in 0..max_len { + let tokens_t = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let ys = self + .model + .decoder_forward(&tokens_t, audio_features, i == 0)?; + + let (_, seq_len, _) = ys.dims3()?; + let logits = self + .model + .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)? + .i(0)? + .i(0)?; + + // Get most likely token + let logits_v: Vec = logits.to_vec1()?; + let next_token = logits_v + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.total_cmp(b)) + .map(|(i, _)| u32::try_from(i).unwrap()) + .unwrap(); + + if next_token == self.token_id(whisper_model::EOT_TOKEN)? { + break; + } + + tokens.push(next_token); + } + + Ok(tokens) + } + + fn token_id( + &self, + token: &str, + ) -> Result { + self + .tokenizer + .token_to_id(token) + .ok_or_else(|| anyhow::anyhow!("Token not found: {}", token)) + } +} From 068d184f2e2f83365efa747315e841a7578470f6 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 17:58:11 +0800 Subject: [PATCH 03/14] fix: form data --- .../src/main.rs | 4 +- .../src/{asr.rs => router.rs} | 48 ++++++++++++------- 2 files changed, 32 insertions(+), 20 deletions(-) rename apps/silero-vad-whisper-realtime-api/src/{asr.rs => router.rs} (87%) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index 545ac20..e9337d9 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -12,13 +12,13 @@ use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use crate::{ - asr::transcribe_audio, + router::transcribe_audio, vad::VADProcessor, whisper::{WhichWhisperModel, WhisperProcessor}, }; mod api; -mod asr; +mod router; mod audio_manager; mod vad; mod whisper; diff --git a/apps/silero-vad-whisper-realtime-api/src/asr.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs similarity index 87% rename from apps/silero-vad-whisper-realtime-api/src/asr.rs rename to apps/silero-vad-whisper-realtime-api/src/router.rs index df9a00b..848539e 100644 --- a/apps/silero-vad-whisper-realtime-api/src/asr.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -1,13 +1,13 @@ use std::{collections::HashMap, sync::Arc}; use crate::AppState; -use crate::api::{default_model, default_response_format, default_temperature, ErrorDetail, ErrorResponse, StreamChunk, TranscriptionRequest, TranscriptionResponse}; +use crate::api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionRequest, TranscriptionResponse, default_model, default_response_format, default_temperature}; use crate::audio_manager::AudioBuffer; use anyhow::Result; use axum::{ Json, - extract::{Multipart, Query, State}, + extract::{Multipart, State}, http::StatusCode, response::{ IntoResponse, Response, @@ -30,26 +30,19 @@ use symphonia::{ // Main transcription endpoint pub async fn transcribe_audio( State(state): State>, - Query(params): Query>, mut multipart: Multipart, ) -> Result)> { - // Parse query parameters for streaming - let stream_enabled = params - .get("stream") - .map(|s| s.parse::().unwrap_or(false)) - .unwrap_or(false); - - // Extract audio file from multipart form - let audio_data = match extract_audio_from_multipart(&mut multipart).await { + // Extract both audio file and parameters from multipart form + let (audio_data, params) = match extract_multipart_data(&mut multipart).await { Ok(data) => data, Err(e) => { return Err(( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to extract audio file: {}", e), + message: format!("Failed to extract form data: {}", e), error_type: "invalid_request_error".to_string(), - param: Some("file".to_string()), + param: Some("form".to_string()), code: None, }, }), @@ -57,7 +50,15 @@ pub async fn transcribe_audio( }, }; - // Parse request parameters + println!("params: {:?}", params); + + // Parse query parameters for streaming + let stream_enabled = params + .get("stream") + .map(|s| s.parse::().unwrap_or(false)) + .unwrap_or(false); + + // Parse request parameters from form data let request = TranscriptionRequest { model: params .get("model") @@ -123,17 +124,28 @@ pub async fn transcribe_audio( } } -// Extract audio file from multipart form data -async fn extract_audio_from_multipart(multipart: &mut Multipart) -> Result> { +// Extract both audio file and parameters from multipart form data +async fn extract_multipart_data(multipart: &mut Multipart) -> Result<(Vec, HashMap)> { + let mut audio_data = None; + let mut params = HashMap::new(); + while let Some(field) = multipart.next_field().await? { if let Some(name) = field.name() { + let name = name.to_string(); // Clone the name first to avoid borrow conflict if name == "file" { + // Extract audio file let data = field.bytes().await?; - return Ok(data.to_vec()); + audio_data = Some(data.to_vec()); + } else { + // Extract form parameters + let value = field.text().await?; + params.insert(name, value); } } } - anyhow::bail!("No file field found in multipart data") + + let audio = audio_data.ok_or_else(|| anyhow::anyhow!("No file field found in multipart data"))?; + Ok((audio, params)) } // Convert various audio formats to PCM From e9add22952dd4c0ca9a6ed305b045cc7bc11c1bc Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 18:18:07 +0800 Subject: [PATCH 04/14] feat: enhance Whisper model management and transcription process - Introduced dynamic loading of Whisper models using RwLock for improved performance. - Updated the transcribe_audio function to utilize the new model management system. - Simplified request handling by removing unused TranscriptionRequest construction. - Enhanced error handling for model loading failures during streaming transcription. --- .../src/main.rs | 85 ++++++++++++++----- .../src/router.rs | 75 +++++++++------- 2 files changed, 105 insertions(+), 55 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index e9337d9..babb87f 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use anyhow::Result; use axum::{ @@ -7,7 +7,7 @@ use axum::{ routing::{get, post}, }; use candle_core::Device; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; @@ -18,21 +18,22 @@ use crate::{ }; mod api; -mod router; mod audio_manager; +mod router; mod vad; mod whisper; -// Application state +// Application state with dynamic model loading struct AppState { vad: Arc>, - whisper: Arc>, device: Device, + // Use RwLock for read-heavy workload (checking cache) + whisper_models: Arc>>>>, } impl AppState { async fn new() -> Result { - // Determine device to use - allow override via environment variable + // Determine device to use let device = if std::env::var("CANDLE_FORCE_CPU").is_ok() { candle_core::Device::Cpu } else if candle_core::utils::cuda_is_available() { @@ -51,26 +52,66 @@ impl AppState { .and_then(|s| s.parse().ok()) .unwrap_or(0.3); - // Get Whisper model from environment or use default - let whisper_model = match std::env::var("WHISPER_MODEL").as_deref() { - Ok("tiny") => WhichWhisperModel::Tiny, - Ok("base") => WhichWhisperModel::Base, - Ok("small") => WhichWhisperModel::Small, - Ok("medium") => WhichWhisperModel::Medium, - Ok("large") => WhichWhisperModel::Large, - Ok("large-v2") => WhichWhisperModel::LargeV2, - Ok("large-v3") => WhichWhisperModel::LargeV3, - _ => WhichWhisperModel::Tiny, - }; - println!("🎯 VAD threshold: {vad_threshold}"); - println!("🧠 Whisper model: {whisper_model:?}"); - // Initialize VAD and Whisper processors + // Initialize VAD processor (always use CPU for VAD) let vad = VADProcessor::new(candle_core::Device::Cpu, vad_threshold)?; - let whisper = WhisperProcessor::new(whisper_model, device.clone())?; - Ok(Self { vad: Arc::new(Mutex::new(vad)), whisper: Arc::new(Mutex::new(whisper)), device }) + Ok(Self { + vad: Arc::new(Mutex::new(vad)), + device, + whisper_models: Arc::new(RwLock::new(HashMap::new())), + }) + } + + // Get or create Whisper processor for the specified model + pub async fn get_whisper_processor( + &self, + model_name: &str, + ) -> Result>> { + // First, try to read from cache + { + let models = self.whisper_models.read().await; + if let Some(processor) = models.get(model_name) { + println!("🔄 Using cached Whisper model: {}", model_name); + return Ok(processor.clone()); + } + } + + // If not in cache, create new model + println!("🧠 Loading new Whisper model: {}", model_name); + let whisper_model = Self::parse_model_name(model_name)?; + let processor = Arc::new(Mutex::new(WhisperProcessor::new(whisper_model, self.device.clone())?)); + + // Add to cache + { + let mut models = self.whisper_models.write().await; + models.insert(model_name.to_string(), processor.clone()); + } + + println!("✅ Whisper model loaded and cached: {}", model_name); + Ok(processor) + } + + // Parse model name string to WhichWhisperModel enum + fn parse_model_name(model_name: &str) -> Result { + match model_name.to_lowercase().as_str() { + "tiny" => Ok(WhichWhisperModel::Tiny), + "tiny.en" => Ok(WhichWhisperModel::TinyEn), + "base" => Ok(WhichWhisperModel::Base), + "base.en" => Ok(WhichWhisperModel::BaseEn), + "small" => Ok(WhichWhisperModel::Small), + "small.en" => Ok(WhichWhisperModel::SmallEn), + "medium" => Ok(WhichWhisperModel::Medium), + "medium.en" => Ok(WhichWhisperModel::MediumEn), + "large" => Ok(WhichWhisperModel::Large), + "large-v2" => Ok(WhichWhisperModel::LargeV2), + "large-v3" => Ok(WhichWhisperModel::LargeV3), + "large-v3-turbo" => Ok(WhichWhisperModel::LargeV3Turbo), + "distil-medium.en" => Ok(WhichWhisperModel::DistilMediumEn), + "distil-large-v2" => Ok(WhichWhisperModel::DistilLargeV2), + _ => anyhow::bail!("Unsupported Whisper model: {}. Supported models: tiny, base, small, medium, large, large-v2, large-v3", model_name), + } } } diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 848539e..5947084 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -27,7 +27,7 @@ use symphonia::{ default::get_probe, }; -// Main transcription endpoint +// Main transcription endpoint - remove unused TranscriptionRequest construction pub async fn transcribe_audio( State(state): State>, mut multipart: Multipart, @@ -50,34 +50,21 @@ pub async fn transcribe_audio( }, }; - println!("params: {:?}", params); + println!("Request params: {:?}", params); - // Parse query parameters for streaming + // Parse streaming parameter from form data let stream_enabled = params .get("stream") .map(|s| s.parse::().unwrap_or(false)) .unwrap_or(false); - // Parse request parameters from form data - let request = TranscriptionRequest { - model: params - .get("model") - .cloned() - .unwrap_or_else(default_model), - language: params.get("language").cloned(), - prompt: params.get("prompt").cloned(), - response_format: params - .get("response_format") - .cloned() - .unwrap_or_else(default_response_format), - temperature: params - .get("temperature") - .and_then(|s| s.parse().ok()) - .unwrap_or_else(default_temperature), - stream: stream_enabled, - }; + // Get model name from parameters and clone it to make it owned + let model_name = params + .get("model") + .map(|s| s.clone()) // Clone to make it owned + .unwrap_or_else(|| "tiny".to_string()); // Use tiny as default - println!("Request: {:?}", request); + println!("Using model: {}, streaming: {}", model_name, stream_enabled); // Convert audio to PCM format let pcm_data = match convert_audio_to_pcm(&audio_data).await { @@ -97,17 +84,16 @@ pub async fn transcribe_audio( }, }; - println!("Audio data length: {:?}", pcm_data.len()); + println!("Audio data length: {} samples", pcm_data.len()); - if request.stream { + if stream_enabled { // Return streaming response - let stream = create_transcription_stream(state, pcm_data).await; + let stream = create_transcription_stream(state, model_name, pcm_data).await?; let sse = Sse::new(stream).keep_alive(KeepAlive::default()); - Ok(sse.into_response()) } else { // Return single response - match transcribe_audio_complete(state, pcm_data).await { + match transcribe_audio_complete(state, model_name, pcm_data).await { Ok(transcript) => Ok(Json(TranscriptionResponse { text: transcript }).into_response()), Err(e) => Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -208,13 +194,17 @@ async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { // Process complete audio file and return full transcript pub async fn transcribe_audio_complete( state: Arc, + model_name: String, // Change to owned String audio_data: Vec, ) -> Result { let sample_rate = 16000; + // Get the appropriate Whisper processor for this model + let whisper_processor = state.get_whisper_processor(&model_name).await?; + // Process audio through VAD and Whisper let mut vad = state.vad.lock().await; - let mut whisper = state.whisper.lock().await; + let mut whisper = whisper_processor.lock().await; let mut audio_buffer = AudioBuffer::new(10000, 100, 500, sample_rate); let mut transcripts = Vec::new(); @@ -245,11 +235,30 @@ pub async fn transcribe_audio_complete( // Create streaming transcription response pub async fn create_transcription_stream( state: Arc, + model_name: String, // Change to owned String audio_data: Vec, -) -> impl Stream> { +) -> Result>, (StatusCode, Json)> { + // Get the appropriate Whisper processor for this model + let whisper_processor = match state.get_whisper_processor(&model_name).await { + Ok(processor) => processor, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to load model '{}': {}", model_name, e), + error_type: "invalid_request_error".to_string(), + param: Some("model".to_string()), + code: None, + }, + }), + )); + }, + }; + let sample_rate = 16000; - stream::unfold((state, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate)), move |(state, audio_data, mut processed, mut audio_buffer)| async move { + Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate)), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer)| async move { if processed >= audio_data.len() { return None; } @@ -273,7 +282,7 @@ pub async fn create_transcription_stream( drop(vad); // Process complete audio through Whisper - let mut whisper = state.whisper.lock().await; + let mut whisper = whisper_processor.lock().await; if let Ok(transcript) = whisper.transcribe(&complete_audio) { if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { whisper_result = Some(transcript.trim().to_string()); @@ -294,6 +303,6 @@ pub async fn create_transcription_stream( let event = Event::default().json_data(event_data).unwrap(); - Some((Ok(event), (state.clone(), audio_data, processed, audio_buffer))) - }) + Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer))) + })) } From cc3a6fba0c1ed19fe176dc06c773ccfb3cccec80 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 18:30:43 +0800 Subject: [PATCH 05/14] feat: debug info --- .../src/main.rs | 8 +- .../src/router.rs | 127 +++++++++++++++--- 2 files changed, 111 insertions(+), 24 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index babb87f..f5f3ce4 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -78,18 +78,22 @@ impl AppState { } } - // If not in cache, create new model + // If not in cache, create new model with timing + let loading_start = std::time::Instant::now(); println!("🧠 Loading new Whisper model: {}", model_name); + let whisper_model = Self::parse_model_name(model_name)?; let processor = Arc::new(Mutex::new(WhisperProcessor::new(whisper_model, self.device.clone())?)); + let loading_time = loading_start.elapsed(); + // Add to cache { let mut models = self.whisper_models.write().await; models.insert(model_name.to_string(), processor.clone()); } - println!("✅ Whisper model loaded and cached: {}", model_name); + println!("✅ Whisper model loaded and cached: {} ({:.2}ms)", model_name, loading_time.as_secs_f64() * 1000.0); Ok(processor) } diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 5947084..fa743e2 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -1,7 +1,7 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Instant}; use crate::AppState; -use crate::api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionRequest, TranscriptionResponse, default_model, default_response_format, default_temperature}; +use crate::api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionResponse}; use crate::audio_manager::AudioBuffer; use anyhow::Result; @@ -27,11 +27,52 @@ use symphonia::{ default::get_probe, }; +// Performance statistics struct +#[derive(Debug)] +struct ProcessingStats { + total_duration: std::time::Duration, + audio_conversion_duration: std::time::Duration, + model_loading_duration: std::time::Duration, + vad_processing_duration: std::time::Duration, + whisper_transcription_duration: std::time::Duration, + audio_length_seconds: f32, +} + +impl ProcessingStats { + fn new() -> Self { + Self { + total_duration: std::time::Duration::ZERO, + audio_conversion_duration: std::time::Duration::ZERO, + model_loading_duration: std::time::Duration::ZERO, + vad_processing_duration: std::time::Duration::ZERO, + whisper_transcription_duration: std::time::Duration::ZERO, + audio_length_seconds: 0.0, + } + } + + fn print_summary(&self) { + println!("📊 Processing Statistics:"); + println!(" 📁 Audio conversion: {:.2}ms", self.audio_conversion_duration.as_secs_f64() * 1000.0); + println!(" 🧠 Model loading: {:.2}ms", self.model_loading_duration.as_secs_f64() * 1000.0); + println!(" 🎯 VAD processing: {:.2}ms", self.vad_processing_duration.as_secs_f64() * 1000.0); + println!(" 🗣️ Whisper transcription: {:.2}ms", self.whisper_transcription_duration.as_secs_f64() * 1000.0); + println!(" ⏱️ Total processing: {:.2}ms", self.total_duration.as_secs_f64() * 1000.0); + println!(" 🎵 Audio length: {:.2}s", self.audio_length_seconds); + if self.audio_length_seconds > 0.0 { + let real_time_factor = self.total_duration.as_secs_f64() / self.audio_length_seconds as f64; + println!(" ⚡ Real-time factor: {:.2}x", real_time_factor); + } + } +} + // Main transcription endpoint - remove unused TranscriptionRequest construction pub async fn transcribe_audio( State(state): State>, mut multipart: Multipart, ) -> Result)> { + let start_time = Instant::now(); + let mut stats = ProcessingStats::new(); + // Extract both audio file and parameters from multipart form let (audio_data, params) = match extract_multipart_data(&mut multipart).await { Ok(data) => data, @@ -66,7 +107,8 @@ pub async fn transcribe_audio( println!("Using model: {}, streaming: {}", model_name, stream_enabled); - // Convert audio to PCM format + // Convert audio to PCM format with timing + let conversion_start = Instant::now(); let pcm_data = match convert_audio_to_pcm(&audio_data).await { Ok(data) => data, Err(e) => { @@ -83,29 +125,39 @@ pub async fn transcribe_audio( )); }, }; + stats.audio_conversion_duration = conversion_start.elapsed(); + stats.audio_length_seconds = pcm_data.len() as f32 / 16000.0; // Assuming 16kHz sample rate - println!("Audio data length: {} samples", pcm_data.len()); + println!("Audio data length: {} samples ({:.2}s)", pcm_data.len(), stats.audio_length_seconds); if stream_enabled { // Return streaming response - let stream = create_transcription_stream(state, model_name, pcm_data).await?; + let stream = create_transcription_stream(state, model_name, pcm_data, stats).await?; let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok(sse.into_response()) } else { // Return single response - match transcribe_audio_complete(state, model_name, pcm_data).await { - Ok(transcript) => Ok(Json(TranscriptionResponse { text: transcript }).into_response()), - Err(e) => Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: ErrorDetail { - message: format!("Transcription failed: {}", e), - error_type: "server_error".to_string(), - param: None, - code: None, - }, - }), - )), + match transcribe_audio_complete(state, model_name, pcm_data, &mut stats).await { + Ok(transcript) => { + stats.total_duration = start_time.elapsed(); + stats.print_summary(); + Ok(Json(TranscriptionResponse { text: transcript }).into_response()) + }, + Err(e) => { + stats.total_duration = start_time.elapsed(); + stats.print_summary(); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Transcription failed: {}", e), + error_type: "server_error".to_string(), + param: None, + code: None, + }, + }), + )) + }, } } } @@ -196,11 +248,14 @@ pub async fn transcribe_audio_complete( state: Arc, model_name: String, // Change to owned String audio_data: Vec, + stats: &mut ProcessingStats, ) -> Result { let sample_rate = 16000; - // Get the appropriate Whisper processor for this model + // Get the appropriate Whisper processor for this model with timing + let model_loading_start = Instant::now(); let whisper_processor = state.get_whisper_processor(&model_name).await?; + stats.model_loading_duration = model_loading_start.elapsed(); // Process audio through VAD and Whisper let mut vad = state.vad.lock().await; @@ -210,6 +265,9 @@ pub async fn transcribe_audio_complete( let mut transcripts = Vec::new(); let mut frame_buffer = Vec::::new(); + let vad_start = Instant::now(); + let mut whisper_total_time = std::time::Duration::ZERO; + // Process in chunks for chunk in audio_data.chunks(1024) { frame_buffer.extend_from_slice(chunk); @@ -221,7 +279,11 @@ pub async fn transcribe_audio_complete( let is_speech = vad.is_speech(speech_prob); if let Some(complete_audio) = audio_buffer.add_chunk(&frame, is_speech) { + // Measure Whisper transcription time + let whisper_start = Instant::now(); let transcript = whisper.transcribe(&complete_audio)?; + whisper_total_time += whisper_start.elapsed(); + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { transcripts.push(transcript.trim().to_string()); } @@ -229,6 +291,9 @@ pub async fn transcribe_audio_complete( } } + stats.vad_processing_duration = vad_start.elapsed() - whisper_total_time; + stats.whisper_transcription_duration = whisper_total_time; + Ok(transcripts.join(" ")) } @@ -237,8 +302,12 @@ pub async fn create_transcription_stream( state: Arc, model_name: String, // Change to owned String audio_data: Vec, + mut stats: ProcessingStats, ) -> Result>, (StatusCode, Json)> { - // Get the appropriate Whisper processor for this model + let stream_start = Instant::now(); + + // Get the appropriate Whisper processor for this model with timing + let model_loading_start = Instant::now(); let whisper_processor = match state.get_whisper_processor(&model_name).await { Ok(processor) => processor, Err(e) => { @@ -255,11 +324,15 @@ pub async fn create_transcription_stream( )); }, }; + stats.model_loading_duration = model_loading_start.elapsed(); let sample_rate = 16000; - Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate)), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer)| async move { + Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate), stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move { if processed >= audio_data.len() { + // Print final statistics for streaming + stats.total_duration = stream_start.elapsed(); + stats.print_summary(); return None; } @@ -272,6 +345,7 @@ pub async fn create_transcription_stream( let mut whisper_result = None; // Process through VAD + let vad_chunk_start = Instant::now(); let mut vad = state.vad.lock().await; if let Ok(speech_prob) = vad.process_chunk(chunk) { let is_speech = vad.is_speech(speech_prob); @@ -280,15 +354,24 @@ pub async fn create_transcription_stream( if let Some(complete_audio) = audio_buffer.add_chunk(chunk, is_speech) { // Release VAD lock before acquiring Whisper lock drop(vad); + let vad_chunk_time = vad_chunk_start.elapsed(); + stats.vad_processing_duration += vad_chunk_time; // Process complete audio through Whisper + let whisper_chunk_start = Instant::now(); let mut whisper = whisper_processor.lock().await; if let Ok(transcript) = whisper.transcribe(&complete_audio) { + let whisper_chunk_time = whisper_chunk_start.elapsed(); + stats.whisper_transcription_duration += whisper_chunk_time; + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { whisper_result = Some(transcript.trim().to_string()); + println!("🎯 Chunk transcribed in {:.2}ms: \"{}\"", whisper_chunk_time.as_secs_f64() * 1000.0, transcript.trim()); } } } + } else { + stats.vad_processing_duration += vad_chunk_start.elapsed(); } // Create event with actual transcription or progress update @@ -303,6 +386,6 @@ pub async fn create_transcription_stream( let event = Event::default().json_data(event_data).unwrap(); - Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer))) + Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer, stats, stream_start))) })) } From 9710b891cfb9d193c17da25b818f2f1dac75eed1 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 22:09:40 +0800 Subject: [PATCH 06/14] fix: ci --- .../src/api.rs | 42 +------ .../src/audio_manager.rs | 113 ------------------ .../src/router.rs | 5 +- 3 files changed, 3 insertions(+), 157 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/api.rs b/apps/silero-vad-whisper-realtime-api/src/api.rs index 3162330..9483018 100644 --- a/apps/silero-vad-whisper-realtime-api/src/api.rs +++ b/apps/silero-vad-whisper-realtime-api/src/api.rs @@ -1,4 +1,4 @@ -use serde::{Deserialize, Serialize}; +use serde::Serialize; #[derive(Debug, Serialize)] pub struct TranscriptionResponse { @@ -25,43 +25,3 @@ pub struct ErrorDetail { pub param: Option, pub code: Option, } - -pub fn default_model() -> String { - "whisper-1".to_string() -} - -pub fn default_response_format() -> String { - "json".to_string() -} - -pub fn default_temperature() -> f32 { - 0.0 -} - -#[derive(Debug, Deserialize)] -pub struct TranscriptionRequest { - /// The audio file object to transcribe - /// In multipart form, this would be the file field - - /// ID of the model to use. Only whisper-1 is currently available. - #[serde(default = "default_model")] - pub model: String, - - /// The language of the input audio - pub language: Option, - - /// An optional text to guide the model's style or continue a previous audio segment - pub prompt: Option, - - /// The format of the transcript output - #[serde(default = "default_response_format")] - pub response_format: String, - - /// The sampling temperature, between 0 and 1 - #[serde(default = "default_temperature")] - pub temperature: f32, - - /// Enable streaming response - #[serde(default)] - pub stream: bool, -} diff --git a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs index 94fba1c..b7a5670 100644 --- a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs +++ b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs @@ -1,118 +1,5 @@ use std::time::Instant; -use anyhow::Result; -use cpal::{ - InputCallbackInfo, - traits::{DeviceTrait, HostTrait, StreamTrait}, -}; -use rubato::{FastFixedIn, PolynomialDegree, Resampler}; - -pub struct AudioManager { - _stream: cpal::Stream, - audio_rx: crossbeam_channel::Receiver>, - resampler: Option>, - buffered_pcm: Vec, -} - -impl AudioManager { - pub fn new( - device_name: Option, - target_sample_rate: u32, - ) -> Result { - let host = cpal::default_host(); - let device = match device_name { - None => host.default_input_device(), - Some(name) => host - .input_devices()? - .find(|d| d.name().map(|n| n == name).unwrap_or(false)), - } - .ok_or_else(|| anyhow::anyhow!("No input device found"))?; - - println!("Using audio input device: {}", device.name()?); - - let config = device.default_input_config()?; - let channel_count = config.channels() as usize; - let device_sample_rate = config.sample_rate().0; - - println!("Device sample rate: {device_sample_rate}Hz, Target: {target_sample_rate}Hz"); - - let (tx, rx) = crossbeam_channel::unbounded(); - - let stream = device.build_input_stream( - &config.into(), - move |data: &[f32], _: &InputCallbackInfo| { - // Extract mono audio (first channel only) - let mono_data = data - .iter() - .step_by(channel_count) - .copied() - .collect::>(); - - if !mono_data.is_empty() { - let _ = tx.send(mono_data); - } - }, - |err| eprintln!("Audio stream error: {err}"), - None, - )?; - - stream.play()?; - - let resampler = if device_sample_rate == target_sample_rate { - None - } else { - let resample_ratio = f64::from(target_sample_rate) / f64::from(device_sample_rate); - - Some(FastFixedIn::new( - resample_ratio, - 10.0, // max_resample_ratio_relative - PolynomialDegree::Septic, - 1024, // chunk_size - 1, // channels - )?) - }; - - Ok(Self { _stream: stream, audio_rx: rx, resampler, buffered_pcm: Vec::new() }) - } - - #[allow(clippy::future_not_send, clippy::unused_async)] - pub async fn receive_audio(&mut self) -> Result> { - let chunk = self.audio_rx.recv()?; - - if let Some(ref mut resampler) = self.resampler { - // Need to resample - self.buffered_pcm.extend_from_slice(&chunk); - - let mut resampled_audio = Vec::new(); - - // Process in chunks of 1024 samples - let full_chunks = self.buffered_pcm.len() / 1024; - let remainder = self.buffered_pcm.len() % 1024; - - for chunk_idx in 0..full_chunks { - let chunk_slice = &self.buffered_pcm[chunk_idx * 1024..(chunk_idx + 1) * 1024]; - let resampled = resampler.process(&[chunk_slice], None)?; - resampled_audio.extend_from_slice(&resampled[0]); - } - - // Handle remainder - if remainder == 0 { - self.buffered_pcm.clear(); - } else { - self - .buffered_pcm - .copy_within(full_chunks * 1024.., 0); - self.buffered_pcm.truncate(remainder); - } - - Ok(resampled_audio) - } else { - // No resampling needed, return the chunk directly - Ok(chunk) - } - } -} - pub struct AudioBuffer { buffer: Vec, max_duration_samples: usize, diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index fa743e2..5420bd0 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -65,7 +65,6 @@ impl ProcessingStats { } } -// Main transcription endpoint - remove unused TranscriptionRequest construction pub async fn transcribe_audio( State(state): State>, mut multipart: Multipart, @@ -244,7 +243,7 @@ async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { } // Process complete audio file and return full transcript -pub async fn transcribe_audio_complete( +async fn transcribe_audio_complete( state: Arc, model_name: String, // Change to owned String audio_data: Vec, @@ -298,7 +297,7 @@ pub async fn transcribe_audio_complete( } // Create streaming transcription response -pub async fn create_transcription_stream( +async fn create_transcription_stream( state: Arc, model_name: String, // Change to owned String audio_data: Vec, From aff50a03cb6c94bad9da482eb91652727d65cb0a Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Sun, 22 Jun 2025 22:15:54 +0800 Subject: [PATCH 07/14] chore: fmt --- .../src/api.rs | 8 +-- .../src/audio_manager.rs | 18 +++--- .../src/main.rs | 7 ++- .../src/router.rs | 59 ++++++++++--------- .../src/vad.rs | 14 ++--- .../src/whisper.rs | 8 +-- 6 files changed, 59 insertions(+), 55 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/api.rs b/apps/silero-vad-whisper-realtime-api/src/api.rs index 9483018..a448d7b 100644 --- a/apps/silero-vad-whisper-realtime-api/src/api.rs +++ b/apps/silero-vad-whisper-realtime-api/src/api.rs @@ -7,7 +7,7 @@ pub struct TranscriptionResponse { #[derive(Debug, Serialize)] pub struct StreamChunk { - pub text: String, + pub text: String, #[serde(skip_serializing_if = "Option::is_none")] pub timestamp: Option, } @@ -19,9 +19,9 @@ pub struct ErrorResponse { #[derive(Debug, Serialize)] pub struct ErrorDetail { - pub message: String, + pub message: String, #[serde(rename = "type")] pub error_type: String, - pub param: Option, - pub code: Option, + pub param: Option, + pub code: Option, } diff --git a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs index b7a5670..ba4cd81 100644 --- a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs +++ b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs @@ -1,16 +1,16 @@ use std::time::Instant; pub struct AudioBuffer { - buffer: Vec, - max_duration_samples: usize, - min_speech_duration_samples: usize, + buffer: Vec, + max_duration_samples: usize, + min_speech_duration_samples: usize, min_silence_duration_samples: usize, - is_recording: bool, - silence_start: Option, - speech_start: Option, - samples_since_speech_start: usize, - samples_since_silence_start: usize, - sample_rate: usize, + is_recording: bool, + silence_start: Option, + speech_start: Option, + samples_since_speech_start: usize, + samples_since_silence_start: usize, + sample_rate: usize, } impl AudioBuffer { diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index f5f3ce4..cd106c8 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -2,7 +2,8 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Result; use axum::{ - Json, Router, + Json, + Router, response::IntoResponse, routing::{get, post}, }; @@ -25,8 +26,8 @@ mod whisper; // Application state with dynamic model loading struct AppState { - vad: Arc>, - device: Device, + vad: Arc>, + device: Device, // Use RwLock for read-heavy workload (checking cache) whisper_models: Arc>>>>, } diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 5420bd0..5c0e802 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -1,16 +1,13 @@ use std::{collections::HashMap, sync::Arc, time::Instant}; -use crate::AppState; -use crate::api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionResponse}; - -use crate::audio_manager::AudioBuffer; use anyhow::Result; use axum::{ Json, extract::{Multipart, State}, http::StatusCode, response::{ - IntoResponse, Response, + IntoResponse, + Response, sse::{Event, KeepAlive, Sse}, }, }; @@ -27,26 +24,32 @@ use symphonia::{ default::get_probe, }; +use crate::{ + AppState, + api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionResponse}, + audio_manager::AudioBuffer, +}; + // Performance statistics struct #[derive(Debug)] struct ProcessingStats { - total_duration: std::time::Duration, - audio_conversion_duration: std::time::Duration, - model_loading_duration: std::time::Duration, - vad_processing_duration: std::time::Duration, + total_duration: std::time::Duration, + audio_conversion_duration: std::time::Duration, + model_loading_duration: std::time::Duration, + vad_processing_duration: std::time::Duration, whisper_transcription_duration: std::time::Duration, - audio_length_seconds: f32, + audio_length_seconds: f32, } impl ProcessingStats { fn new() -> Self { Self { - total_duration: std::time::Duration::ZERO, - audio_conversion_duration: std::time::Duration::ZERO, - model_loading_duration: std::time::Duration::ZERO, - vad_processing_duration: std::time::Duration::ZERO, + total_duration: std::time::Duration::ZERO, + audio_conversion_duration: std::time::Duration::ZERO, + model_loading_duration: std::time::Duration::ZERO, + vad_processing_duration: std::time::Duration::ZERO, whisper_transcription_duration: std::time::Duration::ZERO, - audio_length_seconds: 0.0, + audio_length_seconds: 0.0, } } @@ -80,10 +83,10 @@ pub async fn transcribe_audio( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to extract form data: {}", e), + message: format!("Failed to extract form data: {}", e), error_type: "invalid_request_error".to_string(), - param: Some("form".to_string()), - code: None, + param: Some("form".to_string()), + code: None, }, }), )); @@ -115,10 +118,10 @@ pub async fn transcribe_audio( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to process audio file: {}", e), + message: format!("Failed to process audio file: {}", e), error_type: "invalid_request_error".to_string(), - param: Some("file".to_string()), - code: None, + param: Some("file".to_string()), + code: None, }, }), )); @@ -149,10 +152,10 @@ pub async fn transcribe_audio( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: ErrorDetail { - message: format!("Transcription failed: {}", e), + message: format!("Transcription failed: {}", e), error_type: "server_error".to_string(), - param: None, - code: None, + param: None, + code: None, }, }), )) @@ -314,10 +317,10 @@ async fn create_transcription_stream( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to load model '{}': {}", model_name, e), + message: format!("Failed to load model '{}': {}", model_name, e), error_type: "invalid_request_error".to_string(), - param: Some("model".to_string()), - code: None, + param: Some("model".to_string()), + code: None, }, }), )); @@ -378,7 +381,7 @@ async fn create_transcription_stream( StreamChunk { text: transcript, timestamp: Some(processed as f64 / sample_rate as f64) } } else { StreamChunk { - text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), + text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), timestamp: Some(processed as f64 / sample_rate as f64), } }; diff --git a/apps/silero-vad-whisper-realtime-api/src/vad.rs b/apps/silero-vad-whisper-realtime-api/src/vad.rs index 0c6aa9c..01a1d5b 100644 --- a/apps/silero-vad-whisper-realtime-api/src/vad.rs +++ b/apps/silero-vad-whisper-realtime-api/src/vad.rs @@ -5,14 +5,14 @@ use candle_core::{DType, Device, Tensor}; use candle_onnx::simple_eval; pub struct VADProcessor { - model: candle_onnx::onnx::ModelProto, - frame_size: usize, + model: candle_onnx::onnx::ModelProto, + frame_size: usize, context_size: usize, - sample_rate: Tensor, - state: Tensor, - context: Tensor, - device: Device, - threshold: f32, + sample_rate: Tensor, + state: Tensor, + context: Tensor, + device: Device, + threshold: f32, } impl VADProcessor { diff --git a/apps/silero-vad-whisper-realtime-api/src/whisper.rs b/apps/silero-vad-whisper-realtime-api/src/whisper.rs index 4b3ba32..6957e63 100644 --- a/apps/silero-vad-whisper-realtime-api/src/whisper.rs +++ b/apps/silero-vad-whisper-realtime-api/src/whisper.rs @@ -89,11 +89,11 @@ impl WhichWhisperModel { } pub struct WhisperProcessor { - pub model: WhisperModel, - pub tokenizer: Tokenizer, - pub config: Config, + pub model: WhisperModel, + pub tokenizer: Tokenizer, + pub config: Config, pub mel_filters: Vec, - pub device: Device, + pub device: Device, } impl WhisperProcessor { From ec102831e77593808fc2d60df2638e7383b5d220 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 22:21:33 +0800 Subject: [PATCH 08/14] docs: en --- .../silero-vad-whisper-realtime-api/README.md | 219 +++++++++--------- 1 file changed, 106 insertions(+), 113 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/README.md b/apps/silero-vad-whisper-realtime-api/README.md index 8208fd7..34e35c8 100644 --- a/apps/silero-vad-whisper-realtime-api/README.md +++ b/apps/silero-vad-whisper-realtime-api/README.md @@ -1,101 +1,94 @@ # ASR API - OpenAI Compatible Audio Transcription Service -🎤 一个兼容OpenAI格式的语音转录API服务,支持实时流式响应(SSE),集成了Silero VAD和Whisper模型。 +🎤 An OpenAI-compatible speech transcription API service with real-time streaming response (SSE), integrated with Silero VAD and Whisper models. -## ✨ 功能特性 +## ✨ Features -- 🔄 **兼容OpenAI API**: 完全兼容OpenAI `/v1/audio/transcriptions` 端点格式 -- 📡 **Server-Sent Events (SSE)**: 支持流式响应,实时获取转录结果 -- 🎯 **语音活动检测**: 集成Silero VAD,智能检测语音片段 -- 🧠 **Whisper转录**: 使用Candle框架实现的高效Whisper模型 -- 🚀 **高性能**: 支持GPU加速(CUDA/Metal) -- 🌐 **现代Web界面**: 包含完整的测试页面 +- 🔄 **OpenAI API Compatible**: Full compatibility with OpenAI `/v1/audio/transcriptions` endpoint format +- 📡 **Server-Sent Events (SSE)**: Supports streaming responses for real-time transcription results +- 🎯 **Voice Activity Detection**: Integrated with Silero VAD for intelligent speech segment detection +- 🧠 **Whisper Transcription**: High-performance Whisper model implementation using Candle framework +- 🚀 **High Performance**: Supports GPU acceleration (CUDA/Metal) +- 🌐 **Modern Web Interface**: Includes complete testing page -## 🚀 快速开始 +## 🚀 Quick Start -### 1. 启动服务器 +### 1. Start the Server ```bash -# 进入项目目录 +# Navigate to project directory cd apps/asr-api -# 安装依赖并启动 +# Install dependencies and start cargo run --release ``` -服务器将在 `http://localhost:3000` 启动。 +The server will start at `http://localhost:3000`. -### 2. 测试API - -打开浏览器访问测试页面: -``` -http://localhost:3000/test.html -``` - -或者使用curl命令: +### 2. Test API ```bash -# 基础转录 +# Basic transcription curl -X POST http://localhost:3000/v1/audio/transcriptions \ -F "file=@your_audio.wav" \ -F "model=whisper-1" -# 流式转录 +# Streaming transcription curl -X POST "http://localhost:3000/v1/audio/transcriptions?stream=true" \ -F "file=@your_audio.wav" \ -F "model=whisper-1" \ --no-buffer ``` -## 📋 API文档 +## 📋 API Documentation ### POST `/v1/audio/transcriptions` -转录音频文件为文本。 +Transcribe audio file to text. -#### 请求参数 +#### Request Parameters -| 参数 | 类型 | 必需 | 描述 | -|------|------|------|------| -| `file` | File | ✅ | 要转录的音频文件 | -| `model` | String | ❌ | 模型名称 (默认: "whisper-1") | -| `language` | String | ❌ | 音频语言 | -| `prompt` | String | ❌ | 提示文本 | -| `response_format` | String | ❌ | 响应格式 (默认: "json") | -| `temperature` | Float | ❌ | 采样温度 (默认: 0.0) | -| `stream` | Boolean | ❌ | 启用流式响应 (Query参数) | +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `file` | File | ✅ | Audio file to transcribe | +| `model` | String | ❌ | Model name (default: "whisper-1") | +| `language` | String | ❌ | Audio language | +| `prompt` | String | ❌ | Prompt text | +| `response_format` | String | ❌ | Response format (default: "json") | +| `temperature` | Float | ❌ | Sampling temperature (default: 0.0) | +| `stream` | Boolean | ❌ | Enable streaming response (Query parameter) | -#### 支持的音频格式 +#### Supported Audio Formats - WAV - MP3 - FLAC - M4A -- 以及Symphonia支持的其他格式 +- And other formats supported by Symphonia -#### 响应格式 +#### Response Format -**标准响应 (JSON)**: +**Standard Response (JSON)**: ```json { - "text": "转录的文本内容" + "text": "Transcribed text content" } ``` -**流式响应 (SSE)**: +**Streaming Response (SSE)**: ``` data: {"text": "Processing audio chunk 1 of 4...", "timestamp": 0.5} data: {"text": "Processing audio chunk 2 of 4...", "timestamp": 1.0} -data: {"text": "转录完成的文本", "timestamp": 2.5} +data: {"text": "Completed transcription text", "timestamp": 2.5} ``` -**错误响应**: +**Error Response**: ```json { "error": { - "message": "错误描述", + "message": "Error description", "type": "invalid_request_error", "param": "file", "code": null @@ -103,123 +96,123 @@ data: {"text": "转录完成的文本", "timestamp": 2.5} } ``` -## 🛠️ 开发指南 +## 🛠️ Development Guide -### 项目结构 +### Project Structure ``` apps/asr-api/ ├── src/ -│ ├── main.rs # 主服务器文件 -│ ├── vad.rs # VAD处理器 -│ ├── whisper.rs # Whisper处理器 -│ └── audio_manager.rs # 音频缓冲管理 -├── melfilters.bytes # Mel滤波器数据 -├── melfilters128.bytes # 128维Mel滤波器数据 -├── test.html # 测试页面 -├── Cargo.toml # 依赖配置 -└── README.md # 文档 +│ ├── main.rs # Main server file +│ ├── vad.rs # VAD processor +│ ├── whisper.rs # Whisper processor +│ └── audio_manager.rs # Audio buffer management +├── melfilters.bytes # Mel filter data +├── melfilters128.bytes # 128-dim Mel filter data +├── test.html # Test page +├── Cargo.toml # Dependencies configuration +└── README.md # Documentation ``` -### 核心组件 +### Core Components -1. **VAD处理器**: 使用Silero VAD模型检测语音活动 -2. **Whisper处理器**: 使用Candle实现的Whisper模型进行转录 -3. **音频管理器**: 处理音频缓冲和格式转换 -4. **Web服务器**: 基于Axum的高性能HTTP服务器 +1. **VAD Processor**: Uses Silero VAD model for voice activity detection +2. **Whisper Processor**: Uses Candle-implemented Whisper model for transcription +3. **Audio Manager**: Handles audio buffering and format conversion +4. **Web Server**: High-performance HTTP server based on Axum -### 自定义配置 +### Custom Configuration -可以通过修改 `AppState::new()` 方法来调整以下参数: +You can adjust the following parameters by modifying the `AppState::new()` method: -- VAD阈值 (默认: 0.3) -- Whisper模型 (默认: Tiny) -- 设备选择 (自动选择GPU/CPU) +- VAD threshold (default: 0.3) +- Whisper model (default: Tiny) +- Device selection (auto-select GPU/CPU) -### 添加新功能 +### Adding New Features -1. **支持更多音频格式**: 修改 `convert_audio_to_pcm` 函数 -2. **自定义VAD参数**: 在 `VADProcessor::new` 中调整参数 -3. **更大的Whisper模型**: 在 `WhisperProcessor::new` 中选择不同模型 +1. **Support more audio formats**: Modify `convert_audio_to_pcm` function +2. **Custom VAD parameters**: Adjust parameters in `VADProcessor::new` +3. **Larger Whisper models**: Select different models in `WhisperProcessor::new` -## 🔧 高级配置 +## 🔧 Advanced Configuration -### 环境变量 +### Environment Variables ```bash -# 设置日志级别 +# Set log level export RUST_LOG=debug -# 强制使用CPU +# Force CPU usage export CANDLE_FORCE_CPU=1 ``` -### GPU加速 +### GPU Acceleration -#### CUDA支持 +#### CUDA Support ```bash cargo run --release --features cuda ``` -#### Metal支持 (macOS) +#### Metal Support (macOS) ```bash cargo run --release --features metal ``` -## 📊 性能优化 +## 📊 Performance Optimization -### 推荐配置 +### Recommended Configuration -- **内存**: 最少8GB RAM -- **GPU**: NVIDIA GTX 1060 6GB+ 或 Apple M1+ -- **存储**: SSD推荐,用于模型加载 +- **Memory**: Minimum 8GB RAM +- **GPU**: NVIDIA GTX 1060 6GB+ or Apple M1+ +- **Storage**: SSD recommended for model loading -### 批处理优化 +### Batch Processing Optimization -对于大量文件处理,建议: +For processing large numbers of files, consider: -1. 使用更大的Whisper模型获得更好质量 -2. 启用GPU加速 -3. 调整VAD参数减少误检 +1. Use larger Whisper models for better quality +2. Enable GPU acceleration +3. Adjust VAD parameters to reduce false positives -## 🚨 常见问题 +## 🚨 FAQ -### Q: 转录准确率不高怎么办? -A: 尝试以下方法: -- 使用更大的Whisper模型 (medium/large) -- 确保音频质量良好 (16kHz采样率) -- 调整VAD阈值 -- 提供语言参数 +### Q: How to improve transcription accuracy? +A: Try the following methods: +- Use larger Whisper models (medium/large) +- Ensure good audio quality (16kHz sampling rate) +- Adjust VAD threshold +- Provide language parameter -### Q: 服务器启动慢? -A: 首次启动需要下载模型文件,这是正常现象。模型会缓存到本地。 +### Q: Server starts slowly? +A: First startup requires downloading model files, which is normal. Models are cached locally. -### Q: 支持实时语音输入吗? -A: 目前只支持文件上传,实时语音输入可以参考 `silero-vad-whisper-realtime` 项目。 +### Q: Does it support real-time voice input? +A: Currently only supports file upload. For real-time voice input, refer to the `silero-vad-whisper-realtime` project. -### Q: 如何批量处理文件? -A: 可以编写脚本调用API,或者扩展当前代码支持批处理端点。 +### Q: How to batch process files? +A: You can write scripts to call the API, or extend the current code to support batch processing endpoints. -## 🤝 贡献指南 +## 🤝 Contributing -欢迎提交Issue和Pull Request! +Issues and Pull Requests are welcome! -1. Fork项目 -2. 创建功能分支 -3. 提交改动 -4. 发起Pull Request +1. Fork the project +2. Create feature branch +3. Commit changes +4. Submit Pull Request -## 📄 许可证 +## 📄 License -本项目采用与父项目相同的许可证。 +This project uses the same license as the parent project. -## 🙏 致谢 +## 🙏 Acknowledgments -- [Candle](https://github.com/huggingface/candle) - 高性能ML框架 -- [Axum](https://github.com/tokio-rs/axum) - 现代Web框架 -- [OpenAI](https://openai.com/) - API设计参考 -- [Silero VAD](https://github.com/snakers4/silero-vad) - VAD模型 +- [Candle](https://github.com/huggingface/candle) - High-performance ML framework +- [Axum](https://github.com/tokio-rs/axum) - Modern web framework +- [OpenAI](https://openai.com/) - API design reference +- [Silero VAD](https://github.com/snakers4/silero-vad) - VAD model --- -🎯 **提示**: 第一次运行时会自动下载模型文件,请确保网络连接正常。 +🎯 **Tip**: Model files will be automatically downloaded on first run. Please ensure stable network connection. From 4308ba2b261ae09aec5e5c280f5382b81e1b23a3 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 22:24:39 +0800 Subject: [PATCH 09/14] docs: follow bast practices --- .../silero-vad-whisper-realtime-api/README.md | 226 +++++------------- 1 file changed, 54 insertions(+), 172 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/README.md b/apps/silero-vad-whisper-realtime-api/README.md index 34e35c8..caa5073 100644 --- a/apps/silero-vad-whisper-realtime-api/README.md +++ b/apps/silero-vad-whisper-realtime-api/README.md @@ -1,52 +1,71 @@ -# ASR API - OpenAI Compatible Audio Transcription Service +# silero-vad-whisper-realtime-api -🎤 An OpenAI-compatible speech transcription API service with real-time streaming response (SSE), integrated with Silero VAD and Whisper models. +> An OpenAI-compatible speech transcription API service with real-time streaming response (SSE), integrated with Silero VAD and Whisper models. +> +> This service provides a `/v1/audio/transcriptions` endpoint that is fully compatible with OpenAI's API format, supporting both standard JSON responses and streaming Server-Sent Events. -## ✨ Features +## Getting started -- 🔄 **OpenAI API Compatible**: Full compatibility with OpenAI `/v1/audio/transcriptions` endpoint format -- 📡 **Server-Sent Events (SSE)**: Supports streaming responses for real-time transcription results -- 🎯 **Voice Activity Detection**: Integrated with Silero VAD for intelligent speech segment detection -- 🧠 **Whisper Transcription**: High-performance Whisper model implementation using Candle framework -- 🚀 **High Performance**: Supports GPU acceleration (CUDA/Metal) -- 🌐 **Modern Web Interface**: Includes complete testing page +``` +git clone https://github.com/proj-airi/candle-examples.git +cd apps/silero-vad-whisper-realtime-api +``` -## 🚀 Quick Start +## Build -### 1. Start the Server +``` +cargo fetch --locked +cargo clean +``` -```bash -# Navigate to project directory -cd apps/asr-api +### NVIDIA CUDA -# Install dependencies and start -cargo run --release +``` +cargo build --package silero-vad-whisper-realtime-api --features cuda +``` + +### macOS Metal + +``` +cargo build --package silero-vad-whisper-realtime-api --features metal +``` + +### CPU Only + +``` +cargo build --package silero-vad-whisper-realtime-api +``` + +## Run + +### Any platforms + +```shell +cargo run --package silero-vad-whisper-realtime-api --release ``` The server will start at `http://localhost:3000`. -### 2. Test API +## Usage + +### Basic transcription ```bash -# Basic transcription curl -X POST http://localhost:3000/v1/audio/transcriptions \ -F "file=@your_audio.wav" \ -F "model=whisper-1" +``` + +### Streaming transcription -# Streaming transcription +```bash curl -X POST "http://localhost:3000/v1/audio/transcriptions?stream=true" \ -F "file=@your_audio.wav" \ -F "model=whisper-1" \ --no-buffer ``` -## 📋 API Documentation - -### POST `/v1/audio/transcriptions` - -Transcribe audio file to text. - -#### Request Parameters +## API Parameters | Parameter | Type | Required | Description | |-----------|------|----------|-------------| @@ -56,88 +75,14 @@ Transcribe audio file to text. | `prompt` | String | ❌ | Prompt text | | `response_format` | String | ❌ | Response format (default: "json") | | `temperature` | Float | ❌ | Sampling temperature (default: 0.0) | -| `stream` | Boolean | ❌ | Enable streaming response (Query parameter) | - -#### Supported Audio Formats - -- WAV -- MP3 -- FLAC -- M4A -- And other formats supported by Symphonia - -#### Response Format +| `stream` | Boolean | ❌ | Enable streaming response (query parameter) | -**Standard Response (JSON)**: -```json -{ - "text": "Transcribed text content" -} -``` +## Supported Audio Formats -**Streaming Response (SSE)**: -``` -data: {"text": "Processing audio chunk 1 of 4...", "timestamp": 0.5} +- WAV, MP3, FLAC, M4A +- Any format supported by Symphonia -data: {"text": "Processing audio chunk 2 of 4...", "timestamp": 1.0} - -data: {"text": "Completed transcription text", "timestamp": 2.5} -``` - -**Error Response**: -```json -{ - "error": { - "message": "Error description", - "type": "invalid_request_error", - "param": "file", - "code": null - } -} -``` - -## 🛠️ Development Guide - -### Project Structure - -``` -apps/asr-api/ -├── src/ -│ ├── main.rs # Main server file -│ ├── vad.rs # VAD processor -│ ├── whisper.rs # Whisper processor -│ └── audio_manager.rs # Audio buffer management -├── melfilters.bytes # Mel filter data -├── melfilters128.bytes # 128-dim Mel filter data -├── test.html # Test page -├── Cargo.toml # Dependencies configuration -└── README.md # Documentation -``` - -### Core Components - -1. **VAD Processor**: Uses Silero VAD model for voice activity detection -2. **Whisper Processor**: Uses Candle-implemented Whisper model for transcription -3. **Audio Manager**: Handles audio buffering and format conversion -4. **Web Server**: High-performance HTTP server based on Axum - -### Custom Configuration - -You can adjust the following parameters by modifying the `AppState::new()` method: - -- VAD threshold (default: 0.3) -- Whisper model (default: Tiny) -- Device selection (auto-select GPU/CPU) - -### Adding New Features - -1. **Support more audio formats**: Modify `convert_audio_to_pcm` function -2. **Custom VAD parameters**: Adjust parameters in `VADProcessor::new` -3. **Larger Whisper models**: Select different models in `WhisperProcessor::new` - -## 🔧 Advanced Configuration - -### Environment Variables +## Environment Variables ```bash # Set log level @@ -147,72 +92,9 @@ export RUST_LOG=debug export CANDLE_FORCE_CPU=1 ``` -### GPU Acceleration - -#### CUDA Support -```bash -cargo run --release --features cuda -``` - -#### Metal Support (macOS) -```bash -cargo run --release --features metal -``` - -## 📊 Performance Optimization +## Acknowledgements -### Recommended Configuration - -- **Memory**: Minimum 8GB RAM -- **GPU**: NVIDIA GTX 1060 6GB+ or Apple M1+ -- **Storage**: SSD recommended for model loading - -### Batch Processing Optimization - -For processing large numbers of files, consider: - -1. Use larger Whisper models for better quality -2. Enable GPU acceleration -3. Adjust VAD parameters to reduce false positives - -## 🚨 FAQ - -### Q: How to improve transcription accuracy? -A: Try the following methods: -- Use larger Whisper models (medium/large) -- Ensure good audio quality (16kHz sampling rate) -- Adjust VAD threshold -- Provide language parameter - -### Q: Server starts slowly? -A: First startup requires downloading model files, which is normal. Models are cached locally. - -### Q: Does it support real-time voice input? -A: Currently only supports file upload. For real-time voice input, refer to the `silero-vad-whisper-realtime` project. - -### Q: How to batch process files? -A: You can write scripts to call the API, or extend the current code to support batch processing endpoints. - -## 🤝 Contributing - -Issues and Pull Requests are welcome! - -1. Fork the project -2. Create feature branch -3. Commit changes -4. Submit Pull Request - -## 📄 License - -This project uses the same license as the parent project. - -## 🙏 Acknowledgments - -- [Candle](https://github.com/huggingface/candle) - High-performance ML framework -- [Axum](https://github.com/tokio-rs/axum) - Modern web framework +- [candle](https://github.com/huggingface/candle) - High-performance ML framework +- [axum](https://github.com/tokio-rs/axum) - Modern web framework - [OpenAI](https://openai.com/) - API design reference -- [Silero VAD](https://github.com/snakers4/silero-vad) - VAD model - ---- - -🎯 **Tip**: Model files will be automatically downloaded on first run. Please ensure stable network connection. +- [Silero VAD](https://github.com/snakers4/silero-vad) - Voice activity detection model From e43d64cc5159ec12e73704f502ad0a4b8e8559a3 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 22:27:38 +0800 Subject: [PATCH 10/14] chore: fix Co-authored-by: Neko Ayaka --- .../src/main.rs | 5 ++-- .../src/router.rs | 24 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index cd106c8..902a3a8 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -130,7 +130,7 @@ async fn main() -> Result<()> { // Build application routes let app = Router::new() - .route("/", get(health_check)) + .route("/healthz", get(health_check)) .route("/v1/audio/transcriptions", post(transcribe_audio)) .layer( ServiceBuilder::new() @@ -140,10 +140,11 @@ async fn main() -> Result<()> { .with_state(Arc::new(state)); // Start server + // TODO: use `PORT` as port from environment variables let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; println!("🚀 ASR API server running on http://0.0.0.0:3000"); println!("📝 Available endpoints:"); - println!(" GET / - Health check"); + println!(" GET /healthz - Health check"); println!(" POST /v1/audio/transcriptions - Audio transcription (OpenAI compatible)"); axum::serve(listener, app).await?; diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 5c0e802..8cec1af 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, time::Instant}; +use std::{collections::HashMap, sync::Arc, time::{Duration, Instant}}; use anyhow::Result; use axum::{ @@ -33,22 +33,22 @@ use crate::{ // Performance statistics struct #[derive(Debug)] struct ProcessingStats { - total_duration: std::time::Duration, - audio_conversion_duration: std::time::Duration, - model_loading_duration: std::time::Duration, - vad_processing_duration: std::time::Duration, - whisper_transcription_duration: std::time::Duration, + total_duration: Duration, + audio_conversion_duration: Duration, + model_loading_duration: Duration, + vad_processing_duration: Duration, + whisper_transcription_duration: Duration, audio_length_seconds: f32, } impl ProcessingStats { fn new() -> Self { Self { - total_duration: std::time::Duration::ZERO, - audio_conversion_duration: std::time::Duration::ZERO, - model_loading_duration: std::time::Duration::ZERO, - vad_processing_duration: std::time::Duration::ZERO, - whisper_transcription_duration: std::time::Duration::ZERO, + total_duration: Duration::ZERO, + audio_conversion_duration: Duration::ZERO, + model_loading_duration: Duration::ZERO, + vad_processing_duration: Duration::ZERO, + whisper_transcription_duration: Duration::ZERO, audio_length_seconds: 0.0, } } @@ -268,7 +268,7 @@ async fn transcribe_audio_complete( let mut frame_buffer = Vec::::new(); let vad_start = Instant::now(); - let mut whisper_total_time = std::time::Duration::ZERO; + let mut whisper_total_time = Duration::ZERO; // Process in chunks for chunk in audio_data.chunks(1024) { From 75330426d2581e64da0a0b9b8548d7ba18391d49 Mon Sep 17 00:00:00 2001 From: RainbowBird Date: Sun, 22 Jun 2025 22:36:10 +0800 Subject: [PATCH 11/14] fix: fmt --- apps/silero-vad-whisper-realtime-api/src/router.rs | 6 +++++- rustfmt.toml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 8cec1af..ba541d3 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, sync::Arc, time::{Duration, Instant}}; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; use anyhow::Result; use axum::{ diff --git a/rustfmt.toml b/rustfmt.toml index 0efea85..8644cdd 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -15,7 +15,7 @@ match_block_trailing_comma = true # overflow_delimited_expr = true # nightly # reorder_impl_items = true # nightly # spaces_around_ranges = true # nightly -# unstable_features = true # nightly +unstable_features = true # nightly use_field_init_shorthand = true use_try_shorthand = true struct_field_align_threshold = 999 From 24bfbefa9c1b6932ad3d90913569dab93b0f4157 Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Sun, 22 Jun 2025 22:28:21 +0800 Subject: [PATCH 12/14] chore: lint --- .../silero-vad-whisper-realtime-api/src/main.rs | 4 ++-- .../src/router.rs | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index 902a3a8..d1b2fa5 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -74,14 +74,14 @@ impl AppState { { let models = self.whisper_models.read().await; if let Some(processor) = models.get(model_name) { - println!("🔄 Using cached Whisper model: {}", model_name); + println!("🔄 Using cached Whisper model: {model_name}"); return Ok(processor.clone()); } } // If not in cache, create new model with timing let loading_start = std::time::Instant::now(); - println!("🧠 Loading new Whisper model: {}", model_name); + println!("🧠 Loading new Whisper model: {model_name}"); let whisper_model = Self::parse_model_name(model_name)?; let processor = Arc::new(Mutex::new(WhisperProcessor::new(whisper_model, self.device.clone())?)); diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index ba541d3..742a255 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -67,7 +67,7 @@ impl ProcessingStats { println!(" 🎵 Audio length: {:.2}s", self.audio_length_seconds); if self.audio_length_seconds > 0.0 { let real_time_factor = self.total_duration.as_secs_f64() / self.audio_length_seconds as f64; - println!(" ⚡ Real-time factor: {:.2}x", real_time_factor); + println!(" ⚡ Real-time factor: {real_time_factor:.2}x"); } } } @@ -87,7 +87,7 @@ pub async fn transcribe_audio( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to extract form data: {}", e), + message: format!("Failed to extract form data: {e}"), error_type: "invalid_request_error".to_string(), param: Some("form".to_string()), code: None, @@ -97,7 +97,7 @@ pub async fn transcribe_audio( }, }; - println!("Request params: {:?}", params); + println!("Request params: {params:?}"); // Parse streaming parameter from form data let stream_enabled = params @@ -107,11 +107,10 @@ pub async fn transcribe_audio( // Get model name from parameters and clone it to make it owned let model_name = params - .get("model") - .map(|s| s.clone()) // Clone to make it owned + .get("model").cloned() // Clone to make it owned .unwrap_or_else(|| "tiny".to_string()); // Use tiny as default - println!("Using model: {}, streaming: {}", model_name, stream_enabled); + println!("Using model: {model_name}, streaming: {stream_enabled}"); // Convert audio to PCM format with timing let conversion_start = Instant::now(); @@ -122,7 +121,7 @@ pub async fn transcribe_audio( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to process audio file: {}", e), + message: format!("Failed to process audio file: {e}"), error_type: "invalid_request_error".to_string(), param: Some("file".to_string()), code: None, @@ -156,7 +155,7 @@ pub async fn transcribe_audio( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: ErrorDetail { - message: format!("Transcription failed: {}", e), + message: format!("Transcription failed: {e}"), error_type: "server_error".to_string(), param: None, code: None, @@ -321,7 +320,7 @@ async fn create_transcription_stream( StatusCode::BAD_REQUEST, Json(ErrorResponse { error: ErrorDetail { - message: format!("Failed to load model '{}': {}", model_name, e), + message: format!("Failed to load model '{model_name}': {e}"), error_type: "invalid_request_error".to_string(), param: Some("model".to_string()), code: None, From 87b957d58bcba78de0f970048742df1c8bc2455b Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Sun, 22 Jun 2025 22:37:25 +0800 Subject: [PATCH 13/14] chore: lint --- .../src/main.rs | 1 + .../src/router.rs | 63 ++++++++++--------- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs index d1b2fa5..11d5819 100644 --- a/apps/silero-vad-whisper-realtime-api/src/main.rs +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -33,6 +33,7 @@ struct AppState { } impl AppState { + #[allow(clippy::unused_async)] async fn new() -> Result { // Determine device to use let device = if std::env::var("CANDLE_FORCE_CPU").is_ok() { diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index 742a255..adc23da 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -21,7 +21,7 @@ use symphonia::{ audio::{AudioBufferRef, Signal}, codecs::DecoderOptions, formats::FormatOptions, - io::MediaSourceStream, + io::{MediaSourceStream, MediaSourceStreamOptions}, meta::MetadataOptions, probe::Hint, }, @@ -46,7 +46,7 @@ struct ProcessingStats { } impl ProcessingStats { - fn new() -> Self { + const fn new() -> Self { Self { total_duration: Duration::ZERO, audio_conversion_duration: Duration::ZERO, @@ -66,7 +66,7 @@ impl ProcessingStats { println!(" ⏱️ Total processing: {:.2}ms", self.total_duration.as_secs_f64() * 1000.0); println!(" 🎵 Audio length: {:.2}s", self.audio_length_seconds); if self.audio_length_seconds > 0.0 { - let real_time_factor = self.total_duration.as_secs_f64() / self.audio_length_seconds as f64; + let real_time_factor = self.total_duration.as_secs_f64() / f64::from(self.audio_length_seconds); println!(" ⚡ Real-time factor: {real_time_factor:.2}x"); } } @@ -77,7 +77,7 @@ pub async fn transcribe_audio( mut multipart: Multipart, ) -> Result)> { let start_time = Instant::now(); - let mut stats = ProcessingStats::new(); + let mut processing_stats = ProcessingStats::new(); // Extract both audio file and parameters from multipart form let (audio_data, params) = match extract_multipart_data(&mut multipart).await { @@ -102,8 +102,7 @@ pub async fn transcribe_audio( // Parse streaming parameter from form data let stream_enabled = params .get("stream") - .map(|s| s.parse::().unwrap_or(false)) - .unwrap_or(false); + .is_some_and(|s| s.parse::().unwrap_or(false)); // Get model name from parameters and clone it to make it owned let model_name = params @@ -130,27 +129,28 @@ pub async fn transcribe_audio( )); }, }; - stats.audio_conversion_duration = conversion_start.elapsed(); - stats.audio_length_seconds = pcm_data.len() as f32 / 16000.0; // Assuming 16kHz sample rate - println!("Audio data length: {} samples ({:.2}s)", pcm_data.len(), stats.audio_length_seconds); + processing_stats.audio_conversion_duration = conversion_start.elapsed(); + processing_stats.audio_length_seconds = pcm_data.len() as f32 / 16000.0; // Assuming 16kHz sample rate + + println!("Audio data length: {} samples ({:.2}s)", pcm_data.len(), processing_stats.audio_length_seconds); if stream_enabled { // Return streaming response - let stream = create_transcription_stream(state, model_name, pcm_data, stats).await?; + let stream = create_transcription_stream(state, model_name, pcm_data, processing_stats).await?; let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok(sse.into_response()) } else { // Return single response - match transcribe_audio_complete(state, model_name, pcm_data, &mut stats).await { + match transcribe_audio_complete(state, model_name, pcm_data, &mut processing_stats).await { Ok(transcript) => { - stats.total_duration = start_time.elapsed(); - stats.print_summary(); + processing_stats.total_duration = start_time.elapsed(); + processing_stats.print_summary(); Ok(Json(TranscriptionResponse { text: transcript }).into_response()) }, Err(e) => { - stats.total_duration = start_time.elapsed(); - stats.print_summary(); + processing_stats.total_duration = start_time.elapsed(); + processing_stats.print_summary(); Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { @@ -192,15 +192,16 @@ async fn extract_multipart_data(multipart: &mut Multipart) -> Result<(Vec, H } // Convert various audio formats to PCM +#[allow(clippy::unused_async)] async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { let cursor = std::io::Cursor::new(audio_data.to_vec()); - let media_source = MediaSourceStream::new(Box::new(cursor), Default::default()); + let media_source = MediaSourceStream::new(Box::new(cursor), MediaSourceStreamOptions::default()); let mut hint = Hint::new(); hint.mime_type("audio/wav"); // You might want to detect this automatically - let meta_opts: MetadataOptions = Default::default(); - let fmt_opts: FormatOptions = Default::default(); + let meta_opts = MetadataOptions::default(); + let fmt_opts = FormatOptions::default(); let probed = get_probe().format(&hint, media_source, &fmt_opts, &meta_opts)?; @@ -211,7 +212,7 @@ async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; - let dec_opts: DecoderOptions = Default::default(); + let dec_opts = DecoderOptions::default(); let mut decoder = symphonia::default::get_codecs().make(&track.codec_params, &dec_opts)?; let track_id = track.id; @@ -236,6 +237,7 @@ async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { }, AudioBufferRef::S32(buf) => { for &sample in buf.chan(0) { + #[allow(clippy::cast_precision_loss)] pcm_data.push(sample as f32 / i32::MAX as f32); } }, @@ -253,14 +255,14 @@ async fn transcribe_audio_complete( state: Arc, model_name: String, // Change to owned String audio_data: Vec, - stats: &mut ProcessingStats, + processing_stats: &mut ProcessingStats, ) -> Result { let sample_rate = 16000; // Get the appropriate Whisper processor for this model with timing let model_loading_start = Instant::now(); let whisper_processor = state.get_whisper_processor(&model_name).await?; - stats.model_loading_duration = model_loading_start.elapsed(); + processing_stats.model_loading_duration = model_loading_start.elapsed(); // Process audio through VAD and Whisper let mut vad = state.vad.lock().await; @@ -296,8 +298,8 @@ async fn transcribe_audio_complete( } } - stats.vad_processing_duration = vad_start.elapsed() - whisper_total_time; - stats.whisper_transcription_duration = whisper_total_time; + processing_stats.vad_processing_duration = vad_start.elapsed() - whisper_total_time; + processing_stats.whisper_transcription_duration = whisper_total_time; Ok(transcripts.join(" ")) } @@ -307,7 +309,7 @@ async fn create_transcription_stream( state: Arc, model_name: String, // Change to owned String audio_data: Vec, - mut stats: ProcessingStats, + mut processing_stats: ProcessingStats, ) -> Result>, (StatusCode, Json)> { let stream_start = Instant::now(); @@ -329,11 +331,11 @@ async fn create_transcription_stream( )); }, }; - stats.model_loading_duration = model_loading_start.elapsed(); + processing_stats.model_loading_duration = model_loading_start.elapsed(); let sample_rate = 16000; - Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate), stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move { + Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate), processing_stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move { if processed >= audio_data.len() { // Print final statistics for streaming stats.total_duration = stream_start.elapsed(); @@ -381,11 +383,14 @@ async fn create_transcription_stream( // Create event with actual transcription or progress update let event_data = if let Some(transcript) = whisper_result { - StreamChunk { text: transcript, timestamp: Some(processed as f64 / sample_rate as f64) } + #[allow(clippy::cast_precision_loss)] + StreamChunk { text: transcript, timestamp: Some(processed as f64 / f64::from(sample_rate)) } } else { StreamChunk { - text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), - timestamp: Some(processed as f64 / sample_rate as f64), + #[allow(clippy::cast_precision_loss)] + text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), + #[allow(clippy::cast_precision_loss)] + timestamp: Some(processed as f64 / f64::from(sample_rate)), } }; From 28bcbef78e28a1f0888843c69b62dcca121696ed Mon Sep 17 00:00:00 2001 From: Neko Ayaka Date: Sun, 22 Jun 2025 22:52:03 +0800 Subject: [PATCH 14/14] chore: lint --- apps/silero-vad-whisper-realtime-api/src/router.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs index adc23da..40ce4c2 100644 --- a/apps/silero-vad-whisper-realtime-api/src/router.rs +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -72,6 +72,7 @@ impl ProcessingStats { } } +#[allow(clippy::cast_precision_loss)] pub async fn transcribe_audio( State(state): State>, mut multipart: Multipart, @@ -251,6 +252,7 @@ async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { } // Process complete audio file and return full transcript +#[allow(clippy::significant_drop_tightening)] async fn transcribe_audio_complete( state: Arc, model_name: String, // Change to owned String @@ -382,6 +384,7 @@ async fn create_transcription_stream( } // Create event with actual transcription or progress update + #[allow(clippy::option_if_let_else)] let event_data = if let Some(transcript) = whisper_result { #[allow(clippy::cast_precision_loss)] StreamChunk { text: transcript, timestamp: Some(processed as f64 / f64::from(sample_rate)) }