diff --git a/Cargo.lock b/Cargo.lock index 9a31e1a..82f0e45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -740,12 +740,95 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -1283,6 +1366,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndarray" version = "0.17.1" @@ -1574,7 +1672,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" dependencies = [ "libloading 0.9.0", - "ndarray", + "ndarray 0.17.1", "ort-sys", "smallvec", "tracing", @@ -1600,7 +1698,7 @@ checksum = "a7667842fd2f3b97b029a30fb9a00138867c6915229f5acd6bd809d08250d2ee" dependencies = [ "eyre", "hound", - "ndarray", + "ndarray 0.17.1", "ort", "rustfft", "serde", @@ -1663,12 +1761,38 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -2121,6 +2245,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "smallvec" version = "1.15.1" @@ -2493,6 +2623,26 @@ dependencies = [ "strength_reduce", ] +[[package]] +name = "typed-builder" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd9d30e3a08026c78f246b173243cf07b3696d274debd26680773b6773c2afc7" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c36781cc0e46a83726d9879608e4cf6c2505237e263a8eb8c24502989cfdb28" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.22" @@ -2622,6 +2772,20 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "voice_activity_detector" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a78926100321e74c3d74d389e875f26f395b71528f8ea8b7b505c7f31063dcf3" +dependencies = [ + "futures", + "ndarray 0.16.1", + "ort", + "pin-project", + "thiserror 2.0.17", + "typed-builder", +] + [[package]] name = "voxtype" version = "0.5.2" @@ -2652,6 +2816,7 @@ dependencies = [ "tracing", "tracing-subscriber", "ureq 2.12.1", + "voice_activity_detector", "which", "whisper-rs", ] diff --git a/Cargo.toml b/Cargo.toml index 62e9cef..022f46b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,9 @@ whisper-rs = "0.15.1" # Parakeet speech-to-text (optional, ONNX-based) parakeet-rs = { version = "0.2.9", optional = true } +# Voice Activity Detection for Parakeet (optional, uses bundled Silero VAD) +voice_activity_detector = { version = "0.2", optional = true } + # CPU count for thread detection num_cpus = "1.16" @@ -86,7 +89,7 @@ gpu-cuda = ["whisper-rs/cuda"] gpu-metal = ["whisper-rs/metal"] gpu-hipblas = ["whisper-rs/hipblas"] # Parakeet backend (ONNX-based, alternative to Whisper) -parakeet = ["dep:parakeet-rs"] +parakeet = ["dep:parakeet-rs", "dep:voice_activity_detector"] parakeet-cuda = ["parakeet", "parakeet-rs/cuda"] parakeet-tensorrt = ["parakeet", "parakeet-rs/tensorrt"] parakeet-rocm = ["parakeet", "parakeet-rs/rocm"] diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 6788919..3392cc3 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -1590,6 +1590,103 @@ If Whisper transcribes "vox type" (or "Vox Type"), it will be replaced with "vox --- +## [vad] + +Voice Activity Detection settings. VAD filters silence-only recordings before transcription, preventing Whisper hallucinations when processing silence. + +### enabled + +**Type:** Boolean +**Default:** `false` +**Required:** No + +Enable Voice Activity Detection. When enabled, recordings with no detected speech are rejected before transcription, and the "Cancelled" audio feedback is played. + +**Example:** +```toml +[vad] +enabled = true +``` + +**CLI override:** +```bash +voxtype --vad daemon +``` + +### threshold + +**Type:** Float +**Default:** `0.5` +**Required:** No + +Speech detection threshold from 0.0 to 1.0. Higher values require more confident speech detection (stricter), lower values are more permissive. + +**Example:** +```toml +[vad] +enabled = true +threshold = 0.6 # More strict, may reject soft speech +``` + +**CLI override:** +```bash +voxtype --vad --vad-threshold 0.3 daemon # More permissive +``` + +### min_speech_duration_ms + +**Type:** Integer +**Default:** `100` +**Required:** No + +Minimum speech duration in milliseconds. Recordings with less detected speech than this are rejected. Helps filter out brief noise spikes. + +**Example:** +```toml +[vad] +enabled = true +min_speech_duration_ms = 200 # Require at least 200ms of speech +``` + +### model + +**Type:** String (path) +**Default:** Auto-detected based on engine +**Required:** No + +Path to a custom VAD model file. If not set, uses the default model location (`~/.local/share/voxtype/models/`). + +- **Whisper engine:** Uses `ggml-silero-vad.bin` (GGML format) +- **Parakeet engine:** Uses bundled Silero model (no external file needed) + +**Example:** +```toml +[vad] +enabled = true +model = "/custom/path/to/vad-model.bin" +``` + +### Setup + +Download the VAD model before enabling: + +```bash +voxtype setup vad +``` + +This downloads the appropriate model for your configured transcription engine. + +### Example Configuration + +```toml +[vad] +enabled = true +threshold = 0.5 +min_speech_duration_ms = 100 +``` + +--- + ## [status] Controls status display icons for Waybar and other tray integrations. @@ -1770,6 +1867,8 @@ Most configuration options can be overridden via command line: | whisper.model | `--model` | | output.mode = "clipboard" | `--clipboard` | | output.mode = "paste" | `--paste` | +| vad.enabled | `--vad` | +| vad.threshold | `--vad-threshold` | | status.icon_theme | `--icon-theme` (status subcommand) | | Verbosity | `-v`, `-vv`, `-q` | diff --git a/docs/USER_MANUAL.md b/docs/USER_MANUAL.md index ce7310a..d7614a9 100644 --- a/docs/USER_MANUAL.md +++ b/docs/USER_MANUAL.md @@ -14,6 +14,7 @@ Voxtype is a push-to-talk voice-to-text tool for Linux. Optimized for Wayland, w - [Transcription Engines](#transcription-engines) - [Multi-Model Support](#multi-model-support) - [Improving Transcription Accuracy](#improving-transcription-accuracy) +- [Voice Activity Detection](#voice-activity-detection) - [Whisper Models](#whisper-models) - [Remote Whisper Servers](#remote-whisper-servers) - [CLI Backend (whisper-cli)](#cli-backend-whisper-cli) @@ -694,6 +695,82 @@ voxtype --initial-prompt "Discussion about Terraform and AWS Lambda" daemon --- +## Voice Activity Detection + +Voice Activity Detection (VAD) filters silence-only recordings before sending them to Whisper. This prevents Whisper "hallucinations" where it generates text from silence, such as "(music playing)" or random phrases. + +### When to Use VAD + +Enable VAD if you experience: + +- **Hallucinations from silence**: Whisper generating text when you didn't speak +- **Accidental recordings**: You press the hotkey but don't say anything +- **Noisy environments**: Where brief non-speech sounds trigger recordings + +VAD is disabled by default to preserve existing behavior. It's opt-in. + +### Setup + +First, download the VAD model: + +```bash +voxtype setup vad +``` + +Then enable VAD in your config: + +```toml +[vad] +enabled = true +``` + +Or use the CLI flag for a single session: + +```bash +voxtype --vad daemon +``` + +### How It Works + +After you release the push-to-talk key: + +1. Voxtype captures the audio +2. VAD analyzes the audio for speech content +3. If speech is detected, proceeds to transcription +4. If no speech is detected, plays the "Cancelled" audio feedback and returns to idle + +### Configuration Options + +```toml +[vad] +enabled = true +threshold = 0.5 # 0.0-1.0, higher = stricter +min_speech_duration_ms = 100 # Minimum speech required +``` + +**threshold**: How confident VAD must be that speech is present. Lower values (0.3) are more permissive, higher values (0.7) are stricter but may reject quiet speech. + +**min_speech_duration_ms**: Minimum duration of detected speech required. Helps filter out brief noise spikes. + +### CLI Overrides + +```bash +# Enable VAD for this session +voxtype --vad daemon + +# Enable with custom threshold +voxtype --vad --vad-threshold 0.3 daemon +``` + +### Backends + +VAD uses different backends depending on your transcription engine: + +- **Whisper engine**: Uses whisper-rs built-in VAD (GGML format, requires model download) +- **Parakeet engine**: Uses bundled Silero VAD (no separate download needed) + +--- + ## Whisper Models ### Model Comparison diff --git a/src/cli.rs b/src/cli.rs index 23e6c63..70548a9 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -94,6 +94,14 @@ pub struct Cli { #[arg(long, value_name = "DRIVERS")] pub driver: Option, + /// Enable Voice Activity Detection to filter silence-only recordings + #[arg(long)] + pub vad: bool, + + /// VAD speech detection threshold (0.0-1.0, requires --vad) + #[arg(long, value_name = "THRESHOLD")] + pub vad_threshold: Option, + #[command(subcommand)] pub command: Option, } @@ -435,6 +443,13 @@ pub enum SetupAction { status: bool, }, + /// Download Voice Activity Detection model + Vad { + /// Force re-download even if model exists + #[arg(long)] + force: bool, + }, + /// Compositor integration (fixes modifier key interference) Compositor { #[command(subcommand)] @@ -797,7 +812,10 @@ mod tests { let cli = Cli::parse_from(["voxtype", "record", "start", "--file=out.txt"]); match cli.command { Some(Commands::Record { action }) => { - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::File)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::File) + ); assert_eq!(action.file_path(), Some("out.txt")); } _ => panic!("Expected Record command"), @@ -821,7 +839,10 @@ mod tests { let cli = Cli::parse_from(["voxtype", "record", "start", "--file"]); match cli.command { Some(Commands::Record { action }) => { - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::File)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::File) + ); assert_eq!(action.file_path(), Some("")); // Empty string means use config path } _ => panic!("Expected Record command"), @@ -830,11 +851,21 @@ mod tests { #[test] fn test_record_start_model_and_output_override() { - let cli = Cli::parse_from(["voxtype", "record", "start", "--model", "large-v3-turbo", "--clipboard"]); + let cli = Cli::parse_from([ + "voxtype", + "record", + "start", + "--model", + "large-v3-turbo", + "--clipboard", + ]); match cli.command { Some(Commands::Record { action }) => { assert_eq!(action.model_override(), Some("large-v3-turbo")); - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::Clipboard)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::Clipboard) + ); } _ => panic!("Expected Record command"), } @@ -845,7 +876,10 @@ mod tests { let cli = Cli::parse_from(["voxtype", "record", "start", "--file=/tmp/output.txt"]); match cli.command { Some(Commands::Record { action }) => { - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::File)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::File) + ); assert_eq!(action.file_path(), Some("/tmp/output.txt")); } _ => panic!("Expected Record command"), @@ -868,7 +902,10 @@ mod tests { let cli = Cli::parse_from(["voxtype", "record", "toggle", "--file=out.txt"]); match cli.command { Some(Commands::Record { action }) => { - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::File)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::File) + ); assert_eq!(action.file_path(), Some("out.txt")); } _ => panic!("Expected Record command"), @@ -880,7 +917,10 @@ mod tests { let cli = Cli::parse_from(["voxtype", "record", "toggle", "--file"]); match cli.command { Some(Commands::Record { action }) => { - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::File)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::File) + ); assert_eq!(action.file_path(), Some("")); // Empty string means use config path } _ => panic!("Expected Record command"), @@ -904,17 +944,9 @@ mod tests { #[test] fn test_record_start_file_mutually_exclusive_with_paste() { - let result = Cli::try_parse_from([ - "voxtype", - "record", - "start", - "--file=out.txt", - "--paste", - ]); - assert!( - result.is_err(), - "Should not allow both --file and --paste" - ); + let result = + Cli::try_parse_from(["voxtype", "record", "start", "--file=out.txt", "--paste"]); + assert!(result.is_err(), "Should not allow both --file and --paste"); } #[test] @@ -934,17 +966,9 @@ mod tests { #[test] fn test_record_start_file_mutually_exclusive_with_type() { - let result = Cli::try_parse_from([ - "voxtype", - "record", - "start", - "--file=out.txt", - "--type", - ]); - assert!( - result.is_err(), - "Should not allow both --file and --type" - ); + let result = + Cli::try_parse_from(["voxtype", "record", "start", "--file=out.txt", "--type"]); + assert!(result.is_err(), "Should not allow both --file and --type"); } #[test] @@ -1064,11 +1088,21 @@ mod tests { #[test] fn test_record_start_profile_with_output_mode() { // Profile can be used together with output mode overrides - let cli = Cli::parse_from(["voxtype", "record", "start", "--profile", "slack", "--clipboard"]); + let cli = Cli::parse_from([ + "voxtype", + "record", + "start", + "--profile", + "slack", + "--clipboard", + ]); match cli.command { Some(Commands::Record { action }) => { assert_eq!(action.profile(), Some("slack")); - assert_eq!(action.output_mode_override(), Some(OutputModeOverride::Clipboard)); + assert_eq!( + action.output_mode_override(), + Some(OutputModeOverride::Clipboard) + ); } _ => panic!("Expected Record command"), } @@ -1083,7 +1117,12 @@ mod tests { let cli = Cli::parse_from(["voxtype", "setup", "dms", "--install"]); match cli.command { Some(Commands::Setup { - action: Some(SetupAction::Dms { install, uninstall, qml }), + action: + Some(SetupAction::Dms { + install, + uninstall, + qml, + }), .. }) => { assert!(install, "should have install=true"); @@ -1099,7 +1138,12 @@ mod tests { let cli = Cli::parse_from(["voxtype", "setup", "dms", "--uninstall"]); match cli.command { Some(Commands::Setup { - action: Some(SetupAction::Dms { install, uninstall, qml }), + action: + Some(SetupAction::Dms { + install, + uninstall, + qml, + }), .. }) => { assert!(!install, "should have install=false"); @@ -1115,7 +1159,12 @@ mod tests { let cli = Cli::parse_from(["voxtype", "setup", "dms", "--qml"]); match cli.command { Some(Commands::Setup { - action: Some(SetupAction::Dms { install, uninstall, qml }), + action: + Some(SetupAction::Dms { + install, + uninstall, + qml, + }), .. }) => { assert!(!install, "should have install=false"); @@ -1131,7 +1180,12 @@ mod tests { let cli = Cli::parse_from(["voxtype", "setup", "dms"]); match cli.command { Some(Commands::Setup { - action: Some(SetupAction::Dms { install, uninstall, qml }), + action: + Some(SetupAction::Dms { + install, + uninstall, + qml, + }), .. }) => { assert!(!install, "should have install=false"); diff --git a/src/config.rs b/src/config.rs index edbe765..6597548 100644 --- a/src/config.rs +++ b/src/config.rs @@ -205,6 +205,21 @@ on_transcription = true # Custom word replacements (case-insensitive) # replacements = { "vox type" = "voxtype" } +# [vad] +# Voice Activity Detection - filter silence-only recordings before transcription +# Prevents Whisper hallucinations when processing silence +# +# Enable VAD (default: false) +# enabled = false +# +# Speech detection threshold (0.0-1.0, default: 0.5) +# Higher values require more confident speech detection +# threshold = 0.5 +# +# Minimum speech duration in milliseconds (default: 100) +# Recordings with less speech than this are rejected +# min_speech_duration_ms = 100 + # [status] # Status display icons for Waybar/tray integrations # @@ -274,6 +289,11 @@ pub struct Config { #[serde(default)] pub text: TextConfig, + /// Voice Activity Detection configuration + /// When enabled, filters silence-only recordings before transcription + #[serde(default)] + pub vad: VadConfig, + /// Status display configuration (icons for Waybar/tray integrations) #[serde(default)] pub status: StatusConfig, @@ -716,7 +736,6 @@ pub struct WhisperConfig { pub initial_prompt: Option, // --- Multi-model settings --- - /// Secondary model to use when hotkey.model_modifier is held /// Example: "large-v3-turbo" for difficult audio #[serde(default)] @@ -773,9 +792,7 @@ impl WhisperConfig { } // Fall back to deprecated `backend` with warning if let Some(backend) = self.backend { - tracing::warn!( - "DEPRECATED: [whisper] backend is deprecated, use 'mode' instead" - ); + tracing::warn!("DEPRECATED: [whisper] backend is deprecated, use 'mode' instead"); tracing::warn!( " Change 'backend = \"{}\"' to 'mode = \"{}\"' in config.toml", match backend { @@ -873,6 +890,53 @@ pub enum TranscriptionEngine { Parakeet, } +/// Voice Activity Detection configuration +/// +/// VAD filters silence-only recordings before transcription to prevent +/// Whisper hallucinations when processing silence. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct VadConfig { + /// Enable Voice Activity Detection (default: false) + /// When enabled, recordings with no detected speech are rejected before transcription + #[serde(default)] + pub enabled: bool, + + /// Speech detection threshold (0.0-1.0, default: 0.5) + /// Higher values require more confident speech detection + #[serde(default = "default_vad_threshold")] + pub threshold: f32, + + /// Minimum speech duration in milliseconds (default: 100) + /// Recordings with less speech than this are rejected + #[serde(default = "default_min_speech_duration_ms")] + pub min_speech_duration_ms: u32, + + /// Path to VAD model file (optional) + /// If not set, uses the default model location (~/.local/share/voxtype/models/) + /// Auto-selects appropriate model based on transcription engine + #[serde(default)] + pub model: Option, +} + +fn default_vad_threshold() -> f32 { + 0.5 +} + +fn default_min_speech_duration_ms() -> u32 { + 100 +} + +impl Default for VadConfig { + fn default() -> Self { + Self { + enabled: false, + threshold: default_vad_threshold(), + min_speech_duration_ms: default_min_speech_duration_ms(), + model: None, + } + } +} + /// Text processing configuration #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct TextConfig { @@ -1217,6 +1281,7 @@ impl Default for Config { engine: TranscriptionEngine::default(), parakeet: None, text: TextConfig::default(), + vad: VadConfig::default(), status: StatusConfig::default(), state_file: Some("auto".to_string()), profiles: HashMap::new(), @@ -1974,7 +2039,10 @@ mod tests { let config: Config = toml::from_str(toml_str).unwrap(); assert_eq!(config.engine, TranscriptionEngine::Parakeet); assert!(config.parakeet.is_some()); - assert_eq!(config.parakeet.as_ref().unwrap().model, "parakeet-tdt-0.6b-v3"); + assert_eq!( + config.parakeet.as_ref().unwrap().model, + "parakeet-tdt-0.6b-v3" + ); } #[test] @@ -2002,15 +2070,39 @@ mod tests { #[test] fn test_output_driver_from_str() { - assert_eq!("wtype".parse::().unwrap(), OutputDriver::Wtype); - assert_eq!("dotool".parse::().unwrap(), OutputDriver::Dotool); - assert_eq!("ydotool".parse::().unwrap(), OutputDriver::Ydotool); - assert_eq!("clipboard".parse::().unwrap(), OutputDriver::Clipboard); - assert_eq!("xclip".parse::().unwrap(), OutputDriver::Xclip); + assert_eq!( + "wtype".parse::().unwrap(), + OutputDriver::Wtype + ); + assert_eq!( + "dotool".parse::().unwrap(), + OutputDriver::Dotool + ); + assert_eq!( + "ydotool".parse::().unwrap(), + OutputDriver::Ydotool + ); + assert_eq!( + "clipboard".parse::().unwrap(), + OutputDriver::Clipboard + ); + assert_eq!( + "xclip".parse::().unwrap(), + OutputDriver::Xclip + ); // Case insensitive - assert_eq!("WTYPE".parse::().unwrap(), OutputDriver::Wtype); - assert_eq!("Ydotool".parse::().unwrap(), OutputDriver::Ydotool); - assert_eq!("XCLIP".parse::().unwrap(), OutputDriver::Xclip); + assert_eq!( + "WTYPE".parse::().unwrap(), + OutputDriver::Wtype + ); + assert_eq!( + "Ydotool".parse::().unwrap(), + OutputDriver::Ydotool + ); + assert_eq!( + "XCLIP".parse::().unwrap(), + OutputDriver::Xclip + ); // Invalid assert!("invalid".parse::().is_err()); } @@ -2380,11 +2472,17 @@ mod tests { assert_eq!(config.profiles.len(), 2); let slack = config.get_profile("slack").unwrap(); - assert_eq!(slack.post_process_command, Some("cleanup-for-slack.sh".to_string())); + assert_eq!( + slack.post_process_command, + Some("cleanup-for-slack.sh".to_string()) + ); assert!(slack.output_mode.is_none()); let code = config.get_profile("code").unwrap(); - assert_eq!(code.post_process_command, Some("cleanup-for-code.sh".to_string())); + assert_eq!( + code.post_process_command, + Some("cleanup-for-code.sh".to_string()) + ); assert_eq!(code.output_mode, Some(OutputMode::Clipboard)); } @@ -2413,7 +2511,10 @@ mod tests { let config: Config = toml::from_str(toml_str).unwrap(); let slow = config.get_profile("slow").unwrap(); - assert_eq!(slow.post_process_command, Some("slow-llm-command".to_string())); + assert_eq!( + slow.post_process_command, + Some("slow-llm-command".to_string()) + ); assert_eq!(slow.post_process_timeout_ms, Some(60000)); } diff --git a/src/daemon.rs b/src/daemon.rs index 2eb4087..42fa3e2 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -23,7 +23,12 @@ use tokio::process::Command; use tokio::signal::unix::{signal, SignalKind}; /// Send a desktop notification with optional engine icon -async fn send_notification(title: &str, body: &str, show_engine_icon: bool, engine: crate::config::TranscriptionEngine) { +async fn send_notification( + title: &str, + body: &str, + show_engine_icon: bool, + engine: crate::config::TranscriptionEngine, +) { let title = if show_engine_icon { format!("{} {}", crate::output::engine_icon(engine), title) } else { @@ -31,12 +36,7 @@ async fn send_notification(title: &str, body: &str, show_engine_icon: bool, engi }; let _ = Command::new("notify-send") - .args([ - "--app-name=Voxtype", - "--expire-time=2000", - &title, - body, - ]) + .args(["--app-name=Voxtype", "--expire-time=2000", &title, body]) .stdout(Stdio::null()) .stderr(Stdio::null()) .status() @@ -316,10 +316,16 @@ pub struct Daemon { audio_feedback: Option, text_processor: TextProcessor, post_processor: Option, + // Voice Activity Detection (filters silence-only recordings) + vad: Option>, // Model manager for multi-model support model_manager: Option, // Background task for loading model on-demand - model_load_task: Option, crate::error::TranscribeError>>>, + model_load_task: Option< + tokio::task::JoinHandle< + std::result::Result, crate::error::TranscribeError>, + >, + >, // Background task for transcription (allows cancel during transcription) transcription_task: Option>, } @@ -371,6 +377,27 @@ impl Daemon { PostProcessor::new(cfg) }); + // Initialize VAD if enabled + let vad = if config.vad.enabled { + match crate::vad::create_vad(&config) { + Ok(Some(v)) => { + tracing::info!( + "Voice Activity Detection enabled (threshold: {:.2}, min_speech: {}ms)", + config.vad.threshold, + config.vad.min_speech_duration_ms + ); + Some(v) + } + Ok(None) => None, + Err(e) => { + tracing::warn!("Failed to initialize VAD: {}", e); + None + } + } + } else { + None + }; + Self { config, config_path, @@ -379,6 +406,7 @@ impl Daemon { audio_feedback, text_processor, post_processor, + vad, model_manager: None, model_load_task: None, transcription_task: None, @@ -499,7 +527,13 @@ impl Daemon { // Send notification if enabled if self.config.output.notification.on_recording_stop { - send_notification("Recording Stopped", "Transcribing...", self.config.output.notification.show_engine_icon, self.config.engine).await; + send_notification( + "Recording Stopped", + "Transcribing...", + self.config.output.notification.show_engine_icon, + self.config.engine, + ) + .await; } // Stop recording and get samples @@ -515,6 +549,32 @@ impl Daemon { return false; } + // Voice Activity Detection: skip if no speech detected + if let Some(ref vad) = self.vad { + match vad.detect(&samples) { + Ok(result) if !result.has_speech => { + tracing::debug!( + "No speech detected ({:.1}% speech, {:.2}s), skipping transcription", + result.speech_ratio * 100.0, + result.speech_duration_secs + ); + self.play_feedback(SoundEvent::Cancelled); + self.reset_to_idle(state).await; + return false; + } + Ok(result) => { + tracing::debug!( + "Speech detected: {:.1}% speech ({:.2}s)", + result.speech_ratio * 100.0, + result.speech_duration_secs + ); + } + Err(e) => { + tracing::warn!("VAD failed, proceeding anyway: {}", e); + } + } + } + tracing::info!("Transcribing {:.1}s of audio...", audio_duration); *state = State::Transcribing { audio: samples.clone(), @@ -583,9 +643,7 @@ impl Daemon { // Apply post-processing command (profile overrides default) let final_text = if let Some(profile) = active_profile { if let Some(ref cmd) = profile.post_process_command { - let timeout_ms = profile - .post_process_timeout_ms - .unwrap_or(30000); + let timeout_ms = profile.post_process_timeout_ms.unwrap_or(30000); let profile_config = crate::config::PostProcessConfig { command: cmd.clone(), timeout_ms, @@ -652,18 +710,15 @@ impl Daemon { }; let file_mode = &self.config.output.file_mode; - match write_transcription_to_file(&output_path, &final_text, file_mode).await + match write_transcription_to_file(&output_path, &final_text, file_mode) + .await { Ok(()) => { let mode_str = match file_mode { FileMode::Overwrite => "wrote", FileMode::Append => "appended", }; - tracing::info!( - "{} transcription to {:?}", - mode_str, - output_path - ); + tracing::info!("{} transcription to {:?}", mode_str, output_path); } Err(e) => { tracing::error!( @@ -720,7 +775,8 @@ impl Daemon { &final_text, self.config.output.notification.show_engine_icon, self.config.engine, - ).await; + ) + .await; } *state = State::Idle; @@ -801,7 +857,10 @@ impl Daemon { let mut hotkey_listener = if self.config.hotkey.enabled { tracing::info!("Hotkey: {}", self.config.hotkey.key); let secondary_model = self.config.whisper.secondary_model.clone(); - Some(hotkey::create_listener(&self.config.hotkey, secondary_model)?) + Some(hotkey::create_listener( + &self.config.hotkey, + secondary_model, + )?) } else { tracing::info!( "Built-in hotkey disabled, use 'voxtype record' commands or compositor keybindings" @@ -838,7 +897,9 @@ impl Daemon { } crate::config::TranscriptionEngine::Parakeet => { // Parakeet uses its own model loading - transcriber_preloaded = Some(Arc::from(crate::transcribe::create_transcriber(&self.config)?)); + transcriber_preloaded = Some(Arc::from(crate::transcribe::create_transcriber( + &self.config, + )?)); } } tracing::info!("Model loaded, ready for voice input"); @@ -1635,7 +1696,7 @@ mod tests { // Should not panic }); } - + fn test_pidlock_acquisition_succeeds() { with_test_runtime_dir(|dir| { let lock_path = dir.join("voxtype.lock"); diff --git a/src/error.rs b/src/error.rs index 75be313..aa3f311 100644 --- a/src/error.rs +++ b/src/error.rs @@ -93,6 +93,19 @@ pub enum TranscribeError { RemoteError(String), } +/// Errors related to Voice Activity Detection +#[derive(Error, Debug)] +pub enum VadError { + #[error("VAD model not found: {0}\n Run 'voxtype setup vad' to download.")] + ModelNotFound(String), + + #[error("VAD initialization failed: {0}")] + InitFailed(String), + + #[error("VAD detection failed: {0}")] + DetectionFailed(String), +} + /// Errors related to text output #[derive(Error, Debug)] pub enum OutputError { @@ -120,7 +133,9 @@ pub enum OutputError { #[error("Ctrl+V simulation failed: {0}")] CtrlVFailed(String), - #[error("All output methods failed. Ensure wtype, dotool, ydotool, wl-copy, or xclip is available.")] + #[error( + "All output methods failed. Ensure wtype, dotool, ydotool, wl-copy, or xclip is available." + )] AllMethodsFailed, } diff --git a/src/lib.rs b/src/lib.rs index 474f31f..6c4536b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,6 +81,7 @@ pub mod setup; pub mod state; pub mod text; pub mod transcribe; +pub mod vad; pub use cli::{Cli, Commands, CompositorType, OutputModeOverride, RecordAction, SetupAction}; pub use config::Config; diff --git a/src/main.rs b/src/main.rs index 8bc595c..8ca3815 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,7 +100,10 @@ async fn main() -> anyhow::Result<()> { "whisper" => config.engine = config::TranscriptionEngine::Whisper, "parakeet" => config.engine = config::TranscriptionEngine::Parakeet, _ => { - eprintln!("Error: Invalid engine '{}'. Valid options: whisper, parakeet", engine); + eprintln!( + "Error: Invalid engine '{}'. Valid options: whisper, parakeet", + engine + ); std::process::exit(1); } } @@ -135,6 +138,12 @@ async fn main() -> anyhow::Result<()> { } } } + if cli.vad { + config.vad.enabled = true; + } + if let Some(threshold) = cli.vad_threshold { + config.vad.threshold = threshold.clamp(0.0, 1.0); + } // Run the appropriate command match cli.command.unwrap_or(Commands::Daemon) { @@ -251,7 +260,11 @@ async fn main() -> anyhow::Result<()> { setup::gpu::show_status(); } } - Some(SetupAction::Parakeet { enable, disable, status }) => { + Some(SetupAction::Parakeet { + enable, + disable, + status, + }) => { if status { setup::parakeet::show_status(); } else if enable { @@ -266,6 +279,9 @@ async fn main() -> anyhow::Result<()> { Some(SetupAction::Compositor { compositor_type }) => { setup::compositor::run(&compositor_type).await?; } + Some(SetupAction::Vad { force }) => { + setup::vad::download_vad_model(&config, force).await?; + } None => { // Default: run setup (non-blocking) setup::run_setup(&config, download, model.as_deref(), quiet, no_post_install) @@ -377,7 +393,14 @@ fn send_record_command(config: &config::Config, action: RecordAction) -> anyhow: } else { eprintln!("Error: Profile '{}' not found.", profile_name); eprintln!(); - eprintln!("Available profiles: {}", available.iter().map(|s| s.as_str()).collect::>().join(", ")); + eprintln!( + "Available profiles: {}", + available + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") + ); } std::process::exit(1); } @@ -781,7 +804,10 @@ async fn show_config(config: &config::Config) -> anyhow::Result<()> { if let Some(ref model_type) = parakeet_config.model_type { println!(" model_type = {:?}", model_type); } - println!(" on_demand_loading = {}", parakeet_config.on_demand_loading); + println!( + " on_demand_loading = {}", + parakeet_config.on_demand_loading + ); } else { println!(" (not configured)"); } diff --git a/src/model_manager.rs b/src/model_manager.rs index 31d8469..4446e8b 100644 --- a/src/model_manager.rs +++ b/src/model_manager.rs @@ -6,7 +6,7 @@ //! - Fresh subprocess per model (when gpu_isolation = true) //! - Remote backend model selection -use crate::config::{WhisperMode, WhisperConfig}; +use crate::config::{WhisperConfig, WhisperMode}; use crate::error::TranscribeError; use crate::transcribe::{self, Transcriber}; use std::collections::HashMap; @@ -115,10 +115,7 @@ impl ModelManager { } /// Create a CLI transcriber with model override - fn create_cli_transcriber( - &self, - model: &str, - ) -> Result, TranscribeError> { + fn create_cli_transcriber(&self, model: &str) -> Result, TranscribeError> { let mut config = self.config.clone(); config.model = model.to_string(); tracing::info!("Using whisper-cli subprocess backend"); diff --git a/src/output/mod.rs b/src/output/mod.rs index e1a4402..5bb10fc 100644 --- a/src/output/mod.rs +++ b/src/output/mod.rs @@ -22,8 +22,8 @@ pub mod ydotool; use crate::config::{OutputConfig, OutputDriver}; use crate::error::OutputError; use std::borrow::Cow; -use std::process::Stdio; use std::fs; +use std::process::Stdio; use tokio::process::Command; /// Normalize Unicode curly quotes to ASCII equivalents. @@ -44,7 +44,7 @@ fn normalize_quotes(text: &str) -> Cow<'_, str> { | '\u{201C}' // LEFT DOUBLE QUOTATION MARK | '\u{201D}' // RIGHT DOUBLE QUOTATION MARK | '\u{201F}' // DOUBLE HIGH-REVERSED-9 QUOTATION MARK - | '\u{2033}' // DOUBLE PRIME + | '\u{2033}' // DOUBLE PRIME ) }); @@ -98,7 +98,11 @@ pub fn engine_icon(engine: crate::config::TranscriptionEngine) -> &'static str { } /// Send a transcription notification with optional engine icon -pub async fn send_transcription_notification(text: &str, show_engine_icon: bool, engine: crate::config::TranscriptionEngine) { +pub async fn send_transcription_notification( + text: &str, + show_engine_icon: bool, + engine: crate::config::TranscriptionEngine, +) { // Truncate preview for notification (use chars() to handle multi-byte UTF-8) let preview = if text.chars().count() > 80 { format!("{}...", text.chars().take(80).collect::()) diff --git a/src/setup/dms.rs b/src/setup/dms.rs index 9aa3ace..10fb4e4 100644 --- a/src/setup/dms.rs +++ b/src/setup/dms.rs @@ -325,7 +325,10 @@ pub fn print_config() { println!(" with the following content:\n"); // Print the simplified QML with the actual path substituted - println!("{}", QML_SIMPLE_TEMPLATE.replace("VOXTYPE_PATH", &voxtype_path)); + println!( + "{}", + QML_SIMPLE_TEMPLATE.replace("VOXTYPE_PATH", &voxtype_path) + ); println!("\n\n3. Enable the widget in DankMaterialShell settings.\n"); diff --git a/src/setup/gpu.rs b/src/setup/gpu.rs index d0499d6..4aecb40 100644 --- a/src/setup/gpu.rs +++ b/src/setup/gpu.rs @@ -114,7 +114,12 @@ impl GpuVendor { /// Parse vendor from GPU name string fn from_name(name: &str) -> Self { let lower = name.to_lowercase(); - if lower.contains("nvidia") || lower.contains("geforce") || lower.contains("quadro") || lower.contains("rtx") || lower.contains("gtx") { + if lower.contains("nvidia") + || lower.contains("geforce") + || lower.contains("quadro") + || lower.contains("rtx") + || lower.contains("gtx") + { GpuVendor::Nvidia } else if lower.contains("amd") || lower.contains("radeon") || lower.contains("rx ") { GpuVendor::Amd @@ -286,14 +291,14 @@ pub fn detect_gpu() -> Option { /// Parse VOXTYPE_VULKAN_DEVICE environment variable and return the appropriate vendor pub fn get_selected_gpu_vendor() -> Option { - std::env::var("VOXTYPE_VULKAN_DEVICE").ok().and_then(|val| { - match val.to_lowercase().as_str() { + std::env::var("VOXTYPE_VULKAN_DEVICE") + .ok() + .and_then(|val| match val.to_lowercase().as_str() { "nvidia" | "nv" => Some(GpuVendor::Nvidia), "amd" | "radeon" => Some(GpuVendor::Amd), "intel" => Some(GpuVendor::Intel), _ => None, - } - }) + }) } /// Apply GPU selection environment variables based on VOXTYPE_VULKAN_DEVICE @@ -600,7 +605,10 @@ pub fn show_status() { if gpus.len() > 1 { println!(); if let Some(selected) = get_selected_gpu_vendor() { - println!("GPU selection: {} (via VOXTYPE_VULKAN_DEVICE)", selected.display_name()); + println!( + "GPU selection: {} (via VOXTYPE_VULKAN_DEVICE)", + selected.display_name() + ); } else { println!("GPU selection: auto (first available)"); println!(); @@ -708,7 +716,10 @@ pub fn enable() -> anyhow::Result<()> { // Regenerate systemd service if it exists if super::systemd::regenerate_service_file()? { - println!("Updated systemd service to use Parakeet {} backend.", backend_name); + println!( + "Updated systemd service to use Parakeet {} backend.", + backend_name + ); } println!("Switched to Parakeet ({}) backend.", backend_name); @@ -766,7 +777,10 @@ pub fn disable() -> anyhow::Result<()> { let best_backend = detect_best_parakeet_cpu_backend(); if let Some(backend_name) = best_backend { switch_backend_tiered_parakeet(backend_name)?; - println!("Switched to Parakeet ({}) backend.", backend_name.trim_start_matches("voxtype-parakeet-")); + println!( + "Switched to Parakeet ({}) backend.", + backend_name.trim_start_matches("voxtype-parakeet-") + ); } else { anyhow::bail!( "No Parakeet CPU backend found.\n\ diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 4c21c04..1137049 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -15,6 +15,7 @@ pub mod gpu; pub mod model; pub mod parakeet; pub mod systemd; +pub mod vad; pub mod waybar; use crate::config::Config; @@ -460,7 +461,10 @@ pub async fn run_setup( // Check if parakeet feature is enabled #[cfg(not(feature = "parakeet"))] { - print_failure(&format!("Parakeet model '{}' requires the 'parakeet' feature", model_name)); + print_failure(&format!( + "Parakeet model '{}' requires the 'parakeet' feature", + model_name + )); println!(" Rebuild with: cargo build --features parakeet"); anyhow::bail!("Parakeet feature not enabled"); } @@ -468,8 +472,8 @@ pub async fn run_setup( #[cfg(feature = "parakeet")] { let model_path = models_dir.join(model_name); - let model_valid = model_path.exists() - && model::validate_parakeet_model(&model_path).is_ok(); + let model_valid = + model_path.exists() && model::validate_parakeet_model(&model_path).is_ok(); if model_valid { if !quiet { @@ -482,10 +486,7 @@ pub async fn run_setup( .sum::() }) .unwrap_or(0.0); - print_success(&format!( - "Model ready: {} ({:.0} MB)", - model_name, size - )); + print_success(&format!("Model ready: {} ({:.0} MB)", model_name, size)); } // Update config to use Parakeet model::set_parakeet_config(model_name)?; @@ -507,7 +508,10 @@ pub async fn run_setup( } } else if !quiet { print_info(&format!("Model '{}' not downloaded yet", model_name)); - println!(" Run: voxtype setup --download --model {}", model_name); + println!( + " Run: voxtype setup --download --model {}", + model_name + ); } } } else { @@ -543,10 +547,7 @@ pub async fn run_setup( let size = std::fs::metadata(&model_path) .map(|m| m.len() as f64 / 1024.0 / 1024.0) .unwrap_or(0.0); - print_success(&format!( - "Model ready: {} ({:.0} MB)", - model_name, size - )); + print_success(&format!("Model ready: {} ({:.0} MB)", model_name, size)); } // If user explicitly requested this model, update config even if already downloaded if model_override.is_some() { @@ -748,9 +749,14 @@ pub async fn run_checks(config: &Config) -> anyhow::Result<()> { if config.engine == crate::config::TranscriptionEngine::Parakeet { if let Some(ref parakeet_config) = config.parakeet { let configured_model = ¶keet_config.model; - let model_found = parakeet_models.iter().any(|(name, _)| name == configured_model); + let model_found = parakeet_models + .iter() + .any(|(name, _)| name == configured_model); if !model_found { - print_failure(&format!("Configured Parakeet model '{}' not found", configured_model)); + print_failure(&format!( + "Configured Parakeet model '{}' not found", + configured_model + )); println!(" Download the model or change config to use an available model"); all_ok = false; } diff --git a/src/setup/model.rs b/src/setup/model.rs index af58181..0931caa 100644 --- a/src/setup/model.rs +++ b/src/setup/model.rs @@ -156,7 +156,12 @@ pub async fn interactive_select() -> anyhow::Result<()> { let parakeet_available = cfg!(feature = "parakeet"); let whisper_count = MODELS.len(); let parakeet_count = PARAKEET_MODELS.len(); - let total_count = whisper_count + if parakeet_available { parakeet_count } else { 0 }; + let total_count = whisper_count + + if parakeet_available { + parakeet_count + } else { + 0 + }; // --- Whisper Section --- println!("--- Whisper (OpenAI, 99+ languages) ---\n"); @@ -704,7 +709,10 @@ fn download_parakeet_model_by_info(model: &ParakeetModelInfo) -> anyhow::Result< // Validate all files are present validate_parakeet_model(&model_path)?; - print_success(&format!("Model '{}' downloaded to {:?}", model.name, model_path)); + print_success(&format!( + "Model '{}' downloaded to {:?}", + model.name, model_path + )); Ok(()) } @@ -811,10 +819,7 @@ fn update_parakeet_in_config(config: &str, model_name: &str) -> String { // Add [parakeet] section if not present if !has_parakeet_section { - result.push_str(&format!( - "\n[parakeet]\nmodel = \"{}\"\n", - model_name - )); + result.push_str(&format!("\n[parakeet]\nmodel = \"{}\"\n", model_name)); } // Remove trailing newline if original didn't have one @@ -853,10 +858,7 @@ pub fn list_installed_parakeet() { }) .unwrap_or(0.0); - println!( - " {} ({:.0} MB) - {}", - model.name, size, model.description - ); + println!(" {} ({:.0} MB) - {}", model.name, size, model.description); found = true; } } @@ -1181,16 +1183,23 @@ translate = false use crate::config::TranscriptionEngine; // Simulate: engine=Whisper, current model="base.en" - let is_whisper_engine = matches!(TranscriptionEngine::Whisper, TranscriptionEngine::Whisper); + let is_whisper_engine = + matches!(TranscriptionEngine::Whisper, TranscriptionEngine::Whisper); let current_whisper_model = "base.en"; // "base.en" should have star let is_current = is_whisper_engine && "base.en" == current_whisper_model; - assert!(is_current, "base.en should show star when it's the current Whisper model"); + assert!( + is_current, + "base.en should show star when it's the current Whisper model" + ); // "small.en" should NOT have star let is_current = is_whisper_engine && "small.en" == current_whisper_model; - assert!(!is_current, "small.en should not show star when base.en is current"); + assert!( + !is_current, + "small.en should not show star when base.en is current" + ); } #[test] @@ -1198,16 +1207,25 @@ translate = false use crate::config::TranscriptionEngine; // Simulate: engine=Parakeet, current model="parakeet-tdt-0.6b-v3" - let is_parakeet_engine = matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Parakeet); + let is_parakeet_engine = + matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Parakeet); let current_parakeet_model: Option<&str> = Some("parakeet-tdt-0.6b-v3"); // "parakeet-tdt-0.6b-v3" should have star - let is_current = is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); - assert!(is_current, "parakeet-tdt-0.6b-v3 should show star when it's the current Parakeet model"); + let is_current = + is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); + assert!( + is_current, + "parakeet-tdt-0.6b-v3 should show star when it's the current Parakeet model" + ); // "parakeet-tdt-0.6b-v3-int8" should NOT have star - let is_current = is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3-int8"); - assert!(!is_current, "parakeet-tdt-0.6b-v3-int8 should not show star when other model is current"); + let is_current = + is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3-int8"); + assert!( + !is_current, + "parakeet-tdt-0.6b-v3-int8 should not show star when other model is current" + ); } #[test] @@ -1215,18 +1233,27 @@ translate = false use crate::config::TranscriptionEngine; // When engine is Parakeet, Whisper models should NOT show star - let is_whisper_engine = matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Whisper); + let is_whisper_engine = + matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Whisper); let current_whisper_model = "base.en"; let is_current = is_whisper_engine && "base.en" == current_whisper_model; - assert!(!is_current, "Whisper models should not show star when engine is Parakeet"); + assert!( + !is_current, + "Whisper models should not show star when engine is Parakeet" + ); // When engine is Whisper, Parakeet models should NOT show star - let is_parakeet_engine = matches!(TranscriptionEngine::Whisper, TranscriptionEngine::Parakeet); + let is_parakeet_engine = + matches!(TranscriptionEngine::Whisper, TranscriptionEngine::Parakeet); let current_parakeet_model: Option<&str> = Some("parakeet-tdt-0.6b-v3"); - let is_current = is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); - assert!(!is_current, "Parakeet models should not show star when engine is Whisper"); + let is_current = + is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); + assert!( + !is_current, + "Parakeet models should not show star when engine is Whisper" + ); } #[test] @@ -1234,11 +1261,16 @@ translate = false use crate::config::TranscriptionEngine; // When parakeet config is None (not configured) - let is_parakeet_engine = matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Parakeet); + let is_parakeet_engine = + matches!(TranscriptionEngine::Parakeet, TranscriptionEngine::Parakeet); let current_parakeet_model: Option<&str> = None; // No model should show star when no parakeet config exists - let is_current = is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); - assert!(!is_current, "No star should show when parakeet config is not set"); + let is_current = + is_parakeet_engine && current_parakeet_model == Some("parakeet-tdt-0.6b-v3"); + assert!( + !is_current, + "No star should show when parakeet config is not set" + ); } } diff --git a/src/setup/parakeet.rs b/src/setup/parakeet.rs index 66e876d..a4686f6 100644 --- a/src/setup/parakeet.rs +++ b/src/setup/parakeet.rs @@ -149,7 +149,11 @@ fn detect_best_parakeet_backend() -> Option { /// Detect if NVIDIA GPU is present fn detect_nvidia_gpu() -> bool { // Check for nvidia-smi - if let Ok(output) = Command::new("nvidia-smi").arg("--query-gpu=name").arg("--format=csv,noheader").output() { + if let Ok(output) = Command::new("nvidia-smi") + .arg("--query-gpu=name") + .arg("--format=csv,noheader") + .output() + { return output.status.success() && !output.stdout.is_empty(); } @@ -176,7 +180,10 @@ fn detect_amd_gpu() -> bool { if name.starts_with("renderD") { // Check if it's an AMD device via sysfs let card_num = name.trim_start_matches("renderD"); - let vendor_path = format!("/sys/class/drm/card{}/device/vendor", card_num.parse::().unwrap_or(0) - 128); + let vendor_path = format!( + "/sys/class/drm/card{}/device/vendor", + card_num.parse::().unwrap_or(0) - 128 + ); if let Ok(vendor) = fs::read_to_string(&vendor_path) { // AMD vendor ID is 0x1002 if vendor.trim() == "0x1002" { @@ -240,13 +247,18 @@ pub fn show_status() { println!(" Backend: {}", backend.display_name()); println!( " Binary: {}", - Path::new(VOXTYPE_LIB_DIR).join(backend.binary_name()).display() + Path::new(VOXTYPE_LIB_DIR) + .join(backend.binary_name()) + .display() ); } } else { println!("Active engine: Whisper"); if let Some(backend) = detect_current_whisper_backend() { - println!(" Binary: {}", Path::new(VOXTYPE_LIB_DIR).join(backend).display()); + println!( + " Binary: {}", + Path::new(VOXTYPE_LIB_DIR).join(backend).display() + ); } } @@ -363,18 +375,31 @@ pub fn disable() -> anyhow::Result<()> { whisper_backend } else { // Try to find any available Whisper backend - for fallback in ["voxtype-avx512", "voxtype-avx2", "voxtype-vulkan", "voxtype-cpu"] { + for fallback in [ + "voxtype-avx512", + "voxtype-avx2", + "voxtype-vulkan", + "voxtype-cpu", + ] { if Path::new(VOXTYPE_LIB_DIR).join(fallback).exists() { - eprintln!("Note: {} not found, using {} instead", whisper_backend, fallback); + eprintln!( + "Note: {} not found, using {} instead", + whisper_backend, fallback + ); break; } } // Find first available - ["voxtype-avx512", "voxtype-avx2", "voxtype-vulkan", "voxtype-cpu"] - .iter() - .find(|b| Path::new(VOXTYPE_LIB_DIR).join(b).exists()) - .copied() - .ok_or_else(|| anyhow::anyhow!("No Whisper backend found to switch to"))? + [ + "voxtype-avx512", + "voxtype-avx2", + "voxtype-vulkan", + "voxtype-cpu", + ] + .iter() + .find(|b| Path::new(VOXTYPE_LIB_DIR).join(b).exists()) + .copied() + .ok_or_else(|| anyhow::anyhow!("No Whisper backend found to switch to"))? }; switch_binary(final_backend)?; @@ -384,7 +409,10 @@ pub fn disable() -> anyhow::Result<()> { println!("Updated systemd service to use Whisper backend."); } - println!("Switched to Whisper ({}) backend.", final_backend.trim_start_matches("voxtype-")); + println!( + "Switched to Whisper ({}) backend.", + final_backend.trim_start_matches("voxtype-") + ); println!(); println!("Restart voxtype to use Whisper:"); println!(" systemctl --user restart voxtype"); @@ -399,7 +427,10 @@ mod tests { #[test] fn test_parakeet_backend_binary_names() { assert_eq!(ParakeetBackend::Avx2.binary_name(), "voxtype-parakeet-avx2"); - assert_eq!(ParakeetBackend::Avx512.binary_name(), "voxtype-parakeet-avx512"); + assert_eq!( + ParakeetBackend::Avx512.binary_name(), + "voxtype-parakeet-avx512" + ); assert_eq!(ParakeetBackend::Cuda.binary_name(), "voxtype-parakeet-cuda"); assert_eq!(ParakeetBackend::Rocm.binary_name(), "voxtype-parakeet-rocm"); } @@ -415,7 +446,10 @@ mod tests { #[test] fn test_parakeet_whisper_equivalents() { assert_eq!(ParakeetBackend::Avx2.whisper_equivalent(), "voxtype-avx2"); - assert_eq!(ParakeetBackend::Avx512.whisper_equivalent(), "voxtype-avx512"); + assert_eq!( + ParakeetBackend::Avx512.whisper_equivalent(), + "voxtype-avx512" + ); assert_eq!(ParakeetBackend::Cuda.whisper_equivalent(), "voxtype-vulkan"); assert_eq!(ParakeetBackend::Rocm.whisper_equivalent(), "voxtype-vulkan"); } diff --git a/src/setup/vad.rs b/src/setup/vad.rs new file mode 100644 index 0000000..76950de --- /dev/null +++ b/src/setup/vad.rs @@ -0,0 +1,90 @@ +//! Voice Activity Detection model download + +use super::{print_failure, print_info, print_success}; +use crate::config::Config; +use crate::vad; +use std::process::Command; + +/// Download the VAD model for the current transcription engine +pub async fn download_vad_model(config: &Config, force: bool) -> anyhow::Result<()> { + let models_dir = Config::models_dir(); + + // Ensure models directory exists + std::fs::create_dir_all(&models_dir)?; + + let filename = vad::get_default_model_filename(config.engine); + let model_path = models_dir.join(filename); + let url = vad::get_vad_model_url(config.engine); + + println!("\nVoice Activity Detection Model Setup\n"); + println!("====================================\n"); + println!("Engine: {:?}", config.engine); + println!("Model: {}", filename); + println!("Path: {:?}\n", model_path); + + // Check if already downloaded + if model_path.exists() && !force { + print_success(&format!("VAD model already installed at {:?}", model_path)); + println!(); + print_info("To enable VAD, add to config.toml:"); + println!(" [vad]"); + println!(" enabled = true"); + println!(); + print_info("Or use CLI flag: voxtype --vad"); + println!(); + print_info("Use --force to re-download"); + return Ok(()); + } + + println!("Downloading VAD model..."); + println!("URL: {}", url); + + // Use curl for downloading + let status = Command::new("curl") + .args([ + "-L", // Follow redirects + "--progress-bar", // Show progress bar + "-o", + model_path.to_str().unwrap_or("vad_model"), + url, + ]) + .status(); + + match status { + Ok(exit_status) if exit_status.success() => { + print_success(&format!("VAD model saved to {:?}", model_path)); + println!(); + + // Show how to enable + print_info("To enable VAD, add to config.toml:"); + println!(" [vad]"); + println!(" enabled = true"); + println!(); + print_info("Or use CLI flag: voxtype --vad"); + + Ok(()) + } + Ok(exit_status) => { + print_failure(&format!( + "Download failed: curl exited with code {}", + exit_status.code().unwrap_or(-1) + )); + // Clean up partial download + let _ = std::fs::remove_file(&model_path); + anyhow::bail!("Download failed") + } + Err(e) => { + print_failure(&format!("Failed to run curl: {}", e)); + print_info("Please ensure curl is installed (e.g., 'sudo pacman -S curl')"); + anyhow::bail!("curl not available: {}", e) + } + } +} + +/// Check if VAD model is installed for the given engine +pub fn is_vad_model_installed(engine: crate::config::TranscriptionEngine) -> bool { + let models_dir = Config::models_dir(); + let filename = vad::get_default_model_filename(engine); + let model_path = models_dir.join(filename); + model_path.exists() +} diff --git a/src/transcribe/mod.rs b/src/transcribe/mod.rs index eea5c82..04e2004 100644 --- a/src/transcribe/mod.rs +++ b/src/transcribe/mod.rs @@ -50,7 +50,9 @@ pub fn create_transcriber(config: &Config) -> Result, Trans "Parakeet engine selected but [parakeet] config section is missing".to_string(), ) })?; - Ok(Box::new(parakeet::ParakeetTranscriber::new(parakeet_config)?)) + Ok(Box::new(parakeet::ParakeetTranscriber::new( + parakeet_config, + )?)) } #[cfg(not(feature = "parakeet"))] TranscriptionEngine::Parakeet => Err(TranscribeError::InitFailed( @@ -76,7 +78,10 @@ pub fn create_transcriber_with_config_path( // Apply GPU selection from VOXTYPE_VULKAN_DEVICE environment variable // This sets VK_LOADER_DRIVERS_SELECT to filter Vulkan drivers if let Some(vendor) = gpu::apply_gpu_selection() { - tracing::info!("GPU selection: {} (via VOXTYPE_VULKAN_DEVICE)", vendor.display_name()); + tracing::info!( + "GPU selection: {} (via VOXTYPE_VULKAN_DEVICE)", + vendor.display_name() + ); } match config.effective_mode() { diff --git a/src/transcribe/parakeet.rs b/src/transcribe/parakeet.rs index a742a17..7fa7ab8 100644 --- a/src/transcribe/parakeet.rs +++ b/src/transcribe/parakeet.rs @@ -10,9 +10,15 @@ use super::Transcriber; use crate::config::{ParakeetConfig, ParakeetModelType}; use crate::error::TranscribeError; -#[cfg(any(feature = "parakeet-cuda", feature = "parakeet-rocm", feature = "parakeet-tensorrt"))] +#[cfg(any( + feature = "parakeet-cuda", + feature = "parakeet-rocm", + feature = "parakeet-tensorrt" +))] use parakeet_rs::ExecutionProvider; -use parakeet_rs::{ExecutionConfig, Parakeet, ParakeetTDT, Transcriber as ParakeetTranscriberTrait}; +use parakeet_rs::{ + ExecutionConfig, Parakeet, ParakeetTDT, Transcriber as ParakeetTranscriberTrait, +}; use std::path::PathBuf; use std::sync::Mutex; @@ -54,15 +60,15 @@ impl ParakeetTranscriber { let model = match model_type { ParakeetModelType::Ctc => { - let parakeet = Parakeet::from_pretrained(&model_path, exec_config) - .map_err(|e| { + let parakeet = + Parakeet::from_pretrained(&model_path, exec_config).map_err(|e| { TranscribeError::InitFailed(format!("Parakeet CTC init failed: {}", e)) })?; ParakeetModel::Ctc(Mutex::new(parakeet)) } ParakeetModelType::Tdt => { - let parakeet = ParakeetTDT::from_pretrained(&model_path, exec_config) - .map_err(|e| { + let parakeet = + ParakeetTDT::from_pretrained(&model_path, exec_config).map_err(|e| { TranscribeError::InitFailed(format!("Parakeet TDT init failed: {}", e)) })?; ParakeetModel::Tdt(Mutex::new(parakeet)) @@ -82,7 +88,9 @@ impl ParakeetTranscriber { impl Transcriber for ParakeetTranscriber { fn transcribe(&self, samples: &[f32]) -> Result { if samples.is_empty() { - return Err(TranscribeError::AudioFormat("Empty audio buffer".to_string())); + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); } let duration_secs = samples.len() as f32 / 16000.0; @@ -98,7 +106,10 @@ impl Transcriber for ParakeetTranscriber { let text = match &self.model { ParakeetModel::Ctc(parakeet) => { let mut parakeet = parakeet.lock().map_err(|e| { - TranscribeError::InferenceFailed(format!("Failed to lock Parakeet mutex: {}", e)) + TranscribeError::InferenceFailed(format!( + "Failed to lock Parakeet mutex: {}", + e + )) })?; let result = parakeet @@ -109,14 +120,20 @@ impl Transcriber for ParakeetTranscriber { None, // default timestamp mode ) .map_err(|e| { - TranscribeError::InferenceFailed(format!("Parakeet CTC inference failed: {}", e)) + TranscribeError::InferenceFailed(format!( + "Parakeet CTC inference failed: {}", + e + )) })?; result.text.trim().to_string() } ParakeetModel::Tdt(parakeet) => { let mut parakeet = parakeet.lock().map_err(|e| { - TranscribeError::InferenceFailed(format!("Failed to lock Parakeet mutex: {}", e)) + TranscribeError::InferenceFailed(format!( + "Failed to lock Parakeet mutex: {}", + e + )) })?; let result = parakeet @@ -127,7 +144,10 @@ impl Transcriber for ParakeetTranscriber { None, // default timestamp mode ) .map_err(|e| { - TranscribeError::InferenceFailed(format!("Parakeet TDT inference failed: {}", e)) + TranscribeError::InferenceFailed(format!( + "Parakeet TDT inference failed: {}", + e + )) })?; result.text.trim().to_string() @@ -169,7 +189,11 @@ fn build_execution_config() -> Option { return Some(ExecutionConfig::new().with_execution_provider(ExecutionProvider::ROCm)); } - #[cfg(not(any(feature = "parakeet-cuda", feature = "parakeet-tensorrt", feature = "parakeet-rocm")))] + #[cfg(not(any( + feature = "parakeet-cuda", + feature = "parakeet-tensorrt", + feature = "parakeet-rocm" + )))] { None } @@ -181,8 +205,8 @@ fn build_execution_config() -> Option { /// CTC models have: model.onnx (or model_int8.onnx), tokenizer.json fn detect_model_type(path: &PathBuf) -> ParakeetModelType { // Check for TDT model structure - let has_encoder = path.join("encoder-model.onnx").exists() - || path.join("encoder-model.onnx.data").exists(); + let has_encoder = + path.join("encoder-model.onnx").exists() || path.join("encoder-model.onnx.data").exists(); let has_decoder = path.join("decoder_joint-model.onnx").exists(); if has_encoder && has_decoder { @@ -191,8 +215,7 @@ fn detect_model_type(path: &PathBuf) -> ParakeetModelType { } // Check for CTC model structure - let has_ctc_model = path.join("model.onnx").exists() - || path.join("model_int8.onnx").exists(); + let has_ctc_model = path.join("model.onnx").exists() || path.join("model_int8.onnx").exists(); let has_tokenizer = path.join("tokenizer.json").exists(); if has_ctc_model && has_tokenizer { diff --git a/src/vad/mod.rs b/src/vad/mod.rs new file mode 100644 index 0000000..5233de6 --- /dev/null +++ b/src/vad/mod.rs @@ -0,0 +1,158 @@ +//! Voice Activity Detection (VAD) module +//! +//! Provides voice activity detection to filter silence-only recordings before transcription. +//! This prevents Whisper hallucinations when processing silence. +//! +//! Two backends are supported: +//! - Whisper VAD: Uses the built-in VAD from whisper-rs (for Whisper engine) +//! - Silero VAD: Uses the voice_activity_detector crate (for Parakeet engine) + +mod whisper_vad; + +#[cfg(feature = "parakeet")] +mod silero_vad; + +use crate::config::{Config, TranscriptionEngine, VadConfig}; +use crate::error::VadError; +use std::path::PathBuf; + +/// Result of voice activity detection +#[derive(Debug, Clone)] +pub struct VadResult { + /// Whether speech was detected in the audio + pub has_speech: bool, + + /// Total duration of detected speech in seconds + pub speech_duration_secs: f32, + + /// Ratio of speech to total audio duration (0.0 - 1.0) + pub speech_ratio: f32, +} + +/// Trait for voice activity detection implementations +pub trait VoiceActivityDetector: Send + Sync { + /// Detect voice activity in audio samples + /// + /// # Arguments + /// * `samples` - Audio samples at 16kHz mono (f32 normalized to [-1.0, 1.0]) + /// + /// # Returns + /// * `VadResult` containing speech detection results + fn detect(&self, samples: &[f32]) -> Result; +} + +/// Create a VAD instance based on configuration and transcription engine +pub fn create_vad(config: &Config) -> Result>, VadError> { + if !config.vad.enabled { + return Ok(None); + } + + let vad: Box = match config.engine { + TranscriptionEngine::Whisper => { + let model_path = resolve_whisper_vad_model_path(&config.vad)?; + Box::new(whisper_vad::WhisperVad::new(&model_path, &config.vad)?) + } + #[cfg(feature = "parakeet")] + TranscriptionEngine::Parakeet => { + let model_path = resolve_silero_vad_model_path(&config.vad)?; + Box::new(silero_vad::SileroVad::new(&model_path, &config.vad)?) + } + #[cfg(not(feature = "parakeet"))] + TranscriptionEngine::Parakeet => { + return Err(VadError::InitFailed( + "Parakeet VAD requires the 'parakeet' feature".to_string(), + )); + } + }; + + Ok(Some(vad)) +} + +/// Resolve the path to the Whisper VAD model +fn resolve_whisper_vad_model_path(config: &VadConfig) -> Result { + // If model path is explicitly configured, use it + if let Some(ref model) = config.model { + let path = PathBuf::from(model); + if path.exists() { + return Ok(path); + } + return Err(VadError::ModelNotFound(model.clone())); + } + + // Use default model location + let models_dir = crate::config::Config::models_dir(); + let model_path = models_dir.join("ggml-silero-vad.bin"); + + if model_path.exists() { + Ok(model_path) + } else { + Err(VadError::ModelNotFound(model_path.display().to_string())) + } +} + +/// Resolve the path to the Silero VAD ONNX model +/// Note: voice_activity_detector uses a bundled model, so this returns a dummy path +#[cfg(feature = "parakeet")] +fn resolve_silero_vad_model_path(_config: &VadConfig) -> Result { + // voice_activity_detector crate uses a bundled Silero model + // No external model file is needed + Ok(PathBuf::from("bundled")) +} + +/// Get the default VAD model filename for the given transcription engine +pub fn get_default_model_filename(engine: TranscriptionEngine) -> &'static str { + match engine { + TranscriptionEngine::Whisper => "ggml-silero-vad.bin", + TranscriptionEngine::Parakeet => "silero_vad.onnx", + } +} + +/// Get the download URL for the VAD model +pub fn get_vad_model_url(engine: TranscriptionEngine) -> &'static str { + match engine { + TranscriptionEngine::Whisper => { + "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v6.2.0.bin" + } + TranscriptionEngine::Parakeet => { + "https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx" + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vad_result_default() { + let result = VadResult { + has_speech: false, + speech_duration_secs: 0.0, + speech_ratio: 0.0, + }; + assert!(!result.has_speech); + assert_eq!(result.speech_duration_secs, 0.0); + assert_eq!(result.speech_ratio, 0.0); + } + + #[test] + fn test_get_default_model_filename() { + assert_eq!( + get_default_model_filename(TranscriptionEngine::Whisper), + "ggml-silero-vad.bin" + ); + assert_eq!( + get_default_model_filename(TranscriptionEngine::Parakeet), + "silero_vad.onnx" + ); + } + + #[test] + fn test_vad_disabled_returns_none() { + let config = Config::default(); + // VAD is disabled by default + assert!(!config.vad.enabled); + let vad = create_vad(&config).unwrap(); + assert!(vad.is_none()); + } +} diff --git a/src/vad/silero_vad.rs b/src/vad/silero_vad.rs new file mode 100644 index 0000000..4530778 --- /dev/null +++ b/src/vad/silero_vad.rs @@ -0,0 +1,123 @@ +//! Silero VAD implementation using voice_activity_detector crate +//! +//! Uses the voice_activity_detector crate which bundles the Silero VAD model. +//! This implementation is used with the Parakeet transcription engine. + +use super::{VadResult, VoiceActivityDetector}; +use crate::config::VadConfig; +use crate::error::VadError; +use std::path::Path; +use std::sync::Mutex; +use voice_activity_detector::VoiceActivityDetector as VadDetector; + +/// Silero VAD implementation using voice_activity_detector crate +pub struct SileroVad { + /// VAD detector instance (wrapped in Mutex for thread safety) + detector: Mutex, + /// Speech detection threshold (0.0 - 1.0) + threshold: f32, + /// Minimum speech duration in milliseconds + min_speech_duration_ms: u32, +} + +impl SileroVad { + /// Create a new Silero VAD instance + /// + /// # Arguments + /// * `_model_path` - Ignored, the crate uses a bundled model + /// * `config` - VAD configuration + pub fn new(_model_path: &Path, config: &VadConfig) -> Result { + tracing::debug!("Initializing Silero VAD (bundled model)"); + + // voice_activity_detector requires 512 samples per chunk at 16kHz + let detector = VadDetector::builder() + .sample_rate(16000) + .chunk_size(512usize) + .build() + .map_err(|e| VadError::InitFailed(format!("Failed to create VAD detector: {}", e)))?; + + tracing::info!("Silero VAD initialized successfully"); + + Ok(Self { + detector: Mutex::new(detector), + threshold: config.threshold.clamp(0.0, 1.0), + min_speech_duration_ms: config.min_speech_duration_ms, + }) + } +} + +impl VoiceActivityDetector for SileroVad { + fn detect(&self, samples: &[f32]) -> Result { + let mut detector = self + .detector + .lock() + .map_err(|e| VadError::DetectionFailed(format!("Failed to acquire VAD lock: {}", e)))?; + + // Process audio in chunks of 512 samples (required by voice_activity_detector) + const CHUNK_SIZE: usize = 512; + let chunk_duration_secs = CHUNK_SIZE as f32 / 16000.0; // 32ms per chunk + + let mut speech_chunks = 0; + let mut total_chunks = 0; + + // Process audio in chunks + for chunk in samples.chunks(CHUNK_SIZE) { + // Pad last chunk if needed + let prob = if chunk.len() < CHUNK_SIZE { + let mut padded: Vec = chunk.to_vec(); + padded.resize(CHUNK_SIZE, 0.0); + detector.predict(padded) + } else { + detector.predict(chunk.to_vec()) + }; + + if prob >= self.threshold { + speech_chunks += 1; + } + total_chunks += 1; + } + + // Reset detector for next use + detector.reset(); + + // Calculate speech duration and ratio + let speech_duration_secs = speech_chunks as f32 * chunk_duration_secs; + let total_duration_secs = samples.len() as f32 / 16000.0; + let speech_ratio = if total_chunks > 0 { + speech_chunks as f32 / total_chunks as f32 + } else { + 0.0 + }; + + // Determine if speech was detected + let min_speech_secs = self.min_speech_duration_ms as f32 / 1000.0; + let has_speech = speech_duration_secs >= min_speech_secs; + + tracing::debug!( + "VAD result: {}/{} chunks with speech ({:.2}s, {:.1}% of {:.2}s total)", + speech_chunks, + total_chunks, + speech_duration_secs, + speech_ratio * 100.0, + total_duration_secs + ); + + Ok(VadResult { + has_speech, + speech_duration_secs, + speech_ratio, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_threshold_clamping() { + assert_eq!(1.5f32.clamp(0.0, 1.0), 1.0); + assert_eq!((-0.5f32).clamp(0.0, 1.0), 0.0); + assert_eq!(0.5f32.clamp(0.0, 1.0), 0.5); + } +} diff --git a/src/vad/whisper_vad.rs b/src/vad/whisper_vad.rs new file mode 100644 index 0000000..43ebd03 --- /dev/null +++ b/src/vad/whisper_vad.rs @@ -0,0 +1,149 @@ +//! Whisper VAD implementation using whisper-rs built-in VAD +//! +//! Uses the WhisperVadContext from whisper-rs which wraps Silero VAD +//! in GGML format, optimized for use with whisper.cpp. + +use super::{VadResult, VoiceActivityDetector}; +use crate::config::VadConfig; +use crate::error::VadError; +use std::path::Path; +use std::sync::Mutex; +use whisper_rs::{WhisperVadContext, WhisperVadContextParams, WhisperVadParams}; + +/// Whisper VAD implementation using whisper-rs +pub struct WhisperVad { + /// VAD context (wrapped in Mutex because WhisperVadContext is not Send/Sync) + ctx: Mutex, + /// Speech detection threshold (0.0 - 1.0) + threshold: f32, + /// Minimum speech duration in milliseconds + min_speech_duration_ms: u32, +} + +impl WhisperVad { + /// Create a new Whisper VAD instance + /// + /// # Arguments + /// * `model_path` - Path to the GGML VAD model file (ggml-silero-vad.bin) + /// * `config` - VAD configuration + pub fn new(model_path: &Path, config: &VadConfig) -> Result { + let model_str = model_path + .to_str() + .ok_or_else(|| VadError::InitFailed("Invalid model path".to_string()))?; + + tracing::debug!("Loading Whisper VAD model from {:?}", model_path); + + let params = WhisperVadContextParams::default(); + + let ctx = WhisperVadContext::new(model_str, params) + .map_err(|e| VadError::InitFailed(format!("Failed to load VAD model: {}", e)))?; + + tracing::info!("Whisper VAD model loaded successfully"); + + Ok(Self { + ctx: Mutex::new(ctx), + threshold: config.threshold.clamp(0.0, 1.0), + min_speech_duration_ms: config.min_speech_duration_ms, + }) + } +} + +impl VoiceActivityDetector for WhisperVad { + fn detect(&self, samples: &[f32]) -> Result { + let mut ctx = self + .ctx + .lock() + .map_err(|e| VadError::DetectionFailed(format!("Failed to acquire VAD lock: {}", e)))?; + + // Configure VAD parameters + let mut params = WhisperVadParams::new(); + params.set_threshold(self.threshold); + params.set_min_speech_duration(self.min_speech_duration_ms as i32); + // Use defaults for silence duration (100ms) and padding (30ms) + + // Run VAD detection + let segments = ctx + .segments_from_samples(params, samples) + .map_err(|e| VadError::DetectionFailed(format!("VAD detection failed: {}", e)))?; + + // Calculate total speech duration from segments + // Timestamps are in centiseconds (10ms units) + let mut total_speech_centiseconds = 0.0f32; + let num_segments = segments.num_segments(); + + for i in 0..num_segments { + if let (Some(start), Some(end)) = ( + segments.get_segment_start_timestamp(i), + segments.get_segment_end_timestamp(i), + ) { + total_speech_centiseconds += end - start; + } + } + + // Convert centiseconds to seconds + let speech_duration_secs = total_speech_centiseconds / 100.0; + + // Calculate total audio duration (samples at 16kHz) + let total_duration_secs = samples.len() as f32 / 16000.0; + + // Calculate speech ratio + let speech_ratio = if total_duration_secs > 0.0 { + (speech_duration_secs / total_duration_secs).clamp(0.0, 1.0) + } else { + 0.0 + }; + + // Determine if speech was detected + // Has speech if any segments were found AND total speech meets minimum duration + let min_speech_secs = self.min_speech_duration_ms as f32 / 1000.0; + let has_speech = num_segments > 0 && speech_duration_secs >= min_speech_secs; + + tracing::debug!( + "VAD result: {} segments, {:.2}s speech ({:.1}% of {:.2}s total)", + num_segments, + speech_duration_secs, + speech_ratio * 100.0, + total_duration_secs + ); + + Ok(VadResult { + has_speech, + speech_duration_secs, + speech_ratio, + }) + } +} + +// WhisperVad is Send + Sync because the internal WhisperVadContext is wrapped in a Mutex +unsafe impl Send for WhisperVad {} +unsafe impl Sync for WhisperVad {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_threshold_clamping() { + // Test that threshold is clamped to valid range + let config = VadConfig { + enabled: true, + threshold: 1.5, // Above max + min_speech_duration_ms: 100, + model: None, + }; + + // Can't test actual VAD without a model, but we can verify the struct + // would clamp the threshold + let clamped = config.threshold.clamp(0.0, 1.0); + assert_eq!(clamped, 1.0); + + let config2 = VadConfig { + enabled: true, + threshold: -0.5, // Below min + min_speech_duration_ms: 100, + model: None, + }; + let clamped2 = config2.threshold.clamp(0.0, 1.0); + assert_eq!(clamped2, 0.0); + } +}