diff --git a/Cargo.lock b/Cargo.lock index 9ea27ca..51b1f5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,17 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -131,6 +142,134 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core 0.4.5", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit 0.7.3", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core 0.5.2", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "multer", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04" +dependencies = [ + "axum 0.7.9", + "axum-core 0.4.5", + "bytes", + "fastrand", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "multer", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -220,6 +359,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.18.1" @@ -538,6 +686,15 @@ dependencies = [ "windows", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -587,6 +744,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "cudarc" version = "0.16.4" @@ -680,6 +847,16 @@ dependencies = [ "syn", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dirs" version = "5.0.1" @@ -1211,6 +1388,16 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -1285,6 +1472,30 @@ version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +[[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64 0.22.1", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.5.0" @@ -1355,12 +1566,24 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.6.0" @@ -1374,6 +1597,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -1762,6 +1986,18 @@ dependencies = [ "libc", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.7.5" @@ -1814,6 +2050,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1861,6 +2107,23 @@ dependencies = [ "syn", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "multimap" version = "0.10.1" @@ -2827,6 +3090,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_plain" version = "1.0.2" @@ -2848,6 +3121,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2951,6 +3235,38 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "silero-vad-whisper-realtime-api" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "axum-extra", + "byteorder", + "candle-core", + "candle-nn", + "candle-onnx", + "candle-transformers", + "clap", + "cpal", + "crossbeam-channel", + "futures", + "hf-hub", + "rand 0.9.1", + "rubato", + "serde", + "serde_json", + "symphonia", + "tokenizers", + "tokio", + "tokio-stream", + "tower", + "tower-http", + "tracing", + "tracing-chrome", + "tracing-subscriber", +] + [[package]] name = "slab" version = "0.4.10" @@ -2984,6 +3300,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -3384,6 +3706,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.15" @@ -3427,6 +3760,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3437,14 +3771,24 @@ checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ "bitflags 2.9.1", "bytes", + "futures-core", "futures-util", "http", "http-body", + "http-body-util", + "http-range-header", + "httpdate", "iri-string", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3465,6 +3809,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -3543,6 +3888,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "ug" version = "0.4.0" @@ -3591,6 +3942,12 @@ dependencies = [ "ug", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.18" diff --git a/Cargo.toml b/Cargo.toml index 13b467d..d935d72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ version = "0.0.1" authors = ["Project AIRI"] edition = "2024" -rust-version = "1.80" +rust-version = "1.85" readme = "README.md" homepage = "https://github.com/proj-airi/candle-examples" repository = "https://github.com/proj-airi/candle-examples" @@ -17,6 +17,7 @@ members = [ "apps/silero-vad-realtime", "apps/silero-vad-realtime-minimum", "apps/silero-vad-whisper-realtime", + "apps/silero-vad-whisper-realtime-api", "apps/whisper-realtime", ] diff --git a/apps/orpheus-tts/src/main.rs b/apps/orpheus-tts/src/main.rs index 977ee31..c0ba359 100644 --- a/apps/orpheus-tts/src/main.rs +++ b/apps/orpheus-tts/src/main.rs @@ -16,7 +16,7 @@ use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43 -const STOP_TOKEN_ID: u32 = 128258; +const STOP_TOKEN_ID: u32 = 128_258; #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] enum Voice { @@ -39,21 +39,22 @@ enum Voice { } impl Voice { - fn as_str(&self) -> &'static str { + const fn as_str(self) -> &'static str { match self { - Voice::Tara => "tara", - Voice::Leah => "leah", - Voice::Jess => "jess", - Voice::Leo => "leo", - Voice::Dan => "dan", - Voice::Mia => "mia", - Voice::Zac => "zac", - Voice::Zoe => "zoe", + Self::Tara => "tara", + Self::Leah => "leah", + Self::Jess => "jess", + Self::Leo => "leo", + Self::Dan => "dan", + Self::Mia => "mia", + Self::Zac => "zac", + Self::Zoe => "zoe", } } } #[derive(Parser)] +#[allow(clippy::struct_excessive_bools)] struct Args { #[arg(long)] cpu: bool, @@ -154,12 +155,14 @@ pub trait Sample { } impl Sample for f32 { + #[allow(clippy::cast_possible_truncation)] fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } } impl Sample for f64 { + #[allow(clippy::cast_possible_truncation)] fn to_i16(&self) -> i16 { (self.clamp(-1.0, 1.0) * 32767.0) as i16 } @@ -171,6 +174,7 @@ impl Sample for i16 { } } +#[allow(clippy::missing_panics_doc)] pub fn write_pcm_as_wav( w: &mut W, samples: &[S], @@ -178,9 +182,9 @@ pub fn write_pcm_as_wav( ) -> std::io::Result<()> { let len = 12u32; // header let len = len + 24u32; // fmt - let len = len + samples.len() as u32 * 2 + 8; // data + let len = len + u32::try_from(samples.len()).unwrap() * 2 + 8; // data let n_channels = 1u16; - let bytes_per_second = sample_rate * 2 * n_channels as u32; + let bytes_per_second = sample_rate * 2 * u32::from(n_channels); w.write_all(b"RIFF")?; w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes w.write_all(b"WAVE")?; @@ -197,21 +201,21 @@ pub fn write_pcm_as_wav( // Data block w.write_all(b"data")?; - w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; - for sample in samples.iter() { - w.write_all(&sample.to_i16().to_le_bytes())? + w.write_all(&(u32::try_from(samples.len()).unwrap() * 2).to_le_bytes())?; + for sample in samples { + w.write_all(&sample.to_i16().to_le_bytes())?; } Ok(()) } struct Model { - model: LlamaModel, + llama: LlamaModel, tokenizer: Tokenizer, logits_processor: LogitsProcessor, cache: Cache, device: Device, verbose_prompt: bool, - snac_model: SnacModel, + snac: SnacModel, out_file: String, voice: Voice, } @@ -226,15 +230,14 @@ impl Model { .build()?; let model_id = match args.orpheus_model_id { - Some(model_id) => model_id.to_string(), + Some(model_id) => model_id, None => match args.which_orpheus_model { WhichOrpheusModel::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(), }, }; - let revision = match args.revision { - Some(r) => r, - None => "main".to_string(), - }; + let revision = args + .revision + .map_or_else(|| "main".to_string(), |r| r); let repo = api.repo(hf_hub::Repo::with_revision(model_id, hf_hub::RepoType::Model, revision)); let model_file = match args.orpheus_model_file { Some(m) => vec![m.into()], @@ -292,16 +295,16 @@ impl Model { println!("load the model in {:?}", start_time.elapsed()); let cache = Cache::new(true, dtype, &config, &device)?; - let snac_model = load_snac_model(hf_token.clone().as_str(), &device)?; + let snac_model = load_snac_model(hf_token.as_str(), &device)?; Ok(Self { - model, + llama: model, tokenizer, logits_processor, cache, device, verbose_prompt: args.verbose_prompt, - snac_model, + snac: snac_model, out_file: args.out_file, voice: args.voice, }) @@ -319,7 +322,7 @@ impl Model { .map_err(E::msg)?; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82 - let mut tokens = [&[128259], tokens.get_ids(), &[128009, 128260, 128261, 128257]].concat(); + let mut tokens = [&[128_259], tokens.get_ids(), &[128_009, 128_260, 128_261, 128_257]].concat(); if self.verbose_prompt { println!("prompt tokens: {tokens:?}"); } @@ -342,7 +345,7 @@ impl Model { let context = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(context, device)?.unsqueeze(0)?; let logits = self - .model + .llama .forward(&input, context_index, &mut cache)?; let logits = logits.squeeze(0)?; index_pos += context.len(); @@ -354,7 +357,7 @@ impl Model { Some(tok) => { let tok = tok.parse::()?; // https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63 - let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096); + let tok = tok - 10 - ((u32::try_from(audio_tokens.len()).unwrap() % 7) * 4096); audio_tokens.push(tok); }, None => { @@ -393,9 +396,7 @@ impl Model { let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?; let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?; let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?; - let pcm = self - .snac_model - .decode(&[&codes0, &codes1, &codes2])?; + let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?; println!("decoded to pcm {pcm:?}"); let mut output = std::fs::File::create(&self.out_file)?; diff --git a/apps/silero-vad-whisper-realtime-api/Cargo.toml b/apps/silero-vad-whisper-realtime-api/Cargo.toml new file mode 100644 index 0000000..514a9c5 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "silero-vad-whisper-realtime-api" +version = "0.1.0" +edition = "2024" + +[dependencies] +anyhow = "1.0.98" +byteorder = "1.5.0" +candle-core = { version = "0.9.1" } +candle-nn = { version = "0.9.1" } +candle-transformers = { version = "0.9.1" } +candle-onnx = { version = "0.9.1" } +clap = { version = "4.5.38", features = ["derive"] } +cpal = "0.15.3" +hf-hub = "0.4.2" +rand = "0.9.1" +rubato = "0.16.2" +serde_json = "1.0.140" +symphonia = "0.5.4" +tokenizers = "0.21.1" +tracing-chrome = "0.7.2" +tracing-subscriber = "0.3.19" +tracing = "0.1.41" +tokio = "1.45.1" +crossbeam-channel = "0.5.15" +axum = { version = "0.8.4", features = ["multipart"] } +serde = { version = "1.0.219", features = ["derive"] } + +# Server-sent events and HTTP utilities +futures = "0.3.31" +tokio-stream = "0.1.17" +axum-extra = { version = "0.9.5", features = ["typed-header"] } +tower = "0.5.1" +tower-http = { version = "0.6.2", features = ["fs", "cors"] } + +[features] +default = [] +metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] +cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] diff --git a/apps/silero-vad-whisper-realtime-api/README.md b/apps/silero-vad-whisper-realtime-api/README.md new file mode 100644 index 0000000..caa5073 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/README.md @@ -0,0 +1,100 @@ +# silero-vad-whisper-realtime-api + +> An OpenAI-compatible speech transcription API service with real-time streaming response (SSE), integrated with Silero VAD and Whisper models. +> +> This service provides a `/v1/audio/transcriptions` endpoint that is fully compatible with OpenAI's API format, supporting both standard JSON responses and streaming Server-Sent Events. + +## Getting started + +``` +git clone https://github.com/proj-airi/candle-examples.git +cd apps/silero-vad-whisper-realtime-api +``` + +## Build + +``` +cargo fetch --locked +cargo clean +``` + +### NVIDIA CUDA + +``` +cargo build --package silero-vad-whisper-realtime-api --features cuda +``` + +### macOS Metal + +``` +cargo build --package silero-vad-whisper-realtime-api --features metal +``` + +### CPU Only + +``` +cargo build --package silero-vad-whisper-realtime-api +``` + +## Run + +### Any platforms + +```shell +cargo run --package silero-vad-whisper-realtime-api --release +``` + +The server will start at `http://localhost:3000`. + +## Usage + +### Basic transcription + +```bash +curl -X POST http://localhost:3000/v1/audio/transcriptions \ + -F "file=@your_audio.wav" \ + -F "model=whisper-1" +``` + +### Streaming transcription + +```bash +curl -X POST "http://localhost:3000/v1/audio/transcriptions?stream=true" \ + -F "file=@your_audio.wav" \ + -F "model=whisper-1" \ + --no-buffer +``` + +## API Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `file` | File | ✅ | Audio file to transcribe | +| `model` | String | ❌ | Model name (default: "whisper-1") | +| `language` | String | ❌ | Audio language | +| `prompt` | String | ❌ | Prompt text | +| `response_format` | String | ❌ | Response format (default: "json") | +| `temperature` | Float | ❌ | Sampling temperature (default: 0.0) | +| `stream` | Boolean | ❌ | Enable streaming response (query parameter) | + +## Supported Audio Formats + +- WAV, MP3, FLAC, M4A +- Any format supported by Symphonia + +## Environment Variables + +```bash +# Set log level +export RUST_LOG=debug + +# Force CPU usage +export CANDLE_FORCE_CPU=1 +``` + +## Acknowledgements + +- [candle](https://github.com/huggingface/candle) - High-performance ML framework +- [axum](https://github.com/tokio-rs/axum) - Modern web framework +- [OpenAI](https://openai.com/) - API design reference +- [Silero VAD](https://github.com/snakers4/silero-vad) - Voice activity detection model diff --git a/apps/silero-vad-whisper-realtime-api/melfilters.bytes b/apps/silero-vad-whisper-realtime-api/melfilters.bytes new file mode 100644 index 0000000..0874829 Binary files /dev/null and b/apps/silero-vad-whisper-realtime-api/melfilters.bytes differ diff --git a/apps/silero-vad-whisper-realtime-api/melfilters128.bytes b/apps/silero-vad-whisper-realtime-api/melfilters128.bytes new file mode 100644 index 0000000..f287c5b Binary files /dev/null and b/apps/silero-vad-whisper-realtime-api/melfilters128.bytes differ diff --git a/apps/silero-vad-whisper-realtime-api/src/api.rs b/apps/silero-vad-whisper-realtime-api/src/api.rs new file mode 100644 index 0000000..a448d7b --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/api.rs @@ -0,0 +1,27 @@ +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct TranscriptionResponse { + pub text: String, +} + +#[derive(Debug, Serialize)] +pub struct StreamChunk { + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub timestamp: Option, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + pub param: Option, + pub code: Option, +} diff --git a/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs new file mode 100644 index 0000000..ba4cd81 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/audio_manager.rs @@ -0,0 +1,116 @@ +use std::time::Instant; + +pub struct AudioBuffer { + buffer: Vec, + max_duration_samples: usize, + min_speech_duration_samples: usize, + min_silence_duration_samples: usize, + is_recording: bool, + silence_start: Option, + speech_start: Option, + samples_since_speech_start: usize, + samples_since_silence_start: usize, + sample_rate: usize, +} + +impl AudioBuffer { + pub const fn new( + max_duration_ms: u64, + min_speech_duration_ms: u64, + min_silence_duration_ms: u64, + sample_rate: u32, + ) -> Self { + let sample_rate = sample_rate as usize; + Self { + buffer: Vec::new(), + max_duration_samples: (max_duration_ms * sample_rate as u64 / 1000) as usize, + min_speech_duration_samples: (min_speech_duration_ms * sample_rate as u64 / 1000) as usize, + min_silence_duration_samples: (min_silence_duration_ms * sample_rate as u64 / 1000) as usize, + is_recording: false, + silence_start: None, + speech_start: None, + samples_since_speech_start: 0, + samples_since_silence_start: 0, + sample_rate, + } + } + + pub fn add_chunk( + &mut self, + chunk: &[f32], + is_speech: bool, + ) -> Option> { + if is_speech { + #[allow(clippy::if_not_else)] + if !self.is_recording { + if self.speech_start.is_none() { + self.speech_start = Some(Instant::now()); + self.samples_since_speech_start = 0; + } + + self.samples_since_speech_start += chunk.len(); + + if self.samples_since_speech_start >= self.min_speech_duration_samples { + self.is_recording = true; + self.silence_start = None; + self.samples_since_silence_start = 0; + println!("🚀 Started recording"); + } + } else { + // Reset silence tracking + self.silence_start = None; + self.samples_since_silence_start = 0; + } + } else { + // Reset speech tracking + self.speech_start = None; + self.samples_since_speech_start = 0; + + if self.is_recording { + if self.silence_start.is_none() { + self.silence_start = Some(Instant::now()); + self.samples_since_silence_start = 0; + } + + self.samples_since_silence_start += chunk.len(); + + if self.samples_since_silence_start >= self.min_silence_duration_samples { + // End of speech detected + if !self.buffer.is_empty() { + let result = self.buffer.clone(); + self.reset(); + #[allow(clippy::cast_precision_loss)] + let duration_secs = result.len() as f32 / self.sample_rate as f32; + println!("🔇 Stopped recording, {duration_secs:.2}s"); + return Some(result); + } + + self.reset(); + } + } + } + + if self.is_recording { + self.buffer.extend_from_slice(chunk); + + // Check if buffer exceeds max duration + if self.buffer.len() >= self.max_duration_samples { + let result = self.buffer.clone(); + self.reset(); + println!("⏰ Max duration reached, {} samples", result.len()); + return Some(result); + } + } + + None + } + + fn reset(&mut self) { + self.buffer.clear(); + self.is_recording = false; + self.silence_start = None; + self.speech_start = None; + self.samples_since_speech_start = 0; + self.samples_since_silence_start = 0; + } +} diff --git a/apps/silero-vad-whisper-realtime-api/src/main.rs b/apps/silero-vad-whisper-realtime-api/src/main.rs new file mode 100644 index 0000000..11d5819 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/main.rs @@ -0,0 +1,162 @@ +use std::{collections::HashMap, sync::Arc}; + +use anyhow::Result; +use axum::{ + Json, + Router, + response::IntoResponse, + routing::{get, post}, +}; +use candle_core::Device; +use tokio::sync::{Mutex, RwLock}; +use tower::ServiceBuilder; +use tower_http::cors::CorsLayer; + +use crate::{ + router::transcribe_audio, + vad::VADProcessor, + whisper::{WhichWhisperModel, WhisperProcessor}, +}; + +mod api; +mod audio_manager; +mod router; +mod vad; +mod whisper; + +// Application state with dynamic model loading +struct AppState { + vad: Arc>, + device: Device, + // Use RwLock for read-heavy workload (checking cache) + whisper_models: Arc>>>>, +} + +impl AppState { + #[allow(clippy::unused_async)] + async fn new() -> Result { + // Determine device to use + let device = if std::env::var("CANDLE_FORCE_CPU").is_ok() { + candle_core::Device::Cpu + } else if candle_core::utils::cuda_is_available() { + candle_core::Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + candle_core::Device::new_metal(0)? + } else { + candle_core::Device::Cpu + }; + + println!("🚀 Using device: {device:?}"); + + // Get VAD threshold from environment or use default + let vad_threshold = std::env::var("VAD_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0.3); + + println!("🎯 VAD threshold: {vad_threshold}"); + + // Initialize VAD processor (always use CPU for VAD) + let vad = VADProcessor::new(candle_core::Device::Cpu, vad_threshold)?; + + Ok(Self { + vad: Arc::new(Mutex::new(vad)), + device, + whisper_models: Arc::new(RwLock::new(HashMap::new())), + }) + } + + // Get or create Whisper processor for the specified model + pub async fn get_whisper_processor( + &self, + model_name: &str, + ) -> Result>> { + // First, try to read from cache + { + let models = self.whisper_models.read().await; + if let Some(processor) = models.get(model_name) { + println!("🔄 Using cached Whisper model: {model_name}"); + return Ok(processor.clone()); + } + } + + // If not in cache, create new model with timing + let loading_start = std::time::Instant::now(); + println!("🧠 Loading new Whisper model: {model_name}"); + + let whisper_model = Self::parse_model_name(model_name)?; + let processor = Arc::new(Mutex::new(WhisperProcessor::new(whisper_model, self.device.clone())?)); + + let loading_time = loading_start.elapsed(); + + // Add to cache + { + let mut models = self.whisper_models.write().await; + models.insert(model_name.to_string(), processor.clone()); + } + + println!("✅ Whisper model loaded and cached: {} ({:.2}ms)", model_name, loading_time.as_secs_f64() * 1000.0); + Ok(processor) + } + + // Parse model name string to WhichWhisperModel enum + fn parse_model_name(model_name: &str) -> Result { + match model_name.to_lowercase().as_str() { + "tiny" => Ok(WhichWhisperModel::Tiny), + "tiny.en" => Ok(WhichWhisperModel::TinyEn), + "base" => Ok(WhichWhisperModel::Base), + "base.en" => Ok(WhichWhisperModel::BaseEn), + "small" => Ok(WhichWhisperModel::Small), + "small.en" => Ok(WhichWhisperModel::SmallEn), + "medium" => Ok(WhichWhisperModel::Medium), + "medium.en" => Ok(WhichWhisperModel::MediumEn), + "large" => Ok(WhichWhisperModel::Large), + "large-v2" => Ok(WhichWhisperModel::LargeV2), + "large-v3" => Ok(WhichWhisperModel::LargeV3), + "large-v3-turbo" => Ok(WhichWhisperModel::LargeV3Turbo), + "distil-medium.en" => Ok(WhichWhisperModel::DistilMediumEn), + "distil-large-v2" => Ok(WhichWhisperModel::DistilLargeV2), + _ => anyhow::bail!("Unsupported Whisper model: {}. Supported models: tiny, base, small, medium, large, large-v2, large-v3", model_name), + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::fmt::init(); + + // Initialize application state + let state = AppState::new().await?; + + // Build application routes + let app = Router::new() + .route("/healthz", get(health_check)) + .route("/v1/audio/transcriptions", post(transcribe_audio)) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) + .into_inner(), + ) + .with_state(Arc::new(state)); + + // Start server + // TODO: use `PORT` as port from environment variables + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + println!("🚀 ASR API server running on http://0.0.0.0:3000"); + println!("📝 Available endpoints:"); + println!(" GET /healthz - Health check"); + println!(" POST /v1/audio/transcriptions - Audio transcription (OpenAI compatible)"); + + axum::serve(listener, app).await?; + Ok(()) +} + +// Health check endpoint +async fn health_check() -> impl IntoResponse { + Json(serde_json::json!({ + "status": "ok", + "service": "ASR API", + "version": "1.0.0" + })) +} diff --git a/apps/silero-vad-whisper-realtime-api/src/router.rs b/apps/silero-vad-whisper-realtime-api/src/router.rs new file mode 100644 index 0000000..40ce4c2 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/router.rs @@ -0,0 +1,404 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use axum::{ + Json, + extract::{Multipart, State}, + http::StatusCode, + response::{ + IntoResponse, + Response, + sse::{Event, KeepAlive, Sse}, + }, +}; +use futures::stream::{self, Stream}; +use symphonia::{ + core::{ + audio::{AudioBufferRef, Signal}, + codecs::DecoderOptions, + formats::FormatOptions, + io::{MediaSourceStream, MediaSourceStreamOptions}, + meta::MetadataOptions, + probe::Hint, + }, + default::get_probe, +}; + +use crate::{ + AppState, + api::{ErrorDetail, ErrorResponse, StreamChunk, TranscriptionResponse}, + audio_manager::AudioBuffer, +}; + +// Performance statistics struct +#[derive(Debug)] +struct ProcessingStats { + total_duration: Duration, + audio_conversion_duration: Duration, + model_loading_duration: Duration, + vad_processing_duration: Duration, + whisper_transcription_duration: Duration, + audio_length_seconds: f32, +} + +impl ProcessingStats { + const fn new() -> Self { + Self { + total_duration: Duration::ZERO, + audio_conversion_duration: Duration::ZERO, + model_loading_duration: Duration::ZERO, + vad_processing_duration: Duration::ZERO, + whisper_transcription_duration: Duration::ZERO, + audio_length_seconds: 0.0, + } + } + + fn print_summary(&self) { + println!("📊 Processing Statistics:"); + println!(" 📁 Audio conversion: {:.2}ms", self.audio_conversion_duration.as_secs_f64() * 1000.0); + println!(" 🧠 Model loading: {:.2}ms", self.model_loading_duration.as_secs_f64() * 1000.0); + println!(" 🎯 VAD processing: {:.2}ms", self.vad_processing_duration.as_secs_f64() * 1000.0); + println!(" 🗣️ Whisper transcription: {:.2}ms", self.whisper_transcription_duration.as_secs_f64() * 1000.0); + println!(" ⏱️ Total processing: {:.2}ms", self.total_duration.as_secs_f64() * 1000.0); + println!(" 🎵 Audio length: {:.2}s", self.audio_length_seconds); + if self.audio_length_seconds > 0.0 { + let real_time_factor = self.total_duration.as_secs_f64() / f64::from(self.audio_length_seconds); + println!(" ⚡ Real-time factor: {real_time_factor:.2}x"); + } + } +} + +#[allow(clippy::cast_precision_loss)] +pub async fn transcribe_audio( + State(state): State>, + mut multipart: Multipart, +) -> Result)> { + let start_time = Instant::now(); + let mut processing_stats = ProcessingStats::new(); + + // Extract both audio file and parameters from multipart form + let (audio_data, params) = match extract_multipart_data(&mut multipart).await { + Ok(data) => data, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to extract form data: {e}"), + error_type: "invalid_request_error".to_string(), + param: Some("form".to_string()), + code: None, + }, + }), + )); + }, + }; + + println!("Request params: {params:?}"); + + // Parse streaming parameter from form data + let stream_enabled = params + .get("stream") + .is_some_and(|s| s.parse::().unwrap_or(false)); + + // Get model name from parameters and clone it to make it owned + let model_name = params + .get("model").cloned() // Clone to make it owned + .unwrap_or_else(|| "tiny".to_string()); // Use tiny as default + + println!("Using model: {model_name}, streaming: {stream_enabled}"); + + // Convert audio to PCM format with timing + let conversion_start = Instant::now(); + let pcm_data = match convert_audio_to_pcm(&audio_data).await { + Ok(data) => data, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to process audio file: {e}"), + error_type: "invalid_request_error".to_string(), + param: Some("file".to_string()), + code: None, + }, + }), + )); + }, + }; + + processing_stats.audio_conversion_duration = conversion_start.elapsed(); + processing_stats.audio_length_seconds = pcm_data.len() as f32 / 16000.0; // Assuming 16kHz sample rate + + println!("Audio data length: {} samples ({:.2}s)", pcm_data.len(), processing_stats.audio_length_seconds); + + if stream_enabled { + // Return streaming response + let stream = create_transcription_stream(state, model_name, pcm_data, processing_stats).await?; + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } else { + // Return single response + match transcribe_audio_complete(state, model_name, pcm_data, &mut processing_stats).await { + Ok(transcript) => { + processing_stats.total_duration = start_time.elapsed(); + processing_stats.print_summary(); + Ok(Json(TranscriptionResponse { text: transcript }).into_response()) + }, + Err(e) => { + processing_stats.total_duration = start_time.elapsed(); + processing_stats.print_summary(); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Transcription failed: {e}"), + error_type: "server_error".to_string(), + param: None, + code: None, + }, + }), + )) + }, + } + } +} + +// Extract both audio file and parameters from multipart form data +async fn extract_multipart_data(multipart: &mut Multipart) -> Result<(Vec, HashMap)> { + let mut audio_data = None; + let mut params = HashMap::new(); + + while let Some(field) = multipart.next_field().await? { + if let Some(name) = field.name() { + let name = name.to_string(); // Clone the name first to avoid borrow conflict + if name == "file" { + // Extract audio file + let data = field.bytes().await?; + audio_data = Some(data.to_vec()); + } else { + // Extract form parameters + let value = field.text().await?; + params.insert(name, value); + } + } + } + + let audio = audio_data.ok_or_else(|| anyhow::anyhow!("No file field found in multipart data"))?; + Ok((audio, params)) +} + +// Convert various audio formats to PCM +#[allow(clippy::unused_async)] +async fn convert_audio_to_pcm(audio_data: &[u8]) -> Result> { + let cursor = std::io::Cursor::new(audio_data.to_vec()); + let media_source = MediaSourceStream::new(Box::new(cursor), MediaSourceStreamOptions::default()); + + let mut hint = Hint::new(); + hint.mime_type("audio/wav"); // You might want to detect this automatically + + let meta_opts = MetadataOptions::default(); + let fmt_opts = FormatOptions::default(); + + let probed = get_probe().format(&hint, media_source, &fmt_opts, &meta_opts)?; + + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; + + let dec_opts = DecoderOptions::default(); + let mut decoder = symphonia::default::get_codecs().make(&track.codec_params, &dec_opts)?; + + let track_id = track.id; + let mut pcm_data = Vec::new(); + + // Decode the audio + while let Ok(packet) = format.next_packet() { + if packet.track_id() != track_id { + continue; + } + + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => { + for &sample in buf.chan(0) { + pcm_data.push(sample); + } + }, + AudioBufferRef::S16(buf) => { + for &sample in buf.chan(0) { + pcm_data.push(f32::from(sample) / f32::from(i16::MAX)); + } + }, + AudioBufferRef::S32(buf) => { + for &sample in buf.chan(0) { + #[allow(clippy::cast_precision_loss)] + pcm_data.push(sample as f32 / i32::MAX as f32); + } + }, + _ => { + anyhow::bail!("Unsupported audio format"); + }, + } + } + + Ok(pcm_data) +} + +// Process complete audio file and return full transcript +#[allow(clippy::significant_drop_tightening)] +async fn transcribe_audio_complete( + state: Arc, + model_name: String, // Change to owned String + audio_data: Vec, + processing_stats: &mut ProcessingStats, +) -> Result { + let sample_rate = 16000; + + // Get the appropriate Whisper processor for this model with timing + let model_loading_start = Instant::now(); + let whisper_processor = state.get_whisper_processor(&model_name).await?; + processing_stats.model_loading_duration = model_loading_start.elapsed(); + + // Process audio through VAD and Whisper + let mut vad = state.vad.lock().await; + let mut whisper = whisper_processor.lock().await; + let mut audio_buffer = AudioBuffer::new(10000, 100, 500, sample_rate); + + let mut transcripts = Vec::new(); + let mut frame_buffer = Vec::::new(); + + let vad_start = Instant::now(); + let mut whisper_total_time = Duration::ZERO; + + // Process in chunks + for chunk in audio_data.chunks(1024) { + frame_buffer.extend_from_slice(chunk); + + // Process 512-sample frames + while frame_buffer.len() >= 512 { + let frame: Vec = frame_buffer.drain(..512).collect(); + let speech_prob = vad.process_chunk(&frame)?; + let is_speech = vad.is_speech(speech_prob); + + if let Some(complete_audio) = audio_buffer.add_chunk(&frame, is_speech) { + // Measure Whisper transcription time + let whisper_start = Instant::now(); + let transcript = whisper.transcribe(&complete_audio)?; + whisper_total_time += whisper_start.elapsed(); + + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { + transcripts.push(transcript.trim().to_string()); + } + } + } + } + + processing_stats.vad_processing_duration = vad_start.elapsed() - whisper_total_time; + processing_stats.whisper_transcription_duration = whisper_total_time; + + Ok(transcripts.join(" ")) +} + +// Create streaming transcription response +async fn create_transcription_stream( + state: Arc, + model_name: String, // Change to owned String + audio_data: Vec, + mut processing_stats: ProcessingStats, +) -> Result>, (StatusCode, Json)> { + let stream_start = Instant::now(); + + // Get the appropriate Whisper processor for this model with timing + let model_loading_start = Instant::now(); + let whisper_processor = match state.get_whisper_processor(&model_name).await { + Ok(processor) => processor, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: ErrorDetail { + message: format!("Failed to load model '{model_name}': {e}"), + error_type: "invalid_request_error".to_string(), + param: Some("model".to_string()), + code: None, + }, + }), + )); + }, + }; + processing_stats.model_loading_duration = model_loading_start.elapsed(); + + let sample_rate = 16000; + + Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate), processing_stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move { + if processed >= audio_data.len() { + // Print final statistics for streaming + stats.total_duration = stream_start.elapsed(); + stats.print_summary(); + return None; + } + + // Process audio in chunks suitable for VAD (512 samples at a time) + let chunk_size = 512.min(audio_data.len() - processed); + let chunk = &audio_data[processed..processed + chunk_size]; + processed += chunk_size; + + // Process through VAD and Whisper processors + let mut whisper_result = None; + + // Process through VAD + let vad_chunk_start = Instant::now(); + let mut vad = state.vad.lock().await; + if let Ok(speech_prob) = vad.process_chunk(chunk) { + let is_speech = vad.is_speech(speech_prob); + + // Add to audio buffer and check if we have complete audio + if let Some(complete_audio) = audio_buffer.add_chunk(chunk, is_speech) { + // Release VAD lock before acquiring Whisper lock + drop(vad); + let vad_chunk_time = vad_chunk_start.elapsed(); + stats.vad_processing_duration += vad_chunk_time; + + // Process complete audio through Whisper + let whisper_chunk_start = Instant::now(); + let mut whisper = whisper_processor.lock().await; + if let Ok(transcript) = whisper.transcribe(&complete_audio) { + let whisper_chunk_time = whisper_chunk_start.elapsed(); + stats.whisper_transcription_duration += whisper_chunk_time; + + if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") { + whisper_result = Some(transcript.trim().to_string()); + println!("🎯 Chunk transcribed in {:.2}ms: \"{}\"", whisper_chunk_time.as_secs_f64() * 1000.0, transcript.trim()); + } + } + } + } else { + stats.vad_processing_duration += vad_chunk_start.elapsed(); + } + + // Create event with actual transcription or progress update + #[allow(clippy::option_if_let_else)] + let event_data = if let Some(transcript) = whisper_result { + #[allow(clippy::cast_precision_loss)] + StreamChunk { text: transcript, timestamp: Some(processed as f64 / f64::from(sample_rate)) } + } else { + StreamChunk { + #[allow(clippy::cast_precision_loss)] + text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0), + #[allow(clippy::cast_precision_loss)] + timestamp: Some(processed as f64 / f64::from(sample_rate)), + } + }; + + let event = Event::default().json_data(event_data).unwrap(); + + Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer, stats, stream_start))) + })) +} diff --git a/apps/silero-vad-whisper-realtime-api/src/vad.rs b/apps/silero-vad-whisper-realtime-api/src/vad.rs new file mode 100644 index 0000000..01a1d5b --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/vad.rs @@ -0,0 +1,85 @@ +use std::collections::HashMap; + +use anyhow::Result; +use candle_core::{DType, Device, Tensor}; +use candle_onnx::simple_eval; + +pub struct VADProcessor { + model: candle_onnx::onnx::ModelProto, + frame_size: usize, + context_size: usize, + sample_rate: Tensor, + state: Tensor, + context: Tensor, + device: Device, + threshold: f32, +} + +impl VADProcessor { + pub fn new( + device: Device, + threshold: f32, + ) -> Result { + let api = hf_hub::api::sync::Api::new()?; + let model_path = api + .model("onnx-community/silero-vad".into()) + .get("onnx/model.onnx")?; + + let model = candle_onnx::read_file(model_path)?; + + let sample_rate_value = 16000i64; + let (frame_size, context_size) = (512, 64); + + Ok(Self { + model, + frame_size, + context_size, + sample_rate: Tensor::new(sample_rate_value, &device)?, + state: Tensor::zeros((2, 1, 128), DType::F32, &device)?, + context: Tensor::zeros((1, context_size), DType::F32, &device)?, + device, + threshold, + }) + } + + pub fn process_chunk( + &mut self, + chunk: &[f32], + ) -> Result { + if chunk.len() != self.frame_size { + return Ok(0.0); + } + + let next_context = Tensor::from_slice(&chunk[self.frame_size - self.context_size..], (1, self.context_size), &self.device)?; + let chunk_tensor = Tensor::from_vec(chunk.to_vec(), (1, self.frame_size), &self.device)?; + + let input = Tensor::cat(&[&self.context, &chunk_tensor], 1)?; + let inputs: HashMap = HashMap::from_iter([("input".to_string(), input), ("sr".to_string(), self.sample_rate.clone()), ("state".to_string(), self.state.clone())]); + + let outputs = simple_eval(&self.model, inputs)?; + let graph = self.model.graph.as_ref().unwrap(); + let out_names = &graph.output; + + let output = outputs + .get(&out_names[0].name) + .ok_or_else(|| anyhow::anyhow!("Missing VAD output tensor: {}", &out_names[0].name))? + .clone(); + + self.state = outputs + .get(&out_names[1].name) + .ok_or_else(|| anyhow::anyhow!("Missing VAD state tensor: {}", &out_names[1].name))? + .clone(); + + self.context = next_context; + + let speech_prob = output.flatten_all()?.to_vec1::()?[0]; + Ok(speech_prob) + } + + pub fn is_speech( + &self, + prob: f32, + ) -> bool { + prob >= self.threshold + } +} diff --git a/apps/silero-vad-whisper-realtime-api/src/whisper.rs b/apps/silero-vad-whisper-realtime-api/src/whisper.rs new file mode 100644 index 0000000..6957e63 --- /dev/null +++ b/apps/silero-vad-whisper-realtime-api/src/whisper.rs @@ -0,0 +1,206 @@ +use anyhow::Result; +use byteorder::{ByteOrder, LittleEndian}; +use candle_core::{Device, IndexOp, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::whisper::{self as whisper_model, Config, audio}; +use clap::ValueEnum; +use hf_hub::{Repo, RepoType, api::sync::Api}; +use tokenizers::Tokenizer; + +pub enum WhisperModel { + Normal(whisper_model::model::Whisper), +} + +impl WhisperModel { + pub fn encoder_forward( + &mut self, + x: &Tensor, + flush: bool, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.encoder.forward(x, flush), + } + } + + pub fn decoder_forward( + &mut self, + x: &Tensor, + encoder_out: &Tensor, + flush: bool, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.decoder.forward(x, encoder_out, flush), + } + } + + pub fn decoder_final_linear( + &self, + x: &Tensor, + ) -> candle_core::Result { + match self { + Self::Normal(model) => model.decoder.final_linear(x), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +pub enum WhichWhisperModel { + Tiny, + #[value(name = "tiny.en")] + TinyEn, + Base, + #[value(name = "base.en")] + BaseEn, + Small, + #[value(name = "small.en")] + SmallEn, + Medium, + #[value(name = "medium.en")] + MediumEn, + Large, + LargeV2, + LargeV3, + LargeV3Turbo, + #[value(name = "distil-medium.en")] + DistilMediumEn, + #[value(name = "distil-large-v2")] + DistilLargeV2, +} + +impl WhichWhisperModel { + pub const fn model_and_revision(self) -> (&'static str, &'static str) { + match self { + Self::Tiny => ("openai/whisper-tiny", "main"), + Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), + Self::Base => ("openai/whisper-base", "refs/pr/22"), + Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"), + Self::Small => ("openai/whisper-small", "main"), + Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), + Self::Medium => ("openai/whisper-medium", "main"), + Self::MediumEn => ("openai/whisper-medium.en", "main"), + Self::Large => ("openai/whisper-large", "refs/pr/36"), + Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), + Self::LargeV3 => ("openai/whisper-large-v3", "main"), + Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"), + Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), + Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), + } + } +} + +pub struct WhisperProcessor { + pub model: WhisperModel, + pub tokenizer: Tokenizer, + pub config: Config, + pub mel_filters: Vec, + pub device: Device, +} + +impl WhisperProcessor { + pub fn new( + model: WhichWhisperModel, + device: Device, + ) -> Result { + // Load the Whisper model based on the provided model type + let api = Api::new()?; + let (model_id, revision) = model.model_and_revision(); + let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string())); + + let config_filename = repo.get("config.json")?; + let tokenizer_filename = repo.get("tokenizer.json")?; + let model_filename = repo.get("model.safetensors")?; + + let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?; + + let name = model_filename.display(); + println!("Loading Whisper model from: {name:?}"); + + let var_builder = unsafe { VarBuilder::from_mmaped_safetensors(&[model_filename], whisper_model::DTYPE, &device)? }; + + let model = WhisperModel::Normal(whisper_model::model::Whisper::load(&var_builder, config.clone())?); + + let mel_bytes = match config.num_mel_bins { + 80 => include_bytes!("../melfilters.bytes").as_slice(), + 128 => include_bytes!("../melfilters128.bytes").as_slice(), + num_mel_bins => anyhow::bail!("Unsupported number of mel bins: {}", num_mel_bins), + }; + + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + ::read_f32_into(mel_bytes, &mut mel_filters); + + Ok(Self { model, tokenizer, config, mel_filters, device }) + } + + pub fn transcribe( + &mut self, + audio: &[f32], + ) -> Result { + // Convert PCM to mel spectrogram + let mel = audio::pcm_to_mel(&self.config, audio, &self.mel_filters); + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, self.config.num_mel_bins, mel_len / self.config.num_mel_bins), &self.device)?; + + // Run inference + let audio_features = self.model.encoder_forward(&mel, true)?; + // Simple greedy decoding + let tokens = self.decode_greedy(&audio_features)?; + + let text = self + .tokenizer + .decode(&tokens, true) + .map_err(anyhow::Error::msg)?; + + Ok(text) + } + + fn decode_greedy( + &mut self, + audio_features: &Tensor, + ) -> Result> { + let mut tokens = vec![self.token_id(whisper_model::SOT_TOKEN)?, self.token_id(whisper_model::TRANSCRIBE_TOKEN)?, self.token_id(whisper_model::NO_TIMESTAMPS_TOKEN)?]; + + let max_len = 50; // Short sequence for real-time processing + + for i in 0..max_len { + let tokens_t = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + let ys = self + .model + .decoder_forward(&tokens_t, audio_features, i == 0)?; + + let (_, seq_len, _) = ys.dims3()?; + let logits = self + .model + .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)? + .i(0)? + .i(0)?; + + // Get most likely token + let logits_v: Vec = logits.to_vec1()?; + let next_token = logits_v + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.total_cmp(b)) + .map(|(i, _)| u32::try_from(i).unwrap()) + .unwrap(); + + if next_token == self.token_id(whisper_model::EOT_TOKEN)? { + break; + } + + tokens.push(next_token); + } + + Ok(tokens) + } + + fn token_id( + &self, + token: &str, + ) -> Result { + self + .tokenizer + .token_to_id(token) + .ok_or_else(|| anyhow::anyhow!("Token not found: {}", token)) + } +} diff --git a/rustfmt.toml b/rustfmt.toml index 0efea85..8644cdd 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -15,7 +15,7 @@ match_block_trailing_comma = true # overflow_delimited_expr = true # nightly # reorder_impl_items = true # nightly # spaces_around_ranges = true # nightly -# unstable_features = true # nightly +unstable_features = true # nightly use_field_init_shorthand = true use_try_shorthand = true struct_field_align_threshold = 999