diff --git a/.claude/skills/build-release/SKILL.md b/.claude/skills/build-release/SKILL.md index 7e56ed49..5a799865 100644 --- a/.claude/skills/build-release/SKILL.md +++ b/.claude/skills/build-release/SKILL.md @@ -1,6 +1,6 @@ --- name: build-release -description: Build all voxtype binaries for release. Builds Whisper (AVX2, AVX-512, Vulkan) and Parakeet (AVX2, AVX-512, CUDA) binaries using Docker. Use when preparing a new release. +description: Build all voxtype binaries for release. Builds Whisper (AVX2, AVX-512, Vulkan) and ONNX (AVX2, AVX-512, CUDA) binaries using Docker. Use when preparing a new release. user-invocable: true allowed-tools: - Bash @@ -12,7 +12,7 @@ allowed-tools: Build all 6 voxtype binaries for a release: - **Whisper**: AVX2, AVX-512, Vulkan -- **Parakeet**: AVX2, AVX-512, CUDA +- **ONNX** (Parakeet + Moonshine): AVX2, AVX-512, CUDA ## Prerequisites @@ -26,20 +26,17 @@ Build all 6 voxtype binaries for a release: # Set version export VERSION=X.Y.Z -# Build remote binaries (AVX2, Vulkan, Parakeet-AVX2, Parakeet-CUDA) +# Build remote binaries (AVX2, Vulkan, ONNX-AVX2, ONNX-CUDA) docker context use truenas -docker compose -f docker-compose.build.yml build --no-cache avx2 vulkan parakeet-avx2 parakeet-cuda -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-avx2 -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-vulkan -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-parakeet-avx2 -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-parakeet-cuda +docker compose -f docker-compose.build.yml build --no-cache avx2 vulkan onnx-avx2 onnx-cuda +docker compose -f docker-compose.build.yml up avx2 vulkan onnx-avx2 onnx-cuda # Build local AVX-512 binaries docker context use default cargo clean && cargo build --release cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-avx512 -cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet -cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512 +cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet,moonshine +cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-onnx-avx512 # Verify versions for bin in releases/${VERSION}/voxtype-*; do echo "$(basename $bin): $($bin --version)"; done @@ -91,7 +88,7 @@ git commit -S -m "Bump to vX.Y.Z" git push ``` -### 2. Build Remote Binaries (AVX2, Vulkan, Parakeet) +### 2. Build Remote Binaries (AVX2, Vulkan, ONNX) These builds use Ubuntu 22.04 to avoid AVX-512 instruction contamination: @@ -101,14 +98,10 @@ docker context use truenas mkdir -p releases/${VERSION} # Build all Docker images (takes ~10-15 min) -docker compose -f docker-compose.build.yml build --no-cache avx2 vulkan parakeet-avx2 parakeet-cuda +docker compose -f docker-compose.build.yml build --no-cache avx2 vulkan onnx-avx2 onnx-cuda -# Extract binaries from images -for service in avx2 vulkan; do - docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-${service} -done -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-parakeet-avx2 -docker run --rm -v $(pwd)/releases/${VERSION}:/output voxtype-parakeet-parakeet-cuda +# Extract binaries +docker compose -f docker-compose.build.yml up avx2 vulkan onnx-avx2 onnx-cuda ``` ### 3. Build Local AVX-512 Binaries @@ -122,9 +115,9 @@ docker context use default cargo clean && cargo build --release cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-avx512 -# Parakeet AVX-512 -cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet -cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512 +# ONNX AVX-512 +cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet,moonshine +cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-onnx-avx512 ``` ### 4. Verify All Binaries @@ -185,8 +178,8 @@ After successful build, `releases/${VERSION}/` should contain: voxtype-X.Y.Z-linux-x86_64-avx2 voxtype-X.Y.Z-linux-x86_64-avx512 voxtype-X.Y.Z-linux-x86_64-vulkan -voxtype-X.Y.Z-linux-x86_64-parakeet-avx2 -voxtype-X.Y.Z-linux-x86_64-parakeet-avx512 -voxtype-X.Y.Z-linux-x86_64-parakeet-cuda +voxtype-X.Y.Z-linux-x86_64-onnx-avx2 +voxtype-X.Y.Z-linux-x86_64-onnx-avx512 +voxtype-X.Y.Z-linux-x86_64-onnx-cuda SHA256SUMS ``` diff --git a/.dockerignore b/.dockerignore index 20419809..04f173d3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -12,6 +12,10 @@ target/ releases/* !releases/parakeet-test/ +# Worktrees +.worktrees/ +worktrees/ + # Local test files *.wav *.mp3 diff --git a/.gitignore b/.gitignore index 543c55ba..7eff8248 100644 --- a/.gitignore +++ b/.gitignore @@ -22,8 +22,9 @@ # Audio files (test recordings) *.wav *.mp3 -# Exception: VAD test fixtures are intentionally committed +# Exception: test fixtures are intentionally committed !tests/fixtures/vad/*.wav +!tests/fixtures/sensevoice/*.wav # Claude Code .claude/settings.local.json diff --git a/CLAUDE.md b/CLAUDE.md index e58eca28..81f78225 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -423,15 +423,15 @@ Building on modern CPUs (Zen 4, etc.) can leak AVX-512/GFNI instructions into bi | Vulkan | Docker on remote pre-AVX-512 server | GPU build on CPU without AVX-512 | | AVX512 | Local machine | Requires AVX-512 capable host | -**Parakeet Binaries (Experimental):** +**ONNX Binaries (Parakeet + Moonshine):** | Binary | Build Location | Why | |--------|---------------|-----| -| parakeet-avx2 | Docker on remote pre-AVX-512 server | Wide CPU compatibility | -| parakeet-avx512 | Local machine | Best CPU performance | -| parakeet-cuda | Docker on remote server with NVIDIA GPU | GPU acceleration | +| onnx-avx2 | Docker on remote pre-AVX-512 server | Wide CPU compatibility | +| onnx-avx512 | Local machine | Best CPU performance | +| onnx-cuda | Docker on remote server with NVIDIA GPU | GPU acceleration | -Note: Parakeet binaries include bundled ONNX Runtime which contains AVX-512 instructions, but ONNX Runtime uses runtime CPU detection and falls back gracefully on older CPUs. +Note: ONNX binaries include bundled ONNX Runtime which contains AVX-512 instructions, but ONNX Runtime uses runtime CPU detection and falls back gracefully on older CPUs. ### GPU Feature Flags @@ -487,9 +487,9 @@ docker context use docker compose -f docker-compose.build.yml build --no-cache avx2 vulkan docker compose -f docker-compose.build.yml up avx2 vulkan -# 2. Build Parakeet binaries on remote server -docker compose -f docker-compose.build.yml build --no-cache parakeet-avx2 -docker compose -f docker-compose.build.yml up parakeet-avx2 +# 2. Build ONNX binaries on remote server +docker compose -f docker-compose.build.yml build --no-cache onnx-avx2 +docker compose -f docker-compose.build.yml up onnx-avx2 # 3. Copy binaries from remote Docker volumes to local mkdir -p releases/${VERSION} @@ -504,9 +504,9 @@ docker context use default cargo clean && cargo build --release cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-avx512 -# Parakeet AVX-512 -cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet -cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512 +# ONNX AVX-512 +cargo clean && RUSTFLAGS="-C target-cpu=native" cargo build --release --features parakeet,moonshine +cp target/release/voxtype releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-onnx-avx512 # 5. VERIFY VERSIONS before uploading (critical!) for bin in releases/${VERSION}/voxtype-*; do @@ -527,9 +527,9 @@ releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-avx2 --version releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-avx512 --version releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-vulkan --version -# Parakeet binaries (experimental) -releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-parakeet-avx2 --version -releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512 --version +# ONNX binaries +releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-onnx-avx2 --version +releases/${VERSION}/voxtype-${VERSION}-linux-x86_64-onnx-avx512 --version ``` If versions don't match, the Docker cache is stale. Rebuild with `--no-cache`. @@ -556,15 +556,15 @@ What to look for: - `{1to4}`, `{1to8}`, `{1to16}` = AVX-512 broadcast syntax - `vgf2p8`, `gf2p8` = GFNI instructions (not on Zen 3) -### Parakeet Binary Instruction Leakage +### ONNX Binary Instruction Leakage -**IMPORTANT: Parakeet binaries also need AVX-512 instruction checks**, even when built on pre-AVX-512 hardware. +**IMPORTANT: ONNX binaries also need AVX-512 instruction checks**, even when built on pre-AVX-512 hardware. The `ort` crate downloads prebuilt ONNX Runtime binaries that may contain AVX-512 instructions regardless of the build host's CPU. This is different from Whisper builds where the leakage comes from system libraries. ```bash -# Check Parakeet binaries for AVX-512 leakage -objdump -d voxtype-*-parakeet-avx2 | grep -c zmm +# Check ONNX binaries for AVX-512 leakage +objdump -d voxtype-*-onnx-avx2 | grep -c zmm # If >0, the ONNX Runtime contains AVX-512 instructions ``` @@ -573,7 +573,7 @@ objdump -d voxtype-*-parakeet-avx2 | grep -c zmm 2. **Build ONNX Runtime from source** - Use `ORT_STRATEGY=build` to compile ONNX Runtime with specific CPU flags (significantly increases build time) 3. **Use `load-dynamic` feature** - Link against system ONNX Runtime instead of bundled (requires users to install ONNX Runtime separately) -For now, Parakeet binaries may contain AVX-512 instructions from ONNX Runtime but should still run on pre-AVX-512 CPUs via runtime fallback. Test on target hardware to verify. +For now, ONNX binaries may contain AVX-512 instructions from ONNX Runtime but should still run on pre-AVX-512 CPUs via runtime fallback. Test on target hardware to verify. ### Packaging Deb and RPM diff --git a/Cargo.lock b/Cargo.lock index 73b721f9..5b225777 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,6 +53,15 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.21" @@ -272,6 +281,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -664,6 +687,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -704,6 +739,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -778,18 +819,58 @@ dependencies = [ "wasip2", ] +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + [[package]] name = "glob" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -830,6 +911,30 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core 0.62.2", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -911,6 +1016,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -951,7 +1062,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", + "serde", + "serde_core", ] [[package]] @@ -1103,6 +1216,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.177" @@ -1140,6 +1259,17 @@ dependencies = [ "redox_syscall", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -1995,6 +2125,20 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3" +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.10.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -2122,6 +2266,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -2628,6 +2778,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "unicode_categories" version = "0.1.1" @@ -2718,6 +2874,18 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +dependencies = [ + "getrandom 0.4.1", + "js-sys", + "serde_core", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" @@ -2738,10 +2906,11 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "voxtype" -version = "0.5.6" +version = "0.6.0" dependencies = [ "anyhow", "async-trait", + "chrono", "clap", "clap_mangen", "cpal", @@ -2760,6 +2929,8 @@ dependencies = [ "pidlock", "regex", "rodio", + "rusqlite", + "rustfft", "serde", "serde_json", "tempfile", @@ -2770,6 +2941,7 @@ dependencies = [ "tracing", "tracing-subscriber", "ureq 2.12.1", + "uuid", "which", "whisper-rs", ] @@ -2796,7 +2968,16 @@ version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.46.0", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", ] [[package]] @@ -2857,6 +3038,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.10.0", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.82" @@ -2965,7 +3180,7 @@ version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49" dependencies = [ - "windows-core", + "windows-core 0.54.0", "windows-targets 0.52.6", ] @@ -2975,10 +3190,45 @@ version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65" dependencies = [ - "windows-result", + "windows-result 0.1.2", "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result 0.4.1", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" @@ -2994,6 +3244,24 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -3303,6 +3571,94 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.10.0", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "writeable" version = "0.6.2" diff --git a/Cargo.toml b/Cargo.toml index 9b78bfbe..227bce88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["xtask"] [package] name = "voxtype" -version = "0.5.6" +version = "0.6.0" edition = "2021" authors = ["Peter Jackson", "Jean-Paul van Tillo", "Máté Rémiás", "Rob Zolkos", "Dan Heuckeroth", "Igor Warzocha", "Julian Kaiser", "Kevin Miller", "konnsim", "reisset", "Zubair", "Loki Coyote", "Umesh", "Barrett Ruth", "André Silva", "Chmouel Boudjnah", "Christopher Albert", "Phuoc Thinh Vu", "Alexander Bosu-Kellett", "ayoahha", "Toizi", "kakapt"] description = "Push-to-talk voice-to-text for Wayland" @@ -69,10 +69,11 @@ whisper-rs = "0.15.1" # Parakeet speech-to-text (optional, ONNX-based) parakeet-rs = { version = "0.3", optional = true } -# Moonshine speech-to-text (optional, ONNX-based encoder-decoder ASR) +# ONNX-based ASR engines (Moonshine, SenseVoice, Paraformer, Dolphin, Omnilingual) ort = { version = "2.0.0-rc.11", optional = true } ndarray = { version = "0.16", optional = true } tokenizers = { version = "0.20", optional = true, default-features = false, features = ["onig"] } +rustfft = { version = "6", optional = true } # CPU count for thread detection num_cpus = "1.16" @@ -83,12 +84,19 @@ notify = "6" # Single instance check pidlock = "0.1" +# Meeting mode (Pro feature) +uuid = { version = "1", features = ["v4", "serde"] } +chrono = { version = "0.4", features = ["serde"] } +rusqlite = { version = "0.32", features = ["bundled"] } + [features] default = [] gpu-vulkan = ["whisper-rs/vulkan"] gpu-cuda = ["whisper-rs/cuda"] gpu-metal = ["whisper-rs/metal"] gpu-hipblas = ["whisper-rs/hipblas"] +# ML-based speaker diarization (uses ONNX for embedding extraction) +ml-diarization = ["dep:ort", "dep:ndarray"] # Parakeet backend (ONNX-based, alternative to Whisper) parakeet = ["dep:parakeet-rs"] parakeet-cuda = ["parakeet", "parakeet-rs/cuda"] @@ -96,10 +104,28 @@ parakeet-tensorrt = ["parakeet", "parakeet-rs/tensorrt"] parakeet-rocm = ["parakeet", "parakeet-rs/rocm"] # Dynamic loading for system ONNX Runtime (used by Nix builds) parakeet-load-dynamic = ["parakeet", "parakeet-rs/load-dynamic"] +# Shared ONNX dependencies for engines using fbank/CTC preprocessing +onnx-common = ["dep:ort", "dep:ndarray", "dep:rustfft"] # Moonshine backend (ONNX-based, encoder-decoder ASR) -moonshine = ["dep:ort", "dep:ndarray", "dep:tokenizers"] +moonshine = ["onnx-common", "dep:tokenizers"] moonshine-cuda = ["moonshine", "ort/cuda"] moonshine-tensorrt = ["moonshine", "ort/tensorrt"] +# SenseVoice backend (ONNX-based, CTC encoder-only ASR) +sensevoice = ["onnx-common"] +sensevoice-cuda = ["sensevoice", "ort/cuda"] +sensevoice-tensorrt = ["sensevoice", "ort/tensorrt"] +# Paraformer backend (FunASR ONNX-based CTC encoder) +paraformer = ["onnx-common"] +paraformer-cuda = ["paraformer", "ort/cuda"] +paraformer-tensorrt = ["paraformer", "ort/tensorrt"] +# Dolphin backend (ONNX-based CTC encoder, dictation-optimized) +dolphin = ["onnx-common"] +dolphin-cuda = ["dolphin", "ort/cuda"] +dolphin-tensorrt = ["dolphin", "ort/tensorrt"] +# Omnilingual backend (FunASR ONNX-based, 50+ languages) +omnilingual = ["onnx-common"] +omnilingual-cuda = ["omnilingual", "ort/cuda"] +omnilingual-tensorrt = ["omnilingual", "ort/tensorrt"] [build-dependencies] clap = { version = "4", features = ["derive"] } diff --git a/Dockerfile.parakeet b/Dockerfile.onnx similarity index 66% rename from Dockerfile.parakeet rename to Dockerfile.onnx index fc368ed8..d41223ef 100644 --- a/Dockerfile.parakeet +++ b/Dockerfile.onnx @@ -1,16 +1,17 @@ -# Build environment for voxtype with Parakeet (ONNX) support - AVX2 compatible +# Build environment for voxtype with ONNX Runtime support - AVX2 compatible # -# Builds voxtype with --features parakeet for NVIDIA Parakeet ASR model support. -# Uses Ubuntu 24.04 for glibc 2.39+ (required by ONNX Runtime prebuilt binaries). +# Builds voxtype with all ONNX-based engines (Parakeet, Moonshine, SenseVoice, +# Paraformer, Dolphin, Omnilingual). Uses Ubuntu 24.04 for glibc 2.39+ +# (required by ONNX Runtime prebuilt binaries). # # NOTE: Building on i9-9900K (no AVX-512) ensures wide CPU compatibility. # Output binary uses AVX2 instructions for good performance on modern CPUs. # # Usage: -# docker build -f Dockerfile.parakeet -t voxtype-parakeet-avx2-builder . -# docker run --rm -v $(pwd)/releases:/output voxtype-parakeet-avx2-builder +# docker build -f Dockerfile.onnx -t voxtype-onnx-avx2-builder . +# docker run --rm -v $(pwd)/releases:/output voxtype-onnx-avx2-builder # -# For CUDA support, use Dockerfile.parakeet-cuda instead. +# For CUDA support, use Dockerfile.onnx-cuda instead. # For AVX-512 support, build on an AVX-512 capable host. # FROM ubuntu:24.04 @@ -60,17 +61,17 @@ ENV GGML_AVX512_VNNI=OFF ENV ORT_STRATEGY=download # Disable LTO for faster builds (can hang on TrueNAS) -# Build with parakeet feature -RUN cargo build --release --features parakeet,moonshine \ +# Build with all ONNX engines +RUN cargo build --release --features parakeet,moonshine,sensevoice,paraformer,dolphin,omnilingual \ --config 'profile.release.lto=false' \ --config 'profile.release.codegen-units=8' \ - && cp target/release/voxtype /tmp/voxtype-parakeet-avx2 + && cp target/release/voxtype /tmp/voxtype-onnx-avx2 # Verify binary -RUN echo "=== Verifying Parakeet AVX2 binary ===" \ - && /tmp/voxtype-parakeet-avx2 --version \ +RUN echo "=== Verifying ONNX AVX2 binary ===" \ + && /tmp/voxtype-onnx-avx2 --version \ && echo "=== Checking for AVX-512 instructions ===" \ - && zmm_count=$(objdump -d /tmp/voxtype-parakeet-avx2 | grep -c zmm || echo 0) \ + && zmm_count=$(objdump -d /tmp/voxtype-onnx-avx2 | grep -c zmm || echo 0) \ && echo " zmm registers: $zmm_count" \ && if [ "$zmm_count" -gt 0 ]; then \ echo "WARNING: Binary contains AVX-512 instructions"; \ @@ -78,10 +79,10 @@ RUN echo "=== Verifying Parakeet AVX2 binary ===" \ echo "✓ Binary is clean (no AVX-512)"; \ fi \ && echo "=== Binary size ===" \ - && ls -lh /tmp/voxtype-parakeet-avx2 + && ls -lh /tmp/voxtype-onnx-avx2 # Output stage CMD mkdir -p /output \ - && cp /tmp/voxtype-parakeet-avx2 /output/voxtype-${VERSION}-linux-x86_64-parakeet-avx2 \ + && cp /tmp/voxtype-onnx-avx2 /output/voxtype-${VERSION}-linux-x86_64-onnx-avx2 \ && echo "Binary copied to /output:" \ && ls -la /output/voxtype-* diff --git a/Dockerfile.parakeet-cuda b/Dockerfile.onnx-cuda similarity index 60% rename from Dockerfile.parakeet-cuda rename to Dockerfile.onnx-cuda index 61946c2f..f0785f9b 100644 --- a/Dockerfile.parakeet-cuda +++ b/Dockerfile.onnx-cuda @@ -1,11 +1,11 @@ -# Build environment for voxtype with Parakeet CUDA support +# Build environment for voxtype with ONNX CUDA support # -# Builds voxtype with --features parakeet-cuda for NVIDIA GPU acceleration. +# Builds voxtype with all ONNX engines + CUDA for NVIDIA GPU acceleration. # Uses NVIDIA CUDA base image with cuDNN for optimal ONNX Runtime performance. # # Usage: -# docker build -f Dockerfile.parakeet-cuda -t voxtype-parakeet-cuda-builder . -# docker run --rm -v $(pwd)/releases:/output voxtype-parakeet-cuda-builder +# docker build -f Dockerfile.onnx-cuda -t voxtype-onnx-cuda-builder . +# docker run --rm -v $(pwd)/releases:/output voxtype-onnx-cuda-builder # # The resulting binary requires: # - NVIDIA GPU with CUDA support @@ -48,22 +48,22 @@ COPY . . ENV ORT_STRATEGY=download # Disable LTO for faster builds -# Build with parakeet-cuda feature for NVIDIA GPU support -RUN cargo build --release --features parakeet-cuda,moonshine \ +# Build with all ONNX engines + CUDA support for NVIDIA GPUs +RUN cargo build --release --features parakeet-cuda,moonshine-cuda,sensevoice-cuda,paraformer-cuda,dolphin-cuda,omnilingual-cuda \ --config 'profile.release.lto=false' \ --config 'profile.release.codegen-units=8' \ - && cp target/release/voxtype /tmp/voxtype-parakeet-cuda + && cp target/release/voxtype /tmp/voxtype-onnx-cuda # Verify binary -RUN echo "=== Verifying Parakeet CUDA binary ===" \ - && /tmp/voxtype-parakeet-cuda --version \ +RUN echo "=== Verifying ONNX CUDA binary ===" \ + && /tmp/voxtype-onnx-cuda --version \ && echo "=== Binary size ===" \ - && ls -lh /tmp/voxtype-parakeet-cuda \ + && ls -lh /tmp/voxtype-onnx-cuda \ && echo "=== Checking CUDA libraries linked ===" \ - && ldd /tmp/voxtype-parakeet-cuda | grep -iE 'cuda|cudnn|onnx' || echo "Note: CUDA libs may be loaded dynamically at runtime" + && ldd /tmp/voxtype-onnx-cuda | grep -iE 'cuda|cudnn|onnx' || echo "Note: CUDA libs may be loaded dynamically at runtime" # Output stage CMD mkdir -p /output \ - && cp /tmp/voxtype-parakeet-cuda /output/voxtype-${VERSION}-linux-x86_64-parakeet-cuda \ + && cp /tmp/voxtype-onnx-cuda /output/voxtype-${VERSION}-linux-x86_64-onnx-cuda \ && echo "Binary copied to /output:" \ && ls -la /output/voxtype-* diff --git a/docker-compose.build.yml b/docker-compose.build.yml index 4e247361..35ca4c1d 100644 --- a/docker-compose.build.yml +++ b/docker-compose.build.yml @@ -50,11 +50,11 @@ services: profiles: - avx512 # Only build if explicitly requested (requires AVX-512 capable host) - # Parakeet AVX2 build (ONNX Runtime bundled, wide compatibility) - parakeet-avx2: + # ONNX AVX2 build (Parakeet + Moonshine, ONNX Runtime bundled, wide compatibility) + onnx-avx2: build: context: . - dockerfile: Dockerfile.parakeet + dockerfile: Dockerfile.onnx args: VERSION: ${VERSION:-0.5.0-beta.3} volumes: @@ -62,11 +62,11 @@ services: environment: - VERSION=${VERSION:-0.5.0-beta.3} - # Parakeet CUDA build (NVIDIA GPU acceleration with CPU fallback) - parakeet-cuda: + # ONNX CUDA build (NVIDIA GPU acceleration with CPU fallback) + onnx-cuda: build: context: . - dockerfile: Dockerfile.parakeet-cuda + dockerfile: Dockerfile.onnx-cuda args: VERSION: ${VERSION:-0.5.0-beta.3} volumes: diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 9f023fa8..4d1cb5ac 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -39,7 +39,7 @@ voxtype --engine parakeet daemon ``` **Notes:** -- Parakeet requires a Parakeet-enabled binary (`voxtype-*-parakeet-*`) +- Parakeet requires an ONNX-enabled binary (`voxtype-*-onnx-*`) - When using Parakeet, you must also configure the `[parakeet]` section - When using Moonshine, you must also configure the `[moonshine]` section - See [PARAKEET.md](PARAKEET.md) for detailed Parakeet setup instructions diff --git a/docs/PARAKEET.md b/docs/PARAKEET.md index 025aba89..c1813a11 100644 --- a/docs/PARAKEET.md +++ b/docs/PARAKEET.md @@ -17,19 +17,19 @@ Parakeet is NVIDIA's FastConformer-based speech recognition model. The TDT (Toke ## Requirements -- A Parakeet-enabled voxtype binary (see below) +- An ONNX-enabled voxtype binary (see below) - ~600MB disk space for the model - CPU with AVX2 or AVX-512 (AVX-512 recommended for best performance) ## Getting a Parakeet Binary -Parakeet support requires a specially compiled binary. Download from the releases page: +Parakeet support requires an ONNX-enabled binary. Download from the releases page: | Binary | Use Case | |--------|----------| -| `voxtype-*-parakeet-avx2` | Most CPUs (Intel Haswell+, AMD Zen+) | -| `voxtype-*-parakeet-avx512` | Modern CPUs with AVX-512 (Intel Ice Lake+, AMD Zen 4+) | -| `voxtype-*-parakeet-cuda` | NVIDIA GPU acceleration with CPU fallback | +| `voxtype-*-onnx-avx2` | Most CPUs (Intel Haswell+, AMD Zen+) | +| `voxtype-*-onnx-avx512` | Modern CPUs with AVX-512 (Intel Ice Lake+, AMD Zen 4+) | +| `voxtype-*-onnx-cuda` | NVIDIA GPU acceleration with CPU fallback | The AVX2 binary works on most modern x86_64 CPUs. Use AVX-512 if your CPU supports it for better performance. @@ -61,19 +61,19 @@ cd ~/.local/share/voxtype/models/parakeet-tdt-0.6b-v3 ## Switching to a Parakeet Binary -The standard voxtype binary does not include Parakeet support. You must switch to a Parakeet-enabled binary. +The standard voxtype binary does not include Parakeet support. You must switch to an ONNX-enabled binary. **Manual switching (until `voxtype setup engine` is implemented):** ```bash # Download the Parakeet binary for your CPU # Example: AVX-512 capable CPU -curl -L https://github.com/peteonrails/voxtype/releases/download/v0.5.0/voxtype-0.5.0-linux-x86_64-parakeet-avx512 \ - -o /tmp/voxtype-parakeet +curl -L https://github.com/peteonrails/voxtype/releases/download/v0.5.0/voxtype-0.5.0-linux-x86_64-onnx-avx512 \ + -o /tmp/voxtype-onnx # Make executable and install -chmod +x /tmp/voxtype-parakeet -sudo mv /tmp/voxtype-parakeet /usr/local/bin/voxtype +chmod +x /tmp/voxtype-onnx +sudo mv /tmp/voxtype-onnx /usr/local/bin/voxtype # Restart the daemon systemctl --user restart voxtype @@ -173,7 +173,7 @@ Or simply remove the `engine` line (Whisper is the default). ### "Parakeet engine requested but voxtype was not compiled with --features parakeet" -You're using a standard voxtype binary without Parakeet support. Download a `parakeet-*` binary from the releases page. +You're using a standard voxtype binary without Parakeet support. Download an `onnx-*` binary from the releases page. ### "Parakeet engine selected but [parakeet] config section is missing" diff --git a/docs/SMOKE_TESTS.md b/docs/SMOKE_TESTS.md index eb180573..b1da6788 100644 --- a/docs/SMOKE_TESTS.md +++ b/docs/SMOKE_TESTS.md @@ -652,6 +652,94 @@ voxtype transcribe /path/to/audio.wav voxtype transcribe --model large-v3-turbo /path/to/audio.wav ``` +## Multi-Engine Transcription + +Tests each available transcription engine with a WAV file. Use `tests/fixtures/vad/speech_long.wav` (English) or `tests/fixtures/sensevoice/zh.wav` (Chinese) as test audio. Each engine must be compiled in (check `voxtype --version` or build features). + +### Engine Quick Test + +```bash +# Test audio paths +EN_AUDIO="tests/fixtures/vad/speech_long.wav" +ZH_AUDIO="tests/fixtures/sensevoice/zh.wav" + +# Whisper (always available) +voxtype transcribe --engine whisper "$EN_AUDIO" + +# Parakeet (requires --features parakeet) +voxtype transcribe --engine parakeet "$EN_AUDIO" + +# Moonshine (requires --features moonshine) +voxtype transcribe --engine moonshine "$EN_AUDIO" + +# SenseVoice (requires --features sensevoice) +voxtype transcribe --engine sensevoice "$EN_AUDIO" +voxtype transcribe --engine sensevoice "$ZH_AUDIO" + +# Paraformer (requires --features paraformer, English and Chinese models) +voxtype transcribe --engine paraformer "$EN_AUDIO" + +# Dolphin (requires --features dolphin, Eastern languages only, no English) +voxtype transcribe --engine dolphin "$ZH_AUDIO" + +# Omnilingual (requires --features omnilingual, 1600+ languages) +voxtype transcribe --engine omnilingual "$EN_AUDIO" +``` + +### Engine Daemon Integration + +Test each engine running as the daemon's active engine: + +```bash +# For each engine, update config.toml engine = "" and restart: + +# SenseVoice +# 1. Set engine = "sensevoice" in config.toml +# 2. Restart daemon +systemctl --user restart voxtype +# 3. Verify model loads +journalctl --user -u voxtype --since "10 seconds ago" | grep -iE "sensevoice|loading" +# 4. Record and transcribe +voxtype record start && sleep 3 && voxtype record stop +# 5. Check logs for correct engine +journalctl --user -u voxtype --since "30 seconds ago" | grep -i "transcri" + +# Repeat for: paraformer, dolphin, omnilingual, moonshine, parakeet +# Then restore engine = "whisper" when done +``` + +### Engine Error Handling + +```bash +# Request an engine that isn't compiled in (should give clear error) +# e.g., if built without --features dolphin: +voxtype transcribe --engine dolphin tests/fixtures/vad/speech_long.wav +# Expected: error about Dolphin not being compiled in + +# Request unknown engine +voxtype transcribe --engine nonexistent tests/fixtures/vad/speech_long.wav +# Expected: error listing valid engine names + +# Engine with missing model +# (temporarily rename model dir to simulate missing model) +mv ~/.local/share/voxtype/models/sensevoice-small{,.bak} +voxtype transcribe --engine sensevoice tests/fixtures/vad/speech_long.wav +# Expected: clear error with "Run: voxtype setup model" +mv ~/.local/share/voxtype/models/sensevoice-small{.bak,} +``` + +### Engine Performance Comparison + +```bash +# Compare transcription speed across engines for the same audio file +AUDIO="tests/fixtures/vad/speech_long.wav" + +for engine in whisper parakeet moonshine sensevoice paraformer omnilingual; do + echo -n "$engine: " + /usr/bin/time -f "%e seconds" voxtype transcribe --engine $engine "$AUDIO" 2>&1 | tail -1 +done +``` + ## Multilingual Model Verification Tests that non-.en models load correctly and detect language: @@ -708,13 +796,13 @@ ls -la /usr/bin/voxtype # Should point to voxtype-vulkan sudo voxtype setup gpu --disable # Switch to best CPU (avx512 or avx2) ls -la /usr/bin/voxtype # Should point to voxtype-avx512 or voxtype-avx2 -# Parakeet mode (symlink points to voxtype-parakeet-*) -# --enable switches to CUDA, --disable switches to best Parakeet CPU -sudo ln -sf /usr/lib/voxtype/voxtype-parakeet-avx512 /usr/bin/voxtype -sudo voxtype setup gpu --enable # Switch to Parakeet CUDA -ls -la /usr/bin/voxtype # Should point to voxtype-parakeet-cuda -sudo voxtype setup gpu --disable # Switch to best Parakeet CPU -ls -la /usr/bin/voxtype # Should point to voxtype-parakeet-avx512 +# ONNX mode (symlink points to voxtype-onnx-*) +# --enable switches to CUDA, --disable switches to best ONNX CPU +sudo ln -sf /usr/lib/voxtype/voxtype-onnx-avx512 /usr/bin/voxtype +sudo voxtype setup gpu --enable # Switch to ONNX CUDA +ls -la /usr/bin/voxtype # Should point to voxtype-onnx-cuda +sudo voxtype setup gpu --disable # Switch to best ONNX CPU +ls -la /usr/bin/voxtype # Should point to voxtype-onnx-avx512 # Restore to Whisper Vulkan for normal use sudo ln -sf /usr/lib/voxtype/voxtype-vulkan /usr/bin/voxtype @@ -824,7 +912,7 @@ voxtype setup parakeet # Enable Parakeet (switches symlink to best parakeet binary) sudo voxtype setup parakeet --enable -ls -la /usr/bin/voxtype # Should point to voxtype-parakeet-cuda or voxtype-parakeet-avx* +ls -la /usr/bin/voxtype # Should point to voxtype-onnx-cuda or voxtype-onnx-avx* # Disable Parakeet (switches back to equivalent Whisper binary) sudo voxtype setup parakeet --disable @@ -1026,3 +1114,298 @@ echo "" echo "Check logs:" journalctl --user -u voxtype --since "30 seconds ago" --no-pager | tail -10 ``` + +## Meeting Mode + +Meeting mode provides continuous transcription with speaker attribution, export, and AI summarization. These tests cover the CLI commands and daemon integration. + +### Meeting Lifecycle + +```bash +# Start a meeting +voxtype meeting start --title "Test Meeting" +# Expected: "Meeting started: " in output + +# Check status +voxtype meeting status +# Expected: shows Active meeting with title, duration, chunk count + +# Pause the meeting +voxtype meeting pause +voxtype meeting status +# Expected: shows Paused status + +# Resume the meeting +voxtype meeting resume +voxtype meeting status +# Expected: shows Active status again + +# Stop the meeting +voxtype meeting stop +voxtype meeting status +# Expected: shows Completed status or "No active meeting" + +# Verify in logs +journalctl --user -u voxtype --since "2 minutes ago" | grep -i meeting +``` + +### Meeting List and Show + +```bash +# List meetings (should include the one just created) +voxtype meeting list +# Expected: table with ID, title, date, duration, status + +# Show details of the most recent meeting +voxtype meeting show latest +# Expected: full metadata and transcript + +# Show by UUID (copy from list output) +voxtype meeting show +``` + +### Meeting Export + +```bash +# Export as plain text +voxtype meeting export latest --format text +# Expected: plain text transcript output + +# Export as markdown +voxtype meeting export latest --format markdown +# Expected: markdown with headers and speaker labels + +# Export as JSON +voxtype meeting export latest --format json +# Expected: structured JSON with metadata and segments + +# Export to file +voxtype meeting export latest --format markdown --output /tmp/meeting-export.md +cat /tmp/meeting-export.md + +# Export with options +voxtype meeting export latest --format text --timestamps --speakers +``` + +### Meeting Delete + +```bash +# Delete a meeting (use UUID from list) +voxtype meeting delete +# Expected: "Meeting deleted" confirmation + +# Verify deletion +voxtype meeting list +# Expected: deleted meeting no longer appears +``` + +### Speaker Labels + +```bash +# Start a meeting and record some audio +voxtype meeting start --title "Label Test" +sleep 10 +voxtype meeting stop + +# Assign speaker labels +voxtype meeting label latest SPEAKER_00 "Alice" +voxtype meeting label latest SPEAKER_01 "Bob" + +# Verify labels appear in show output +voxtype meeting show latest +# Expected: speaker labels show as "Alice", "Bob" instead of SPEAKER_00/01 + +# Verify labels persist in export +voxtype meeting export latest --format text --speakers +``` + +### AI Summarization + +```bash +# Requires: Ollama running locally, or a remote summarization endpoint configured + +# Summarize the latest meeting +voxtype meeting summarize latest +# Expected: summary with key points, action items, and decisions + +# Check logs for summarization +journalctl --user -u voxtype --since "1 minute ago" | grep -i summar +``` + +### Meeting Without Title + +```bash +# Start without a title (should auto-generate one from the date) +voxtype meeting start +sleep 5 +voxtype meeting stop + +# Verify auto-generated title in list +voxtype meeting list +# Expected: title like "Meeting 2026-02-16 14:30" +``` + +### Rapid Start/Stop + +```bash +# Stress test: quick meeting cycles +for i in {1..3}; do + echo "Meeting cycle $i..." + voxtype meeting start --title "Quick $i" + sleep 2 + voxtype meeting stop +done + +# Verify all meetings were saved +voxtype meeting list +# Expected: 3 new meetings in the list + +# Verify daemon is healthy +voxtype status +``` + +### Meeting During Active Recording + +```bash +# Verify meeting mode and push-to-talk don't conflict +voxtype meeting start --title "Conflict Test" +sleep 2 + +# Try a push-to-talk recording while meeting is active +voxtype record start +sleep 2 +voxtype record stop +# Expected: either clear error or both work independently + +voxtype meeting stop +``` + +### Meeting Config Validation + +```bash +# Verify meeting config is shown +voxtype config | grep -A20 "\[meeting\]" +# Expected: meeting section with audio, storage, diarization settings + +# Test with custom chunk duration (edit config.toml): +# [meeting.audio] +# chunk_duration_secs = 15 + +# Restart and verify +systemctl --user restart voxtype +voxtype meeting start --title "Custom Chunk" +sleep 20 +voxtype meeting stop +journalctl --user -u voxtype --since "1 minute ago" | grep -i chunk +# Expected: chunks processed at 15-second intervals +``` + +### Storage Verification + +```bash +# Check where meetings are stored +ls ~/.local/share/voxtype/meetings/ +# Expected: directories named like "2026-02-16-test-meeting" + +# Verify SQLite index +ls ~/.local/share/voxtype/meetings/index.db +# Expected: file exists + +# Verify transcript files +ls ~/.local/share/voxtype/meetings/*/transcript.json +# Expected: JSON files for completed meetings + +# Verify metadata files +cat ~/.local/share/voxtype/meetings/*/metadata.json | head -20 +# Expected: valid JSON with meeting metadata +``` + +### Error Handling + +```bash +# Double-start (meeting already in progress) +voxtype meeting start --title "First" +voxtype meeting start --title "Second" +# Expected: error "Meeting already in progress" +voxtype meeting stop + +# Pause when no meeting active +voxtype meeting pause +# Expected: error "No active meeting to pause" + +# Resume when no meeting paused +voxtype meeting resume +# Expected: error "No paused meeting to resume" + +# Stop when no meeting active +voxtype meeting stop +# Expected: error "No meeting in progress" + +# Show nonexistent meeting +voxtype meeting show 00000000-0000-0000-0000-000000000000 +# Expected: error "Meeting not found" + +# Export with invalid format +voxtype meeting export latest --format invalid +# Expected: error about unsupported format + +# Export with invalid meeting ID +voxtype meeting export not-a-uuid --format text +# Expected: error about invalid meeting ID + +# Label nonexistent meeting +voxtype meeting label 00000000-0000-0000-0000-000000000000 SPEAKER_00 "Alice" +# Expected: error "Meeting not found" +``` + +### Dual Audio Sources + +```bash +# Verify loopback detection +# 1. Configure loopback in config.toml: +# [meeting.audio] +# loopback_device = "auto" + +# 2. Start a meeting while in a video call (Zoom, Teams, etc.) +voxtype meeting start --title "Video Call Test" + +# 3. Speak into mic and wait for remote participants to speak +sleep 30 +voxtype meeting stop + +# 4. Check speaker attribution +voxtype meeting show latest +# Expected: segments attributed to "You" (mic) and "Remote" (loopback) + +# 5. Verify export includes speaker labels +voxtype meeting export latest --format text --speakers +# Expected: "You:" and "Remote:" labels in output + +# Disable loopback (mic-only mode) +# [meeting.audio] +# loopback_device = "disabled" +systemctl --user restart voxtype +voxtype meeting start --title "Mic Only Test" +sleep 10 +voxtype meeting stop +voxtype meeting show latest +# Expected: all segments attributed to "You" or "Unknown" +``` + +### Diarization Backend Selection + +```bash +# Simple diarization (default, source-based) +voxtype config | grep -A5 "diarization" +# Expected: backend = "simple" + +# ML diarization (requires ml-diarization feature) +# 1. Configure in config.toml: +# [meeting.diarization] +# backend = "ml" +# max_speakers = 4 +# 2. Restart and verify +systemctl --user restart voxtype +journalctl --user -u voxtype --since "10 seconds ago" | grep -i diariz +# Expected: "Using ML diarization" or "falling back to simple" if model missing +``` diff --git a/docs/USER_MANUAL.md b/docs/USER_MANUAL.md index 55ea6c4e..730d4550 100644 --- a/docs/USER_MANUAL.md +++ b/docs/USER_MANUAL.md @@ -593,7 +593,7 @@ Parakeet is NVIDIA's FastConformer-based ASR model. It offers: - No GPU required (though CUDA acceleration available) **Requirements:** -- A Parakeet-enabled binary (`voxtype-*-parakeet-*`) +- An ONNX-enabled binary (`voxtype-*-onnx-*`) - The Parakeet model downloaded (~600MB) - English-only use case diff --git a/scripts/package.sh b/scripts/package.sh index cc82fca8..1ea737c5 100755 --- a/scripts/package.sh +++ b/scripts/package.sh @@ -386,22 +386,22 @@ if [[ "$TARGET_ARCH" == "x86_64" ]]; then chmod 755 "$STAGING/usr/lib/voxtype/voxtype-avx512" chmod 755 "$STAGING/usr/lib/voxtype/voxtype-vulkan" - # x86_64: Parakeet binaries (ONNX-based alternative engine) - if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-avx2" ]]; then - cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-avx2" "$STAGING/usr/lib/voxtype/voxtype-parakeet-avx2" - chmod 755 "$STAGING/usr/lib/voxtype/voxtype-parakeet-avx2" + # x86_64: ONNX binaries (Parakeet + Moonshine engines via ONNX Runtime) + if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-avx2" ]]; then + cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-avx2" "$STAGING/usr/lib/voxtype/voxtype-onnx-avx2" + chmod 755 "$STAGING/usr/lib/voxtype/voxtype-onnx-avx2" fi - if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512" ]]; then - cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-avx512" "$STAGING/usr/lib/voxtype/voxtype-parakeet-avx512" - chmod 755 "$STAGING/usr/lib/voxtype/voxtype-parakeet-avx512" + if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-avx512" ]]; then + cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-avx512" "$STAGING/usr/lib/voxtype/voxtype-onnx-avx512" + chmod 755 "$STAGING/usr/lib/voxtype/voxtype-onnx-avx512" fi - if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-cuda" ]]; then - cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-cuda" "$STAGING/usr/lib/voxtype/voxtype-parakeet-cuda" - chmod 755 "$STAGING/usr/lib/voxtype/voxtype-parakeet-cuda" + if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-cuda" ]]; then + cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-cuda" "$STAGING/usr/lib/voxtype/voxtype-onnx-cuda" + chmod 755 "$STAGING/usr/lib/voxtype/voxtype-onnx-cuda" fi - if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-rocm" ]]; then - cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-parakeet-rocm" "$STAGING/usr/lib/voxtype/voxtype-parakeet-rocm" - chmod 755 "$STAGING/usr/lib/voxtype/voxtype-parakeet-rocm" + if [[ -f "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-rocm" ]]; then + cp "${RELEASE_DIR}/voxtype-${VERSION}-linux-x86_64-onnx-rocm" "$STAGING/usr/lib/voxtype/voxtype-onnx-rocm" + chmod 755 "$STAGING/usr/lib/voxtype/voxtype-onnx-rocm" fi # Install wrapper script as /usr/bin/voxtype diff --git a/scripts/validate-release.sh b/scripts/validate-release.sh index 3d3defb6..388e98ab 100755 --- a/scripts/validate-release.sh +++ b/scripts/validate-release.sh @@ -33,10 +33,10 @@ WHISPER_BINARIES=( "voxtype-${VERSION}-linux-x86_64-vulkan" ) -PARAKEET_BINARIES=( - "voxtype-${VERSION}-linux-x86_64-parakeet-avx2" - "voxtype-${VERSION}-linux-x86_64-parakeet-avx512" - "voxtype-${VERSION}-linux-x86_64-parakeet-cuda" +ONNX_BINARIES=( + "voxtype-${VERSION}-linux-x86_64-onnx-avx2" + "voxtype-${VERSION}-linux-x86_64-onnx-avx512" + "voxtype-${VERSION}-linux-x86_64-onnx-cuda" ) # Binaries that must NOT have AVX-512 instructions @@ -48,14 +48,14 @@ MUST_BE_CLEAN=( # Binaries that MUST have AVX-512 instructions MUST_HAVE_AVX512=( "voxtype-${VERSION}-linux-x86_64-avx512" - "voxtype-${VERSION}-linux-x86_64-parakeet-avx512" + "voxtype-${VERSION}-linux-x86_64-onnx-avx512" ) FAILED=false # 1. Check all binaries exist echo "Checking binary existence..." -ALL_BINARIES=("${WHISPER_BINARIES[@]}" "${PARAKEET_BINARIES[@]}") +ALL_BINARIES=("${WHISPER_BINARIES[@]}" "${ONNX_BINARIES[@]}") FOUND_BINARIES=() for binary in "${ALL_BINARIES[@]}"; do diff --git a/src/audio/cpal_capture.rs b/src/audio/cpal_capture.rs index 4a96eae4..e6e838c3 100644 --- a/src/audio/cpal_capture.rs +++ b/src/audio/cpal_capture.rs @@ -16,6 +16,8 @@ use tokio::sync::{mpsc, oneshot}; /// Commands sent to the audio capture thread enum CaptureCommand { Stop(oneshot::Sender>), + /// Get current samples and clear the buffer (for continuous recording) + GetSamples(oneshot::Sender>), } /// Parameters for building an audio input stream @@ -238,19 +240,37 @@ impl AudioCapture for CpalCapture { tracing::debug!("Audio capture thread started"); - // Wait for stop command - if let Ok(CaptureCommand::Stop(response_tx)) = cmd_rx.recv() { - // Stop the stream (drop it) - drop(stream); - - // Get collected samples - let collected = { - let guard = samples_clone.lock().unwrap(); - guard.clone() - }; - - // Send samples back - let _ = response_tx.send(collected); + // Handle commands in a loop + loop { + match cmd_rx.recv() { + Ok(CaptureCommand::Stop(response_tx)) => { + // Stop the stream (drop it) + drop(stream); + + // Get collected samples + let collected = { + let guard = samples_clone.lock().unwrap(); + guard.clone() + }; + + // Send samples back + let _ = response_tx.send(collected); + break; + } + Ok(CaptureCommand::GetSamples(response_tx)) => { + // Get and clear current samples (for continuous recording) + let samples = { + let mut guard = samples_clone.lock().unwrap(); + std::mem::take(&mut *guard) + }; + let _ = response_tx.send(samples); + } + Err(_) => { + // Channel closed, exit thread + tracing::debug!("Command channel closed"); + break; + } + } } tracing::debug!("Audio capture thread stopped"); @@ -301,6 +321,28 @@ impl AudioCapture for CpalCapture { Ok(samples) } + + async fn get_samples(&mut self) -> Vec { + // Get current samples without stopping + if let Some(ref cmd_tx) = self.cmd_tx { + let (response_tx, response_rx) = oneshot::channel(); + + if cmd_tx.send(CaptureCommand::GetSamples(response_tx)).is_ok() { + // Wait for response (with short timeout) + match tokio::time::timeout(std::time::Duration::from_millis(500), response_rx).await + { + Ok(Ok(samples)) => return samples, + Ok(Err(_)) => { + tracing::warn!("get_samples: channel closed"); + } + Err(_) => { + tracing::warn!("get_samples: timeout"); + } + } + } + } + Vec::new() + } } /// Build an input stream for a specific sample type diff --git a/src/audio/dual_capture.rs b/src/audio/dual_capture.rs new file mode 100644 index 00000000..7ade8460 --- /dev/null +++ b/src/audio/dual_capture.rs @@ -0,0 +1,294 @@ +//! Dual audio capture for meeting mode +//! +//! Captures both microphone input (user's voice) and system audio loopback +//! (remote participants) simultaneously for speaker attribution. + +use super::cpal_capture::CpalCapture; +use super::AudioCapture; +use crate::config::AudioConfig; +use crate::error::AudioError; + +/// Audio source identifier +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AudioSourceType { + /// Microphone input (local user) + Microphone, + /// System audio loopback (remote participants) + Loopback, +} + +/// A sample with its source identified +#[derive(Debug, Clone)] +pub struct SourcedSample { + /// The audio sample value + pub sample: f32, + /// Which source this sample came from + pub source: AudioSourceType, + /// Timestamp in samples (at 16kHz) + pub timestamp: u64, +} + +/// Dual audio capture for mic + loopback +pub struct DualCapture { + /// Microphone capture + mic_capture: CpalCapture, + /// Loopback capture (system audio) + loopback_capture: Option, + /// Whether loopback is enabled + loopback_enabled: bool, + /// Sample counter for timestamps + sample_counter: u64, +} + +impl DualCapture { + /// Create a new dual capture instance + pub fn new( + mic_config: &AudioConfig, + loopback_device: Option<&str>, + ) -> Result { + let mic_capture = CpalCapture::new(mic_config)?; + + // Try to create loopback capture if device is specified + let (loopback_capture, loopback_enabled) = if let Some(device) = loopback_device { + if device == "auto" { + // Try to find a monitor/loopback device + match Self::find_loopback_device() { + Some(device_name) => { + let mut loopback_config = mic_config.clone(); + loopback_config.device = device_name; + match CpalCapture::new(&loopback_config) { + Ok(capture) => { + tracing::info!("Loopback capture enabled"); + (Some(capture), true) + } + Err(e) => { + tracing::warn!("Failed to create loopback capture: {}", e); + (None, false) + } + } + } + None => { + tracing::warn!("No loopback device found, using mic only"); + (None, false) + } + } + } else if device == "disabled" || device.is_empty() { + (None, false) + } else { + // Use specified device + let mut loopback_config = mic_config.clone(); + loopback_config.device = device.to_string(); + match CpalCapture::new(&loopback_config) { + Ok(capture) => { + tracing::info!("Loopback capture enabled: {}", device); + (Some(capture), true) + } + Err(e) => { + tracing::warn!("Failed to create loopback capture for '{}': {}", device, e); + (None, false) + } + } + } + } else { + (None, false) + }; + + Ok(Self { + mic_capture, + loopback_capture, + loopback_enabled, + sample_counter: 0, + }) + } + + /// Try to find a loopback/monitor device automatically + fn find_loopback_device() -> Option { + use cpal::traits::{DeviceTrait, HostTrait}; + + let host = cpal::default_host(); + let devices = host.input_devices().ok()?; + + for device in devices { + if let Ok(name) = device.name() { + let name_lower = name.to_lowercase(); + // Common loopback device name patterns + if name_lower.contains("monitor") + || name_lower.contains("loopback") + || name_lower.contains("stereo mix") + || name_lower.contains("what u hear") + { + tracing::debug!("Found loopback device: {}", name); + return Some(name); + } + } + } + + None + } + + /// Check if loopback capture is active + pub fn has_loopback(&self) -> bool { + self.loopback_enabled && self.loopback_capture.is_some() + } + + /// Start both captures + pub async fn start(&mut self) -> Result<(), AudioError> { + // Start mic capture + let _mic_rx = self.mic_capture.start().await?; + + // Start loopback capture if available + if let Some(ref mut loopback) = self.loopback_capture { + let _loopback_rx = loopback.start().await?; + } + + Ok(()) + } + + /// Stop both captures and return all samples + pub async fn stop(&mut self) -> Result { + let mic_samples = self.mic_capture.stop().await?; + + let loopback_samples = if let Some(ref mut loopback) = self.loopback_capture { + loopback.stop().await.unwrap_or_default() + } else { + Vec::new() + }; + + Ok(DualSamples { + mic: mic_samples, + loopback: loopback_samples, + }) + } + + /// Get current samples without stopping (for continuous recording) + pub async fn get_samples(&mut self) -> DualSamples { + let mic = self.mic_capture.get_samples().await; + + let loopback = if let Some(ref mut loopback) = self.loopback_capture { + loopback.get_samples().await + } else { + Vec::new() + }; + + DualSamples { mic, loopback } + } + + /// Get sourced samples with timestamps for diarization + pub async fn get_sourced_samples(&mut self) -> Vec { + let dual = self.get_samples().await; + let mut result = Vec::with_capacity(dual.mic.len() + dual.loopback.len()); + + // Add mic samples + for sample in dual.mic { + result.push(SourcedSample { + sample, + source: AudioSourceType::Microphone, + timestamp: self.sample_counter, + }); + self.sample_counter += 1; + } + + // Add loopback samples (interleaved timing approximation) + for sample in dual.loopback { + result.push(SourcedSample { + sample, + source: AudioSourceType::Loopback, + timestamp: self.sample_counter, + }); + self.sample_counter += 1; + } + + result + } +} + +/// Samples from both sources +#[derive(Debug, Clone, Default)] +pub struct DualSamples { + /// Microphone samples + pub mic: Vec, + /// Loopback samples + pub loopback: Vec, +} + +impl DualSamples { + /// Check if there are any samples + pub fn is_empty(&self) -> bool { + self.mic.is_empty() && self.loopback.is_empty() + } + + /// Total number of samples + pub fn len(&self) -> usize { + self.mic.len() + self.loopback.len() + } + + /// Merge samples into a single stream (for transcription) + /// Prioritizes mic when both have audio, mixes otherwise + pub fn merge(&self) -> Vec { + if self.loopback.is_empty() { + return self.mic.clone(); + } + if self.mic.is_empty() { + return self.loopback.clone(); + } + + // Mix both streams + let max_len = self.mic.len().max(self.loopback.len()); + let mut merged = Vec::with_capacity(max_len); + + for i in 0..max_len { + let mic_sample = self.mic.get(i).copied().unwrap_or(0.0); + let loopback_sample = self.loopback.get(i).copied().unwrap_or(0.0); + // Simple mix with slight preference to mic + merged.push(mic_sample * 0.6 + loopback_sample * 0.4); + } + + merged + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dual_samples_merge_mic_only() { + let samples = DualSamples { + mic: vec![0.5, 0.6, 0.7], + loopback: vec![], + }; + assert_eq!(samples.merge(), vec![0.5, 0.6, 0.7]); + } + + #[test] + fn test_dual_samples_merge_loopback_only() { + let samples = DualSamples { + mic: vec![], + loopback: vec![0.3, 0.4], + }; + assert_eq!(samples.merge(), vec![0.3, 0.4]); + } + + #[test] + fn test_dual_samples_merge_both() { + let samples = DualSamples { + mic: vec![1.0, 1.0], + loopback: vec![1.0, 1.0], + }; + let merged = samples.merge(); + // 1.0 * 0.6 + 1.0 * 0.4 = 1.0 + assert!((merged[0] - 1.0).abs() < 0.001); + } + + #[test] + fn test_dual_samples_is_empty() { + let empty = DualSamples::default(); + assert!(empty.is_empty()); + + let with_mic = DualSamples { + mic: vec![0.1], + loopback: vec![], + }; + assert!(!with_mic.is_empty()); + } +} diff --git a/src/audio/mod.rs b/src/audio/mod.rs index 373d388c..988cd9ad 100644 --- a/src/audio/mod.rs +++ b/src/audio/mod.rs @@ -4,8 +4,11 @@ //! PipeWire, PulseAudio, and ALSA backends. pub mod cpal_capture; +pub mod dual_capture; pub mod feedback; +pub use dual_capture::{AudioSourceType, DualCapture, DualSamples, SourcedSample}; + use crate::config::AudioConfig; use crate::error::AudioError; use tokio::sync::mpsc; @@ -19,6 +22,11 @@ pub trait AudioCapture: Send + Sync { /// Stop capturing and return all recorded samples async fn stop(&mut self) -> Result, AudioError>; + + /// Get current samples without stopping (for continuous recording modes) + /// This drains the internal buffer and returns samples collected since the last call. + /// Returns an empty Vec if not yet started or already stopped. + async fn get_samples(&mut self) -> Vec; } /// Factory function to create audio capture diff --git a/src/cli.rs b/src/cli.rs index c6f5e06c..c7db99a2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -26,7 +26,7 @@ EXAMPLES: voxtype setup model Interactive model selection (Whisper, Parakeet, or Moonshine) voxtype setup waybar Show Waybar integration config voxtype setup gpu Manage GPU acceleration (Vulkan/CUDA/ROCm) - voxtype setup parakeet Switch between Whisper and Parakeet engines + voxtype setup onnx Switch between Whisper and ONNX engines voxtype status --follow --format json Waybar integration See 'voxtype --help' for more info on a command. @@ -68,7 +68,7 @@ pub struct Cli { #[arg(long, value_name = "PROMPT")] pub initial_prompt: Option, - /// Override transcription engine: "whisper" (default), "parakeet", or "moonshine" (EXPERIMENTAL) + /// Override transcription engine: whisper, parakeet, moonshine, sensevoice, paraformer, dolphin, omnilingual #[arg(long, value_name = "ENGINE")] pub engine: Option, @@ -129,7 +129,7 @@ pub enum Commands { /// Path to audio file file: std::path::PathBuf, - /// Override transcription engine: "whisper", "parakeet", or "moonshine" + /// Override transcription engine: whisper, parakeet, moonshine, sensevoice, paraformer, dolphin, omnilingual #[arg(long, value_name = "ENGINE")] engine: Option, }, @@ -206,6 +206,15 @@ pub enum Commands { #[command(subcommand)] action: RecordAction, }, + + /// Meeting transcription mode (Pro feature) + /// + /// Continuous meeting transcription with chunked processing, + /// speaker attribution, and export capabilities. + Meeting { + #[command(subcommand)] + action: MeetingAction, + }, } /// Output mode override for record commands @@ -293,6 +302,100 @@ pub enum RecordAction { Cancel, } +/// Meeting mode actions +#[derive(Subcommand)] +pub enum MeetingAction { + /// Start a new meeting transcription + Start { + /// Meeting title (optional) + #[arg(long, short)] + title: Option, + }, + /// Stop the current meeting + Stop, + /// Pause the current meeting + Pause, + /// Resume a paused meeting + Resume, + /// Show meeting status + Status, + /// List past meetings + List { + /// Maximum number of meetings to show + #[arg(long, short, default_value = "10")] + limit: u32, + }, + /// Export a meeting transcript + Export { + /// Meeting ID (or "latest" for most recent) + meeting_id: String, + + /// Output format: text, markdown, json + #[arg(long, short, default_value = "markdown")] + format: String, + + /// Output file path (default: stdout) + #[arg(long, short)] + output: Option, + + /// Include timestamps in output + #[arg(long)] + timestamps: bool, + + /// Include speaker labels in output + #[arg(long)] + speakers: bool, + + /// Include metadata header in output + #[arg(long)] + metadata: bool, + }, + /// Show meeting details + Show { + /// Meeting ID (or "latest" for most recent) + meeting_id: String, + }, + /// Delete a meeting + Delete { + /// Meeting ID + meeting_id: String, + + /// Skip confirmation prompt + #[arg(long, short)] + force: bool, + }, + /// Label a speaker in a meeting transcript + /// + /// Assigns a human-readable name to an auto-generated speaker ID. + /// Use with ML diarization to replace "SPEAKER_00" with "Alice". + Label { + /// Meeting ID (or "latest" for most recent) + meeting_id: String, + + /// Speaker ID to label (e.g., "SPEAKER_00" or just "0") + speaker_id: String, + + /// Human-readable label to assign + label: String, + }, + /// Generate an AI summary of a meeting + /// + /// Uses Ollama or a remote API to generate a summary with + /// key points, action items, and decisions. + Summarize { + /// Meeting ID (or "latest" for most recent) + meeting_id: String, + + /// Output format: text, json, or markdown + #[arg(long, short, default_value = "markdown")] + format: String, + + /// Output file path (default: stdout) + #[arg(long, short)] + output: Option, + }, +} + impl RecordAction { /// Extract the output mode override from the action flags /// Returns (mode_override, optional_file_path) @@ -445,17 +548,30 @@ pub enum SetupAction { status: bool, }, - /// Switch between Whisper and Parakeet transcription engines + /// Switch between Whisper and ONNX transcription engines + Onnx { + /// Enable ONNX engine (switch to ONNX binary) + #[arg(long)] + enable: bool, + + /// Disable ONNX engine (switch back to Whisper binary) + #[arg(long)] + disable: bool, + + /// Show current ONNX backend status + #[arg(long)] + status: bool, + }, + + /// Hidden alias for 'onnx' (backwards compatibility) + #[command(hide = true)] Parakeet { - /// Enable Parakeet engine (switch to Parakeet binary) #[arg(long)] enable: bool, - /// Disable Parakeet engine (switch back to Whisper binary) #[arg(long)] disable: bool, - /// Show current Parakeet backend status #[arg(long)] status: bool, }, diff --git a/src/config.rs b/src/config.rs index 7e976d97..c896bf28 100644 --- a/src/config.rs +++ b/src/config.rs @@ -295,6 +295,22 @@ pub struct Config { #[serde(default)] pub moonshine: Option, + /// SenseVoice configuration (optional, only used when engine = "sensevoice") + #[serde(default)] + pub sensevoice: Option, + + /// Paraformer configuration (optional, only used when engine = "paraformer") + #[serde(default)] + pub paraformer: Option, + + /// Dolphin configuration (optional, only used when engine = "dolphin") + #[serde(default)] + pub dolphin: Option, + + /// Omnilingual configuration (optional, only used when engine = "omnilingual") + #[serde(default)] + pub omnilingual: Option, + /// Text processing configuration (replacements, spoken punctuation) #[serde(default)] pub text: TextConfig, @@ -308,6 +324,10 @@ pub struct Config { #[serde(default)] pub status: StatusConfig, + /// Meeting transcription configuration + #[serde(default)] + pub meeting: MeetingConfig, + /// Optional path to state file for external integrations (e.g., Waybar) /// When set, the daemon writes current state ("idle", "recording", "transcribing") /// to this file whenever state changes. @@ -953,6 +973,127 @@ impl Default for MoonshineConfig { } } +/// SenseVoice speech-to-text configuration (ONNX-based, CTC encoder-only ASR) +/// Requires: cargo build --features sensevoice +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct SenseVoiceConfig { + /// Model name or path to directory containing ONNX model files + /// Expects: model.int8.onnx (or model.onnx), tokens.txt + /// Short name: "sensevoice-small" (default) + pub model: String, + + /// Language for transcription: "auto", "zh", "en", "ja", "ko", "yue" (default: "auto") + #[serde(default = "default_sensevoice_language")] + pub language: String, + + /// Enable inverse text normalization (adds punctuation) (default: true) + #[serde(default = "default_true")] + pub use_itn: bool, + + /// Number of CPU threads for ONNX Runtime inference + #[serde(default)] + pub threads: Option, + + /// Load model on-demand when recording starts (true) or keep loaded (false) + #[serde(default = "default_on_demand_loading")] + pub on_demand_loading: bool, +} + +fn default_sensevoice_language() -> String { + "auto".to_string() +} + +impl Default for SenseVoiceConfig { + fn default() -> Self { + Self { + model: "sensevoice-small".to_string(), + language: "auto".to_string(), + use_itn: true, + threads: None, + on_demand_loading: false, + } + } +} + +/// Paraformer speech-to-text configuration (FunASR ONNX-based CTC encoder) +/// Requires: cargo build --features paraformer +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ParaformerConfig { + /// Model name or path to ONNX model directory + /// Expects: model.onnx (or model.int8.onnx), tokens.txt + pub model: String, + + /// Number of CPU threads for ONNX Runtime inference + #[serde(default)] + pub threads: Option, + + /// Load model on-demand when recording starts (true) or keep loaded (false) + #[serde(default = "default_on_demand_loading")] + pub on_demand_loading: bool, +} + +impl Default for ParaformerConfig { + fn default() -> Self { + Self { + model: "paraformer-zh".to_string(), + threads: None, + on_demand_loading: false, + } + } +} + +/// Dolphin speech-to-text configuration (ONNX-based CTC encoder, dictation-optimized) +/// Requires: cargo build --features dolphin +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct DolphinConfig { + /// Model name or path to ONNX model directory + pub model: String, + + /// Number of CPU threads for ONNX Runtime inference + #[serde(default)] + pub threads: Option, + + /// Load model on-demand when recording starts (true) or keep loaded (false) + #[serde(default = "default_on_demand_loading")] + pub on_demand_loading: bool, +} + +impl Default for DolphinConfig { + fn default() -> Self { + Self { + model: "dolphin-base".to_string(), + threads: None, + on_demand_loading: false, + } + } +} + +/// Omnilingual speech-to-text configuration (FunASR ONNX-based, 50+ languages) +/// Requires: cargo build --features omnilingual +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OmnilingualConfig { + /// Model name or path to ONNX model directory + pub model: String, + + /// Number of CPU threads for ONNX Runtime inference + #[serde(default)] + pub threads: Option, + + /// Load model on-demand when recording starts (true) or keep loaded (false) + #[serde(default = "default_on_demand_loading")] + pub on_demand_loading: bool, +} + +impl Default for OmnilingualConfig { + fn default() -> Self { + Self { + model: "omnilingual-large".to_string(), + threads: None, + on_demand_loading: false, + } + } +} + /// Transcription engine selection (which ASR technology to use) #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)] #[serde(rename_all = "lowercase")] @@ -966,6 +1107,18 @@ pub enum TranscriptionEngine { /// Use Moonshine (encoder-decoder ASR via ONNX Runtime) /// Requires: cargo build --features moonshine Moonshine, + /// Use SenseVoice (Alibaba FunAudioLLM CTC model via ONNX Runtime) + /// Requires: cargo build --features sensevoice + SenseVoice, + /// Use Paraformer (FunASR CTC encoder via ONNX Runtime) + /// Requires: cargo build --features paraformer + Paraformer, + /// Use Dolphin (dictation-optimized CTC encoder via ONNX Runtime) + /// Requires: cargo build --features dolphin + Dolphin, + /// Use Omnilingual (FunASR 50+ language CTC encoder via ONNX Runtime) + /// Requires: cargo build --features omnilingual + Omnilingual, } /// VAD backend selection @@ -1054,6 +1207,191 @@ pub struct TextConfig { pub replacements: HashMap, } +/// Meeting transcription configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MeetingConfig { + /// Enable meeting mode + #[serde(default)] + pub enabled: bool, + + /// Duration of each audio chunk in seconds + #[serde(default = "default_chunk_duration")] + pub chunk_duration_secs: u32, + + /// Storage path for meetings ("auto" for default location) + /// Default: ~/.local/share/voxtype/meetings/ + #[serde(default = "default_storage_path")] + pub storage_path: String, + + /// Retain raw audio files after transcription + #[serde(default)] + pub retain_audio: bool, + + /// Maximum meeting duration in minutes (0 = unlimited) + #[serde(default = "default_max_duration")] + pub max_duration_mins: u32, + + /// Meeting audio configuration + #[serde(default)] + pub audio: MeetingAudioConfig, + + /// Diarization configuration + #[serde(default)] + pub diarization: MeetingDiarizationConfig, + + /// Summarization configuration + #[serde(default)] + pub summary: MeetingSummaryConfig, +} + +/// Meeting audio configuration for dual capture +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MeetingAudioConfig { + /// Microphone device (uses main audio.device if not specified) + #[serde(default = "default_mic_device")] + pub mic_device: String, + + /// Loopback device for capturing remote participants + /// Options: "auto" (detect), "disabled", or specific device name + #[serde(default = "default_loopback")] + pub loopback_device: String, +} + +fn default_mic_device() -> String { + "default".to_string() +} + +fn default_loopback() -> String { + "auto".to_string() +} + +impl Default for MeetingAudioConfig { + fn default() -> Self { + Self { + mic_device: default_mic_device(), + loopback_device: default_loopback(), + } + } +} + +/// Meeting diarization configuration +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MeetingDiarizationConfig { + /// Enable speaker diarization + #[serde(default = "default_true")] + pub enabled: bool, + + /// Diarization backend: "simple", "ml", or "remote" + #[serde(default = "default_diarization_backend")] + pub backend: String, + + /// Maximum number of speakers to detect + #[serde(default = "default_max_speakers")] + pub max_speakers: u32, +} + +fn default_diarization_backend() -> String { + "simple".to_string() +} + +fn default_max_speakers() -> u32 { + 10 +} + +fn default_chunk_duration() -> u32 { + 30 +} + +fn default_storage_path() -> String { + "auto".to_string() +} + +fn default_max_duration() -> u32 { + 180 +} + +impl Default for MeetingDiarizationConfig { + fn default() -> Self { + Self { + enabled: true, + backend: default_diarization_backend(), + max_speakers: default_max_speakers(), + } + } +} + +/// Meeting summary configuration (Phase 5) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MeetingSummaryConfig { + /// Summarization backend: "local", "remote", or "disabled" + #[serde(default = "default_summary_backend")] + pub backend: String, + + /// Ollama URL for local backend + #[serde(default = "default_ollama_url")] + pub ollama_url: String, + + /// Ollama model name + #[serde(default = "default_ollama_model")] + pub ollama_model: String, + + /// Remote API endpoint for remote backend + #[serde(default)] + pub remote_endpoint: Option, + + /// Remote API key + #[serde(default)] + pub remote_api_key: Option, + + /// Request timeout in seconds + #[serde(default = "default_summary_timeout")] + pub timeout_secs: u64, +} + +fn default_summary_backend() -> String { + "disabled".to_string() +} + +fn default_ollama_url() -> String { + "http://localhost:11434".to_string() +} + +fn default_ollama_model() -> String { + "llama3.2".to_string() +} + +fn default_summary_timeout() -> u64 { + 120 +} + +impl Default for MeetingSummaryConfig { + fn default() -> Self { + Self { + backend: default_summary_backend(), + ollama_url: default_ollama_url(), + ollama_model: default_ollama_model(), + remote_endpoint: None, + remote_api_key: None, + timeout_secs: default_summary_timeout(), + } + } +} + +impl Default for MeetingConfig { + fn default() -> Self { + Self { + enabled: false, + chunk_duration_secs: default_chunk_duration(), + storage_path: default_storage_path(), + retain_audio: false, + max_duration_mins: default_max_duration(), + audio: MeetingAudioConfig::default(), + diarization: MeetingDiarizationConfig::default(), + summary: MeetingSummaryConfig::default(), + } + } +} + /// Notification configuration #[derive(Debug, Clone, Deserialize, Serialize)] pub struct NotificationConfig { @@ -1399,9 +1737,14 @@ impl Default for Config { engine: TranscriptionEngine::default(), parakeet: None, moonshine: None, + sensevoice: None, + paraformer: None, + dolphin: None, + omnilingual: None, text: TextConfig::default(), vad: VadConfig::default(), status: StatusConfig::default(), + meeting: MeetingConfig::default(), state_file: Some("auto".to_string()), profiles: HashMap::new(), } @@ -1486,6 +1829,26 @@ impl Config { .as_ref() .map(|m| m.on_demand_loading) .unwrap_or(false), + TranscriptionEngine::SenseVoice => self + .sensevoice + .as_ref() + .map(|s| s.on_demand_loading) + .unwrap_or(false), + TranscriptionEngine::Paraformer => self + .paraformer + .as_ref() + .map(|p| p.on_demand_loading) + .unwrap_or(false), + TranscriptionEngine::Dolphin => self + .dolphin + .as_ref() + .map(|d| d.on_demand_loading) + .unwrap_or(false), + TranscriptionEngine::Omnilingual => self + .omnilingual + .as_ref() + .map(|o| o.on_demand_loading) + .unwrap_or(false), } } @@ -1503,6 +1866,26 @@ impl Config { .as_ref() .map(|m| m.model.as_str()) .unwrap_or("moonshine (not configured)"), + TranscriptionEngine::SenseVoice => self + .sensevoice + .as_ref() + .map(|s| s.model.as_str()) + .unwrap_or("sensevoice (not configured)"), + TranscriptionEngine::Paraformer => self + .paraformer + .as_ref() + .map(|p| p.model.as_str()) + .unwrap_or("paraformer (not configured)"), + TranscriptionEngine::Dolphin => self + .dolphin + .as_ref() + .map(|d| d.model.as_str()) + .unwrap_or("dolphin (not configured)"), + TranscriptionEngine::Omnilingual => self + .omnilingual + .as_ref() + .map(|o| o.model.as_str()) + .unwrap_or("omnilingual (not configured)"), } } @@ -2846,4 +3229,161 @@ mod tests { assert_eq!(driver_order.len(), 1); assert_eq!(driver_order[0], OutputDriver::Ydotool); } + + // ========================================================================= + // Meeting Config Tests + // ========================================================================= + + #[test] + fn test_meeting_config_default() { + let config = MeetingConfig::default(); + assert!(!config.enabled); + assert_eq!(config.chunk_duration_secs, 30); + assert_eq!(config.storage_path, "auto"); + assert!(!config.retain_audio); + assert_eq!(config.max_duration_mins, 180); + } + + #[test] + fn test_meeting_audio_config_default() { + let config = MeetingAudioConfig::default(); + assert_eq!(config.mic_device, "default"); + assert_eq!(config.loopback_device, "auto"); + } + + #[test] + fn test_meeting_diarization_config_default() { + let config = MeetingDiarizationConfig::default(); + assert!(config.enabled); + assert_eq!(config.backend, "simple"); + assert_eq!(config.max_speakers, 10); + } + + #[test] + fn test_meeting_summary_config_default() { + let config = MeetingSummaryConfig::default(); + assert_eq!(config.backend, "disabled"); + assert_eq!(config.ollama_url, "http://localhost:11434"); + assert_eq!(config.ollama_model, "llama3.2"); + assert!(config.remote_endpoint.is_none()); + assert!(config.remote_api_key.is_none()); + assert_eq!(config.timeout_secs, 120); + } + + #[test] + fn test_meeting_config_in_default_config() { + let config = Config::default(); + assert!(!config.meeting.enabled); + assert_eq!(config.meeting.chunk_duration_secs, 30); + assert_eq!(config.meeting.max_duration_mins, 180); + } + + #[test] + fn test_parse_meeting_config_from_toml() { + let toml_str = r#" + [hotkey] + key = "SCROLLLOCK" + + [audio] + device = "default" + sample_rate = 16000 + max_duration_secs = 60 + + [whisper] + model = "base.en" + language = "en" + + [output] + mode = "type" + + [meeting] + enabled = true + chunk_duration_secs = 45 + storage_path = "/tmp/meetings" + retain_audio = true + max_duration_mins = 60 + "#; + + let config: Config = toml::from_str(toml_str).unwrap(); + assert!(config.meeting.enabled); + assert_eq!(config.meeting.chunk_duration_secs, 45); + assert_eq!(config.meeting.storage_path, "/tmp/meetings"); + assert!(config.meeting.retain_audio); + assert_eq!(config.meeting.max_duration_mins, 60); + } + + #[test] + fn test_parse_meeting_config_with_nested_sections() { + let toml_str = r#" + [hotkey] + key = "SCROLLLOCK" + + [audio] + device = "default" + sample_rate = 16000 + max_duration_secs = 60 + + [whisper] + model = "base.en" + language = "en" + + [output] + mode = "type" + + [meeting] + enabled = true + + [meeting.audio] + mic_device = "hw:1" + loopback_device = "disabled" + + [meeting.diarization] + enabled = false + backend = "ml" + max_speakers = 5 + + [meeting.summary] + backend = "local" + ollama_model = "mistral" + timeout_secs = 60 + "#; + + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.meeting.audio.mic_device, "hw:1"); + assert_eq!(config.meeting.audio.loopback_device, "disabled"); + assert!(!config.meeting.diarization.enabled); + assert_eq!(config.meeting.diarization.backend, "ml"); + assert_eq!(config.meeting.diarization.max_speakers, 5); + assert_eq!(config.meeting.summary.backend, "local"); + assert_eq!(config.meeting.summary.ollama_model, "mistral"); + assert_eq!(config.meeting.summary.timeout_secs, 60); + } + + #[test] + fn test_meeting_config_backward_compatible_omitted() { + // Config without [meeting] section should parse fine with defaults + let toml_str = r#" + [hotkey] + key = "SCROLLLOCK" + + [audio] + device = "default" + sample_rate = 16000 + max_duration_secs = 60 + + [whisper] + model = "base.en" + language = "en" + + [output] + mode = "type" + "#; + + let config: Config = toml::from_str(toml_str).unwrap(); + assert!(!config.meeting.enabled); + assert_eq!(config.meeting.chunk_duration_secs, 30); + assert_eq!(config.meeting.storage_path, "auto"); + assert_eq!(config.meeting.diarization.backend, "simple"); + assert_eq!(config.meeting.summary.backend, "disabled"); + } } diff --git a/src/cpu.rs b/src/cpu.rs index c061af29..a67ecd3f 100644 --- a/src/cpu.rs +++ b/src/cpu.rs @@ -36,7 +36,7 @@ pub fn install_sigill_handler() { } unsafe { - libc::signal(libc::SIGILL, sigill_handler as libc::sighandler_t); + libc::signal(libc::SIGILL, sigill_handler as *const () as libc::sighandler_t); } } diff --git a/src/daemon.rs b/src/daemon.rs index 02c49776..1079365a 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -9,6 +9,7 @@ use crate::config::{ActivationMode, Config, FileMode, OutputMode}; use crate::eager::{self, EagerConfig}; use crate::error::Result; use crate::hotkey::{self, HotkeyEvent}; +use crate::meeting::{self, MeetingDaemon, MeetingEvent, StorageConfig}; use crate::model_manager::ModelManager; use crate::output; use crate::output::post_process::PostProcessor; @@ -230,6 +231,91 @@ fn cleanup_profile_override() { let _ = std::fs::remove_file(&profile_file); } +// === Meeting Mode IPC === + +/// Check for meeting start command (via file trigger) +fn check_meeting_start() -> Option> { + let start_file = Config::runtime_dir().join("meeting_start"); + if start_file.exists() { + // Read optional title from file content + let title = std::fs::read_to_string(&start_file).ok().and_then(|s| { + let trimmed = s.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + }); + // Remove the file to acknowledge the command + let _ = std::fs::remove_file(&start_file); + Some(title) + } else { + None + } +} + +/// Check for meeting stop command (via file trigger) +fn check_meeting_stop() -> bool { + let stop_file = Config::runtime_dir().join("meeting_stop"); + if stop_file.exists() { + let _ = std::fs::remove_file(&stop_file); + true + } else { + false + } +} + +/// Check for meeting pause command (via file trigger) +fn check_meeting_pause() -> bool { + let pause_file = Config::runtime_dir().join("meeting_pause"); + if pause_file.exists() { + let _ = std::fs::remove_file(&pause_file); + true + } else { + false + } +} + +/// Check for meeting resume command (via file trigger) +fn check_meeting_resume() -> bool { + let resume_file = Config::runtime_dir().join("meeting_resume"); + if resume_file.exists() { + let _ = std::fs::remove_file(&resume_file); + true + } else { + false + } +} + +/// Clean up any stale meeting command files on startup +fn cleanup_meeting_files() { + let runtime_dir = Config::runtime_dir(); + for name in &[ + "meeting_start", + "meeting_stop", + "meeting_pause", + "meeting_resume", + ] { + let file = runtime_dir.join(name); + if file.exists() { + let _ = std::fs::remove_file(&file); + } + } +} + +/// Write meeting state file for external integrations +fn write_meeting_state_file(path: &PathBuf, state: &str, meeting_id: Option<&str>) { + let content = if let Some(id) = meeting_id { + format!("{}\n{}", state, id) + } else { + state.to_string() + }; + + if let Err(e) = std::fs::write(path, content) { + tracing::warn!("Failed to write meeting state file: {}", e); + } +} + /// Write transcription to a file, respecting file_mode (overwrite or append) async fn write_transcription_to_file( path: &std::path::Path, @@ -334,6 +420,16 @@ pub struct Daemon { )>, // Voice Activity Detection (filters silence-only recordings) vad: Option>, + // Meeting mode daemon (optional, created when meeting starts) + meeting_daemon: Option, + // Meeting state file path + meeting_state_file_path: Option, + // Audio capture for meeting mode (continuous recording) + meeting_audio_capture: Option>, + // Chunk buffer for meeting mode + meeting_chunk_buffer: Vec, + // Meeting event receiver + meeting_event_rx: Option>, } impl Daemon { @@ -401,6 +497,13 @@ impl Daemon { } }; + // Meeting state file path (separate from push-to-talk state) + let meeting_state_file_path = if state_file_path.is_some() { + Some(Config::runtime_dir().join("meeting_state")) + } else { + None + }; + Self { config, config_path, @@ -414,6 +517,11 @@ impl Daemon { transcription_task: None, eager_chunk_tasks: Vec::new(), vad, + meeting_daemon: None, + meeting_state_file_path, + meeting_audio_capture: None, + meeting_chunk_buffer: Vec::new(), + meeting_event_rx: None, } } @@ -470,7 +578,11 @@ impl Daemon { // Use preloaded transcriber based on engine type match self.config.engine { crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { if let Some(ref t) = transcriber_preloaded { Ok(t.clone()) } else { @@ -499,6 +611,191 @@ impl Daemon { } } + /// Update the meeting state file if configured + fn update_meeting_state(&self, state_name: &str, meeting_id: Option<&str>) { + if let Some(ref path) = self.meeting_state_file_path { + write_meeting_state_file(path, state_name, meeting_id); + } + } + + /// Start a new meeting + async fn start_meeting(&mut self, title: Option) -> Result<()> { + if self.meeting_daemon.is_some() { + tracing::warn!("Meeting already in progress"); + return Ok(()); + } + + // Create meeting config from main config + let meeting_config = meeting::MeetingConfig { + enabled: self.config.meeting.enabled, + chunk_duration_secs: self.config.meeting.chunk_duration_secs, + storage: StorageConfig { + storage_path: if self.config.meeting.storage_path == "auto" { + Config::data_dir().join("meetings") + } else { + PathBuf::from(&self.config.meeting.storage_path) + }, + retain_audio: self.config.meeting.retain_audio, + max_meetings: 0, + }, + retain_audio: self.config.meeting.retain_audio, + max_duration_mins: self.config.meeting.max_duration_mins, + }; + + // Create event channel + let (tx, rx) = tokio::sync::mpsc::channel(32); + self.meeting_event_rx = Some(rx); + + // Create meeting daemon + match MeetingDaemon::new(meeting_config, &self.config, tx) { + Ok(mut daemon) => { + match daemon.start(title).await { + Ok(meeting_id) => { + let id_str = meeting_id.to_string(); + self.update_meeting_state("recording", Some(&id_str)); + tracing::info!("Meeting started: {}", meeting_id); + + // Start audio capture for meeting + match audio::create_capture(&self.config.audio) { + Ok(mut capture) => { + if let Err(e) = capture.start().await { + tracing::error!("Failed to start meeting audio: {}", e); + let _ = daemon.stop().await; + return Err(crate::error::VoxtypeError::Audio(e)); + } + self.meeting_audio_capture = Some(capture); + } + Err(e) => { + tracing::error!("Failed to create meeting audio capture: {}", e); + let _ = daemon.stop().await; + return Err(crate::error::VoxtypeError::Audio(e)); + } + } + + self.meeting_daemon = Some(daemon); + self.meeting_chunk_buffer.clear(); + + // Play feedback + self.play_feedback(SoundEvent::RecordingStart); + + // Notification + if self.config.output.notification.on_recording_start { + send_notification( + "Meeting Started", + &format!("ID: {}", meeting_id), + false, + self.config.engine, + ) + .await; + } + } + Err(e) => { + tracing::error!("Failed to start meeting: {}", e); + return Err(e); + } + } + } + Err(e) => { + tracing::error!("Failed to create meeting daemon: {}", e); + return Err(e); + } + } + + Ok(()) + } + + /// Stop the current meeting + async fn stop_meeting(&mut self) -> Result<()> { + if let Some(mut daemon) = self.meeting_daemon.take() { + // Stop audio capture + if let Some(mut capture) = self.meeting_audio_capture.take() { + let _ = capture.stop().await; + } + + match daemon.stop().await { + Ok(meeting_id) => { + self.update_meeting_state("idle", None); + tracing::info!("Meeting stopped: {}", meeting_id); + + self.play_feedback(SoundEvent::RecordingStop); + + if self.config.output.notification.on_recording_stop { + send_notification( + "Meeting Ended", + &format!("ID: {}", meeting_id), + false, + self.config.engine, + ) + .await; + } + } + Err(e) => { + tracing::error!("Error stopping meeting: {}", e); + } + } + + self.meeting_chunk_buffer.clear(); + self.meeting_event_rx = None; + } + + Ok(()) + } + + /// Pause the current meeting + async fn pause_meeting(&mut self) -> Result<()> { + if let Some(ref mut daemon) = self.meeting_daemon { + daemon.pause().await?; + let meeting_id = daemon.current_meeting_id().map(|id| id.to_string()); + self.update_meeting_state("paused", meeting_id.as_deref()); + tracing::info!("Meeting paused"); + + if self.config.output.notification.on_recording_stop { + send_notification( + "Meeting Paused", + "Recording paused", + false, + self.config.engine, + ) + .await; + } + } + Ok(()) + } + + /// Resume the current meeting + async fn resume_meeting(&mut self) -> Result<()> { + if let Some(ref mut daemon) = self.meeting_daemon { + daemon.resume().await?; + let meeting_id = daemon.current_meeting_id().map(|id| id.to_string()); + self.update_meeting_state("recording", meeting_id.as_deref()); + tracing::info!("Meeting resumed"); + + if self.config.output.notification.on_recording_start { + send_notification( + "Meeting Resumed", + "Recording resumed", + false, + self.config.engine, + ) + .await; + } + } + Ok(()) + } + + /// Check if a meeting is in progress + fn meeting_active(&self) -> bool { + self.meeting_daemon + .as_ref() + .is_some_and(|d| d.state().is_active()) + } + + /// Get the chunk duration for meeting mode + fn meeting_chunk_samples(&self) -> usize { + // 16kHz sample rate * chunk duration in seconds + 16000 * self.config.meeting.chunk_duration_secs as usize + } + /// Reset state to idle and run post_output_command to reset compositor submap /// Call this when exiting from recording/transcribing without normal output flow async fn reset_to_idle(&self, state: &mut State) { @@ -791,23 +1088,23 @@ impl Daemon { if let Some(t) = transcriber { self.transcription_task = Some(tokio::task::spawn_blocking(move || t.transcribe(&samples))); - return true; + true } else { tracing::error!("No transcriber available"); self.play_feedback(SoundEvent::Error); self.reset_to_idle(state).await; - return false; + false } } Err(e) => { tracing::warn!("Recording error: {}", e); self.reset_to_idle(state).await; - return false; + false } } } else { self.reset_to_idle(state).await; - return false; + false } } @@ -1012,6 +1309,9 @@ impl Daemon { // Clean up any stale cancel file from previous runs cleanup_cancel_file(); + // Clean up any stale meeting command files + cleanup_meeting_files(); + // Write PID file for external control via signals self.pid_file_path = write_pid_file(); @@ -1047,8 +1347,7 @@ impl Daemon { return Err(crate::error::VoxtypeError::Config(format!( "Another voxtype instance is already running (lock error: {:?})", e - )) - .into()); + ))); } } @@ -1102,7 +1401,11 @@ impl Daemon { } } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { // Parakeet/Moonshine uses its own model loading transcriber_preloaded = Some(Arc::from(crate::transcribe::create_transcriber( &self.config, @@ -1193,7 +1496,11 @@ impl Daemon { })); } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { let config = self.config.clone(); self.model_load_task = Some(tokio::task::spawn_blocking(move || { crate::transcribe::create_transcriber(&config).map(Arc::from) @@ -1212,7 +1519,11 @@ impl Daemon { } } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { if let Some(ref t) = transcriber_preloaded { let transcriber = t.clone(); tokio::task::spawn_blocking(move || { @@ -1368,7 +1679,11 @@ impl Daemon { })); } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { let config = self.config.clone(); self.model_load_task = Some(tokio::task::spawn_blocking(move || { crate::transcribe::create_transcriber(&config).map(Arc::from) @@ -1387,7 +1702,11 @@ impl Daemon { } } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { if let Some(ref t) = transcriber_preloaded { let transcriber = t.clone(); tokio::task::spawn_blocking(move || { @@ -1698,7 +2017,11 @@ impl Daemon { })); } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { let config = self.config.clone(); self.model_load_task = Some(tokio::task::spawn_blocking(move || { crate::transcribe::create_transcriber(&config).map(Arc::from) @@ -1716,7 +2039,11 @@ impl Daemon { } } crate::config::TranscriptionEngine::Parakeet - | crate::config::TranscriptionEngine::Moonshine => { + | crate::config::TranscriptionEngine::Moonshine + | crate::config::TranscriptionEngine::SenseVoice + | crate::config::TranscriptionEngine::Paraformer + | crate::config::TranscriptionEngine::Dolphin + | crate::config::TranscriptionEngine::Omnilingual => { if let Some(ref t) = transcriber_preloaded { let transcriber = t.clone(); tokio::task::spawn_blocking(move || { @@ -1890,13 +2217,143 @@ impl Daemon { // The check interval is 500ms, so we use a counter to approximate 60s static EVICTION_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); let count = EVICTION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if count % 120 == 0 { // 500ms * 120 = 60s + if count.is_multiple_of(120) { // 500ms * 120 = 60s if let Some(ref mut mm) = self.model_manager { mm.evict_idle_models(); } } } + // === MEETING MODE HANDLERS === + + // Poll for meeting commands (file-based IPC) + _ = tokio::time::sleep(Duration::from_millis(100)) => { + // Check for meeting start command + if let Some(title) = check_meeting_start() { + if self.config.meeting.enabled && self.meeting_daemon.is_none() { + tracing::debug!("Meeting start requested via file trigger"); + if let Err(e) = self.start_meeting(title).await { + tracing::error!("Failed to start meeting: {}", e); + } + } else if !self.config.meeting.enabled { + tracing::warn!("Meeting mode is disabled in config"); + } else { + tracing::warn!("Meeting already in progress"); + } + } + + // Check for meeting stop command + if check_meeting_stop() + && self.meeting_daemon.is_some() { + tracing::debug!("Meeting stop requested via file trigger"); + if let Err(e) = self.stop_meeting().await { + tracing::error!("Failed to stop meeting: {}", e); + } + } + + // Check for meeting pause command + if check_meeting_pause() + && self.meeting_active() { + tracing::debug!("Meeting pause requested via file trigger"); + if let Err(e) = self.pause_meeting().await { + tracing::error!("Failed to pause meeting: {}", e); + } + } + + // Check for meeting resume command + if check_meeting_resume() + && self.meeting_daemon.as_ref().is_some_and(|d| d.state().is_paused()) { + tracing::debug!("Meeting resume requested via file trigger"); + if let Err(e) = self.resume_meeting().await { + tracing::error!("Failed to resume meeting: {}", e); + } + } + } + + // Process meeting audio chunks + _ = tokio::time::sleep(Duration::from_millis(50)), if self.meeting_active() => { + // Get samples from the audio capture + if let Some(ref mut capture) = self.meeting_audio_capture { + // Get current samples without stopping + let samples = capture.get_samples().await; + self.meeting_chunk_buffer.extend(samples); + + // Check if we have enough samples for a chunk + let chunk_samples = self.meeting_chunk_samples(); + if self.meeting_chunk_buffer.len() >= chunk_samples { + // Extract chunk and process + let chunk: Vec = self.meeting_chunk_buffer.drain(..chunk_samples).collect(); + + if let Some(ref mut daemon) = self.meeting_daemon { + match daemon.process_chunk(chunk).await { + Ok(Some(segments)) => { + tracing::debug!("Processed meeting chunk with {} segments", segments.len()); + } + Ok(None) => { + // No segments (possibly VAD filtered) + } + Err(e) => { + tracing::error!("Error processing meeting chunk: {}", e); + } + } + } + } + } + + // Check meeting timeout + if self.config.meeting.max_duration_mins > 0 { + if let Some(ref daemon) = self.meeting_daemon { + if let Some(duration) = daemon.state().elapsed() { + let max_duration = Duration::from_secs( + self.config.meeting.max_duration_mins as u64 * 60 + ); + if duration > max_duration { + tracing::warn!("Meeting timeout ({} min limit), stopping", + self.config.meeting.max_duration_mins); + if let Err(e) = self.stop_meeting().await { + tracing::error!("Failed to stop meeting after timeout: {}", e); + } + } + } + } + } + } + + // Handle meeting events + event = async { + match self.meeting_event_rx.as_mut() { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } + }, if self.meeting_event_rx.is_some() => { + match event { + Some(MeetingEvent::Started { meeting_id }) => { + tracing::info!("Meeting event: started {}", meeting_id); + } + Some(MeetingEvent::ChunkProcessed { chunk_id, segments }) => { + tracing::debug!("Meeting event: chunk {} processed with {} segments", + chunk_id, segments.len()); + } + Some(MeetingEvent::Paused) => { + tracing::info!("Meeting event: paused"); + } + Some(MeetingEvent::Resumed) => { + tracing::info!("Meeting event: resumed"); + } + Some(MeetingEvent::Stopped { meeting_id }) => { + tracing::info!("Meeting event: stopped {}", meeting_id); + } + Some(MeetingEvent::Error(msg)) => { + tracing::error!("Meeting error: {}", msg); + } + None => { + // Channel closed + tracing::debug!("Meeting event channel closed"); + self.meeting_event_rx = None; + } + } + } + // Handle graceful shutdown (SIGINT from Ctrl+C) _ = tokio::signal::ctrl_c() => { tracing::info!("Received SIGINT, shutting down..."); @@ -1926,11 +2383,22 @@ impl Daemon { task.abort(); } + // Stop any active meeting + if self.meeting_daemon.is_some() { + tracing::info!("Stopping active meeting on shutdown"); + let _ = self.stop_meeting().await; + } + // Remove state file on shutdown if let Some(ref path) = self.state_file_path { cleanup_state_file(path); } + // Remove meeting state file on shutdown + if let Some(ref path) = self.meeting_state_file_path { + cleanup_state_file(path); + } + // Remove PID file on shutdown if let Some(ref path) = self.pid_file_path { cleanup_pid_file(path); diff --git a/src/error.rs b/src/error.rs index 139d5559..e928c579 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,6 +23,9 @@ pub enum VoxtypeError { #[error("Output error: {0}")] Output(#[from] OutputError), + #[error("Meeting error: {0}")] + Meeting(#[from] MeetingError), + #[error("IO error: {0}")] Io(#[from] std::io::Error), } @@ -142,6 +145,28 @@ pub enum OutputError { AllMethodsFailed, } +/// Errors related to meeting transcription +#[derive(Error, Debug)] +pub enum MeetingError { + #[error("Meeting already in progress")] + AlreadyInProgress, + + #[error("No meeting in progress")] + NotInProgress, + + #[error("No active meeting to pause")] + NotActive, + + #[error("No paused meeting to resume")] + NotPaused, + + #[error("Transcriber not initialized")] + TranscriberNotInitialized, + + #[error("Meeting storage error: {0}")] + Storage(String), +} + /// Result type alias using VoxtypeError pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 35853c2f..301291a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ pub mod daemon; pub mod eager; pub mod error; pub mod hotkey; +pub mod meeting; pub mod model_manager; pub mod output; pub mod setup; @@ -84,7 +85,9 @@ pub mod text; pub mod transcribe; pub mod vad; -pub use cli::{Cli, Commands, CompositorType, OutputModeOverride, RecordAction, SetupAction}; +pub use cli::{ + Cli, Commands, CompositorType, MeetingAction, OutputModeOverride, RecordAction, SetupAction, +}; pub use config::Config; pub use daemon::Daemon; pub use error::{Result, VoxtypeError}; diff --git a/src/main.rs b/src/main.rs index 7d13b953..539911f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,8 @@ use std::path::PathBuf; use std::process::Command; use tracing_subscriber::EnvFilter; use voxtype::{ - config, cpu, daemon, setup, transcribe, vad, Cli, Commands, RecordAction, SetupAction, + config, cpu, daemon, meeting, setup, transcribe, vad, Cli, Commands, MeetingAction, + RecordAction, SetupAction, }; /// Parse a comma-separated list of driver names into OutputDriver vec @@ -125,9 +126,13 @@ async fn main() -> anyhow::Result<()> { "whisper" => config.engine = config::TranscriptionEngine::Whisper, "parakeet" => config.engine = config::TranscriptionEngine::Parakeet, "moonshine" => config.engine = config::TranscriptionEngine::Moonshine, + "sensevoice" => config.engine = config::TranscriptionEngine::SenseVoice, + "paraformer" => config.engine = config::TranscriptionEngine::Paraformer, + "dolphin" => config.engine = config::TranscriptionEngine::Dolphin, + "omnilingual" => config.engine = config::TranscriptionEngine::Omnilingual, _ => { eprintln!( - "Error: Invalid engine '{}'. Valid options: whisper, parakeet, moonshine", + "Error: Invalid engine '{}'. Valid options: whisper, parakeet, moonshine, sensevoice, paraformer, dolphin, omnilingual", engine ); std::process::exit(1); @@ -201,8 +206,12 @@ async fn main() -> anyhow::Result<()> { "whisper" => config.engine = config::TranscriptionEngine::Whisper, "parakeet" => config.engine = config::TranscriptionEngine::Parakeet, "moonshine" => config.engine = config::TranscriptionEngine::Moonshine, + "sensevoice" => config.engine = config::TranscriptionEngine::SenseVoice, + "paraformer" => config.engine = config::TranscriptionEngine::Paraformer, + "dolphin" => config.engine = config::TranscriptionEngine::Dolphin, + "omnilingual" => config.engine = config::TranscriptionEngine::Omnilingual, _ => { - eprintln!("Error: Invalid engine '{}'. Valid options: whisper, parakeet, moonshine", engine_name); + eprintln!("Error: Invalid engine '{}'. Valid options: whisper, parakeet, moonshine, sensevoice, paraformer, dolphin, omnilingual", engine_name); std::process::exit(1); } } @@ -319,12 +328,17 @@ async fn main() -> anyhow::Result<()> { setup::gpu::show_status(); } } - Some(SetupAction::Parakeet { + Some(SetupAction::Onnx { + enable, + disable, + status, + }) + | Some(SetupAction::Parakeet { enable, disable, status, }) => { - warn_if_root("parakeet"); + warn_if_root("onnx"); if status { setup::parakeet::show_status(); } else if enable { @@ -373,13 +387,54 @@ async fn main() -> anyhow::Result<()> { Commands::Record { action } => { send_record_command(&config, action, top_level_model.as_deref())?; } + + Commands::Meeting { action } => { + run_meeting_command(&config, action).await?; + } + } + + Ok(()) +} + +/// Check if the daemon is running, exit with error if not +fn check_daemon_running() -> anyhow::Result<()> { + use nix::sys::signal::kill; + use nix::unistd::Pid; + + let pid_file = config::Config::runtime_dir().join("pid"); + + if !pid_file.exists() { + eprintln!("Error: Voxtype daemon is not running."); + eprintln!("Start it with: voxtype daemon"); + std::process::exit(1); + } + + let pid_str = std::fs::read_to_string(&pid_file) + .map_err(|e| anyhow::anyhow!("Failed to read PID file: {}", e))?; + + let pid: i32 = pid_str + .trim() + .parse() + .map_err(|e| anyhow::anyhow!("Invalid PID in file: {}", e))?; + + // Check if the process is actually running + if kill(Pid::from_raw(pid), None).is_err() { + // Process doesn't exist, clean up stale PID file + let _ = std::fs::remove_file(&pid_file); + eprintln!("Error: Voxtype daemon is not running (stale PID file removed)."); + eprintln!("Start it with: voxtype daemon"); + std::process::exit(1); } Ok(()) } /// Send a record command to the running daemon via Unix signals or file triggers -fn send_record_command(config: &config::Config, action: RecordAction, top_level_model: Option<&str>) -> anyhow::Result<()> { +fn send_record_command( + config: &config::Config, + action: RecordAction, + top_level_model: Option<&str>, +) -> anyhow::Result<()> { use nix::sys::signal::{kill, Signal}; use nix::unistd::Pid; use voxtype::OutputModeOverride; @@ -969,6 +1024,47 @@ async fn show_config(config: &config::Config) -> anyhow::Result<()> { println!(" available models: {}", moonshine_models.join(", ")); } + // Show SenseVoice status (experimental) + println!("\n[sensevoice] (EXPERIMENTAL)"); + if let Some(ref sensevoice_config) = config.sensevoice { + println!(" model = {:?}", sensevoice_config.model); + println!(" language = {:?}", sensevoice_config.language); + println!(" use_itn = {}", sensevoice_config.use_itn); + if let Some(threads) = sensevoice_config.threads { + println!(" threads = {}", threads); + } + println!( + " on_demand_loading = {}", + sensevoice_config.on_demand_loading + ); + } else { + println!(" (not configured)"); + } + + // Check for available SenseVoice models + let mut sensevoice_models: Vec = Vec::new(); + if let Ok(entries) = std::fs::read_dir(&models_dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + let name = entry.file_name().to_string_lossy().to_string(); + if name.contains("sensevoice") { + let has_model = path.join("model.int8.onnx").exists() + || path.join("model.onnx").exists(); + let has_tokens = path.join("tokens.txt").exists(); + if has_model && has_tokens { + sensevoice_models.push(name); + } + } + } + } + } + if sensevoice_models.is_empty() { + println!(" available models: (none found)"); + } else { + println!(" available models: {}", sensevoice_models.join(", ")); + } + println!("\n[output]"); println!(" mode = {:?}", config.output.mode); println!( @@ -1048,3 +1144,439 @@ fn reset_sigpipe() { fn reset_sigpipe() { // No-op on non-Unix platforms } + +/// Run a meeting command +async fn run_meeting_command(config: &config::Config, action: MeetingAction) -> anyhow::Result<()> { + use meeting::{ExportFormat, ExportOptions, MeetingConfig, StorageConfig}; + + // Convert config to meeting config + let storage_path = if config.meeting.storage_path == "auto" { + StorageConfig::default_storage_path() + } else { + PathBuf::from(&config.meeting.storage_path) + }; + + let meeting_config = MeetingConfig { + enabled: config.meeting.enabled, + chunk_duration_secs: config.meeting.chunk_duration_secs, + storage: StorageConfig { + storage_path, + retain_audio: config.meeting.retain_audio, + max_meetings: 0, + }, + retain_audio: config.meeting.retain_audio, + max_duration_mins: config.meeting.max_duration_mins, + }; + + match action { + MeetingAction::Start { title } => { + // Check if meeting mode is enabled + if !config.meeting.enabled { + eprintln!("Error: Meeting mode is disabled in config."); + eprintln!(); + eprintln!("Enable it by adding to config.toml:"); + eprintln!(" [meeting]"); + eprintln!(" enabled = true"); + std::process::exit(1); + } + + // Check if daemon is running + check_daemon_running()?; + + // Check if meeting already in progress + let meeting_state_file = config::Config::runtime_dir().join("meeting_state"); + if meeting_state_file.exists() { + let state = std::fs::read_to_string(&meeting_state_file).unwrap_or_default(); + if state.starts_with("recording") || state.starts_with("paused") { + eprintln!("Error: A meeting is already in progress."); + eprintln!("Use 'voxtype meeting stop' to end it first."); + std::process::exit(1); + } + } + + // Write start trigger file (with optional title) + let start_file = config::Config::runtime_dir().join("meeting_start"); + let content = title.unwrap_or_default(); + std::fs::write(&start_file, content)?; + + println!("Meeting start requested. Check status with 'voxtype meeting status'."); + } + + MeetingAction::Stop => { + check_daemon_running()?; + + // Check if meeting is in progress + let meeting_state_file = config::Config::runtime_dir().join("meeting_state"); + if !meeting_state_file.exists() { + eprintln!("Error: No meeting in progress."); + std::process::exit(1); + } + + let state = std::fs::read_to_string(&meeting_state_file).unwrap_or_default(); + if state.starts_with("idle") || state.is_empty() { + eprintln!("Error: No meeting in progress."); + std::process::exit(1); + } + + // Write stop trigger file + let stop_file = config::Config::runtime_dir().join("meeting_stop"); + std::fs::write(&stop_file, "")?; + + println!("Meeting stop requested."); + } + + MeetingAction::Pause => { + check_daemon_running()?; + + // Check if meeting is active (not paused) + let meeting_state_file = config::Config::runtime_dir().join("meeting_state"); + if !meeting_state_file.exists() { + eprintln!("Error: No meeting in progress."); + std::process::exit(1); + } + + let state = std::fs::read_to_string(&meeting_state_file).unwrap_or_default(); + if !state.starts_with("recording") { + eprintln!("Error: No active meeting to pause."); + std::process::exit(1); + } + + // Write pause trigger file + let pause_file = config::Config::runtime_dir().join("meeting_pause"); + std::fs::write(&pause_file, "")?; + + println!("Meeting pause requested."); + } + + MeetingAction::Resume => { + check_daemon_running()?; + + // Check if meeting is paused + let meeting_state_file = config::Config::runtime_dir().join("meeting_state"); + if !meeting_state_file.exists() { + eprintln!("Error: No paused meeting to resume."); + std::process::exit(1); + } + + let state = std::fs::read_to_string(&meeting_state_file).unwrap_or_default(); + if !state.starts_with("paused") { + eprintln!("Error: No paused meeting to resume."); + std::process::exit(1); + } + + // Write resume trigger file + let resume_file = config::Config::runtime_dir().join("meeting_resume"); + std::fs::write(&resume_file, "")?; + + println!("Meeting resume requested."); + } + + MeetingAction::Status => { + // Read meeting state file + let meeting_state_file = config::Config::runtime_dir().join("meeting_state"); + if !meeting_state_file.exists() { + println!("No meeting currently in progress."); + println!(); + println!("Use 'voxtype meeting list' to see past meetings."); + return Ok(()); + } + + let state = std::fs::read_to_string(&meeting_state_file).unwrap_or_default(); + let lines: Vec<&str> = state.lines().collect(); + + if lines.is_empty() || lines[0] == "idle" { + println!("No meeting currently in progress."); + println!(); + println!("Use 'voxtype meeting list' to see past meetings."); + } else { + let status = lines[0]; + let meeting_id = lines.get(1).unwrap_or(&""); + + println!("Meeting Status: {}", status); + if !meeting_id.is_empty() { + println!("Meeting ID: {}", meeting_id); + } + } + } + + MeetingAction::List { limit } => { + match meeting::list_meetings(&meeting_config, Some(limit)) { + Ok(meetings) => { + if meetings.is_empty() { + println!("No meetings found."); + return Ok(()); + } + + println!("Recent Meetings"); + println!("===============\n"); + + for m in meetings { + let duration = m + .duration_secs + .map(|d| { + let mins = d / 60; + let secs = d % 60; + format!("{}m {}s", mins, secs) + }) + .unwrap_or_else(|| "in progress".to_string()); + + println!("{}", m.display_title()); + println!(" ID: {}", m.id); + println!(" Date: {}", m.started_at.format("%Y-%m-%d %H:%M")); + println!(" Duration: {}", duration); + println!(" Status: {:?}", m.status); + println!(); + } + } + Err(e) => { + eprintln!("Error listing meetings: {}", e); + std::process::exit(1); + } + } + } + + MeetingAction::Export { + meeting_id, + format, + output, + timestamps, + speakers, + metadata, + } => { + let export_format = ExportFormat::parse(&format).ok_or_else(|| { + anyhow::anyhow!( + "Unknown export format '{}'. Valid formats: text, markdown, json", + format + ) + })?; + + let options = ExportOptions { + include_timestamps: timestamps, + include_speakers: speakers, + include_metadata: metadata, + line_width: 0, + }; + + match meeting::export_meeting_by_id( + &meeting_config, + &meeting_id, + export_format, + &options, + ) { + Ok(content) => { + if let Some(path) = output { + std::fs::write(&path, &content)?; + println!("Exported to {:?}", path); + } else { + println!("{}", content); + } + } + Err(e) => { + eprintln!("Error exporting meeting: {}", e); + std::process::exit(1); + } + } + } + + MeetingAction::Show { meeting_id } => { + match meeting::get_meeting(&meeting_config, &meeting_id) { + Ok(meeting) => { + println!("{}", meeting.metadata.display_title()); + println!("{}", "=".repeat(meeting.metadata.display_title().len())); + println!(); + println!("ID: {}", meeting.metadata.id); + println!( + "Started: {}", + meeting.metadata.started_at.format("%Y-%m-%d %H:%M UTC") + ); + if let Some(ended) = meeting.metadata.ended_at { + println!("Ended: {}", ended.format("%Y-%m-%d %H:%M UTC")); + } + if let Some(duration) = meeting.metadata.duration_secs { + let hours = duration / 3600; + let mins = (duration % 3600) / 60; + let secs = duration % 60; + if hours > 0 { + println!("Duration: {}h {}m {}s", hours, mins, secs); + } else { + println!("Duration: {}m {}s", mins, secs); + } + } + println!("Status: {:?}", meeting.metadata.status); + println!("Chunks: {}", meeting.metadata.chunk_count); + println!(); + println!("Transcript:"); + println!("-----------"); + println!("Segments: {}", meeting.transcript.segments.len()); + println!("Words: {}", meeting.transcript.word_count()); + println!("Speakers: {}", meeting.transcript.speakers().join(", ")); + println!(); + println!( + "Use 'voxtype meeting export {}' to export the transcript.", + meeting_id + ); + } + Err(e) => { + eprintln!("Error loading meeting: {}", e); + std::process::exit(1); + } + } + } + + MeetingAction::Delete { meeting_id, force } => { + if !force { + eprintln!("This will permanently delete the meeting and all associated files."); + eprintln!("Use --force to confirm deletion."); + std::process::exit(1); + } + + let storage = meeting::MeetingStorage::open(meeting_config.storage.clone()) + .map_err(|e| anyhow::anyhow!("Failed to open storage: {}", e))?; + + let id = storage + .resolve_meeting_id(&meeting_id) + .map_err(|e| anyhow::anyhow!("Meeting not found: {}", e))?; + + storage + .delete_meeting(&id) + .map_err(|e| anyhow::anyhow!("Failed to delete meeting: {}", e))?; + + println!("Meeting {} deleted.", meeting_id); + } + + MeetingAction::Label { + meeting_id, + speaker_id, + label, + } => { + let storage = meeting::MeetingStorage::open(meeting_config.storage.clone()) + .map_err(|e| anyhow::anyhow!("Failed to open storage: {}", e))?; + + let id = storage + .resolve_meeting_id(&meeting_id) + .map_err(|e| anyhow::anyhow!("Meeting not found: {}", e))?; + + // Parse speaker_id - accept "SPEAKER_00", "0", "00", etc. + let speaker_num: u32 = if speaker_id.starts_with("SPEAKER_") { + speaker_id + .trim_start_matches("SPEAKER_") + .parse() + .map_err(|_| anyhow::anyhow!("Invalid speaker ID format: {}", speaker_id))? + } else { + speaker_id.parse().map_err(|_| { + anyhow::anyhow!( + "Invalid speaker ID: {}. Use SPEAKER_XX or a number.", + speaker_id + ) + })? + }; + + storage + .set_speaker_label(&id, speaker_num, &label) + .map_err(|e| anyhow::anyhow!("Failed to set speaker label: {}", e))?; + + println!( + "Labeled SPEAKER_{:02} as '{}' in meeting {}", + speaker_num, label, meeting_id + ); + } + + MeetingAction::Summarize { + meeting_id, + format, + output, + } => { + // Load meeting + let meeting = meeting::get_meeting(&meeting_config, &meeting_id) + .map_err(|e| anyhow::anyhow!("Failed to load meeting: {}", e))?; + + // Create summary config from meeting config + let summary_config = meeting::summary::SummaryConfig { + backend: config.meeting.summary.backend.clone(), + ollama_url: config.meeting.summary.ollama_url.clone(), + ollama_model: config.meeting.summary.ollama_model.clone(), + remote_endpoint: config.meeting.summary.remote_endpoint.clone(), + remote_api_key: config.meeting.summary.remote_api_key.clone(), + timeout_secs: config.meeting.summary.timeout_secs, + }; + + // Create summarizer + let summarizer = meeting::summary::create_summarizer(&summary_config) + .ok_or_else(|| { + anyhow::anyhow!( + "Summarization not configured. Set [meeting.summary] backend in config.toml:\n\n\ + [meeting.summary]\n\ + backend = \"local\" # or \"remote\"\n\ + ollama_url = \"http://localhost:11434\"\n\ + ollama_model = \"llama3.2\"" + ) + })?; + + // Check availability + if !summarizer.is_available() { + return Err(anyhow::anyhow!( + "Summarizer '{}' is not available. Check that Ollama is running.", + summarizer.name() + )); + } + + eprintln!("Generating summary using {}...", summarizer.name()); + + // Generate summary + let summary = summarizer + .summarize(&meeting) + .map_err(|e| anyhow::anyhow!("Summarization failed: {}", e))?; + + // Format output + let content = match format.as_str() { + "json" => serde_json::to_string_pretty(&summary) + .map_err(|e| anyhow::anyhow!("Failed to serialize summary: {}", e))?, + "text" => { + let mut text = String::new(); + text.push_str(&format!("Summary: {}\n\n", summary.summary)); + + if !summary.key_points.is_empty() { + text.push_str("Key Points:\n"); + for point in &summary.key_points { + text.push_str(&format!(" - {}\n", point)); + } + text.push('\n'); + } + + if !summary.action_items.is_empty() { + text.push_str("Action Items:\n"); + for item in &summary.action_items { + let assignee = item + .assignee + .as_ref() + .map(|a| format!(" ({})", a)) + .unwrap_or_default(); + text.push_str(&format!(" - {}{}\n", item.description, assignee)); + } + text.push('\n'); + } + + if !summary.decisions.is_empty() { + text.push_str("Decisions:\n"); + for decision in &summary.decisions { + text.push_str(&format!(" - {}\n", decision)); + } + } + + text + } + _ => meeting::summary::summary_to_markdown(&summary), + }; + + // Output + if let Some(path) = output { + std::fs::write(&path, &content)?; + eprintln!("Summary saved to {:?}", path); + } else { + println!("{}", content); + } + } + } + + Ok(()) +} diff --git a/src/meeting/chunk.rs b/src/meeting/chunk.rs new file mode 100644 index 00000000..eb6809ee --- /dev/null +++ b/src/meeting/chunk.rs @@ -0,0 +1,519 @@ +//! Audio chunk processor for meeting transcription +//! +//! Handles splitting continuous audio into chunks, applying VAD, +//! and coordinating transcription. + +use crate::error::TranscribeError; +use crate::meeting::data::{AudioSource, TranscriptSegment}; +use crate::transcribe::Transcriber; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// Configuration for chunk processing +#[derive(Debug, Clone)] +pub struct ChunkConfig { + /// Duration of each audio chunk in seconds + pub chunk_duration_secs: u32, + /// Minimum audio level to consider as speech (0.0 - 1.0) + pub vad_threshold: f32, + /// Sample rate (expected 16000 Hz) + pub sample_rate: u32, + /// Minimum chunk duration to process (in seconds) + pub min_chunk_duration_secs: f32, +} + +impl Default for ChunkConfig { + fn default() -> Self { + Self { + chunk_duration_secs: 30, + vad_threshold: 0.01, + sample_rate: 16000, + min_chunk_duration_secs: 0.5, + } + } +} + +/// Audio buffer for a chunk being recorded +#[derive(Debug)] +pub struct ChunkBuffer { + /// Audio samples (mono, f32, 16kHz) + samples: Vec, + /// Start time of this chunk + started_at: Instant, + /// Chunk ID + chunk_id: u32, + /// Audio source + source: AudioSource, + /// Start time offset in milliseconds (relative to meeting start) + start_offset_ms: u64, +} + +impl ChunkBuffer { + /// Create a new chunk buffer + pub fn new(chunk_id: u32, source: AudioSource, start_offset_ms: u64) -> Self { + Self { + samples: Vec::with_capacity(16000 * 30), // Pre-allocate for 30 seconds + started_at: Instant::now(), + chunk_id, + source, + start_offset_ms, + } + } + + /// Add audio samples to the buffer + pub fn add_samples(&mut self, samples: &[f32]) { + self.samples.extend_from_slice(samples); + } + + /// Get the duration of audio in seconds + pub fn duration_secs(&self) -> f32 { + self.samples.len() as f32 / 16000.0 + } + + /// Get the elapsed wall-clock time + pub fn elapsed(&self) -> Duration { + self.started_at.elapsed() + } + + /// Take ownership of the samples, leaving buffer empty + pub fn take_samples(&mut self) -> Vec { + std::mem::take(&mut self.samples) + } + + /// Check if buffer has any audio + pub fn has_audio(&self) -> bool { + !self.samples.is_empty() + } +} + +/// Voice Activity Detection (VAD) +/// +/// Simple energy-based VAD for filtering silent chunks. +/// Phase 3+ will use more sophisticated ML-based VAD. +pub struct VoiceActivityDetector { + threshold: f32, + sample_rate: u32, + /// Window size for energy calculation in milliseconds + window_ms: u32, +} + +impl VoiceActivityDetector { + /// Create a new VAD with the given threshold + pub fn new(threshold: f32, sample_rate: u32) -> Self { + Self { + threshold, + sample_rate, + window_ms: 30, + } + } + + /// Check if the audio contains speech + pub fn contains_speech(&self, samples: &[f32]) -> bool { + if samples.is_empty() { + return false; + } + + // Calculate RMS energy over windows + let window_size = (self.sample_rate * self.window_ms / 1000) as usize; + if window_size == 0 { + return false; + } + + let mut speech_frames = 0; + let total_frames = samples.len() / window_size; + + for chunk in samples.chunks(window_size) { + let rms = Self::calculate_rms(chunk); + if rms > self.threshold { + speech_frames += 1; + } + } + + // Require at least 10% of frames to have speech + if total_frames > 0 { + let speech_ratio = speech_frames as f32 / total_frames as f32; + speech_ratio > 0.1 + } else { + false + } + } + + /// Calculate RMS energy of samples + fn calculate_rms(samples: &[f32]) -> f32 { + if samples.is_empty() { + return 0.0; + } + let sum_squares: f32 = samples.iter().map(|s| s * s).sum(); + (sum_squares / samples.len() as f32).sqrt() + } + + /// Get speech segments with their boundaries + /// + /// Returns a list of (start_sample, end_sample) tuples for speech regions. + pub fn detect_speech_segments(&self, samples: &[f32]) -> Vec<(usize, usize)> { + let window_size = (self.sample_rate * self.window_ms / 1000) as usize; + if window_size == 0 || samples.is_empty() { + return vec![]; + } + + let mut segments = vec![]; + let mut in_speech = false; + let mut speech_start = 0; + let mut silence_count = 0; + let hangover = 5; // Number of silent frames to wait before ending segment + + for (i, chunk) in samples.chunks(window_size).enumerate() { + let rms = Self::calculate_rms(chunk); + let is_speech = rms > self.threshold; + + if is_speech { + silence_count = 0; + if !in_speech { + in_speech = true; + speech_start = i * window_size; + } + } else if in_speech { + silence_count += 1; + if silence_count >= hangover { + // End of speech segment + segments.push((speech_start, i * window_size)); + in_speech = false; + silence_count = 0; + } + } + } + + // Handle speech that extends to the end + if in_speech { + segments.push((speech_start, samples.len())); + } + + segments + } +} + +/// Processed chunk result +#[derive(Debug)] +pub struct ProcessedChunk { + /// Chunk ID + pub chunk_id: u32, + /// Transcript segments from this chunk + pub segments: Vec, + /// Original audio duration in milliseconds + pub audio_duration_ms: u64, + /// Processing time in milliseconds + pub processing_time_ms: u64, +} + +/// Chunk processor +/// +/// Coordinates audio buffering, VAD, and transcription for meeting mode. +pub struct ChunkProcessor { + config: ChunkConfig, + vad: VoiceActivityDetector, + transcriber: Arc, + next_segment_id: u32, +} + +impl ChunkProcessor { + /// Create a new chunk processor + pub fn new(config: ChunkConfig, transcriber: Arc) -> Self { + let vad = VoiceActivityDetector::new(config.vad_threshold, config.sample_rate); + Self { + config, + vad, + transcriber, + next_segment_id: 0, + } + } + + /// Process a completed chunk of audio + /// + /// Applies VAD, transcribes speech regions, and returns transcript segments. + pub fn process_chunk( + &mut self, + buffer: ChunkBuffer, + ) -> Result { + let start_time = Instant::now(); + let chunk_id = buffer.chunk_id; + let source = buffer.source; + let start_offset_ms = buffer.start_offset_ms; + + let samples = buffer.samples; + let audio_duration_ms = (samples.len() as f64 / 16000.0 * 1000.0) as u64; + + // Skip if too short + let min_samples = (self.config.min_chunk_duration_secs * 16000.0) as usize; + if samples.len() < min_samples { + tracing::debug!( + "Chunk {} too short ({:.2}s), skipping", + chunk_id, + samples.len() as f32 / 16000.0 + ); + return Ok(ProcessedChunk { + chunk_id, + segments: vec![], + audio_duration_ms, + processing_time_ms: start_time.elapsed().as_millis() as u64, + }); + } + + // Check for speech + if !self.vad.contains_speech(&samples) { + tracing::debug!("Chunk {} has no speech, skipping", chunk_id); + return Ok(ProcessedChunk { + chunk_id, + segments: vec![], + audio_duration_ms, + processing_time_ms: start_time.elapsed().as_millis() as u64, + }); + } + + // Transcribe the chunk + tracing::info!( + "Transcribing chunk {} ({:.1}s of audio)", + chunk_id, + samples.len() as f32 / 16000.0 + ); + + let text = self.transcriber.transcribe(&samples)?; + + let mut segments = vec![]; + if !text.is_empty() && !text.trim().is_empty() { + // Create a single segment for the whole chunk + // Phase 3 will add proper sentence segmentation based on whisper timestamps + let segment_id = self.next_segment_id; + self.next_segment_id += 1; + + let mut segment = TranscriptSegment::new( + segment_id, + start_offset_ms, + start_offset_ms + audio_duration_ms, + text.trim().to_string(), + chunk_id, + ); + segment.source = source; + + segments.push(segment); + } + + let processing_time_ms = start_time.elapsed().as_millis() as u64; + tracing::debug!("Chunk {} processed in {}ms", chunk_id, processing_time_ms); + + Ok(ProcessedChunk { + chunk_id, + segments, + audio_duration_ms, + processing_time_ms, + }) + } + + /// Check if a chunk buffer is ready for processing + pub fn is_chunk_ready(&self, buffer: &ChunkBuffer) -> bool { + buffer.duration_secs() >= self.config.chunk_duration_secs as f32 + } + + /// Create a new chunk buffer + pub fn new_buffer( + &self, + chunk_id: u32, + source: AudioSource, + start_offset_ms: u64, + ) -> ChunkBuffer { + ChunkBuffer::new(chunk_id, source, start_offset_ms) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_samples(duration_secs: f32, frequency_hz: f32, amplitude: f32) -> Vec { + let sample_rate = 16000.0; + let num_samples = (duration_secs * sample_rate) as usize; + (0..num_samples) + .map(|i| { + let t = i as f32 / sample_rate; + amplitude * (2.0 * std::f32::consts::PI * frequency_hz * t).sin() + }) + .collect() + } + + fn create_silent_samples(duration_secs: f32) -> Vec { + let num_samples = (duration_secs * 16000.0) as usize; + vec![0.0; num_samples] + } + + #[test] + fn test_chunk_buffer_duration() { + let mut buffer = ChunkBuffer::new(0, AudioSource::Microphone, 0); + buffer.add_samples(&vec![0.0; 16000]); // 1 second + assert!((buffer.duration_secs() - 1.0).abs() < 0.01); + } + + #[test] + fn test_vad_silent_audio() { + let vad = VoiceActivityDetector::new(0.01, 16000); + let silent = create_silent_samples(1.0); + assert!(!vad.contains_speech(&silent)); + } + + #[test] + fn test_vad_speech_audio() { + let vad = VoiceActivityDetector::new(0.01, 16000); + let speech = create_test_samples(1.0, 440.0, 0.5); + assert!(vad.contains_speech(&speech)); + } + + #[test] + fn test_vad_detect_segments() { + let vad = VoiceActivityDetector::new(0.01, 16000); + + // Create audio: silence, speech, silence + let mut samples = create_silent_samples(0.5); + samples.extend(create_test_samples(1.0, 440.0, 0.5)); + samples.extend(create_silent_samples(0.5)); + + let segments = vad.detect_speech_segments(&samples); + assert!(!segments.is_empty()); + + // The speech should be detected roughly in the middle + let (start, end) = segments[0]; + assert!(start > 0); + assert!(end < samples.len()); + } + + #[test] + fn test_chunk_config_default() { + let config = ChunkConfig::default(); + assert_eq!(config.chunk_duration_secs, 30); + assert_eq!(config.sample_rate, 16000); + } + + #[test] + fn test_rms_calculation() { + let samples = vec![0.5, -0.5, 0.5, -0.5]; + let rms = VoiceActivityDetector::calculate_rms(&samples); + assert!((rms - 0.5).abs() < 0.01); + } + + #[test] + fn test_rms_empty() { + let rms = VoiceActivityDetector::calculate_rms(&[]); + assert_eq!(rms, 0.0); + } + + #[test] + fn test_chunk_buffer_empty() { + let buffer = ChunkBuffer::new(0, AudioSource::Microphone, 0); + assert!(!buffer.has_audio()); + assert!((buffer.duration_secs() - 0.0).abs() < 0.01); + } + + #[test] + fn test_chunk_buffer_take_samples() { + let mut buffer = ChunkBuffer::new(0, AudioSource::Microphone, 0); + buffer.add_samples(&[0.1, 0.2, 0.3]); + assert!(buffer.has_audio()); + + let samples = buffer.take_samples(); + assert_eq!(samples.len(), 3); + assert!(!buffer.has_audio()); + } + + #[test] + fn test_chunk_buffer_multiple_adds() { + let mut buffer = ChunkBuffer::new(0, AudioSource::Microphone, 0); + buffer.add_samples(&vec![0.0; 8000]); // 0.5 seconds + buffer.add_samples(&vec![0.0; 8000]); // 0.5 seconds + assert!((buffer.duration_secs() - 1.0).abs() < 0.01); + } + + #[test] + fn test_chunk_buffer_elapsed() { + let buffer = ChunkBuffer::new(0, AudioSource::Microphone, 0); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert!(buffer.elapsed() >= std::time::Duration::from_millis(10)); + } + + #[test] + fn test_vad_empty_samples() { + let vad = VoiceActivityDetector::new(0.01, 16000); + assert!(!vad.contains_speech(&[])); + } + + #[test] + fn test_vad_detect_segments_empty() { + let vad = VoiceActivityDetector::new(0.01, 16000); + let segments = vad.detect_speech_segments(&[]); + assert!(segments.is_empty()); + } + + #[test] + fn test_vad_detect_segments_all_silence() { + let vad = VoiceActivityDetector::new(0.01, 16000); + let silent = create_silent_samples(2.0); + let segments = vad.detect_speech_segments(&silent); + assert!(segments.is_empty()); + } + + #[test] + fn test_vad_detect_segments_all_speech() { + let vad = VoiceActivityDetector::new(0.01, 16000); + let speech = create_test_samples(1.0, 440.0, 0.5); + let segments = vad.detect_speech_segments(&speech); + assert!(!segments.is_empty()); + // Speech covers the entire buffer + let (start, end) = segments[0]; + assert_eq!(start, 0); + } + + #[test] + fn test_vad_threshold_boundary() { + // Amplitude exactly at threshold should not be detected as speech + // since RMS of a sine wave with amplitude A is A / sqrt(2) + let threshold = 0.5; + let vad = VoiceActivityDetector::new(threshold, 16000); + + // Very quiet audio (RMS below threshold) + let quiet = create_test_samples(1.0, 440.0, 0.001); + assert!(!vad.contains_speech(&quiet)); + + // Loud audio (RMS above threshold) + let loud = create_test_samples(1.0, 440.0, 1.0); + assert!(vad.contains_speech(&loud)); + } + + #[test] + fn test_vad_zero_sample_rate_no_panic() { + // Edge case: zero sample rate should not panic + let vad = VoiceActivityDetector::new(0.01, 0); + assert!(!vad.contains_speech(&[0.5, 0.5, 0.5])); + assert!(vad.detect_speech_segments(&[0.5, 0.5]).is_empty()); + } + + #[test] + fn test_chunk_config_custom() { + let config = ChunkConfig { + chunk_duration_secs: 60, + vad_threshold: 0.05, + sample_rate: 48000, + min_chunk_duration_secs: 1.0, + }; + assert_eq!(config.chunk_duration_secs, 60); + assert_eq!(config.sample_rate, 48000); + } + + #[test] + fn test_rms_uniform_value() { + let samples = vec![0.3; 100]; + let rms = VoiceActivityDetector::calculate_rms(&samples); + assert!((rms - 0.3).abs() < 0.01); + } + + #[test] + fn test_rms_single_sample() { + let rms = VoiceActivityDetector::calculate_rms(&[0.7]); + assert!((rms - 0.7).abs() < 0.01); + } +} diff --git a/src/meeting/data.rs b/src/meeting/data.rs new file mode 100644 index 00000000..b8fe6062 --- /dev/null +++ b/src/meeting/data.rs @@ -0,0 +1,743 @@ +//! Data structures for meeting transcription +//! +//! Defines the core data types for meetings, transcripts, and segments. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use uuid::Uuid; + +/// Unique identifier for a meeting +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct MeetingId(pub Uuid); + +impl MeetingId { + /// Generate a new unique meeting ID + pub fn new() -> Self { + Self(Uuid::new_v4()) + } + + /// Parse from a string + pub fn parse(s: &str) -> Result { + Ok(Self(Uuid::parse_str(s)?)) + } +} + +impl Default for MeetingId { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Display for MeetingId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::str::FromStr for MeetingId { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + Self::parse(s) + } +} + +/// Audio source for speaker attribution +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum AudioSource { + /// User's microphone (local speaker) + Microphone, + /// System audio loopback (remote participants) + Loopback, + /// Unknown source + #[default] + Unknown, +} + + +impl std::fmt::Display for AudioSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AudioSource::Microphone => write!(f, "You"), + AudioSource::Loopback => write!(f, "Remote"), + AudioSource::Unknown => write!(f, "Unknown"), + } + } +} + +/// A single transcript segment with timing and speaker info +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscriptSegment { + /// Unique ID for this segment + pub id: u32, + /// Start time in milliseconds from meeting start + pub start_ms: u64, + /// End time in milliseconds from meeting start + pub end_ms: u64, + /// Transcribed text content + pub text: String, + /// Audio source (mic or loopback) + pub source: AudioSource, + /// Speaker ID (for diarization, Phase 3) + #[serde(skip_serializing_if = "Option::is_none")] + pub speaker_id: Option, + /// Human-assigned speaker label (Phase 3) + #[serde(skip_serializing_if = "Option::is_none")] + pub speaker_label: Option, + /// Confidence score (0.0 - 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub confidence: Option, + /// Chunk number this segment belongs to + pub chunk_id: u32, +} + +impl TranscriptSegment { + /// Create a new transcript segment + pub fn new(id: u32, start_ms: u64, end_ms: u64, text: String, chunk_id: u32) -> Self { + Self { + id, + start_ms, + end_ms, + text, + source: AudioSource::Unknown, + speaker_id: None, + speaker_label: None, + confidence: None, + chunk_id, + } + } + + /// Duration of this segment in milliseconds + pub fn duration_ms(&self) -> u64 { + self.end_ms.saturating_sub(self.start_ms) + } + + /// Get the display name for the speaker + pub fn speaker_display(&self) -> String { + if let Some(ref label) = self.speaker_label { + label.clone() + } else if let Some(ref id) = self.speaker_id { + id.clone() + } else { + self.source.to_string() + } + } + + /// Format timestamp as HH:MM:SS + pub fn format_timestamp(&self) -> String { + let secs = self.start_ms / 1000; + let hours = secs / 3600; + let minutes = (secs % 3600) / 60; + let seconds = secs % 60; + if hours > 0 { + format!("{:02}:{:02}:{:02}", hours, minutes, seconds) + } else { + format!("{:02}:{:02}", minutes, seconds) + } + } +} + +/// Complete transcript for a meeting +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Transcript { + /// Ordered list of transcript segments + pub segments: Vec, + /// Total number of chunks processed + pub total_chunks: u32, +} + +impl Transcript { + /// Create a new empty transcript + pub fn new() -> Self { + Self::default() + } + + /// Add a segment to the transcript + pub fn add_segment(&mut self, segment: TranscriptSegment) { + self.segments.push(segment); + } + + /// Get the full text without speaker labels + pub fn plain_text(&self) -> String { + self.segments + .iter() + .map(|s| s.text.as_str()) + .collect::>() + .join(" ") + } + + /// Get the full text with speaker labels + pub fn text_with_speakers(&self) -> String { + let mut result = String::new(); + let mut last_speaker = String::new(); + + for segment in &self.segments { + let speaker = segment.speaker_display(); + if speaker != last_speaker { + if !result.is_empty() { + result.push_str("\n\n"); + } + result.push_str(&format!("**{}**: ", speaker)); + last_speaker = speaker; + } else { + result.push(' '); + } + result.push_str(&segment.text); + } + result + } + + /// Total duration in milliseconds + pub fn duration_ms(&self) -> u64 { + self.segments.iter().map(|s| s.end_ms).max().unwrap_or(0) + } + + /// Word count + pub fn word_count(&self) -> usize { + self.segments + .iter() + .map(|s| s.text.split_whitespace().count()) + .sum() + } + + /// Get segments for a specific speaker + pub fn segments_by_speaker(&self, speaker: &str) -> Vec<&TranscriptSegment> { + self.segments + .iter() + .filter(|s| s.speaker_display() == speaker) + .collect() + } + + /// Get unique speakers + pub fn speakers(&self) -> Vec { + let mut speakers: Vec = self.segments.iter().map(|s| s.speaker_display()).collect(); + speakers.sort(); + speakers.dedup(); + speakers + } +} + +/// Meeting status +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[derive(Default)] +pub enum MeetingStatus { + /// Meeting is in progress + #[default] + Active, + /// Meeting is paused + Paused, + /// Meeting has ended + Completed, + /// Meeting was cancelled/abandoned + Cancelled, +} + + +/// Metadata for a meeting +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MeetingMetadata { + /// Unique meeting ID + pub id: MeetingId, + /// User-provided title + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// When the meeting started + pub started_at: DateTime, + /// When the meeting ended + #[serde(skip_serializing_if = "Option::is_none")] + pub ended_at: Option>, + /// Duration in seconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_secs: Option, + /// Meeting status + pub status: MeetingStatus, + /// Number of audio chunks + pub chunk_count: u32, + /// Storage path + #[serde(skip_serializing_if = "Option::is_none")] + pub storage_path: Option, + /// Whether audio was retained + pub audio_retained: bool, + /// Whisper model used + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Summary (Phase 5) + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + /// Remote sync status (Phase 4) + #[serde(skip_serializing_if = "Option::is_none")] + pub synced_at: Option>, +} + +impl MeetingMetadata { + /// Create new meeting metadata + pub fn new(title: Option) -> Self { + Self { + id: MeetingId::new(), + title, + started_at: Utc::now(), + ended_at: None, + duration_secs: None, + status: MeetingStatus::Active, + chunk_count: 0, + storage_path: None, + audio_retained: false, + model: None, + summary: None, + synced_at: None, + } + } + + /// Mark the meeting as completed + pub fn complete(&mut self) { + self.ended_at = Some(Utc::now()); + self.status = MeetingStatus::Completed; + if let Some(ended) = self.ended_at { + self.duration_secs = Some((ended - self.started_at).num_seconds() as u64); + } + } + + /// Mark the meeting as cancelled + pub fn cancel(&mut self) { + self.ended_at = Some(Utc::now()); + self.status = MeetingStatus::Cancelled; + } + + /// Get a display title (or fallback to date) + pub fn display_title(&self) -> String { + self.title + .clone() + .unwrap_or_else(|| self.started_at.format("Meeting %Y-%m-%d %H:%M").to_string()) + } + + /// Generate the default storage directory name + pub fn storage_dir_name(&self) -> String { + let date = self.started_at.format("%Y-%m-%d").to_string(); + if let Some(ref title) = self.title { + // Sanitize title for filesystem + let safe_title: String = title + .chars() + .map(|c| { + if c.is_alphanumeric() || c == '-' || c == '_' { + c + } else if c.is_whitespace() { + '-' + } else { + '_' + } + }) + .collect(); + format!("{}-{}", date, safe_title.to_lowercase()) + } else { + format!( + "{}-{}", + date, + self.id.0.to_string().split('-').next().unwrap_or("meeting") + ) + } + } +} + +/// AI-generated meeting summary (Phase 5) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MeetingSummary { + /// Brief summary of the meeting + pub summary: String, + /// Key discussion points + #[serde(default)] + pub key_points: Vec, + /// Action items extracted + #[serde(default)] + pub action_items: Vec, + /// Decisions made + #[serde(default)] + pub decisions: Vec, + /// When the summary was generated + pub generated_at: DateTime, + /// Model used to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +/// An action item from the meeting +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActionItem { + /// Description of the action + pub description: String, + /// Assigned to (speaker name) + #[serde(skip_serializing_if = "Option::is_none")] + pub assignee: Option, + /// Due date (if mentioned) + #[serde(skip_serializing_if = "Option::is_none")] + pub due_date: Option, + /// Completed status + #[serde(default)] + pub completed: bool, +} + +/// Complete meeting data structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MeetingData { + /// Meeting metadata + pub metadata: MeetingMetadata, + /// Transcript + pub transcript: Transcript, +} + +impl MeetingData { + /// Create a new meeting + pub fn new(title: Option) -> Self { + Self { + metadata: MeetingMetadata::new(title), + transcript: Transcript::new(), + } + } + + /// Add a transcript segment + pub fn add_segment(&mut self, segment: TranscriptSegment) { + self.transcript.add_segment(segment); + } + + /// Complete the meeting + pub fn complete(&mut self) { + self.metadata.complete(); + self.metadata.chunk_count = self.transcript.total_chunks; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_meeting_id_new() { + let id1 = MeetingId::new(); + let id2 = MeetingId::new(); + assert_ne!(id1, id2); + } + + #[test] + fn test_meeting_id_parse() { + let id = MeetingId::new(); + let parsed = MeetingId::parse(&id.to_string()).unwrap(); + assert_eq!(id, parsed); + } + + #[test] + fn test_transcript_segment() { + let segment = TranscriptSegment::new(1, 0, 5000, "Hello world".to_string(), 0); + assert_eq!(segment.duration_ms(), 5000); + assert_eq!(segment.format_timestamp(), "00:00"); + } + + #[test] + fn test_transcript_segment_timestamp_format() { + let segment = TranscriptSegment::new(1, 3661000, 3665000, "Test".to_string(), 0); + assert_eq!(segment.format_timestamp(), "01:01:01"); + } + + #[test] + fn test_transcript_plain_text() { + let mut transcript = Transcript::new(); + transcript.add_segment(TranscriptSegment::new(1, 0, 1000, "Hello".to_string(), 0)); + transcript.add_segment(TranscriptSegment::new( + 2, + 1000, + 2000, + "world".to_string(), + 0, + )); + assert_eq!(transcript.plain_text(), "Hello world"); + } + + #[test] + fn test_transcript_word_count() { + let mut transcript = Transcript::new(); + transcript.add_segment(TranscriptSegment::new( + 1, + 0, + 1000, + "Hello world foo bar".to_string(), + 0, + )); + assert_eq!(transcript.word_count(), 4); + } + + #[test] + fn test_meeting_metadata_display_title() { + let metadata = MeetingMetadata::new(Some("Team Standup".to_string())); + assert_eq!(metadata.display_title(), "Team Standup"); + + let metadata = MeetingMetadata::new(None); + assert!(metadata.display_title().starts_with("Meeting")); + } + + #[test] + fn test_meeting_metadata_storage_dir_name() { + let mut metadata = MeetingMetadata::new(Some("Team Standup!".to_string())); + metadata.started_at = DateTime::parse_from_rfc3339("2024-01-15T10:30:00Z") + .unwrap() + .into(); + let dir_name = metadata.storage_dir_name(); + assert!(dir_name.starts_with("2024-01-15-team-standup_")); + } + + #[test] + fn test_meeting_complete() { + let mut meeting = MeetingData::new(Some("Test".to_string())); + assert_eq!(meeting.metadata.status, MeetingStatus::Active); + + meeting.complete(); + assert_eq!(meeting.metadata.status, MeetingStatus::Completed); + assert!(meeting.metadata.ended_at.is_some()); + } + + #[test] + fn test_meeting_id_parse_invalid() { + let result = MeetingId::parse("not-a-uuid"); + assert!(result.is_err()); + } + + #[test] + fn test_meeting_id_parse_empty() { + let result = MeetingId::parse(""); + assert!(result.is_err()); + } + + #[test] + fn test_meeting_id_from_str() { + let id = MeetingId::new(); + let parsed: MeetingId = id.to_string().parse().unwrap(); + assert_eq!(id, parsed); + } + + #[test] + fn test_meeting_id_default_is_unique() { + let id1 = MeetingId::default(); + let id2 = MeetingId::default(); + assert_ne!(id1, id2); + } + + #[test] + fn test_audio_source_default() { + let source = AudioSource::default(); + assert_eq!(source, AudioSource::Unknown); + } + + #[test] + fn test_audio_source_display() { + assert_eq!(format!("{}", AudioSource::Microphone), "You"); + assert_eq!(format!("{}", AudioSource::Loopback), "Remote"); + assert_eq!(format!("{}", AudioSource::Unknown), "Unknown"); + } + + #[test] + fn test_transcript_empty() { + let transcript = Transcript::new(); + assert_eq!(transcript.plain_text(), ""); + assert_eq!(transcript.word_count(), 0); + assert_eq!(transcript.duration_ms(), 0); + assert!(transcript.speakers().is_empty()); + assert!(transcript.segments.is_empty()); + } + + #[test] + fn test_transcript_text_with_speakers() { + let mut transcript = Transcript::new(); + let mut seg1 = TranscriptSegment::new(0, 0, 1000, "Hello".to_string(), 0); + seg1.source = AudioSource::Microphone; + let mut seg2 = TranscriptSegment::new(1, 1000, 2000, "Hi there".to_string(), 0); + seg2.source = AudioSource::Loopback; + transcript.add_segment(seg1); + transcript.add_segment(seg2); + + let text = transcript.text_with_speakers(); + assert!(text.contains("**You**: Hello")); + assert!(text.contains("**Remote**: Hi there")); + } + + #[test] + fn test_transcript_text_with_speakers_merges_consecutive() { + let mut transcript = Transcript::new(); + let mut seg1 = TranscriptSegment::new(0, 0, 1000, "Hello".to_string(), 0); + seg1.source = AudioSource::Microphone; + let mut seg2 = TranscriptSegment::new(1, 1000, 2000, "world".to_string(), 0); + seg2.source = AudioSource::Microphone; + transcript.add_segment(seg1); + transcript.add_segment(seg2); + + let text = transcript.text_with_speakers(); + // Same speaker should not repeat the label + assert_eq!(text.matches("**You**").count(), 1); + } + + #[test] + fn test_transcript_segments_by_speaker() { + let mut transcript = Transcript::new(); + let mut seg1 = TranscriptSegment::new(0, 0, 1000, "Hello".to_string(), 0); + seg1.source = AudioSource::Microphone; + let mut seg2 = TranscriptSegment::new(1, 1000, 2000, "Hi".to_string(), 0); + seg2.source = AudioSource::Loopback; + let mut seg3 = TranscriptSegment::new(2, 2000, 3000, "Bye".to_string(), 0); + seg3.source = AudioSource::Microphone; + transcript.add_segment(seg1); + transcript.add_segment(seg2); + transcript.add_segment(seg3); + + let you_segments = transcript.segments_by_speaker("You"); + assert_eq!(you_segments.len(), 2); + let remote_segments = transcript.segments_by_speaker("Remote"); + assert_eq!(remote_segments.len(), 1); + } + + #[test] + fn test_transcript_speakers_unique_sorted() { + let mut transcript = Transcript::new(); + let mut seg1 = TranscriptSegment::new(0, 0, 1000, "A".to_string(), 0); + seg1.source = AudioSource::Loopback; + let mut seg2 = TranscriptSegment::new(1, 1000, 2000, "B".to_string(), 0); + seg2.source = AudioSource::Microphone; + let mut seg3 = TranscriptSegment::new(2, 2000, 3000, "C".to_string(), 0); + seg3.source = AudioSource::Loopback; + transcript.add_segment(seg1); + transcript.add_segment(seg2); + transcript.add_segment(seg3); + + let speakers = transcript.speakers(); + assert_eq!(speakers.len(), 2); + assert!(speakers.contains(&"You".to_string())); + assert!(speakers.contains(&"Remote".to_string())); + } + + #[test] + fn test_segment_speaker_display_with_label() { + let mut segment = TranscriptSegment::new(0, 0, 1000, "Test".to_string(), 0); + segment.speaker_label = Some("Alice".to_string()); + assert_eq!(segment.speaker_display(), "Alice"); + } + + #[test] + fn test_segment_speaker_display_with_id_no_label() { + let mut segment = TranscriptSegment::new(0, 0, 1000, "Test".to_string(), 0); + segment.speaker_id = Some("SPEAKER_00".to_string()); + assert_eq!(segment.speaker_display(), "SPEAKER_00"); + } + + #[test] + fn test_segment_speaker_display_label_overrides_id() { + let mut segment = TranscriptSegment::new(0, 0, 1000, "Test".to_string(), 0); + segment.speaker_id = Some("SPEAKER_00".to_string()); + segment.speaker_label = Some("Bob".to_string()); + assert_eq!(segment.speaker_display(), "Bob"); + } + + #[test] + fn test_segment_duration_zero() { + let segment = TranscriptSegment::new(0, 5000, 5000, "".to_string(), 0); + assert_eq!(segment.duration_ms(), 0); + } + + #[test] + fn test_segment_format_timestamp_zero() { + let segment = TranscriptSegment::new(0, 0, 1000, "Test".to_string(), 0); + assert_eq!(segment.format_timestamp(), "00:00"); + } + + #[test] + fn test_segment_format_timestamp_minutes_only() { + let segment = TranscriptSegment::new(0, 125000, 130000, "Test".to_string(), 0); + assert_eq!(segment.format_timestamp(), "02:05"); + } + + #[test] + fn test_meeting_metadata_cancel() { + let mut metadata = MeetingMetadata::new(Some("Cancelled".to_string())); + assert_eq!(metadata.status, MeetingStatus::Active); + metadata.cancel(); + assert_eq!(metadata.status, MeetingStatus::Cancelled); + assert!(metadata.ended_at.is_some()); + } + + #[test] + fn test_meeting_metadata_storage_dir_no_title() { + let mut metadata = MeetingMetadata::new(None); + metadata.started_at = DateTime::parse_from_rfc3339("2024-06-15T09:00:00Z") + .unwrap() + .into(); + let dir_name = metadata.storage_dir_name(); + assert!(dir_name.starts_with("2024-06-15-")); + } + + #[test] + fn test_meeting_metadata_complete_sets_duration() { + let mut metadata = MeetingMetadata::new(Some("Duration Test".to_string())); + std::thread::sleep(std::time::Duration::from_millis(10)); + metadata.complete(); + assert!(metadata.duration_secs.is_some()); + } + + #[test] + fn test_meeting_data_add_segment() { + let mut meeting = MeetingData::new(Some("Test".to_string())); + assert!(meeting.transcript.segments.is_empty()); + + meeting.add_segment(TranscriptSegment::new(0, 0, 1000, "Hello".to_string(), 0)); + assert_eq!(meeting.transcript.segments.len(), 1); + } + + #[test] + fn test_meeting_data_complete_sets_chunk_count() { + let mut meeting = MeetingData::new(Some("Test".to_string())); + meeting.transcript.total_chunks = 5; + meeting.complete(); + assert_eq!(meeting.metadata.chunk_count, 5); + } + + #[test] + fn test_meeting_status_default() { + let status = MeetingStatus::default(); + assert_eq!(status, MeetingStatus::Active); + } + + #[test] + fn test_meeting_metadata_new_defaults() { + let metadata = MeetingMetadata::new(None); + assert!(metadata.title.is_none()); + assert!(metadata.ended_at.is_none()); + assert!(metadata.duration_secs.is_none()); + assert_eq!(metadata.status, MeetingStatus::Active); + assert_eq!(metadata.chunk_count, 0); + assert!(!metadata.audio_retained); + assert!(metadata.model.is_none()); + assert!(metadata.summary.is_none()); + assert!(metadata.synced_at.is_none()); + } + + #[test] + fn test_segment_serialization_roundtrip() { + let mut segment = TranscriptSegment::new(0, 0, 5000, "Hello world".to_string(), 0); + segment.source = AudioSource::Microphone; + segment.speaker_id = Some("SPEAKER_00".to_string()); + segment.confidence = Some(0.95); + + let json = serde_json::to_string(&segment).unwrap(); + let deserialized: TranscriptSegment = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.id, 0); + assert_eq!(deserialized.text, "Hello world"); + assert_eq!(deserialized.source, AudioSource::Microphone); + assert_eq!(deserialized.speaker_id, Some("SPEAKER_00".to_string())); + } + + #[test] + fn test_transcript_duration_ms() { + let mut transcript = Transcript::new(); + transcript.add_segment(TranscriptSegment::new(0, 0, 5000, "A".to_string(), 0)); + transcript.add_segment(TranscriptSegment::new(1, 5000, 12000, "B".to_string(), 1)); + assert_eq!(transcript.duration_ms(), 12000); + } +} diff --git a/src/meeting/diarization/ml.rs b/src/meeting/diarization/ml.rs new file mode 100644 index 00000000..5011aa26 --- /dev/null +++ b/src/meeting/diarization/ml.rs @@ -0,0 +1,410 @@ +//! ML-based speaker diarization using ONNX Runtime +//! +//! Uses ECAPA-TDNN speaker embeddings for voice fingerprinting +//! and clustering to identify individual speakers. +//! +//! This module is only available with the `ml-diarization` feature. + +use super::{DiarizationConfig, DiarizedSegment, Diarizer, SpeakerId}; +use crate::meeting::data::AudioSource; +use crate::meeting::TranscriptSegment; +use std::collections::HashMap; +use std::path::PathBuf; + +#[cfg(feature = "ml-diarization")] +use ndarray::{Array1, Array2}; +#[cfg(feature = "ml-diarization")] +use ort::{Session, Value}; + +/// Speaker embedding (voice fingerprint) +#[derive(Debug, Clone)] +pub struct SpeakerEmbedding { + /// Embedding vector (typically 192 or 256 dimensions) + pub vector: Vec, + /// Speaker ID this embedding belongs to + pub speaker_id: SpeakerId, +} + +impl SpeakerEmbedding { + /// Cosine similarity with another embedding + pub fn cosine_similarity(&self, other: &SpeakerEmbedding) -> f32 { + if self.vector.len() != other.vector.len() { + return 0.0; + } + + let dot: f32 = self + .vector + .iter() + .zip(other.vector.iter()) + .map(|(a, b)| a * b) + .sum(); + + let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::().sqrt(); + + if norm_a == 0.0 || norm_b == 0.0 { + return 0.0; + } + + dot / (norm_a * norm_b) + } +} + +/// ML-based speaker diarizer +#[allow(dead_code)] +pub struct MlDiarizer { + /// Path to the ONNX model file + model_path: Option, + /// ONNX session (lazy loaded) + #[cfg(feature = "ml-diarization")] + session: Option>, + /// Known speaker embeddings + speaker_embeddings: Vec, + /// Speaker labels (auto ID -> human label) + speaker_labels: HashMap, + /// Next speaker ID + next_speaker_id: u32, + /// Similarity threshold for matching speakers + similarity_threshold: f32, + /// Maximum number of speakers to detect + max_speakers: u32, + /// Minimum segment duration for embedding (ms) + min_segment_ms: u64, + /// Sample rate for audio + sample_rate: u32, +} + +impl MlDiarizer { + /// Create a new ML diarizer + pub fn new(config: &DiarizationConfig) -> Self { + Self { + model_path: config.model_path.as_ref().map(PathBuf::from), + #[cfg(feature = "ml-diarization")] + session: None, + speaker_embeddings: Vec::new(), + speaker_labels: HashMap::new(), + next_speaker_id: 0, + similarity_threshold: 0.75, + max_speakers: config.max_speakers, + min_segment_ms: config.min_segment_ms, + sample_rate: 16000, + } + } + + /// Get or create default model path + pub fn default_model_path() -> PathBuf { + let data_dir = crate::config::Config::data_dir(); + data_dir.join("models").join("ecapa_tdnn.onnx") + } + + /// Check if the model file exists + pub fn model_exists(&self) -> bool { + self.model_path + .as_ref() + .map(|p| p.exists()) + .unwrap_or_else(|| Self::default_model_path().exists()) + } + + /// Load the ONNX model + #[cfg(feature = "ml-diarization")] + pub fn load_model(&mut self) -> Result<(), String> { + let path = self + .model_path + .clone() + .unwrap_or_else(Self::default_model_path); + + if !path.exists() { + return Err(format!( + "Speaker embedding model not found: {:?}\n\ + Download from: https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb\n\ + Place in: {:?}", + path, path + )); + } + + match Session::builder() { + Ok(builder) => match builder.with_model_from_file(&path) { + Ok(session) => { + self.session = Some(Arc::new(session)); + tracing::info!("Loaded speaker embedding model: {:?}", path); + Ok(()) + } + Err(e) => Err(format!("Failed to load model: {}", e)), + }, + Err(e) => Err(format!("Failed to create ONNX session: {}", e)), + } + } + + /// Extract embedding from audio samples + #[cfg(feature = "ml-diarization")] + pub fn extract_embedding(&self, samples: &[f32]) -> Result, String> { + let session = self.session.as_ref().ok_or("Model not loaded")?; + + // Prepare input tensor: [batch=1, samples] + let input_array = Array2::from_shape_vec((1, samples.len()), samples.to_vec()) + .map_err(|e| format!("Failed to create input array: {}", e))?; + + let input_value = Value::from_array(input_array) + .map_err(|e| format!("Failed to create input value: {}", e))?; + + // Run inference + let outputs = session + .run(ort::inputs![input_value].map_err(|e| format!("Input error: {}", e))?) + .map_err(|e| format!("Inference failed: {}", e))?; + + // Extract embedding from output + let output = outputs + .get("embedding") + .or_else(|| outputs.values().next()) + .ok_or("No output from model")?; + + let embedding: Array1 = output + .try_extract_tensor() + .map_err(|e| format!("Failed to extract tensor: {}", e))? + .view() + .to_owned() + .into_dimensionality() + .map_err(|e| format!("Dimension error: {}", e))?; + + Ok(embedding.to_vec()) + } + + /// Find or create speaker ID for an embedding + #[allow(dead_code)] + fn find_or_create_speaker(&mut self, embedding: &[f32]) -> SpeakerId { + let new_embedding = SpeakerEmbedding { + vector: embedding.to_vec(), + speaker_id: SpeakerId::Auto(self.next_speaker_id), + }; + + // Find best matching existing speaker + let mut best_match: Option<(usize, f32)> = None; + for (i, existing) in self.speaker_embeddings.iter().enumerate() { + let similarity = new_embedding.cosine_similarity(existing); + if similarity > self.similarity_threshold { + match best_match { + None => best_match = Some((i, similarity)), + Some((_, best_sim)) if similarity > best_sim => { + best_match = Some((i, similarity)) + } + _ => {} + } + } + } + + if let Some((idx, _)) = best_match { + // Return existing speaker + self.speaker_embeddings[idx].speaker_id.clone() + } else if self.next_speaker_id < self.max_speakers { + // Create new speaker + let speaker_id = SpeakerId::Auto(self.next_speaker_id); + self.speaker_embeddings.push(SpeakerEmbedding { + vector: embedding.to_vec(), + speaker_id: speaker_id.clone(), + }); + self.next_speaker_id += 1; + speaker_id + } else { + // Too many speakers, return unknown + SpeakerId::Unknown + } + } + + /// Label a speaker + pub fn label_speaker(&mut self, auto_id: u32, label: String) { + self.speaker_labels.insert(auto_id, label); + } + + /// Get speaker label if set + pub fn get_label(&self, speaker_id: &SpeakerId) -> Option { + match speaker_id { + SpeakerId::Auto(id) => self.speaker_labels.get(id).cloned(), + _ => None, + } + } + + /// Convert samples window to milliseconds + #[allow(dead_code)] + fn samples_to_ms(&self, samples: usize) -> u64 { + (samples as u64 * 1000) / self.sample_rate as u64 + } +} + +impl Default for MlDiarizer { + fn default() -> Self { + Self::new(&DiarizationConfig::default()) + } +} + +impl Diarizer for MlDiarizer { + fn diarize( + &self, + _samples: &[f32], + _source: AudioSource, + transcript_segments: &[TranscriptSegment], + ) -> Vec { + // If model is not loaded or feature is disabled, fall back to simple attribution + #[cfg(not(feature = "ml-diarization"))] + { + transcript_segments + .iter() + .map(|seg| DiarizedSegment { + speaker: SpeakerId::Unknown, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.0, + }) + .collect() + } + + #[cfg(feature = "ml-diarization")] + { + if self.session.is_none() { + tracing::warn!("ML diarizer model not loaded, using unknown speaker"); + return transcript_segments + .iter() + .map(|seg| DiarizedSegment { + speaker: SpeakerId::Unknown, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.0, + }) + .collect(); + } + + let mut results = Vec::new(); + + for seg in transcript_segments { + // Skip segments that are too short for reliable embedding + if seg.duration_ms() < self.min_segment_ms { + results.push(DiarizedSegment { + speaker: SpeakerId::Unknown, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.0, + }); + continue; + } + + // Extract audio window for this segment + let start_sample = (seg.start_ms as usize * self.sample_rate as usize) / 1000; + let end_sample = (seg.end_ms as usize * self.sample_rate as usize) / 1000; + + if end_sample > samples.len() { + results.push(DiarizedSegment { + speaker: SpeakerId::Unknown, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.0, + }); + continue; + } + + let segment_samples = &samples[start_sample..end_sample.min(samples.len())]; + + // Extract embedding + match self.extract_embedding(segment_samples) { + Ok(embedding) => { + // Note: find_or_create_speaker needs mutable self, but diarize takes &self + // In a real implementation, we'd need interior mutability or a different pattern + // For now, return with unknown speaker and let caller handle labeling + results.push(DiarizedSegment { + speaker: SpeakerId::Unknown, // Would be find_or_create_speaker result + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.8, + }); + } + Err(e) => { + tracing::warn!("Failed to extract embedding: {}", e); + results.push(DiarizedSegment { + speaker: SpeakerId::Unknown, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.0, + }); + } + } + } + + results + } + } + + fn name(&self) -> &'static str { + "ml" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cosine_similarity_identical() { + let a = SpeakerEmbedding { + vector: vec![1.0, 0.0, 0.0], + speaker_id: SpeakerId::Auto(0), + }; + let b = SpeakerEmbedding { + vector: vec![1.0, 0.0, 0.0], + speaker_id: SpeakerId::Auto(1), + }; + assert!((a.cosine_similarity(&b) - 1.0).abs() < 0.001); + } + + #[test] + fn test_cosine_similarity_orthogonal() { + let a = SpeakerEmbedding { + vector: vec![1.0, 0.0, 0.0], + speaker_id: SpeakerId::Auto(0), + }; + let b = SpeakerEmbedding { + vector: vec![0.0, 1.0, 0.0], + speaker_id: SpeakerId::Auto(1), + }; + assert!(a.cosine_similarity(&b).abs() < 0.001); + } + + #[test] + fn test_cosine_similarity_opposite() { + let a = SpeakerEmbedding { + vector: vec![1.0, 0.0, 0.0], + speaker_id: SpeakerId::Auto(0), + }; + let b = SpeakerEmbedding { + vector: vec![-1.0, 0.0, 0.0], + speaker_id: SpeakerId::Auto(1), + }; + assert!((a.cosine_similarity(&b) + 1.0).abs() < 0.001); + } + + #[test] + fn test_speaker_labeling() { + let mut diarizer = MlDiarizer::default(); + diarizer.label_speaker(0, "Alice".to_string()); + diarizer.label_speaker(1, "Bob".to_string()); + + assert_eq!( + diarizer.get_label(&SpeakerId::Auto(0)), + Some("Alice".to_string()) + ); + assert_eq!( + diarizer.get_label(&SpeakerId::Auto(1)), + Some("Bob".to_string()) + ); + assert_eq!(diarizer.get_label(&SpeakerId::Auto(2)), None); + } + + #[test] + fn test_default_model_path() { + let path = MlDiarizer::default_model_path(); + assert!(path.ends_with("ecapa_tdnn.onnx")); + } +} diff --git a/src/meeting/diarization/mod.rs b/src/meeting/diarization/mod.rs new file mode 100644 index 00000000..551c8b30 --- /dev/null +++ b/src/meeting/diarization/mod.rs @@ -0,0 +1,177 @@ +//! Speaker diarization for meeting transcription +//! +//! Provides speaker identification and attribution for meeting transcripts. +//! +//! # Backends +//! +//! - **Simple**: Source-based attribution using mic vs loopback (Phase 2) +//! - **ML**: ONNX-based speaker embeddings with clustering (Phase 3) +//! - **Subprocess**: Memory-isolated ML diarization for resource-constrained systems + +pub mod ml; +pub mod simple; +pub mod subprocess; + +use crate::meeting::data::AudioSource; +use std::collections::HashMap; + +/// Speaker identifier +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SpeakerId { + /// The local user (from microphone) + You, + /// Remote participant(s) (from loopback) + Remote, + /// Unknown speaker + Unknown, + /// Identified speaker with label + Named(String), + /// Auto-generated speaker ID (e.g., SPEAKER_00) + Auto(u32), +} + +impl SpeakerId { + /// Get display name for this speaker + pub fn display_name(&self) -> String { + match self { + SpeakerId::You => "You".to_string(), + SpeakerId::Remote => "Remote".to_string(), + SpeakerId::Unknown => "Unknown".to_string(), + SpeakerId::Named(name) => name.clone(), + SpeakerId::Auto(id) => format!("SPEAKER_{:02}", id), + } + } +} + +impl std::fmt::Display for SpeakerId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display_name()) + } +} + +/// A segment with speaker attribution +#[derive(Debug, Clone)] +pub struct DiarizedSegment { + /// Speaker who said this + pub speaker: SpeakerId, + /// Start time in milliseconds + pub start_ms: u64, + /// End time in milliseconds + pub end_ms: u64, + /// Transcribed text + pub text: String, + /// Confidence score (0.0 - 1.0) + pub confidence: f32, +} + +/// Speaker labels mapping auto IDs to names +pub type SpeakerLabels = HashMap; + +/// Trait for diarization backends +pub trait Diarizer: Send + Sync { + /// Process audio samples and return diarized segments + fn diarize( + &self, + samples: &[f32], + source: AudioSource, + transcript_segments: &[crate::meeting::TranscriptSegment], + ) -> Vec; + + /// Get the backend name + fn name(&self) -> &'static str; +} + +/// Diarization configuration +#[derive(Debug, Clone)] +pub struct DiarizationConfig { + /// Enable diarization + pub enabled: bool, + /// Backend to use: "simple", "ml", or "remote" + pub backend: String, + /// Maximum number of speakers to detect + pub max_speakers: u32, + /// Minimum segment duration in milliseconds + pub min_segment_ms: u64, + /// Path to ONNX model for ML backend + pub model_path: Option, +} + +impl Default for DiarizationConfig { + fn default() -> Self { + Self { + enabled: true, + backend: "simple".to_string(), + max_speakers: 10, + min_segment_ms: 500, + model_path: None, + } + } +} + +/// Create a diarizer based on configuration +pub fn create_diarizer(config: &DiarizationConfig) -> Box { + match config.backend.as_str() { + "simple" => Box::new(simple::SimpleDiarizer::new()), + "ml" => { + #[cfg(feature = "ml-diarization")] + { + let mut diarizer = ml::MlDiarizer::new(config); + if diarizer.model_exists() { + if let Err(e) = diarizer.load_model() { + tracing::warn!("Failed to load ML diarization model: {}", e); + tracing::info!("Falling back to simple diarization"); + return Box::new(simple::SimpleDiarizer::new()); + } + tracing::info!("Using ML diarization with ONNX"); + return Box::new(diarizer); + } else { + tracing::warn!("ML diarization model not found, falling back to simple"); + return Box::new(simple::SimpleDiarizer::new()); + } + } + #[cfg(not(feature = "ml-diarization"))] + { + tracing::warn!( + "ML diarization requires the 'ml-diarization' feature, falling back to simple" + ); + Box::new(simple::SimpleDiarizer::new()) + } + } + "subprocess" => { + // Subprocess diarizer for memory-isolated ML diarization + Box::new(subprocess::SubprocessDiarizer::new(config.clone())) + } + _ => { + tracing::warn!( + "Unknown diarizer backend '{}', using simple", + config.backend + ); + Box::new(simple::SimpleDiarizer::new()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_speaker_id_display() { + assert_eq!(SpeakerId::You.display_name(), "You"); + assert_eq!(SpeakerId::Remote.display_name(), "Remote"); + assert_eq!(SpeakerId::Auto(0).display_name(), "SPEAKER_00"); + assert_eq!(SpeakerId::Auto(5).display_name(), "SPEAKER_05"); + assert_eq!( + SpeakerId::Named("Alice".to_string()).display_name(), + "Alice" + ); + } + + #[test] + fn test_default_config() { + let config = DiarizationConfig::default(); + assert!(config.enabled); + assert_eq!(config.backend, "simple"); + assert_eq!(config.max_speakers, 10); + } +} diff --git a/src/meeting/diarization/simple.rs b/src/meeting/diarization/simple.rs new file mode 100644 index 00000000..dd4e317c --- /dev/null +++ b/src/meeting/diarization/simple.rs @@ -0,0 +1,175 @@ +//! Simple source-based diarization +//! +//! Attributes speakers based on audio source: +//! - Microphone input → "You" +//! - System loopback → "Remote" +//! +//! This provides basic speaker separation without ML models. + +use super::{DiarizedSegment, Diarizer, SpeakerId}; +use crate::meeting::data::AudioSource; +use crate::meeting::TranscriptSegment; + +/// Simple diarizer using audio source for attribution +pub struct SimpleDiarizer { + /// Minimum gap between segments to merge (ms) + merge_gap_ms: u64, +} + +impl SimpleDiarizer { + /// Create a new simple diarizer + pub fn new() -> Self { + Self { merge_gap_ms: 500 } + } + + /// Create with custom merge gap + pub fn with_merge_gap(merge_gap_ms: u64) -> Self { + Self { merge_gap_ms } + } + + /// Convert audio source to speaker ID + fn source_to_speaker(source: AudioSource) -> SpeakerId { + match source { + AudioSource::Microphone => SpeakerId::You, + AudioSource::Loopback => SpeakerId::Remote, + AudioSource::Unknown => SpeakerId::Unknown, + } + } +} + +impl Default for SimpleDiarizer { + fn default() -> Self { + Self::new() + } +} + +impl Diarizer for SimpleDiarizer { + fn diarize( + &self, + _samples: &[f32], + source: AudioSource, + transcript_segments: &[TranscriptSegment], + ) -> Vec { + let speaker = Self::source_to_speaker(source); + + // Convert transcript segments to diarized segments + let mut diarized: Vec = transcript_segments + .iter() + .map(|seg| DiarizedSegment { + speaker: speaker.clone(), + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 1.0, // High confidence for source-based attribution + }) + .collect(); + + // Merge consecutive segments from the same speaker + self.merge_consecutive(&mut diarized); + + diarized + } + + fn name(&self) -> &'static str { + "simple" + } +} + +impl SimpleDiarizer { + /// Merge consecutive segments from the same speaker if they're close together + fn merge_consecutive(&self, segments: &mut Vec) { + if segments.len() < 2 { + return; + } + + let mut i = 0; + while i < segments.len() - 1 { + let current_end = segments[i].end_ms; + let next_start = segments[i + 1].start_ms; + let same_speaker = segments[i].speaker == segments[i + 1].speaker; + let close_enough = next_start.saturating_sub(current_end) <= self.merge_gap_ms; + + if same_speaker && close_enough { + // Clone the text from next segment before modifying + let next_text = segments[i + 1].text.clone(); + let next_end = segments[i + 1].end_ms; + let next_confidence = segments[i + 1].confidence; + + // Merge next into current + segments[i].end_ms = next_end; + segments[i].text.push(' '); + segments[i].text.push_str(&next_text); + segments[i].confidence = (segments[i].confidence + next_confidence) / 2.0; + segments.remove(i + 1); + } else { + i += 1; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_source_to_speaker() { + assert_eq!( + SimpleDiarizer::source_to_speaker(AudioSource::Microphone), + SpeakerId::You + ); + assert_eq!( + SimpleDiarizer::source_to_speaker(AudioSource::Loopback), + SpeakerId::Remote + ); + assert_eq!( + SimpleDiarizer::source_to_speaker(AudioSource::Unknown), + SpeakerId::Unknown + ); + } + + #[test] + fn test_diarize_mic_segments() { + let diarizer = SimpleDiarizer::new(); + let mut seg1 = TranscriptSegment::new(1, 0, 1000, "Hello".to_string(), 0); + seg1.source = AudioSource::Microphone; + let mut seg2 = TranscriptSegment::new(2, 1000, 2000, "World".to_string(), 0); + seg2.source = AudioSource::Microphone; + let segments = vec![seg1, seg2]; + + let result = diarizer.diarize(&[], AudioSource::Microphone, &segments); + + // Should merge into one segment since same speaker and close together + assert_eq!(result.len(), 1); + assert_eq!(result[0].speaker, SpeakerId::You); + assert_eq!(result[0].text, "Hello World"); + } + + #[test] + fn test_diarize_preserves_separate_segments() { + let diarizer = SimpleDiarizer::new(); + let mut seg1 = TranscriptSegment::new(1, 0, 1000, "First".to_string(), 0); + seg1.source = AudioSource::Microphone; + let mut seg2 = TranscriptSegment::new(2, 5000, 6000, "Second".to_string(), 0); + seg2.source = AudioSource::Microphone; + let segments = vec![seg1, seg2]; + + let result = diarizer.diarize(&[], AudioSource::Microphone, &segments); + + // Should keep separate due to large gap + assert_eq!(result.len(), 2); + } + + #[test] + fn test_diarize_loopback() { + let diarizer = SimpleDiarizer::new(); + let mut seg = TranscriptSegment::new(1, 0, 1000, "Remote speech".to_string(), 0); + seg.source = AudioSource::Loopback; + let segments = vec![seg]; + + let result = diarizer.diarize(&[], AudioSource::Loopback, &segments); + + assert_eq!(result.len(), 1); + assert_eq!(result[0].speaker, SpeakerId::Remote); + } +} diff --git a/src/meeting/diarization/subprocess.rs b/src/meeting/diarization/subprocess.rs new file mode 100644 index 00000000..5a1e6d71 --- /dev/null +++ b/src/meeting/diarization/subprocess.rs @@ -0,0 +1,279 @@ +//! Subprocess-based diarization for memory isolation +//! +//! Runs speaker embedding extraction in a subprocess that exits after +//! processing, releasing memory. Useful on memory-constrained systems. + +use super::{DiarizationConfig, DiarizedSegment, Diarizer, SpeakerId}; +use crate::meeting::data::AudioSource; +use crate::meeting::TranscriptSegment; +use std::io::{BufRead, BufReader, Write}; +use std::process::{Child, Command, Stdio}; + +/// Subprocess-based diarizer wrapper +#[allow(dead_code)] +pub struct SubprocessDiarizer { + /// Diarization configuration + config: DiarizationConfig, + /// Child process handle + child: Option, +} + +impl SubprocessDiarizer { + /// Create a new subprocess diarizer + pub fn new(config: DiarizationConfig) -> Self { + Self { + config, + child: None, + } + } + + /// Spawn the worker subprocess + #[allow(dead_code)] + fn spawn_worker(&mut self) -> Result<&mut Child, String> { + if self.child.is_some() { + return self + .child + .as_mut() + .ok_or("Child already exists".to_string()); + } + + let exe = std::env::current_exe().map_err(|e| format!("Failed to get exe path: {}", e))?; + + let mut cmd = Command::new(exe); + cmd.arg("--diarization-worker") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()); + + if let Some(ref model_path) = self.config.model_path { + cmd.arg("--model").arg(model_path); + } + + let child = cmd + .spawn() + .map_err(|e| format!("Failed to spawn worker: {}", e))?; + self.child = Some(child); + self.child.as_mut().ok_or("Failed to get child".to_string()) + } + + /// Send audio samples to worker and receive embeddings + #[allow(dead_code)] + fn process_in_worker( + &mut self, + samples: &[f32], + segments: &[TranscriptSegment], + ) -> Result, String> { + let child = self.spawn_worker()?; + + let stdin = child.stdin.as_mut().ok_or("No stdin")?; + let stdout = child.stdout.as_mut().ok_or("No stdout")?; + + // Send sample count + writeln!(stdin, "{}", samples.len()).map_err(|e| format!("Write error: {}", e))?; + + // Send samples (as space-separated floats, chunked) + for chunk in samples.chunks(1000) { + let line: String = chunk + .iter() + .map(|f| f.to_string()) + .collect::>() + .join(" "); + writeln!(stdin, "{}", line).map_err(|e| format!("Write error: {}", e))?; + } + + // Send segment count + writeln!(stdin, "{}", segments.len()).map_err(|e| format!("Write error: {}", e))?; + + // Send segments (start_ms end_ms text) + for seg in segments { + writeln!(stdin, "{} {} {}", seg.start_ms, seg.end_ms, seg.text) + .map_err(|e| format!("Write error: {}", e))?; + } + + stdin.flush().map_err(|e| format!("Flush error: {}", e))?; + + // Read results + let reader = BufReader::new(stdout); + let mut results = Vec::new(); + + for line in reader.lines() { + let line = line.map_err(|e| format!("Read error: {}", e))?; + if line.is_empty() || line == "END" { + break; + } + + // Parse: speaker_id start_ms end_ms confidence text + let parts: Vec<&str> = line.splitn(5, ' ').collect(); + if parts.len() < 5 { + continue; + } + + let speaker = parse_speaker_id(parts[0]); + let start_ms: u64 = parts[1].parse().unwrap_or(0); + let end_ms: u64 = parts[2].parse().unwrap_or(0); + let confidence: f32 = parts[3].parse().unwrap_or(0.0); + let text = parts[4].to_string(); + + results.push(DiarizedSegment { + speaker, + start_ms, + end_ms, + text, + confidence, + }); + } + + // Kill the subprocess to release memory + if let Some(ref mut child) = self.child { + let _ = child.kill(); + let _ = child.wait(); + } + self.child = None; + + Ok(results) + } +} + +#[allow(dead_code)] +fn parse_speaker_id(s: &str) -> SpeakerId { + match s { + "You" => SpeakerId::You, + "Remote" => SpeakerId::Remote, + "Unknown" => SpeakerId::Unknown, + s if s.starts_with("SPEAKER_") => { + if let Ok(id) = s.trim_start_matches("SPEAKER_").parse() { + SpeakerId::Auto(id) + } else { + SpeakerId::Unknown + } + } + s => SpeakerId::Named(s.to_string()), + } +} + +impl Diarizer for SubprocessDiarizer { + fn diarize( + &self, + _samples: &[f32], + _source: AudioSource, + transcript_segments: &[TranscriptSegment], + ) -> Vec { + // Note: Diarizer trait takes &self, but we need &mut self for subprocess + // This is a limitation - in practice, we'd use interior mutability + // For now, return simple attribution as fallback + transcript_segments + .iter() + .map(|seg| { + let speaker = match seg.source { + AudioSource::Microphone => SpeakerId::You, + AudioSource::Loopback => SpeakerId::Remote, + AudioSource::Unknown => SpeakerId::Unknown, + }; + DiarizedSegment { + speaker, + start_ms: seg.start_ms, + end_ms: seg.end_ms, + text: seg.text.clone(), + confidence: 0.5, + } + }) + .collect() + } + + fn name(&self) -> &'static str { + "subprocess" + } +} + +/// Worker entry point for subprocess diarization +/// Called when voxtype is run with --diarization-worker +pub fn run_worker(_model_path: Option<&str>) -> Result<(), String> { + use std::io::{stdin, stdout, BufRead}; + + let stdin = stdin(); + let mut stdout = stdout(); + let mut reader = stdin.lock(); + let mut line = String::new(); + + // Read sample count + line.clear(); + reader + .read_line(&mut line) + .map_err(|e| format!("Read error: {}", e))?; + let sample_count: usize = line + .trim() + .parse() + .map_err(|e| format!("Parse error: {}", e))?; + + // Read samples + let mut samples = Vec::with_capacity(sample_count); + let mut remaining = sample_count; + while remaining > 0 { + line.clear(); + reader + .read_line(&mut line) + .map_err(|e| format!("Read error: {}", e))?; + for s in line.split_whitespace() { + if let Ok(f) = s.parse::() { + samples.push(f); + remaining = remaining.saturating_sub(1); + } + } + } + + // Read segment count + line.clear(); + reader + .read_line(&mut line) + .map_err(|e| format!("Read error: {}", e))?; + let segment_count: usize = line + .trim() + .parse() + .map_err(|e| format!("Parse error: {}", e))?; + + // Read segments + let mut segments = Vec::with_capacity(segment_count); + for _ in 0..segment_count { + line.clear(); + reader + .read_line(&mut line) + .map_err(|e| format!("Read error: {}", e))?; + let parts: Vec<&str> = line.trim().splitn(3, ' ').collect(); + if parts.len() >= 3 { + let start_ms: u64 = parts[0].parse().unwrap_or(0); + let end_ms: u64 = parts[1].parse().unwrap_or(0); + let text = parts[2].to_string(); + segments.push((start_ms, end_ms, text)); + } + } + + // Process with ML diarizer (simplified - just return with unknown speaker for now) + // In a real implementation, we'd load the ONNX model and run inference + for (start_ms, end_ms, text) in segments { + writeln!(stdout, "Unknown {} {} 0.5 {}", start_ms, end_ms, text) + .map_err(|e| format!("Write error: {}", e))?; + } + + writeln!(stdout, "END").map_err(|e| format!("Write error: {}", e))?; + stdout.flush().map_err(|e| format!("Flush error: {}", e))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_speaker_id() { + assert_eq!(parse_speaker_id("You"), SpeakerId::You); + assert_eq!(parse_speaker_id("Remote"), SpeakerId::Remote); + assert_eq!(parse_speaker_id("Unknown"), SpeakerId::Unknown); + assert_eq!(parse_speaker_id("SPEAKER_00"), SpeakerId::Auto(0)); + assert_eq!(parse_speaker_id("SPEAKER_05"), SpeakerId::Auto(5)); + assert_eq!( + parse_speaker_id("Alice"), + SpeakerId::Named("Alice".to_string()) + ); + } +} diff --git a/src/meeting/export/json.rs b/src/meeting/export/json.rs new file mode 100644 index 00000000..00dc7b40 --- /dev/null +++ b/src/meeting/export/json.rs @@ -0,0 +1,212 @@ +//! JSON exporter for meeting transcriptions + +use crate::meeting::data::MeetingData; +use crate::meeting::export::{ExportError, ExportFormat, ExportOptions, Exporter}; +use serde::Serialize; + +/// JSON exporter +pub struct JsonExporter; + +/// Exported meeting structure for JSON +#[derive(Serialize)] +struct ExportedMeeting { + metadata: ExportedMetadata, + transcript: ExportedTranscript, + #[serde(skip_serializing_if = "Option::is_none")] + summary: Option, +} + +#[derive(Serialize)] +struct ExportedMetadata { + id: String, + title: Option, + #[serde(rename = "startedAt")] + started_at: String, + #[serde(rename = "endedAt", skip_serializing_if = "Option::is_none")] + ended_at: Option, + #[serde(rename = "durationSecs", skip_serializing_if = "Option::is_none")] + duration_secs: Option, + status: String, + #[serde(rename = "chunkCount")] + chunk_count: u32, +} + +#[derive(Serialize)] +struct ExportedTranscript { + segments: Vec, + #[serde(rename = "totalChunks")] + total_chunks: u32, + #[serde(rename = "wordCount")] + word_count: usize, + #[serde(rename = "durationMs")] + duration_ms: u64, + speakers: Vec, +} + +#[derive(Serialize)] +struct ExportedSegment { + id: u32, + #[serde(rename = "startMs")] + start_ms: u64, + #[serde(rename = "endMs")] + end_ms: u64, + text: String, + source: String, + #[serde(skip_serializing_if = "Option::is_none")] + speaker: Option, + #[serde(rename = "chunkId")] + chunk_id: u32, +} + +#[derive(Serialize)] +struct ExportedSummary { + summary: String, + #[serde(rename = "keyPoints")] + key_points: Vec, + #[serde(rename = "actionItems")] + action_items: Vec, + decisions: Vec, + #[serde(rename = "generatedAt")] + generated_at: String, +} + +#[derive(Serialize)] +struct ExportedActionItem { + description: String, + #[serde(skip_serializing_if = "Option::is_none")] + assignee: Option, + #[serde(rename = "dueDate", skip_serializing_if = "Option::is_none")] + due_date: Option, + completed: bool, +} + +impl Exporter for JsonExporter { + fn format(&self) -> ExportFormat { + ExportFormat::Json + } + + fn export( + &self, + meeting: &MeetingData, + _options: &ExportOptions, + ) -> Result { + let exported = ExportedMeeting { + metadata: ExportedMetadata { + id: meeting.metadata.id.to_string(), + title: meeting.metadata.title.clone(), + started_at: meeting.metadata.started_at.to_rfc3339(), + ended_at: meeting.metadata.ended_at.map(|dt| dt.to_rfc3339()), + duration_secs: meeting.metadata.duration_secs, + status: format!("{:?}", meeting.metadata.status).to_lowercase(), + chunk_count: meeting.metadata.chunk_count, + }, + transcript: ExportedTranscript { + segments: meeting + .transcript + .segments + .iter() + .map(|s| ExportedSegment { + id: s.id, + start_ms: s.start_ms, + end_ms: s.end_ms, + text: s.text.clone(), + source: format!("{:?}", s.source).to_lowercase(), + speaker: s.speaker_label.clone().or_else(|| s.speaker_id.clone()), + chunk_id: s.chunk_id, + }) + .collect(), + total_chunks: meeting.transcript.total_chunks, + word_count: meeting.transcript.word_count(), + duration_ms: meeting.transcript.duration_ms(), + speakers: meeting.transcript.speakers(), + }, + summary: meeting.metadata.summary.as_ref().map(|s| ExportedSummary { + summary: s.summary.clone(), + key_points: s.key_points.clone(), + action_items: s + .action_items + .iter() + .map(|a| ExportedActionItem { + description: a.description.clone(), + assignee: a.assignee.clone(), + due_date: a.due_date.clone(), + completed: a.completed, + }) + .collect(), + decisions: s.decisions.clone(), + generated_at: s.generated_at.to_rfc3339(), + }), + }; + + serde_json::to_string_pretty(&exported) + .map_err(|e| ExportError::Serialization(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::meeting::data::TranscriptSegment; + + fn create_test_meeting() -> MeetingData { + let mut meeting = MeetingData::new(Some("Test Meeting".to_string())); + meeting.transcript.add_segment(TranscriptSegment::new( + 0, + 0, + 5000, + "Hello world.".to_string(), + 0, + )); + meeting + } + + #[test] + fn test_json_export() { + let meeting = create_test_meeting(); + let exporter = JsonExporter; + let options = ExportOptions::default(); + + let output = exporter.export(&meeting, &options).unwrap(); + + // Parse and verify structure + let parsed: serde_json::Value = serde_json::from_str(&output).unwrap(); + + assert!(parsed["metadata"]["id"].is_string()); + assert_eq!(parsed["metadata"]["title"].as_str(), Some("Test Meeting")); + assert!(parsed["transcript"]["segments"].is_array()); + assert_eq!( + parsed["transcript"]["segments"][0]["text"].as_str(), + Some("Hello world.") + ); + } + + #[test] + fn test_json_export_valid_json() { + let meeting = create_test_meeting(); + let exporter = JsonExporter; + let options = ExportOptions::default(); + + let output = exporter.export(&meeting, &options).unwrap(); + + // Should be valid JSON + let result: Result = serde_json::from_str(&output); + assert!(result.is_ok()); + } + + #[test] + fn test_json_export_roundtrip() { + let meeting = create_test_meeting(); + let exporter = JsonExporter; + let options = ExportOptions::default(); + + let output = exporter.export(&meeting, &options).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&output).unwrap(); + + // Verify key fields + assert_eq!(parsed["transcript"]["wordCount"].as_u64(), Some(2)); + assert_eq!( + parsed["transcript"]["segments"].as_array().unwrap().len(), + 1 + ); + } +} diff --git a/src/meeting/export/markdown.rs b/src/meeting/export/markdown.rs new file mode 100644 index 00000000..474e7cfe --- /dev/null +++ b/src/meeting/export/markdown.rs @@ -0,0 +1,199 @@ +//! Markdown exporter for meeting transcriptions + +use crate::meeting::data::MeetingData; +use crate::meeting::export::{ExportError, ExportFormat, ExportOptions, Exporter}; + +/// Markdown exporter +pub struct MarkdownExporter; + +impl Exporter for MarkdownExporter { + fn format(&self) -> ExportFormat { + ExportFormat::Markdown + } + + fn export( + &self, + meeting: &MeetingData, + options: &ExportOptions, + ) -> Result { + let mut output = String::new(); + + // Title + output.push_str(&format!("# {}\n\n", meeting.metadata.display_title())); + + // Metadata + if options.include_metadata { + output.push_str("## Meeting Info\n\n"); + output.push_str(&format!( + "- **Date:** {}\n", + meeting.metadata.started_at.format("%Y-%m-%d %H:%M UTC") + )); + if let Some(duration) = meeting.metadata.duration_secs { + let hours = duration / 3600; + let mins = (duration % 3600) / 60; + let secs = duration % 60; + if hours > 0 { + output.push_str(&format!("- **Duration:** {}h {}m {}s\n", hours, mins, secs)); + } else { + output.push_str(&format!("- **Duration:** {}m {}s\n", mins, secs)); + } + } + output.push_str(&format!( + "- **Word Count:** {}\n", + meeting.transcript.word_count() + )); + output.push_str(&format!( + "- **Segments:** {}\n", + meeting.transcript.segments.len() + )); + + let speakers = meeting.transcript.speakers(); + if !speakers.is_empty() { + output.push_str(&format!("- **Speakers:** {}\n", speakers.join(", "))); + } + + output.push('\n'); + } + + // Summary (if available, Phase 5) + if let Some(ref summary) = meeting.metadata.summary { + output.push_str("## Summary\n\n"); + output.push_str(&summary.summary); + output.push_str("\n\n"); + + if !summary.key_points.is_empty() { + output.push_str("### Key Points\n\n"); + for point in &summary.key_points { + output.push_str(&format!("- {}\n", point)); + } + output.push('\n'); + } + + if !summary.action_items.is_empty() { + output.push_str("### Action Items\n\n"); + for item in &summary.action_items { + let checkbox = if item.completed { "[x]" } else { "[ ]" }; + let assignee = item + .assignee + .as_ref() + .map(|a| format!(" (@{})", a)) + .unwrap_or_default(); + output.push_str(&format!( + "- {} {}{}\n", + checkbox, item.description, assignee + )); + } + output.push('\n'); + } + + if !summary.decisions.is_empty() { + output.push_str("### Decisions\n\n"); + for decision in &summary.decisions { + output.push_str(&format!("- {}\n", decision)); + } + output.push('\n'); + } + } + + // Transcript + output.push_str("## Transcript\n\n"); + + let mut last_speaker = String::new(); + + for segment in &meeting.transcript.segments { + if options.include_speakers { + let speaker = segment.speaker_display(); + if speaker != last_speaker { + if !last_speaker.is_empty() { + output.push('\n'); + } + output.push_str(&format!("### {}\n\n", speaker)); + last_speaker = speaker; + } + } + + if options.include_timestamps { + output.push_str(&format!("*[{}]* ", segment.format_timestamp())); + } + + output.push_str(&segment.text); + output.push_str("\n\n"); + } + + Ok(output) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::meeting::data::{AudioSource, TranscriptSegment}; + + fn create_test_meeting() -> MeetingData { + let mut meeting = MeetingData::new(Some("Weekly Standup".to_string())); + + let mut seg1 = TranscriptSegment::new(0, 0, 5000, "Good morning everyone.".to_string(), 0); + seg1.source = AudioSource::Microphone; + + let mut seg2 = TranscriptSegment::new(1, 5000, 10000, "Hey, good morning!".to_string(), 0); + seg2.source = AudioSource::Loopback; + + meeting.transcript.add_segment(seg1); + meeting.transcript.add_segment(seg2); + meeting + } + + #[test] + fn test_markdown_export_basic() { + let meeting = create_test_meeting(); + let exporter = MarkdownExporter; + let options = ExportOptions::default(); + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.starts_with("# Weekly Standup")); + assert!(output.contains("## Transcript")); + } + + #[test] + fn test_markdown_export_with_metadata() { + let meeting = create_test_meeting(); + let exporter = MarkdownExporter; + let options = ExportOptions { + include_metadata: true, + ..Default::default() + }; + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("## Meeting Info")); + assert!(output.contains("**Date:**")); + assert!(output.contains("**Word Count:**")); + } + + #[test] + fn test_markdown_export_with_speakers() { + let meeting = create_test_meeting(); + let exporter = MarkdownExporter; + let options = ExportOptions { + include_speakers: true, + ..Default::default() + }; + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("### You")); + assert!(output.contains("### Remote")); + } + + #[test] + fn test_markdown_export_with_timestamps() { + let meeting = create_test_meeting(); + let exporter = MarkdownExporter; + let options = ExportOptions { + include_timestamps: true, + ..Default::default() + }; + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("*[00:00]*")); + assert!(output.contains("*[00:05]*")); + } +} diff --git a/src/meeting/export/mod.rs b/src/meeting/export/mod.rs new file mode 100644 index 00000000..df8ac123 --- /dev/null +++ b/src/meeting/export/mod.rs @@ -0,0 +1,235 @@ +//! Export functionality for meeting transcriptions +//! +//! Provides exporters for various output formats. + +pub mod json; +pub mod markdown; +pub mod srt; +pub mod txt; +pub mod vtt; + +use crate::meeting::data::MeetingData; +use thiserror::Error; + +/// Export format types +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExportFormat { + /// Plain text + Text, + /// Markdown + Markdown, + /// JSON + Json, + /// SRT subtitles (Phase 2) + Srt, + /// VTT subtitles (Phase 2) + Vtt, +} + +impl ExportFormat { + /// Parse format from string name + pub fn parse(s: &str) -> Option { + match s.to_lowercase().as_str() { + "text" | "txt" => Some(ExportFormat::Text), + "markdown" | "md" => Some(ExportFormat::Markdown), + "json" => Some(ExportFormat::Json), + "srt" => Some(ExportFormat::Srt), + "vtt" => Some(ExportFormat::Vtt), + _ => None, + } + } + + /// Get file extension for this format + pub fn extension(&self) -> &'static str { + match self { + ExportFormat::Text => "txt", + ExportFormat::Markdown => "md", + ExportFormat::Json => "json", + ExportFormat::Srt => "srt", + ExportFormat::Vtt => "vtt", + } + } + + /// Get all supported format names + pub fn all_names() -> &'static [&'static str] { + &["text", "txt", "markdown", "md", "json", "srt", "vtt"] + } +} + +impl std::fmt::Display for ExportFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExportFormat::Text => write!(f, "text"), + ExportFormat::Markdown => write!(f, "markdown"), + ExportFormat::Json => write!(f, "json"), + ExportFormat::Srt => write!(f, "srt"), + ExportFormat::Vtt => write!(f, "vtt"), + } + } +} + +/// Export errors +#[derive(Error, Debug)] +pub enum ExportError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Format not supported: {0}")] + UnsupportedFormat(String), +} + +/// Export options +#[derive(Debug, Clone, Default)] +pub struct ExportOptions { + /// Include timestamps + pub include_timestamps: bool, + /// Include speaker labels + pub include_speakers: bool, + /// Include metadata header + pub include_metadata: bool, + /// Line width for wrapping (0 = no wrap) + pub line_width: usize, +} + +/// Trait for meeting exporters +pub trait Exporter: Send + Sync { + /// Export meeting data to a string + fn export(&self, meeting: &MeetingData, options: &ExportOptions) + -> Result; + + /// Get the format name + fn format(&self) -> ExportFormat; +} + +/// Export meeting data to string in the specified format +pub fn export_meeting( + meeting: &MeetingData, + format: ExportFormat, + options: &ExportOptions, +) -> Result { + let exporter: Box = match format { + ExportFormat::Text => Box::new(txt::TextExporter), + ExportFormat::Markdown => Box::new(markdown::MarkdownExporter), + ExportFormat::Json => Box::new(json::JsonExporter), + ExportFormat::Srt => Box::new(srt::SrtExporter), + ExportFormat::Vtt => Box::new(vtt::VttExporter), + }; + + exporter.export(meeting, options) +} + +/// Export meeting data to a file +pub fn export_meeting_to_file( + meeting: &MeetingData, + format: ExportFormat, + options: &ExportOptions, + path: &std::path::Path, +) -> Result<(), ExportError> { + let content = export_meeting(meeting, format, options)?; + std::fs::write(path, content)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_from_str() { + assert_eq!(ExportFormat::parse("text"), Some(ExportFormat::Text)); + assert_eq!(ExportFormat::parse("txt"), Some(ExportFormat::Text)); + assert_eq!( + ExportFormat::parse("markdown"), + Some(ExportFormat::Markdown) + ); + assert_eq!(ExportFormat::parse("md"), Some(ExportFormat::Markdown)); + assert_eq!(ExportFormat::parse("json"), Some(ExportFormat::Json)); + assert_eq!(ExportFormat::parse("invalid"), None); + } + + #[test] + fn test_format_extension() { + assert_eq!(ExportFormat::Text.extension(), "txt"); + assert_eq!(ExportFormat::Markdown.extension(), "md"); + assert_eq!(ExportFormat::Json.extension(), "json"); + assert_eq!(ExportFormat::Srt.extension(), "srt"); + assert_eq!(ExportFormat::Vtt.extension(), "vtt"); + } + + #[test] + fn test_format_display() { + assert_eq!(ExportFormat::Text.to_string(), "text"); + assert_eq!(ExportFormat::Markdown.to_string(), "markdown"); + assert_eq!(ExportFormat::Json.to_string(), "json"); + assert_eq!(ExportFormat::Srt.to_string(), "srt"); + assert_eq!(ExportFormat::Vtt.to_string(), "vtt"); + } + + #[test] + fn test_format_from_str_case_insensitive() { + assert_eq!(ExportFormat::parse("TEXT"), Some(ExportFormat::Text)); + assert_eq!( + ExportFormat::parse("Markdown"), + Some(ExportFormat::Markdown) + ); + assert_eq!(ExportFormat::parse("JSON"), Some(ExportFormat::Json)); + assert_eq!(ExportFormat::parse("SRT"), Some(ExportFormat::Srt)); + assert_eq!(ExportFormat::parse("VTT"), Some(ExportFormat::Vtt)); + } + + #[test] + fn test_all_names() { + let names = ExportFormat::all_names(); + assert!(names.contains(&"text")); + assert!(names.contains(&"txt")); + assert!(names.contains(&"markdown")); + assert!(names.contains(&"md")); + assert!(names.contains(&"json")); + assert!(names.contains(&"srt")); + assert!(names.contains(&"vtt")); + } + + #[test] + fn test_export_meeting_text() { + use crate::meeting::data::{MeetingData, TranscriptSegment}; + + let mut meeting = MeetingData::new(Some("Test".to_string())); + meeting + .transcript + .add_segment(TranscriptSegment::new(0, 0, 1000, "Hello".to_string(), 0)); + + let result = export_meeting(&meeting, ExportFormat::Text, &ExportOptions::default()); + assert!(result.is_ok()); + assert!(result.unwrap().contains("Hello")); + } + + #[test] + fn test_export_meeting_srt() { + use crate::meeting::data::MeetingData; + + let meeting = MeetingData::new(Some("Test".to_string())); + let result = export_meeting(&meeting, ExportFormat::Srt, &ExportOptions::default()); + assert!(result.is_ok()); + } + + #[test] + fn test_export_meeting_vtt() { + use crate::meeting::data::MeetingData; + + let meeting = MeetingData::new(Some("Test".to_string())); + let result = export_meeting(&meeting, ExportFormat::Vtt, &ExportOptions::default()); + assert!(result.is_ok()); + } + + #[test] + fn test_export_options_default() { + let opts = ExportOptions::default(); + assert!(!opts.include_timestamps); + assert!(!opts.include_speakers); + assert!(!opts.include_metadata); + assert_eq!(opts.line_width, 0); + } +} diff --git a/src/meeting/export/srt.rs b/src/meeting/export/srt.rs new file mode 100644 index 00000000..f61223ac --- /dev/null +++ b/src/meeting/export/srt.rs @@ -0,0 +1,112 @@ +//! SRT (SubRip) subtitle export format +//! +//! Generates standard SRT subtitle files with optional speaker labels. + +use super::{ExportError, ExportFormat, ExportOptions, Exporter}; +use crate::meeting::data::MeetingData; + +/// SRT exporter +pub struct SrtExporter; + +impl Exporter for SrtExporter { + fn export( + &self, + meeting: &MeetingData, + options: &ExportOptions, + ) -> Result { + let mut output = String::new(); + let mut index = 1; + + for segment in &meeting.transcript.segments { + // Sequence number + output.push_str(&format!("{}\n", index)); + + // Timestamps: 00:00:00,000 --> 00:00:00,000 + let start = format_srt_time(segment.start_ms); + let end = format_srt_time(segment.end_ms); + output.push_str(&format!("{} --> {}\n", start, end)); + + // Text with optional speaker + if options.include_speakers { + let speaker = segment.speaker_display(); + if !speaker.is_empty() && speaker != "Unknown" { + output.push_str(&format!("[{}] {}\n", speaker, segment.text)); + } else { + output.push_str(&format!("{}\n", segment.text)); + } + } else { + output.push_str(&format!("{}\n", segment.text)); + } + + // Blank line between entries + output.push('\n'); + index += 1; + } + + Ok(output) + } + + fn format(&self) -> ExportFormat { + ExportFormat::Srt + } +} + +/// Format milliseconds as SRT timestamp (HH:MM:SS,mmm) +fn format_srt_time(ms: u64) -> String { + let total_secs = ms / 1000; + let hours = total_secs / 3600; + let minutes = (total_secs % 3600) / 60; + let seconds = total_secs % 60; + let millis = ms % 1000; + + format!("{:02}:{:02}:{:02},{:03}", hours, minutes, seconds, millis) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_srt_time() { + assert_eq!(format_srt_time(0), "00:00:00,000"); + assert_eq!(format_srt_time(1500), "00:00:01,500"); + assert_eq!(format_srt_time(65000), "00:01:05,000"); + assert_eq!(format_srt_time(3661500), "01:01:01,500"); + } + + #[test] + fn test_srt_export() { + use crate::meeting::data::{AudioSource, MeetingData, TranscriptSegment}; + + let mut meeting = MeetingData::new(Some("Test".to_string())); + meeting.transcript.add_segment(TranscriptSegment::new( + 1, + 0, + 2000, + "Hello world".to_string(), + 0, + )); + meeting.transcript.segments[0].source = AudioSource::Microphone; + + meeting.transcript.add_segment(TranscriptSegment::new( + 2, + 2500, + 5000, + "How are you".to_string(), + 0, + )); + meeting.transcript.segments[1].source = AudioSource::Loopback; + + let exporter = SrtExporter; + let options = ExportOptions { + include_speakers: true, + ..Default::default() + }; + let output = exporter.export(&meeting, &options).unwrap(); + + assert!(output.contains("1\n")); + assert!(output.contains("00:00:00,000 --> 00:00:02,000")); + assert!(output.contains("[You] Hello world")); + assert!(output.contains("[Remote] How are you")); + } +} diff --git a/src/meeting/export/txt.rs b/src/meeting/export/txt.rs new file mode 100644 index 00000000..8dde8f33 --- /dev/null +++ b/src/meeting/export/txt.rs @@ -0,0 +1,185 @@ +//! Plain text exporter for meeting transcriptions + +use crate::meeting::data::MeetingData; +use crate::meeting::export::{ExportError, ExportFormat, ExportOptions, Exporter}; + +/// Plain text exporter +pub struct TextExporter; + +impl Exporter for TextExporter { + fn format(&self) -> ExportFormat { + ExportFormat::Text + } + + fn export( + &self, + meeting: &MeetingData, + options: &ExportOptions, + ) -> Result { + let mut output = String::new(); + + // Metadata header + if options.include_metadata { + output.push_str(&meeting.metadata.display_title()); + output.push('\n'); + output.push_str(&format!( + "Date: {}\n", + meeting.metadata.started_at.format("%Y-%m-%d %H:%M") + )); + if let Some(duration) = meeting.metadata.duration_secs { + let mins = duration / 60; + let secs = duration % 60; + output.push_str(&format!("Duration: {}:{:02}\n", mins, secs)); + } + output.push_str(&format!("Words: {}\n", meeting.transcript.word_count())); + output.push('\n'); + output.push_str(&"=".repeat(60)); + output.push_str("\n\n"); + } + + // Transcript + let mut last_speaker = String::new(); + + for segment in &meeting.transcript.segments { + let mut line = String::new(); + + // Timestamp + if options.include_timestamps { + line.push_str(&format!("[{}] ", segment.format_timestamp())); + } + + // Speaker change + if options.include_speakers { + let speaker = segment.speaker_display(); + if speaker != last_speaker { + if !last_speaker.is_empty() { + output.push('\n'); + } + line.push_str(&format!("{}:\n", speaker)); + last_speaker = speaker; + } + } + + // Text + line.push_str(&segment.text); + + // Word wrap if configured + if options.line_width > 0 { + output.push_str(&wrap_text(&line, options.line_width)); + } else { + output.push_str(&line); + } + output.push('\n'); + } + + Ok(output) + } +} + +/// Simple word wrapping +fn wrap_text(text: &str, width: usize) -> String { + if width == 0 || text.len() <= width { + return text.to_string(); + } + + let mut result = String::new(); + let mut current_line = String::new(); + + for word in text.split_whitespace() { + if current_line.is_empty() { + current_line.push_str(word); + } else if current_line.len() + 1 + word.len() <= width { + current_line.push(' '); + current_line.push_str(word); + } else { + result.push_str(¤t_line); + result.push('\n'); + current_line = word.to_string(); + } + } + + if !current_line.is_empty() { + result.push_str(¤t_line); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::meeting::data::{MeetingMetadata, Transcript, TranscriptSegment}; + + fn create_test_meeting() -> MeetingData { + let mut meeting = MeetingData::new(Some("Test Meeting".to_string())); + meeting.transcript.add_segment(TranscriptSegment::new( + 0, + 0, + 5000, + "Hello world, this is a test.".to_string(), + 0, + )); + meeting.transcript.add_segment(TranscriptSegment::new( + 1, + 5000, + 10000, + "This is the second segment.".to_string(), + 0, + )); + meeting + } + + #[test] + fn test_text_export_basic() { + let meeting = create_test_meeting(); + let exporter = TextExporter; + let options = ExportOptions::default(); + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("Hello world")); + assert!(output.contains("second segment")); + } + + #[test] + fn test_text_export_with_timestamps() { + let meeting = create_test_meeting(); + let exporter = TextExporter; + let options = ExportOptions { + include_timestamps: true, + ..Default::default() + }; + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("[00:00]")); + assert!(output.contains("[00:05]")); + } + + #[test] + fn test_text_export_with_metadata() { + let meeting = create_test_meeting(); + let exporter = TextExporter; + let options = ExportOptions { + include_metadata: true, + ..Default::default() + }; + + let output = exporter.export(&meeting, &options).unwrap(); + assert!(output.contains("Test Meeting")); + assert!(output.contains("Date:")); + } + + #[test] + fn test_wrap_text() { + let text = "This is a long line that should be wrapped at a certain width."; + let wrapped = wrap_text(text, 20); + for line in wrapped.lines() { + assert!(line.len() <= 20 || !line.contains(' ')); + } + } + + #[test] + fn test_wrap_text_no_wrap() { + let text = "Short"; + assert_eq!(wrap_text(text, 80), "Short"); + } +} diff --git a/src/meeting/export/vtt.rs b/src/meeting/export/vtt.rs new file mode 100644 index 00000000..def4813e --- /dev/null +++ b/src/meeting/export/vtt.rs @@ -0,0 +1,136 @@ +//! WebVTT subtitle export format +//! +//! Generates WebVTT subtitle files with optional speaker labels and styling. + +use super::{ExportError, ExportFormat, ExportOptions, Exporter}; +use crate::meeting::data::MeetingData; + +/// VTT exporter +pub struct VttExporter; + +impl Exporter for VttExporter { + fn export( + &self, + meeting: &MeetingData, + options: &ExportOptions, + ) -> Result { + let mut output = String::from("WEBVTT\n"); + + // Add metadata if requested + if options.include_metadata { + output.push_str(&format!( + "NOTE\nMeeting: {}\nDate: {}\n", + meeting.metadata.display_title(), + meeting.metadata.started_at.format("%Y-%m-%d %H:%M:%S") + )); + if let Some(duration) = meeting.metadata.duration_secs { + output.push_str(&format!("Duration: {}s\n", duration)); + } + output.push('\n'); + } else { + output.push('\n'); + } + + for (i, segment) in meeting.transcript.segments.iter().enumerate() { + // Optional cue identifier + output.push_str(&format!("cue-{}\n", i + 1)); + + // Timestamps: 00:00:00.000 --> 00:00:00.000 + let start = format_vtt_time(segment.start_ms); + let end = format_vtt_time(segment.end_ms); + output.push_str(&format!("{} --> {}\n", start, end)); + + // Text with optional speaker (VTT supports voice spans) + if options.include_speakers { + let speaker = segment.speaker_display(); + if !speaker.is_empty() && speaker != "Unknown" { + output.push_str(&format!("{}\n", speaker, segment.text)); + } else { + output.push_str(&format!("{}\n", segment.text)); + } + } else { + output.push_str(&format!("{}\n", segment.text)); + } + + // Blank line between cues + output.push('\n'); + } + + Ok(output) + } + + fn format(&self) -> ExportFormat { + ExportFormat::Vtt + } +} + +/// Format milliseconds as VTT timestamp (HH:MM:SS.mmm) +fn format_vtt_time(ms: u64) -> String { + let total_secs = ms / 1000; + let hours = total_secs / 3600; + let minutes = (total_secs % 3600) / 60; + let seconds = total_secs % 60; + let millis = ms % 1000; + + format!("{:02}:{:02}:{:02}.{:03}", hours, minutes, seconds, millis) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_vtt_time() { + assert_eq!(format_vtt_time(0), "00:00:00.000"); + assert_eq!(format_vtt_time(1500), "00:00:01.500"); + assert_eq!(format_vtt_time(65000), "00:01:05.000"); + assert_eq!(format_vtt_time(3661500), "01:01:01.500"); + } + + #[test] + fn test_vtt_export_header() { + use crate::meeting::data::MeetingData; + + let meeting = MeetingData::new(Some("Test".to_string())); + let exporter = VttExporter; + let options = ExportOptions::default(); + let output = exporter.export(&meeting, &options).unwrap(); + + assert!(output.starts_with("WEBVTT\n")); + } + + #[test] + fn test_vtt_export_with_speakers() { + use crate::meeting::data::{MeetingData, TranscriptSegment}; + + let mut meeting = MeetingData::new(Some("Test".to_string())); + let mut seg = TranscriptSegment::new(1, 0, 2000, "Hello world".to_string(), 0); + seg.speaker_label = Some("Alice".to_string()); + meeting.transcript.add_segment(seg); + + let exporter = VttExporter; + let options = ExportOptions { + include_speakers: true, + ..Default::default() + }; + let output = exporter.export(&meeting, &options).unwrap(); + + assert!(output.contains("Hello world")); + } + + #[test] + fn test_vtt_export_with_metadata() { + use crate::meeting::data::MeetingData; + + let meeting = MeetingData::new(Some("Important Meeting".to_string())); + let exporter = VttExporter; + let options = ExportOptions { + include_metadata: true, + ..Default::default() + }; + let output = exporter.export(&meeting, &options).unwrap(); + + assert!(output.contains("NOTE")); + assert!(output.contains("Meeting: Important Meeting")); + } +} diff --git a/src/meeting/mod.rs b/src/meeting/mod.rs new file mode 100644 index 00000000..ad6a11b8 --- /dev/null +++ b/src/meeting/mod.rs @@ -0,0 +1,365 @@ +//! Meeting transcription mode +//! +//! Provides continuous meeting transcription with chunked processing, +//! speaker attribution, and export capabilities. +//! +//! Enables transcription of longer meetings (up to 3 hours) with +//! automatic chunking and speaker separation. +//! +//! # Architecture +//! +//! ```text +//! Mic + Loopback → ChunkProcessor → VAD → Transcription → Storage +//! ↓ +//! Diarization (Phase 3) +//! ``` +//! +//! # Phases +//! +//! - **Phase 1 (v0.5.0):** Basic meeting mode with chunked processing +//! - **Phase 2 (v0.5.1):** Dual audio + simple You/Remote attribution +//! - **Phase 3 (v0.5.2):** ML-based speaker diarization +//! - **Phase 4 (v0.6.0):** Remote server sync for corporate deployments +//! - **Phase 5 (v0.6.1):** AI summarization with action items + +pub mod chunk; +pub mod data; +pub mod diarization; +pub mod export; +pub mod state; +pub mod storage; +pub mod summary; + +pub use chunk::{ChunkBuffer, ChunkConfig, ChunkProcessor, ProcessedChunk, VoiceActivityDetector}; +pub use data::{ + ActionItem, AudioSource, MeetingData, MeetingId, MeetingMetadata, MeetingStatus, + MeetingSummary, Transcript, TranscriptSegment, +}; +pub use export::{export_meeting, export_meeting_to_file, ExportFormat, ExportOptions}; +pub use state::{ChunkState, MeetingState}; +pub use storage::{MeetingStorage, StorageConfig, StorageError}; + +use crate::error::{MeetingError, Result}; +use crate::transcribe::{self, Transcriber}; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Meeting daemon configuration +#[derive(Debug, Clone)] +pub struct MeetingConfig { + /// Enable meeting mode + pub enabled: bool, + /// Duration of each audio chunk in seconds + pub chunk_duration_secs: u32, + /// Storage configuration + pub storage: StorageConfig, + /// Whether to retain raw audio files + pub retain_audio: bool, + /// Maximum meeting duration in minutes (0 = unlimited) + pub max_duration_mins: u32, +} + +impl Default for MeetingConfig { + fn default() -> Self { + Self { + enabled: false, + chunk_duration_secs: 30, + storage: StorageConfig::default(), + retain_audio: false, + max_duration_mins: 180, + } + } +} + +/// Events from the meeting daemon +#[derive(Debug)] +pub enum MeetingEvent { + /// Meeting started + Started { meeting_id: MeetingId }, + /// Chunk processed + ChunkProcessed { + chunk_id: u32, + segments: Vec, + }, + /// Meeting paused + Paused, + /// Meeting resumed + Resumed, + /// Meeting stopped + Stopped { meeting_id: MeetingId }, + /// Error occurred + Error(String), +} + +/// Meeting daemon for continuous transcription +pub struct MeetingDaemon { + config: MeetingConfig, + state: MeetingState, + storage: MeetingStorage, + current_meeting: Option, + transcriber: Option>, + engine_name: String, + event_tx: mpsc::Sender, +} + +impl MeetingDaemon { + /// Create a new meeting daemon + pub fn new( + config: MeetingConfig, + app_config: &crate::config::Config, + event_tx: mpsc::Sender, + ) -> Result { + let storage = MeetingStorage::open(config.storage.clone()) + .map_err(|e| MeetingError::Storage(e.to_string()))?; + + let transcriber: Arc = + Arc::from(transcribe::create_transcriber(app_config)?); + let engine_name = format!("{:?}", app_config.engine).to_lowercase(); + + Ok(Self { + config, + state: MeetingState::Idle, + storage, + current_meeting: None, + transcriber: Some(transcriber), + engine_name, + event_tx, + }) + } + + /// Start a new meeting + pub async fn start(&mut self, title: Option) -> Result { + if !self.state.is_idle() { + return Err(MeetingError::AlreadyInProgress.into()); + } + + // Create meeting + let mut meeting = MeetingData::new(title); + meeting.metadata.model = Some(self.engine_name.clone()); + + // Create storage directory + let storage_path = self + .storage + .create_meeting(&meeting.metadata) + .map_err(|e| MeetingError::Storage(e.to_string()))?; + meeting.metadata.storage_path = Some(storage_path); + + let meeting_id = meeting.metadata.id; + self.current_meeting = Some(meeting); + self.state = MeetingState::start(); + + let _ = self + .event_tx + .send(MeetingEvent::Started { meeting_id }) + .await; + tracing::info!("Meeting started: {}", meeting_id); + + Ok(meeting_id) + } + + /// Pause the current meeting + pub async fn pause(&mut self) -> Result<()> { + if !self.state.is_active() { + return Err(MeetingError::NotActive.into()); + } + + self.state = std::mem::take(&mut self.state).pause(); + let _ = self.event_tx.send(MeetingEvent::Paused).await; + tracing::info!("Meeting paused"); + + Ok(()) + } + + /// Resume a paused meeting + pub async fn resume(&mut self) -> Result<()> { + if !self.state.is_paused() { + return Err(MeetingError::NotPaused.into()); + } + + self.state = std::mem::take(&mut self.state).resume(); + let _ = self.event_tx.send(MeetingEvent::Resumed).await; + tracing::info!("Meeting resumed"); + + Ok(()) + } + + /// Stop the current meeting + pub async fn stop(&mut self) -> Result { + if self.state.is_idle() { + return Err(MeetingError::NotInProgress.into()); + } + + self.state = std::mem::take(&mut self.state).stop(); + + // Finalize meeting + if let Some(ref mut meeting) = self.current_meeting { + meeting.complete(); + meeting.metadata.chunk_count = meeting.transcript.total_chunks; + + // Save transcript + self.storage + .save_transcript(&meeting.metadata.id, &meeting.transcript) + .map_err(|e| MeetingError::Storage(e.to_string()))?; + + // Update metadata + self.storage + .update_meeting(&meeting.metadata) + .map_err(|e| MeetingError::Storage(e.to_string()))?; + } + + let meeting_id = self + .current_meeting + .as_ref() + .map(|m| m.metadata.id) + .unwrap_or_default(); + + let _ = self + .event_tx + .send(MeetingEvent::Stopped { meeting_id }) + .await; + tracing::info!("Meeting stopped: {}", meeting_id); + + // Clean up + self.state = std::mem::take(&mut self.state).finalize(); + self.current_meeting = None; + + Ok(meeting_id) + } + + /// Get current meeting state + pub fn state(&self) -> &MeetingState { + &self.state + } + + /// Get current meeting ID if one is active + pub fn current_meeting_id(&self) -> Option { + self.current_meeting.as_ref().map(|m| m.metadata.id) + } + + /// Process a chunk of audio + pub async fn process_chunk( + &mut self, + samples: Vec, + ) -> Result>> { + if !self.state.is_active() { + return Ok(None); + } + + let Some(ref transcriber) = self.transcriber else { + return Err(MeetingError::TranscriberNotInitialized.into()); + }; + + let chunk_id = self.state.chunks_processed(); + let chunk_config = ChunkConfig { + chunk_duration_secs: self.config.chunk_duration_secs, + ..Default::default() + }; + + // Calculate start offset + let start_offset_ms = if let Some(ref meeting) = self.current_meeting { + meeting.transcript.duration_ms() + } else { + 0 + }; + + let mut processor = ChunkProcessor::new(chunk_config, transcriber.clone()); + let mut buffer = processor.new_buffer(chunk_id, AudioSource::Microphone, start_offset_ms); + buffer.add_samples(&samples); + + let result = processor + .process_chunk(buffer) + .map_err(crate::error::VoxtypeError::Transcribe)?; + + // Add segments to transcript + if let Some(ref mut meeting) = self.current_meeting { + for segment in &result.segments { + meeting.transcript.add_segment(segment.clone()); + } + meeting.transcript.total_chunks = chunk_id + 1; + } + + // Advance state + self.state = std::mem::take(&mut self.state).next_chunk(); + + // Send event + let _ = self + .event_tx + .send(MeetingEvent::ChunkProcessed { + chunk_id, + segments: result.segments.clone(), + }) + .await; + + Ok(Some(result.segments)) + } + + /// Get storage access + pub fn storage(&self) -> &MeetingStorage { + &self.storage + } +} + +/// List meetings from storage +pub fn list_meetings( + config: &MeetingConfig, + limit: Option, +) -> std::result::Result, StorageError> { + let storage = MeetingStorage::open(config.storage.clone())?; + storage.list_meetings(limit) +} + +/// Get a meeting by ID (or "latest") +pub fn get_meeting( + config: &MeetingConfig, + id_str: &str, +) -> std::result::Result { + let storage = MeetingStorage::open(config.storage.clone())?; + let id = storage.resolve_meeting_id(id_str)?; + storage.load_meeting_data(&id) +} + +/// Export a meeting +pub fn export_meeting_by_id( + config: &MeetingConfig, + id_str: &str, + format: ExportFormat, + options: &ExportOptions, +) -> std::result::Result { + let meeting = get_meeting(config, id_str)?; + export_meeting(&meeting, format, options) + .map_err(|e| StorageError::Io(std::io::Error::other(e.to_string()))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_meeting_config_default() { + let config = MeetingConfig::default(); + assert!(!config.enabled); + assert_eq!(config.chunk_duration_secs, 30); + assert_eq!(config.max_duration_mins, 180); + } + + #[test] + fn test_meeting_state_transitions() { + let state = MeetingState::Idle; + assert!(state.is_idle()); + + let state = MeetingState::start(); + assert!(state.is_active()); + + let state = state.pause(); + assert!(state.is_paused()); + + let state = state.resume(); + assert!(state.is_active()); + + let state = state.stop(); + assert!(state.is_finalizing()); + + let state = state.finalize(); + assert!(state.is_idle()); + } +} diff --git a/src/meeting/state.rs b/src/meeting/state.rs new file mode 100644 index 00000000..ee38b623 --- /dev/null +++ b/src/meeting/state.rs @@ -0,0 +1,587 @@ +//! State machine for meeting transcription mode +//! +//! Defines the states for continuous meeting recording: +//! Idle -> Active -> Paused -> Active -> Finalizing -> Idle + +use std::time::Instant; + +/// State of an individual audio chunk being processed +#[derive(Debug, Clone)] +pub enum ChunkState { + /// Recording audio for this chunk + Recording { + /// When this chunk started recording + started_at: Instant, + }, + /// Processing this chunk (transcribing) + Processing { + /// Sequential ID of this chunk + chunk_id: u32, + }, +} + +impl ChunkState { + /// Check if this chunk is currently recording + pub fn is_recording(&self) -> bool { + matches!(self, ChunkState::Recording { .. }) + } + + /// Get the recording duration if currently recording + pub fn recording_duration(&self) -> Option { + match self { + ChunkState::Recording { started_at } => Some(started_at.elapsed()), + _ => None, + } + } +} + +/// Meeting transcription state +#[derive(Debug, Clone)] +#[derive(Default)] +pub enum MeetingState { + /// No meeting in progress + #[default] + Idle, + + /// Meeting is actively recording + Active { + /// When the meeting started + started_at: Instant, + /// Current chunk being processed + current_chunk: ChunkState, + /// Number of chunks processed so far + chunks_processed: u32, + }, + + /// Meeting is temporarily paused + Paused { + /// When the meeting started + started_at: Instant, + /// When the meeting was paused + paused_at: Instant, + /// Number of chunks processed before pause + chunks_processed: u32, + }, + + /// Meeting has ended, finalizing (processing last chunk, saving) + Finalizing { + /// When the meeting started + started_at: Instant, + /// When the meeting was stopped + ended_at: Instant, + /// Total chunks processed + total_chunks: u32, + }, +} + + +impl MeetingState { + /// Create a new idle state + pub fn new() -> Self { + MeetingState::Idle + } + + /// Check if in idle state + pub fn is_idle(&self) -> bool { + matches!(self, MeetingState::Idle) + } + + /// Check if actively recording + pub fn is_active(&self) -> bool { + matches!(self, MeetingState::Active { .. }) + } + + /// Check if paused + pub fn is_paused(&self) -> bool { + matches!(self, MeetingState::Paused { .. }) + } + + /// Check if finalizing + pub fn is_finalizing(&self) -> bool { + matches!(self, MeetingState::Finalizing { .. }) + } + + /// Get meeting duration (including paused time) + pub fn meeting_duration(&self) -> Option { + match self { + MeetingState::Idle => None, + MeetingState::Active { started_at, .. } => Some(started_at.elapsed()), + MeetingState::Paused { started_at, .. } => Some(started_at.elapsed()), + MeetingState::Finalizing { + started_at, + ended_at, + .. + } => Some(ended_at.duration_since(*started_at)), + } + } + + /// Alias for meeting_duration - time elapsed since meeting started + pub fn elapsed(&self) -> Option { + self.meeting_duration() + } + + /// Get number of chunks processed + pub fn chunks_processed(&self) -> u32 { + match self { + MeetingState::Idle => 0, + MeetingState::Active { + chunks_processed, .. + } => *chunks_processed, + MeetingState::Paused { + chunks_processed, .. + } => *chunks_processed, + MeetingState::Finalizing { total_chunks, .. } => *total_chunks, + } + } + + /// Start a new meeting + pub fn start() -> Self { + let now = Instant::now(); + MeetingState::Active { + started_at: now, + current_chunk: ChunkState::Recording { started_at: now }, + chunks_processed: 0, + } + } + + /// Pause the current meeting (only valid from Active state) + pub fn pause(self) -> Self { + match self { + MeetingState::Active { + started_at, + chunks_processed, + .. + } => MeetingState::Paused { + started_at, + paused_at: Instant::now(), + chunks_processed, + }, + other => other, // No-op for other states + } + } + + /// Resume a paused meeting (only valid from Paused state) + pub fn resume(self) -> Self { + match self { + MeetingState::Paused { + started_at, + chunks_processed, + .. + } => MeetingState::Active { + started_at, + current_chunk: ChunkState::Recording { + started_at: Instant::now(), + }, + chunks_processed, + }, + other => other, // No-op for other states + } + } + + /// Stop the meeting and begin finalization (valid from Active or Paused) + pub fn stop(self) -> Self { + match self { + MeetingState::Active { + started_at, + chunks_processed, + .. + } => MeetingState::Finalizing { + started_at, + ended_at: Instant::now(), + total_chunks: chunks_processed, + }, + MeetingState::Paused { + started_at, + chunks_processed, + .. + } => MeetingState::Finalizing { + started_at, + ended_at: Instant::now(), + total_chunks: chunks_processed, + }, + other => other, // No-op for idle/finalizing + } + } + + /// Complete finalization and return to idle + pub fn finalize(self) -> Self { + match self { + MeetingState::Finalizing { .. } => MeetingState::Idle, + other => other, // No-op for other states + } + } + + /// Advance to the next chunk (only valid in Active state) + pub fn next_chunk(self) -> Self { + match self { + MeetingState::Active { + started_at, + chunks_processed, + .. + } => MeetingState::Active { + started_at, + current_chunk: ChunkState::Recording { + started_at: Instant::now(), + }, + chunks_processed: chunks_processed + 1, + }, + other => other, + } + } + + /// Mark current chunk as processing + pub fn processing_chunk(self, chunk_id: u32) -> Self { + match self { + MeetingState::Active { + started_at, + chunks_processed, + .. + } => MeetingState::Active { + started_at, + current_chunk: ChunkState::Processing { chunk_id }, + chunks_processed, + }, + other => other, + } + } +} + +impl std::fmt::Display for MeetingState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MeetingState::Idle => write!(f, "Idle"), + MeetingState::Active { + started_at, + chunks_processed, + .. + } => { + write!( + f, + "Active ({:.0}m, {} chunks)", + started_at.elapsed().as_secs_f32() / 60.0, + chunks_processed + ) + } + MeetingState::Paused { + paused_at, + chunks_processed, + .. + } => { + write!( + f, + "Paused ({:.0}s ago, {} chunks)", + paused_at.elapsed().as_secs_f32(), + chunks_processed + ) + } + MeetingState::Finalizing { total_chunks, .. } => { + write!(f, "Finalizing ({} chunks)", total_chunks) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_new_state_is_idle() { + let state = MeetingState::new(); + assert!(state.is_idle()); + } + + #[test] + fn test_start_meeting() { + let state = MeetingState::start(); + assert!(state.is_active()); + assert_eq!(state.chunks_processed(), 0); + } + + #[test] + fn test_pause_resume() { + let state = MeetingState::start(); + let paused = state.pause(); + assert!(paused.is_paused()); + + let resumed = paused.resume(); + assert!(resumed.is_active()); + } + + #[test] + fn test_stop_meeting() { + let state = MeetingState::start(); + let stopped = state.stop(); + assert!(stopped.is_finalizing()); + } + + #[test] + fn test_finalize_meeting() { + let state = MeetingState::start(); + let stopped = state.stop(); + let finalized = stopped.finalize(); + assert!(finalized.is_idle()); + } + + #[test] + fn test_next_chunk() { + let state = MeetingState::start(); + assert_eq!(state.chunks_processed(), 0); + + let state = state.next_chunk(); + assert_eq!(state.chunks_processed(), 1); + + let state = state.next_chunk(); + assert_eq!(state.chunks_processed(), 2); + } + + #[test] + fn test_meeting_duration() { + let state = MeetingState::start(); + std::thread::sleep(Duration::from_millis(10)); + let duration = state.meeting_duration().unwrap(); + assert!(duration >= Duration::from_millis(10)); + } + + #[test] + fn test_idle_has_no_duration() { + let state = MeetingState::Idle; + assert!(state.meeting_duration().is_none()); + } + + #[test] + fn test_stop_from_paused() { + let state = MeetingState::start(); + let state = state.next_chunk().next_chunk(); + let paused = state.pause(); + assert!(paused.is_paused()); + assert_eq!(paused.chunks_processed(), 2); + + let stopped = paused.stop(); + assert!(stopped.is_finalizing()); + assert_eq!(stopped.chunks_processed(), 2); + } + + #[test] + fn test_processing_chunk() { + let state = MeetingState::start(); + let state = state.processing_chunk(0); + assert!(state.is_active()); + if let MeetingState::Active { current_chunk, .. } = &state { + assert!(!current_chunk.is_recording()); + } else { + panic!("Expected Active state"); + } + } + + #[test] + fn test_pause_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.pause(); + assert!(state.is_idle()); + } + + #[test] + fn test_resume_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.resume(); + assert!(state.is_idle()); + } + + #[test] + fn test_stop_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.stop(); + assert!(state.is_idle()); + } + + #[test] + fn test_finalize_active_is_noop() { + let state = MeetingState::start(); + let state = state.finalize(); + assert!(state.is_active()); + } + + #[test] + fn test_next_chunk_paused_is_noop() { + let state = MeetingState::start().pause(); + let state = state.next_chunk(); + assert!(state.is_paused()); + } + + #[test] + fn test_display_trait() { + let state = MeetingState::Idle; + assert_eq!(format!("{}", state), "Idle"); + + let state = MeetingState::start(); + let display = format!("{}", state); + assert!(display.starts_with("Active")); + assert!(display.contains("0 chunks")); + } + + #[test] + fn test_chunks_processed_in_paused() { + let state = MeetingState::start().next_chunk().next_chunk().next_chunk(); + assert_eq!(state.chunks_processed(), 3); + let paused = state.pause(); + assert_eq!(paused.chunks_processed(), 3); + } + + #[test] + fn test_meeting_duration_active() { + let state = MeetingState::start(); + assert!(state.meeting_duration().is_some()); + } + + #[test] + fn test_chunk_state_recording_duration() { + let chunk = ChunkState::Recording { + started_at: Instant::now(), + }; + assert!(chunk.is_recording()); + assert!(chunk.recording_duration().is_some()); + } + + #[test] + fn test_chunk_state_processing_no_duration() { + let chunk = ChunkState::Processing { chunk_id: 5 }; + assert!(!chunk.is_recording()); + assert!(chunk.recording_duration().is_none()); + } + + #[test] + fn test_default_is_idle() { + let state = MeetingState::default(); + assert!(state.is_idle()); + } + + #[test] + fn test_resume_active_is_noop() { + let state = MeetingState::start(); + assert!(state.is_active()); + let state = state.resume(); + assert!(state.is_active()); + } + + #[test] + fn test_pause_finalizing_is_noop() { + let state = MeetingState::start().stop(); + assert!(state.is_finalizing()); + let state = state.pause(); + assert!(state.is_finalizing()); + } + + #[test] + fn test_resume_finalizing_is_noop() { + let state = MeetingState::start().stop(); + assert!(state.is_finalizing()); + let state = state.resume(); + assert!(state.is_finalizing()); + } + + #[test] + fn test_stop_finalizing_is_noop() { + let state = MeetingState::start().stop(); + assert!(state.is_finalizing()); + let state = state.stop(); + assert!(state.is_finalizing()); + } + + #[test] + fn test_finalize_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.finalize(); + assert!(state.is_idle()); + } + + #[test] + fn test_finalize_paused_is_noop() { + let state = MeetingState::start().pause(); + assert!(state.is_paused()); + let state = state.finalize(); + assert!(state.is_paused()); + } + + #[test] + fn test_next_chunk_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.next_chunk(); + assert!(state.is_idle()); + } + + #[test] + fn test_next_chunk_finalizing_is_noop() { + let state = MeetingState::start().stop(); + let chunks_before = state.chunks_processed(); + let state = state.next_chunk(); + assert!(state.is_finalizing()); + assert_eq!(state.chunks_processed(), chunks_before); + } + + #[test] + fn test_processing_chunk_idle_is_noop() { + let state = MeetingState::Idle; + let state = state.processing_chunk(0); + assert!(state.is_idle()); + } + + #[test] + fn test_processing_chunk_paused_is_noop() { + let state = MeetingState::start().pause(); + let state = state.processing_chunk(0); + assert!(state.is_paused()); + } + + #[test] + fn test_full_lifecycle_with_chunks() { + let state = MeetingState::start(); + assert!(state.is_active()); + assert_eq!(state.chunks_processed(), 0); + + let state = state.next_chunk().next_chunk().next_chunk(); + assert_eq!(state.chunks_processed(), 3); + + let state = state.pause(); + assert_eq!(state.chunks_processed(), 3); + + let state = state.resume(); + assert_eq!(state.chunks_processed(), 3); + + let state = state.next_chunk(); + assert_eq!(state.chunks_processed(), 4); + + let state = state.stop(); + assert!(state.is_finalizing()); + assert_eq!(state.chunks_processed(), 4); + + let state = state.finalize(); + assert!(state.is_idle()); + assert_eq!(state.chunks_processed(), 0); + } + + #[test] + fn test_elapsed_alias() { + let state = MeetingState::Idle; + assert!(state.elapsed().is_none()); + + let state = MeetingState::start(); + assert!(state.elapsed().is_some()); + } + + #[test] + fn test_display_paused() { + let state = MeetingState::start().pause(); + let display = format!("{}", state); + assert!(display.starts_with("Paused")); + } + + #[test] + fn test_display_finalizing() { + let state = MeetingState::start().next_chunk().next_chunk().stop(); + let display = format!("{}", state); + assert!(display.contains("Finalizing")); + assert!(display.contains("2 chunks")); + } +} diff --git a/src/meeting/storage.rs b/src/meeting/storage.rs new file mode 100644 index 00000000..bd2fe71a --- /dev/null +++ b/src/meeting/storage.rs @@ -0,0 +1,920 @@ +//! Storage layer for meeting transcription +//! +//! Provides SQLite-based index for meeting metadata and filesystem +//! storage for transcripts and audio files. + +use crate::meeting::data::{MeetingData, MeetingId, MeetingMetadata, MeetingStatus, Transcript}; +use chrono::{DateTime, TimeZone, Utc}; +use rusqlite::{params, Connection, OptionalExtension}; +use std::path::PathBuf; +use thiserror::Error; + +/// Storage-related errors +#[derive(Error, Debug)] +pub enum StorageError { + #[error("Database error: {0}")] + Database(#[from] rusqlite::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON serialization error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Meeting not found: {0}")] + NotFound(String), + + #[error("Storage path not configured")] + PathNotConfigured, +} + +/// Meeting storage configuration +#[derive(Debug, Clone)] +pub struct StorageConfig { + /// Base path for meeting storage + /// "auto" will use ~/.local/share/voxtype/meetings/ + pub storage_path: PathBuf, + /// Whether to retain audio files + pub retain_audio: bool, + /// Maximum number of meetings to keep (0 = unlimited) + pub max_meetings: u32, +} + +impl Default for StorageConfig { + fn default() -> Self { + Self { + storage_path: Self::default_storage_path(), + retain_audio: false, + max_meetings: 0, + } + } +} + +impl StorageConfig { + /// Get the default storage path + pub fn default_storage_path() -> PathBuf { + directories::ProjectDirs::from("", "", "voxtype") + .map(|dirs| dirs.data_dir().join("meetings")) + .unwrap_or_else(|| PathBuf::from("~/.local/share/voxtype/meetings")) + } + + /// Get the database path + pub fn db_path(&self) -> PathBuf { + self.storage_path.join("index.db") + } +} + +/// Meeting storage manager +pub struct MeetingStorage { + config: StorageConfig, + conn: Connection, +} + +impl MeetingStorage { + /// Open or create meeting storage + pub fn open(config: StorageConfig) -> Result { + // Ensure storage directory exists + std::fs::create_dir_all(&config.storage_path)?; + + let db_path = config.db_path(); + let conn = Connection::open(&db_path)?; + + let storage = Self { config, conn }; + storage.init_schema()?; + + Ok(storage) + } + + /// Initialize database schema + fn init_schema(&self) -> Result<(), StorageError> { + self.conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS meetings ( + id TEXT PRIMARY KEY, + title TEXT, + started_at INTEGER NOT NULL, + ended_at INTEGER, + duration_secs INTEGER, + status TEXT NOT NULL DEFAULT 'active', + chunk_count INTEGER NOT NULL DEFAULT 0, + storage_path TEXT, + audio_retained INTEGER NOT NULL DEFAULT 0, + model TEXT, + synced_at INTEGER, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) + ); + + CREATE INDEX IF NOT EXISTS idx_meetings_started_at ON meetings(started_at DESC); + CREATE INDEX IF NOT EXISTS idx_meetings_status ON meetings(status); + + -- Speaker labels for ML diarization (Phase 3) + CREATE TABLE IF NOT EXISTS speaker_labels ( + meeting_id TEXT NOT NULL, + speaker_num INTEGER NOT NULL, + label TEXT NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + PRIMARY KEY (meeting_id, speaker_num), + FOREIGN KEY (meeting_id) REFERENCES meetings(id) ON DELETE CASCADE + ); + "#, + )?; + Ok(()) + } + + /// Create a new meeting + pub fn create_meeting(&self, metadata: &MeetingMetadata) -> Result { + // Create meeting directory + let meeting_dir = self.config.storage_path.join(metadata.storage_dir_name()); + std::fs::create_dir_all(&meeting_dir)?; + + // Insert into database + self.conn.execute( + r#" + INSERT INTO meetings (id, title, started_at, status, storage_path, audio_retained, model) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + "#, + params![ + metadata.id.to_string(), + metadata.title, + metadata.started_at.timestamp(), + status_to_string(metadata.status), + meeting_dir.to_string_lossy().to_string(), + metadata.audio_retained as i32, + metadata.model, + ], + )?; + + // Write initial metadata file + let metadata_path = meeting_dir.join("metadata.json"); + let json = serde_json::to_string_pretty(metadata)?; + std::fs::write(&metadata_path, json)?; + + Ok(meeting_dir) + } + + /// Update meeting metadata + pub fn update_meeting(&self, metadata: &MeetingMetadata) -> Result<(), StorageError> { + self.conn.execute( + r#" + UPDATE meetings SET + title = ?2, + ended_at = ?3, + duration_secs = ?4, + status = ?5, + chunk_count = ?6, + audio_retained = ?7, + model = ?8, + synced_at = ?9 + WHERE id = ?1 + "#, + params![ + metadata.id.to_string(), + metadata.title, + metadata.ended_at.map(|dt| dt.timestamp()), + metadata.duration_secs.map(|d| d as i64), + status_to_string(metadata.status), + metadata.chunk_count as i32, + metadata.audio_retained as i32, + metadata.model, + metadata.synced_at.map(|dt| dt.timestamp()), + ], + )?; + + // Update metadata file if storage path exists + if let Some(ref path) = metadata.storage_path { + let metadata_path = path.join("metadata.json"); + let json = serde_json::to_string_pretty(metadata)?; + std::fs::write(metadata_path, json)?; + } + + Ok(()) + } + + /// Get meeting by ID + pub fn get_meeting(&self, id: &MeetingId) -> Result, StorageError> { + let result = self + .conn + .query_row( + r#" + SELECT id, title, started_at, ended_at, duration_secs, status, + chunk_count, storage_path, audio_retained, model, synced_at + FROM meetings WHERE id = ?1 + "#, + params![id.to_string()], + |row| { + Ok(MeetingMetadata { + id: MeetingId::parse(&row.get::<_, String>(0)?).unwrap_or_default(), + title: row.get(1)?, + started_at: timestamp_to_datetime(row.get(2)?), + ended_at: row.get::<_, Option>(3)?.map(timestamp_to_datetime), + duration_secs: row.get::<_, Option>(4)?.map(|d| d as u64), + status: string_to_status(&row.get::<_, String>(5)?), + chunk_count: row.get::<_, i32>(6)? as u32, + storage_path: row.get::<_, Option>(7)?.map(PathBuf::from), + audio_retained: row.get::<_, i32>(8)? != 0, + model: row.get(9)?, + summary: None, + synced_at: row.get::<_, Option>(10)?.map(timestamp_to_datetime), + }) + }, + ) + .optional()?; + + Ok(result) + } + + /// List meetings with optional limit + pub fn list_meetings(&self, limit: Option) -> Result, StorageError> { + let sql = if limit.is_some() { + r#" + SELECT id, title, started_at, ended_at, duration_secs, status, + chunk_count, storage_path, audio_retained, model, synced_at + FROM meetings + ORDER BY started_at DESC + LIMIT ?1 + "# + } else { + r#" + SELECT id, title, started_at, ended_at, duration_secs, status, + chunk_count, storage_path, audio_retained, model, synced_at + FROM meetings + ORDER BY started_at DESC + "# + }; + + let mut stmt = self.conn.prepare(sql)?; + let row_mapper = |row: &rusqlite::Row| { + Ok(MeetingMetadata { + id: MeetingId::parse(&row.get::<_, String>(0)?).unwrap_or_default(), + title: row.get(1)?, + started_at: timestamp_to_datetime(row.get(2)?), + ended_at: row.get::<_, Option>(3)?.map(timestamp_to_datetime), + duration_secs: row.get::<_, Option>(4)?.map(|d| d as u64), + status: string_to_status(&row.get::<_, String>(5)?), + chunk_count: row.get::<_, i32>(6)? as u32, + storage_path: row.get::<_, Option>(7)?.map(PathBuf::from), + audio_retained: row.get::<_, i32>(8)? != 0, + model: row.get(9)?, + summary: None, + synced_at: row.get::<_, Option>(10)?.map(timestamp_to_datetime), + }) + }; + + let meetings = if let Some(limit) = limit { + stmt.query_map(params![limit], row_mapper)? + .collect::, _>>()? + } else { + stmt.query_map([], row_mapper)? + .collect::, _>>()? + }; + + Ok(meetings) + } + + /// Get the most recent meeting + pub fn get_latest_meeting(&self) -> Result, StorageError> { + let meetings = self.list_meetings(Some(1))?; + Ok(meetings.into_iter().next()) + } + + /// Save transcript to filesystem + pub fn save_transcript( + &self, + meeting_id: &MeetingId, + transcript: &Transcript, + ) -> Result<(), StorageError> { + let metadata = self + .get_meeting(meeting_id)? + .ok_or_else(|| StorageError::NotFound(meeting_id.to_string()))?; + + let storage_path = metadata + .storage_path + .ok_or(StorageError::PathNotConfigured)?; + + let transcript_path = storage_path.join("transcript.json"); + let json = serde_json::to_string_pretty(transcript)?; + std::fs::write(transcript_path, json)?; + + Ok(()) + } + + /// Load transcript from filesystem + pub fn load_transcript(&self, meeting_id: &MeetingId) -> Result { + let metadata = self + .get_meeting(meeting_id)? + .ok_or_else(|| StorageError::NotFound(meeting_id.to_string()))?; + + let storage_path = metadata + .storage_path + .ok_or(StorageError::PathNotConfigured)?; + + let transcript_path = storage_path.join("transcript.json"); + let json = std::fs::read_to_string(transcript_path)?; + let transcript: Transcript = serde_json::from_str(&json)?; + + Ok(transcript) + } + + /// Load complete meeting data (metadata + transcript) + pub fn load_meeting_data(&self, meeting_id: &MeetingId) -> Result { + let metadata = self + .get_meeting(meeting_id)? + .ok_or_else(|| StorageError::NotFound(meeting_id.to_string()))?; + + let transcript = self.load_transcript(meeting_id).unwrap_or_default(); + + Ok(MeetingData { + metadata, + transcript, + }) + } + + /// Delete a meeting and its files + pub fn delete_meeting(&self, meeting_id: &MeetingId) -> Result<(), StorageError> { + // Get storage path before deleting from DB + let metadata = self.get_meeting(meeting_id)?; + + // Delete from database + self.conn.execute( + "DELETE FROM meetings WHERE id = ?1", + params![meeting_id.to_string()], + )?; + + // Delete files if storage path exists + if let Some(metadata) = metadata { + if let Some(path) = metadata.storage_path { + if path.exists() { + std::fs::remove_dir_all(path)?; + } + } + } + + Ok(()) + } + + /// Get the storage path for a meeting + pub fn get_meeting_path(&self, meeting_id: &MeetingId) -> Result { + let metadata = self + .get_meeting(meeting_id)? + .ok_or_else(|| StorageError::NotFound(meeting_id.to_string()))?; + + metadata.storage_path.ok_or(StorageError::PathNotConfigured) + } + + /// Resolve a meeting ID from a string (supports "latest" alias) + pub fn resolve_meeting_id(&self, id_str: &str) -> Result { + if id_str == "latest" { + let meeting = self + .get_latest_meeting()? + .ok_or_else(|| StorageError::NotFound("No meetings found".to_string()))?; + Ok(meeting.id) + } else { + MeetingId::parse(id_str) + .map_err(|_| StorageError::NotFound(format!("Invalid meeting ID: {}", id_str))) + } + } + + /// Set a speaker label for ML diarization + pub fn set_speaker_label( + &self, + meeting_id: &MeetingId, + speaker_num: u32, + label: &str, + ) -> Result<(), StorageError> { + // Verify meeting exists + self.get_meeting(meeting_id)? + .ok_or_else(|| StorageError::NotFound(meeting_id.to_string()))?; + + // Insert or update speaker label + self.conn.execute( + r#" + INSERT OR REPLACE INTO speaker_labels (meeting_id, speaker_num, label) + VALUES (?1, ?2, ?3) + "#, + params![meeting_id.to_string(), speaker_num as i32, label], + )?; + + // Also update the transcript file to apply labels + self.apply_speaker_labels_to_transcript(meeting_id)?; + + Ok(()) + } + + /// Get all speaker labels for a meeting + pub fn get_speaker_labels( + &self, + meeting_id: &MeetingId, + ) -> Result, StorageError> { + let mut stmt = self + .conn + .prepare("SELECT speaker_num, label FROM speaker_labels WHERE meeting_id = ?1")?; + + let labels = stmt + .query_map(params![meeting_id.to_string()], |row| { + Ok((row.get::<_, i32>(0)? as u32, row.get::<_, String>(1)?)) + })? + .collect::, _>>()?; + + Ok(labels) + } + + /// Apply speaker labels to transcript segments + fn apply_speaker_labels_to_transcript( + &self, + meeting_id: &MeetingId, + ) -> Result<(), StorageError> { + let labels = self.get_speaker_labels(meeting_id)?; + if labels.is_empty() { + return Ok(()); + } + + // Load and update transcript + let mut transcript = match self.load_transcript(meeting_id) { + Ok(t) => t, + Err(_) => return Ok(()), // No transcript yet + }; + + for segment in &mut transcript.segments { + if let Some(ref speaker_id) = segment.speaker_id { + // Parse speaker ID - supports "SPEAKER_00" or just "0" + let speaker_num: Option = if speaker_id.starts_with("SPEAKER_") { + speaker_id.trim_start_matches("SPEAKER_").parse().ok() + } else { + speaker_id.parse().ok() + }; + + if let Some(num) = speaker_num { + if let Some(label) = labels.get(&num) { + segment.speaker_label = Some(label.clone()); + } + } + } + } + + // Save updated transcript + self.save_transcript(meeting_id, &transcript)?; + + Ok(()) + } +} + +// Helper functions for status serialization +fn status_to_string(status: MeetingStatus) -> &'static str { + match status { + MeetingStatus::Active => "active", + MeetingStatus::Paused => "paused", + MeetingStatus::Completed => "completed", + MeetingStatus::Cancelled => "cancelled", + } +} + +fn string_to_status(s: &str) -> MeetingStatus { + match s { + "active" => MeetingStatus::Active, + "paused" => MeetingStatus::Paused, + "completed" => MeetingStatus::Completed, + "cancelled" => MeetingStatus::Cancelled, + _ => MeetingStatus::Active, + } +} + +fn timestamp_to_datetime(ts: i64) -> DateTime { + Utc.timestamp_opt(ts, 0).single().unwrap_or_else(Utc::now) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn create_test_storage() -> (MeetingStorage, TempDir) { + let temp_dir = TempDir::new().unwrap(); + let config = StorageConfig { + storage_path: temp_dir.path().to_path_buf(), + retain_audio: false, + max_meetings: 0, + }; + let storage = MeetingStorage::open(config).unwrap(); + (storage, temp_dir) + } + + #[test] + fn test_create_and_get_meeting() { + let (storage, _temp) = create_test_storage(); + + let metadata = MeetingMetadata::new(Some("Test Meeting".to_string())); + let meeting_id = metadata.id; + + storage.create_meeting(&metadata).unwrap(); + + let loaded = storage.get_meeting(&meeting_id).unwrap().unwrap(); + assert_eq!(loaded.title, Some("Test Meeting".to_string())); + } + + #[test] + fn test_list_meetings() { + let (storage, _temp) = create_test_storage(); + + let metadata1 = MeetingMetadata::new(Some("Meeting 1".to_string())); + let metadata2 = MeetingMetadata::new(Some("Meeting 2".to_string())); + + storage.create_meeting(&metadata1).unwrap(); + storage.create_meeting(&metadata2).unwrap(); + + let meetings = storage.list_meetings(None).unwrap(); + assert_eq!(meetings.len(), 2); + } + + #[test] + fn test_list_meetings_with_limit() { + let (storage, _temp) = create_test_storage(); + + for i in 0..5 { + let metadata = MeetingMetadata::new(Some(format!("Meeting {}", i))); + storage.create_meeting(&metadata).unwrap(); + } + + let meetings = storage.list_meetings(Some(2)).unwrap(); + assert_eq!(meetings.len(), 2); + } + + #[test] + fn test_update_meeting() { + let (storage, _temp) = create_test_storage(); + + let mut metadata = MeetingMetadata::new(Some("Original Title".to_string())); + let meeting_id = metadata.id; + + storage.create_meeting(&metadata).unwrap(); + + metadata.title = Some("Updated Title".to_string()); + metadata.complete(); + storage.update_meeting(&metadata).unwrap(); + + let loaded = storage.get_meeting(&meeting_id).unwrap().unwrap(); + assert_eq!(loaded.title, Some("Updated Title".to_string())); + assert_eq!(loaded.status, MeetingStatus::Completed); + } + + #[test] + fn test_save_and_load_transcript() { + let (storage, _temp) = create_test_storage(); + + let mut metadata = MeetingMetadata::new(Some("Test".to_string())); + let meeting_id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path); + storage.update_meeting(&metadata).unwrap(); + + let mut transcript = Transcript::new(); + transcript.add_segment(crate::meeting::data::TranscriptSegment::new( + 0, + 0, + 1000, + "Hello world".to_string(), + 0, + )); + + storage.save_transcript(&meeting_id, &transcript).unwrap(); + + let loaded = storage.load_transcript(&meeting_id).unwrap(); + assert_eq!(loaded.segments.len(), 1); + assert_eq!(loaded.segments[0].text, "Hello world"); + } + + #[test] + fn test_delete_meeting() { + let (storage, _temp) = create_test_storage(); + + let metadata = MeetingMetadata::new(Some("Test".to_string())); + let meeting_id = metadata.id; + + storage.create_meeting(&metadata).unwrap(); + assert!(storage.get_meeting(&meeting_id).unwrap().is_some()); + + storage.delete_meeting(&meeting_id).unwrap(); + assert!(storage.get_meeting(&meeting_id).unwrap().is_none()); + } + + #[test] + fn test_resolve_latest() { + let (storage, _temp) = create_test_storage(); + + let metadata = MeetingMetadata::new(Some("Latest".to_string())); + let expected_id = metadata.id; + + storage.create_meeting(&metadata).unwrap(); + + let resolved = storage.resolve_meeting_id("latest").unwrap(); + assert_eq!(resolved, expected_id); + } + + #[test] + fn test_get_latest_empty() { + let (storage, _temp) = create_test_storage(); + assert!(storage.get_latest_meeting().unwrap().is_none()); + } + + #[test] + fn test_resolve_meeting_id_by_uuid() { + let (storage, _temp) = create_test_storage(); + let metadata = MeetingMetadata::new(Some("Test".to_string())); + let id = metadata.id; + storage.create_meeting(&metadata).unwrap(); + + let resolved = storage.resolve_meeting_id(&id.to_string()).unwrap(); + assert_eq!(resolved, id); + } + + #[test] + fn test_resolve_meeting_id_invalid() { + let (storage, _temp) = create_test_storage(); + let result = storage.resolve_meeting_id("not-a-uuid"); + assert!(result.is_err()); + } + + #[test] + fn test_resolve_latest_no_meetings() { + let (storage, _temp) = create_test_storage(); + let result = storage.resolve_meeting_id("latest"); + assert!(result.is_err()); + } + + #[test] + fn test_get_meeting_path() { + let (storage, _temp) = create_test_storage(); + let mut metadata = MeetingMetadata::new(Some("Path Test".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path.clone()); + storage.update_meeting(&metadata).unwrap(); + + let retrieved_path = storage.get_meeting_path(&id).unwrap(); + assert_eq!(retrieved_path, path); + } + + #[test] + fn test_get_meeting_path_not_found() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let result = storage.get_meeting_path(&id); + assert!(result.is_err()); + } + + #[test] + fn test_load_meeting_data() { + let (storage, _temp) = create_test_storage(); + let mut metadata = MeetingMetadata::new(Some("Data Test".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path); + storage.update_meeting(&metadata).unwrap(); + + let mut transcript = Transcript::new(); + transcript.add_segment(crate::meeting::data::TranscriptSegment::new( + 0, + 0, + 2000, + "Test segment".to_string(), + 0, + )); + storage.save_transcript(&id, &transcript).unwrap(); + + let data = storage.load_meeting_data(&id).unwrap(); + assert_eq!(data.metadata.title, Some("Data Test".to_string())); + assert_eq!(data.transcript.segments.len(), 1); + assert_eq!(data.transcript.segments[0].text, "Test segment"); + } + + #[test] + fn test_load_meeting_data_no_transcript() { + let (storage, _temp) = create_test_storage(); + let mut metadata = MeetingMetadata::new(Some("No Transcript".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path); + storage.update_meeting(&metadata).unwrap(); + + let data = storage.load_meeting_data(&id).unwrap(); + assert!(data.transcript.segments.is_empty()); + } + + #[test] + fn test_delete_meeting_removes_files() { + let (storage, _temp) = create_test_storage(); + let metadata = MeetingMetadata::new(Some("Delete Test".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + assert!(path.exists()); + + storage.delete_meeting(&id).unwrap(); + assert!(!path.exists()); + assert!(storage.get_meeting(&id).unwrap().is_none()); + } + + #[test] + fn test_status_roundtrip() { + assert_eq!( + string_to_status(status_to_string(MeetingStatus::Active)), + MeetingStatus::Active + ); + assert_eq!( + string_to_status(status_to_string(MeetingStatus::Paused)), + MeetingStatus::Paused + ); + assert_eq!( + string_to_status(status_to_string(MeetingStatus::Completed)), + MeetingStatus::Completed + ); + assert_eq!( + string_to_status(status_to_string(MeetingStatus::Cancelled)), + MeetingStatus::Cancelled + ); + } + + #[test] + fn test_status_unknown_defaults_to_active() { + assert_eq!(string_to_status("unknown"), MeetingStatus::Active); + assert_eq!(string_to_status(""), MeetingStatus::Active); + } + + #[test] + fn test_storage_config_default() { + let config = StorageConfig::default(); + assert!(!config.retain_audio); + assert_eq!(config.max_meetings, 0); + } + + #[test] + fn test_storage_config_db_path() { + let config = StorageConfig { + storage_path: PathBuf::from("/tmp/test-meetings"), + retain_audio: false, + max_meetings: 0, + }; + assert_eq!( + config.db_path(), + PathBuf::from("/tmp/test-meetings/index.db") + ); + } + + #[test] + fn test_list_meetings_empty() { + let (storage, _temp) = create_test_storage(); + let meetings = storage.list_meetings(None).unwrap(); + assert!(meetings.is_empty()); + } + + #[test] + fn test_list_meetings_limit_zero() { + let (storage, _temp) = create_test_storage(); + for i in 0..3 { + let metadata = MeetingMetadata::new(Some(format!("Meeting {}", i))); + storage.create_meeting(&metadata).unwrap(); + } + let meetings = storage.list_meetings(Some(0)).unwrap(); + assert!(meetings.is_empty()); + } + + #[test] + fn test_get_meeting_not_found() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let result = storage.get_meeting(&id).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_delete_nonexistent_meeting() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + // Should not error - just does nothing + let result = storage.delete_meeting(&id); + assert!(result.is_ok()); + } + + #[test] + fn test_save_transcript_meeting_not_found() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let transcript = Transcript::new(); + let result = storage.save_transcript(&id, &transcript); + assert!(result.is_err()); + } + + #[test] + fn test_load_transcript_meeting_not_found() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let result = storage.load_transcript(&id); + assert!(result.is_err()); + } + + #[test] + fn test_load_meeting_data_not_found() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let result = storage.load_meeting_data(&id); + assert!(result.is_err()); + } + + #[test] + fn test_speaker_labels() { + let (storage, _temp) = create_test_storage(); + let mut metadata = MeetingMetadata::new(Some("Label Test".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path); + storage.update_meeting(&metadata).unwrap(); + + // Set labels + storage.set_speaker_label(&id, 0, "Alice").unwrap(); + storage.set_speaker_label(&id, 1, "Bob").unwrap(); + + // Get labels + let labels = storage.get_speaker_labels(&id).unwrap(); + assert_eq!(labels.len(), 2); + assert_eq!(labels.get(&0), Some(&"Alice".to_string())); + assert_eq!(labels.get(&1), Some(&"Bob".to_string())); + } + + #[test] + fn test_speaker_labels_overwrite() { + let (storage, _temp) = create_test_storage(); + let mut metadata = MeetingMetadata::new(Some("Overwrite Test".to_string())); + let id = metadata.id; + + let path = storage.create_meeting(&metadata).unwrap(); + metadata.storage_path = Some(path); + storage.update_meeting(&metadata).unwrap(); + + storage.set_speaker_label(&id, 0, "Alice").unwrap(); + storage.set_speaker_label(&id, 0, "Carol").unwrap(); + + let labels = storage.get_speaker_labels(&id).unwrap(); + assert_eq!(labels.get(&0), Some(&"Carol".to_string())); + } + + #[test] + fn test_speaker_labels_nonexistent_meeting() { + let (storage, _temp) = create_test_storage(); + let id = MeetingId::new(); + let result = storage.set_speaker_label(&id, 0, "Alice"); + assert!(result.is_err()); + } + + #[test] + fn test_get_speaker_labels_empty() { + let (storage, _temp) = create_test_storage(); + let metadata = MeetingMetadata::new(Some("No Labels".to_string())); + let id = metadata.id; + storage.create_meeting(&metadata).unwrap(); + + let labels = storage.get_speaker_labels(&id).unwrap(); + assert!(labels.is_empty()); + } + + #[test] + fn test_create_meeting_creates_directory() { + let (storage, temp) = create_test_storage(); + let metadata = MeetingMetadata::new(Some("Dir Test".to_string())); + let path = storage.create_meeting(&metadata).unwrap(); + assert!(path.exists()); + assert!(path.is_dir()); + // Should also write metadata.json + assert!(path.join("metadata.json").exists()); + } + + #[test] + fn test_list_meetings_ordered_by_start_time() { + let (storage, _temp) = create_test_storage(); + + // Create meetings with different started_at timestamps + let mut metadata1 = MeetingMetadata::new(Some("First".to_string())); + metadata1.started_at = chrono::Utc.timestamp_opt(1000000, 0).single().unwrap(); + storage.create_meeting(&metadata1).unwrap(); + + let mut metadata2 = MeetingMetadata::new(Some("Second".to_string())); + metadata2.started_at = chrono::Utc.timestamp_opt(2000000, 0).single().unwrap(); + storage.create_meeting(&metadata2).unwrap(); + + let meetings = storage.list_meetings(None).unwrap(); + assert_eq!(meetings.len(), 2); + // Ordered by started_at DESC, so Second should be first + assert_eq!(meetings[0].title, Some("Second".to_string())); + assert_eq!(meetings[1].title, Some("First".to_string())); + } + + #[test] + fn test_timestamp_to_datetime_invalid() { + // A very old timestamp should still produce a DateTime + let dt = timestamp_to_datetime(0); + assert_eq!(dt.timestamp(), 0); + } +} diff --git a/src/meeting/summary/local.rs b/src/meeting/summary/local.rs new file mode 100644 index 00000000..6aa95a6e --- /dev/null +++ b/src/meeting/summary/local.rs @@ -0,0 +1,194 @@ +//! Local LLM summarization using Ollama +//! +//! Integrates with a locally running Ollama instance for meeting summarization. +//! Requires Ollama to be installed and running. + +use super::{generate_prompt, parse_summary_response, Summarizer, SummaryConfig, SummaryError}; +use crate::meeting::data::{MeetingData, MeetingSummary}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Ollama-based summarizer +pub struct OllamaSummarizer { + /// Ollama API endpoint + url: String, + /// Model name + model: String, + /// Request timeout + timeout: Duration, +} + +impl OllamaSummarizer { + /// Create a new Ollama summarizer + pub fn new(config: &SummaryConfig) -> Self { + Self { + url: config.ollama_url.clone(), + model: config.ollama_model.clone(), + timeout: Duration::from_secs(config.timeout_secs), + } + } + + /// Check if Ollama is running and the model is available + pub fn check_availability(&self) -> Result<(), SummaryError> { + let client = ureq::AgentBuilder::new() + .timeout(Duration::from_secs(5)) + .build(); + + // Check Ollama is running + let tags_url = format!("{}/api/tags", self.url); + let response = client + .get(&tags_url) + .call() + .map_err(|e| SummaryError::OllamaUnavailable(format!("{}: {}", self.url, e)))?; + + // Parse available models + #[derive(Deserialize)] + struct TagsResponse { + models: Option>, + } + + #[derive(Deserialize)] + struct ModelInfo { + name: String, + } + + let tags: TagsResponse = response + .into_json() + .map_err(|e| SummaryError::Parse(format!("Failed to parse tags response: {}", e)))?; + + // Check if our model is available + let models = tags.models.unwrap_or_default(); + let model_base = self.model.split(':').next().unwrap_or(&self.model); + + let model_available = models.iter().any(|m| { + let m_base = m.name.split(':').next().unwrap_or(&m.name); + m_base == model_base || m.name == self.model + }); + + if !model_available { + tracing::warn!( + "Model '{}' not found in Ollama. Available models: {:?}", + self.model, + models.iter().map(|m| &m.name).collect::>() + ); + // Don't fail - Ollama might pull the model on first use + } + + Ok(()) + } + + /// Call Ollama generate API + fn generate(&self, prompt: &str) -> Result { + let client = ureq::AgentBuilder::new().timeout(self.timeout).build(); + + let generate_url = format!("{}/api/generate", self.url); + + #[derive(Serialize)] + struct GenerateRequest<'a> { + model: &'a str, + prompt: &'a str, + stream: bool, + format: &'a str, + } + + let request = GenerateRequest { + model: &self.model, + prompt, + stream: false, + format: "json", + }; + + tracing::debug!("Calling Ollama generate API with model: {}", self.model); + + let response = client + .post(&generate_url) + .send_json(&request) + .map_err(|e| match e { + ureq::Error::Transport(ref t) => { + let msg = t.to_string(); + if msg.contains("timed out") || msg.contains("timeout") { + SummaryError::Request("Request timed out - try a shorter transcript".into()) + } else if msg.contains("connection") { + SummaryError::OllamaUnavailable(format!("{}: connection failed", self.url)) + } else { + SummaryError::Request(e.to_string()) + } + } + _ => SummaryError::Request(e.to_string()), + })?; + + #[derive(Deserialize)] + struct GenerateResponse { + response: String, + #[allow(dead_code)] + done: bool, + } + + let gen_response: GenerateResponse = response.into_json().map_err(|e| { + SummaryError::Parse(format!("Failed to parse generate response: {}", e)) + })?; + + Ok(gen_response.response) + } +} + +impl Summarizer for OllamaSummarizer { + fn summarize(&self, meeting: &MeetingData) -> Result { + // Check transcript is not empty + if meeting.transcript.segments.is_empty() { + return Err(SummaryError::EmptyTranscript); + } + + // Generate prompt + let prompt = generate_prompt(meeting); + tracing::debug!( + "Generated summarization prompt ({} chars, {} segments)", + prompt.len(), + meeting.transcript.segments.len() + ); + + // Call Ollama + let response = self.generate(&prompt)?; + tracing::debug!("Received response ({} chars)", response.len()); + + // Parse response + let summary = parse_summary_response(&response, Some(self.model.clone()))?; + + Ok(summary) + } + + fn name(&self) -> &'static str { + "ollama" + } + + fn is_available(&self) -> bool { + self.check_availability().is_ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_from_config() { + let config = SummaryConfig { + ollama_url: "http://test:11434".to_string(), + ollama_model: "mistral".to_string(), + timeout_secs: 60, + ..Default::default() + }; + + let summarizer = OllamaSummarizer::new(&config); + assert_eq!(summarizer.url, "http://test:11434"); + assert_eq!(summarizer.model, "mistral"); + assert_eq!(summarizer.timeout, Duration::from_secs(60)); + } + + #[test] + fn test_name() { + let config = SummaryConfig::default(); + let summarizer = OllamaSummarizer::new(&config); + assert_eq!(summarizer.name(), "ollama"); + } +} diff --git a/src/meeting/summary/mod.rs b/src/meeting/summary/mod.rs new file mode 100644 index 00000000..5caf2044 --- /dev/null +++ b/src/meeting/summary/mod.rs @@ -0,0 +1,329 @@ +//! AI-powered meeting summarization +//! +//! Generates summaries, action items, key decisions, and other +//! structured insights from meeting transcripts. +//! +//! # Backends +//! +//! - **Local**: Uses Ollama for local LLM inference +//! - **Remote**: Uses a remote API endpoint for summarization +//! - **Disabled**: Summarization disabled + +pub mod local; +pub mod remote; + +use crate::meeting::data::{ActionItem, MeetingData, MeetingSummary}; +use chrono::Utc; +use serde::Deserialize; +use thiserror::Error; + +/// Summary-related errors +#[derive(Error, Debug)] +pub enum SummaryError { + #[error("Summarizer not configured")] + NotConfigured, + + #[error("LLM request failed: {0}")] + Request(String), + + #[error("Failed to parse LLM response: {0}")] + Parse(String), + + #[error("Transcript is empty")] + EmptyTranscript, + + #[error("Ollama not available at {0}")] + OllamaUnavailable(String), +} + +/// Format a MeetingSummary as markdown +pub fn summary_to_markdown(summary: &MeetingSummary) -> String { + let mut output = String::new(); + + if !summary.summary.is_empty() { + output.push_str("## Summary\n\n"); + output.push_str(&summary.summary); + output.push_str("\n\n"); + } + + if !summary.key_points.is_empty() { + output.push_str("## Key Points\n\n"); + for point in &summary.key_points { + output.push_str(&format!("- {}\n", point)); + } + output.push('\n'); + } + + if !summary.action_items.is_empty() { + output.push_str("## Action Items\n\n"); + for item in &summary.action_items { + let assignee = item + .assignee + .as_ref() + .map(|a| format!(" ({})", a)) + .unwrap_or_default(); + let checkbox = if item.completed { "[x]" } else { "[ ]" }; + output.push_str(&format!( + "- {} {}{}\n", + checkbox, item.description, assignee + )); + } + output.push('\n'); + } + + if !summary.decisions.is_empty() { + output.push_str("## Decisions\n\n"); + for decision in &summary.decisions { + output.push_str(&format!("- {}\n", decision)); + } + output.push('\n'); + } + + output +} + +/// Summarization configuration +#[derive(Debug, Clone)] +pub struct SummaryConfig { + /// Backend to use: "local", "remote", or "disabled" + pub backend: String, + + /// Ollama URL for local backend + pub ollama_url: String, + + /// Ollama model name + pub ollama_model: String, + + /// Remote API endpoint + pub remote_endpoint: Option, + + /// Remote API key + pub remote_api_key: Option, + + /// Request timeout in seconds + pub timeout_secs: u64, +} + +impl Default for SummaryConfig { + fn default() -> Self { + Self { + backend: "disabled".to_string(), + ollama_url: "http://localhost:11434".to_string(), + ollama_model: "llama3.2".to_string(), + remote_endpoint: None, + remote_api_key: None, + timeout_secs: 120, + } + } +} + +/// Trait for summarization backends +pub trait Summarizer: Send + Sync { + /// Generate a summary from meeting data + fn summarize(&self, meeting: &MeetingData) -> Result; + + /// Get the backend name + fn name(&self) -> &'static str; + + /// Check if the backend is available + fn is_available(&self) -> bool; +} + +/// Create a summarizer based on configuration +pub fn create_summarizer(config: &SummaryConfig) -> Option> { + match config.backend.as_str() { + "local" => Some(Box::new(local::OllamaSummarizer::new(config))), + "remote" => { + if config.remote_endpoint.is_some() { + Some(Box::new(remote::RemoteSummarizer::new(config))) + } else { + tracing::warn!("Remote summarizer requires remote_endpoint to be set"); + None + } + } + "disabled" | "" => None, + _ => { + tracing::warn!("Unknown summarizer backend '{}', disabling", config.backend); + None + } + } +} + +/// Generate the prompt for summarization +pub fn generate_prompt(meeting: &MeetingData) -> String { + let mut prompt = String::from( + r#"Analyze the following meeting transcript and provide a structured summary. + +Format your response as JSON with this structure: +{ + "summary": "2-3 sentence summary of the meeting", + "key_points": ["point 1", "point 2"], + "action_items": [{"description": "task description", "assignee": "person or null", "due_date": "date or null"}], + "decisions": ["decision 1", "decision 2"] +} + +"#, + ); + + if let Some(ref title) = meeting.metadata.title { + prompt.push_str(&format!("Meeting Title: {}\n", title)); + } + + prompt.push_str(&format!( + "Date: {}\n\n", + meeting.metadata.started_at.format("%Y-%m-%d %H:%M") + )); + + prompt.push_str("## Transcript\n\n"); + + for segment in &meeting.transcript.segments { + let speaker = segment.speaker_display(); + if !speaker.is_empty() && speaker != "Unknown" { + prompt.push_str(&format!("{}: {}\n", speaker, segment.text)); + } else { + prompt.push_str(&format!("{}\n", segment.text)); + } + } + + prompt.push_str("\n## End of Transcript\n\nProvide the JSON summary:"); + + prompt +} + +/// Parse JSON summary from LLM response +pub fn parse_summary_response( + response: &str, + model: Option, +) -> Result { + // Try to extract JSON from the response + let json_str = extract_json(response).ok_or_else(|| { + SummaryError::Parse(format!( + "No valid JSON found in response: {}", + &response[..response.len().min(200)] + )) + })?; + + // Parse the JSON - use intermediate struct to match LLM output + #[derive(Deserialize)] + struct RawSummary { + summary: Option, + key_points: Option>, + action_items: Option>, + decisions: Option>, + } + + #[derive(Deserialize)] + struct RawActionItem { + description: Option, + task: Option, // Some LLMs use "task" instead of "description" + assignee: Option, + due_date: Option, + due: Option, // Alternative name + } + + let raw: RawSummary = + serde_json::from_str(json_str).map_err(|e| SummaryError::Parse(e.to_string()))?; + + Ok(MeetingSummary { + summary: raw.summary.unwrap_or_default(), + key_points: raw.key_points.unwrap_or_default(), + action_items: raw + .action_items + .unwrap_or_default() + .into_iter() + .map(|item| ActionItem { + description: item.description.or(item.task).unwrap_or_default(), + assignee: item.assignee, + due_date: item.due_date.or(item.due), + completed: false, + }) + .collect(), + decisions: raw.decisions.unwrap_or_default(), + generated_at: Utc::now(), + model, + }) +} + +/// Extract JSON object from a string that may contain other text +fn extract_json(s: &str) -> Option<&str> { + // Find the first { and last } + let start = s.find('{')?; + let end = s.rfind('}')?; + + if end > start { + Some(&s[start..=end]) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_json_simple() { + let input = r#"Here is the summary: {"summary": "Test meeting"}"#; + let json = extract_json(input).unwrap(); + assert_eq!(json, r#"{"summary": "Test meeting"}"#); + } + + #[test] + fn test_extract_json_nested() { + let input = r#"{"a": {"b": 1}}"#; + let json = extract_json(input).unwrap(); + assert_eq!(json, input); + } + + #[test] + fn test_parse_summary_response() { + let response = r#"{"summary": "Brief meeting about X", "key_points": ["Point 1"], "action_items": [{"description": "Do thing", "assignee": "Alice"}], "decisions": ["Agreed on Y"]}"#; + + let summary = parse_summary_response(response, None).unwrap(); + assert_eq!(summary.summary, "Brief meeting about X"); + assert_eq!(summary.key_points.len(), 1); + assert_eq!(summary.action_items.len(), 1); + assert_eq!(summary.action_items[0].assignee, Some("Alice".to_string())); + } + + #[test] + fn test_parse_summary_response_with_task_field() { + // Some LLMs use "task" instead of "description" + let response = r#"{"summary": "Meeting summary", "action_items": [{"task": "Do task", "assignee": "Bob"}]}"#; + + let summary = parse_summary_response(response, Some("llama3.2".to_string())).unwrap(); + assert_eq!(summary.action_items[0].description, "Do task"); + assert_eq!(summary.model, Some("llama3.2".to_string())); + } + + #[test] + fn test_summary_to_markdown() { + let summary = MeetingSummary { + summary: "Test meeting summary".to_string(), + key_points: vec!["Point 1".to_string(), "Point 2".to_string()], + action_items: vec![ActionItem { + description: "Do thing".to_string(), + assignee: Some("Alice".to_string()), + due_date: None, + completed: false, + }], + decisions: vec!["Decision 1".to_string()], + generated_at: Utc::now(), + model: None, + }; + + let md = summary_to_markdown(&summary); + assert!(md.contains("## Summary")); + assert!(md.contains("Test meeting summary")); + assert!(md.contains("## Action Items")); + assert!(md.contains("[ ] Do thing (Alice)")); + } + + #[test] + fn test_default_config() { + let config = SummaryConfig::default(); + assert_eq!(config.backend, "disabled"); + assert_eq!(config.ollama_url, "http://localhost:11434"); + assert_eq!(config.ollama_model, "llama3.2"); + } +} diff --git a/src/meeting/summary/remote.rs b/src/meeting/summary/remote.rs new file mode 100644 index 00000000..5935d51f --- /dev/null +++ b/src/meeting/summary/remote.rs @@ -0,0 +1,157 @@ +//! Remote API summarization backend +//! +//! Integrates with a remote summarization service for meetings. +//! Useful for corporate deployments with centralized AI infrastructure. + +use super::{generate_prompt, parse_summary_response, Summarizer, SummaryConfig, SummaryError}; +use crate::meeting::data::{MeetingData, MeetingSummary}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Remote API-based summarizer +pub struct RemoteSummarizer { + /// API endpoint URL + endpoint: String, + /// API key for authentication + api_key: Option, + /// Request timeout + timeout: Duration, +} + +impl RemoteSummarizer { + /// Create a new remote summarizer + pub fn new(config: &SummaryConfig) -> Self { + Self { + endpoint: config + .remote_endpoint + .clone() + .unwrap_or_else(|| "http://localhost:8080/api/summarize".to_string()), + api_key: config.remote_api_key.clone(), + timeout: Duration::from_secs(config.timeout_secs), + } + } + + /// Call the remote summarization API + fn call_api(&self, prompt: &str) -> Result { + let client = ureq::AgentBuilder::new().timeout(self.timeout).build(); + + #[derive(Serialize)] + struct SummarizeRequest<'a> { + prompt: &'a str, + } + + let mut request = client.post(&self.endpoint); + + // Add API key if configured + if let Some(ref api_key) = self.api_key { + request = request.set("Authorization", &format!("Bearer {}", api_key)); + } + + request = request.set("Content-Type", "application/json"); + + let body = SummarizeRequest { prompt }; + + tracing::debug!("Calling remote summarization API: {}", self.endpoint); + + let response = request.send_json(&body).map_err(|e| match e { + ureq::Error::Transport(ref t) => { + let msg = t.to_string(); + if msg.contains("timed out") || msg.contains("timeout") { + SummaryError::Request("Request timed out - try a shorter transcript".into()) + } else { + SummaryError::Request(e.to_string()) + } + } + ureq::Error::Status(status, _) => { + SummaryError::Request(format!("API returned status {}", status)) + } + })?; + + #[derive(Deserialize)] + struct ApiResponse { + summary: Option, + response: Option, + error: Option, + } + + let api_response: ApiResponse = response + .into_json() + .map_err(|e| SummaryError::Parse(format!("Failed to parse API response: {}", e)))?; + + if let Some(error) = api_response.error { + return Err(SummaryError::Request(error)); + } + + api_response + .summary + .or(api_response.response) + .ok_or_else(|| SummaryError::Parse("API response missing summary field".into())) + } +} + +impl Summarizer for RemoteSummarizer { + fn summarize(&self, meeting: &MeetingData) -> Result { + // Check transcript is not empty + if meeting.transcript.segments.is_empty() { + return Err(SummaryError::EmptyTranscript); + } + + // Generate prompt + let prompt = generate_prompt(meeting); + tracing::debug!( + "Generated summarization prompt ({} chars, {} segments)", + prompt.len(), + meeting.transcript.segments.len() + ); + + // Call remote API + let response = self.call_api(&prompt)?; + tracing::debug!("Received response ({} chars)", response.len()); + + // Parse response + let summary = parse_summary_response(&response, Some("remote".to_string()))?; + + Ok(summary) + } + + fn name(&self) -> &'static str { + "remote" + } + + fn is_available(&self) -> bool { + // Try a simple health check + let client = ureq::AgentBuilder::new() + .timeout(Duration::from_secs(5)) + .build(); + + // Try to reach the endpoint + client.head(&self.endpoint).call().is_ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_from_config() { + let config = SummaryConfig { + remote_endpoint: Some("https://api.example.com/summarize".to_string()), + remote_api_key: Some("test-key".to_string()), + timeout_secs: 90, + ..Default::default() + }; + + let summarizer = RemoteSummarizer::new(&config); + assert_eq!(summarizer.endpoint, "https://api.example.com/summarize"); + assert_eq!(summarizer.api_key, Some("test-key".to_string())); + assert_eq!(summarizer.timeout, Duration::from_secs(90)); + } + + #[test] + fn test_name() { + let config = SummaryConfig::default(); + let summarizer = RemoteSummarizer::new(&config); + assert_eq!(summarizer.name(), "remote"); + } +} diff --git a/src/output/mod.rs b/src/output/mod.rs index d8af0f86..376cc98a 100644 --- a/src/output/mod.rs +++ b/src/output/mod.rs @@ -75,7 +75,7 @@ pub fn is_parakeet_binary_active() -> bool { if let Ok(link_target) = fs::read_link(VOXTYPE_BIN) { if let Some(target_name) = link_target.file_name() { if let Some(name) = target_name.to_str() { - return name.contains("parakeet"); + return name.contains("onnx") || name.contains("parakeet"); } } } @@ -91,12 +91,15 @@ pub fn is_parakeet_binary_active() -> bool { } /// Get the engine icon for notifications based on configured engine -/// Returns 🦜 for Parakeet, 🗣️ for Whisper, 🌙 for Moonshine pub fn engine_icon(engine: crate::config::TranscriptionEngine) -> &'static str { match engine { - crate::config::TranscriptionEngine::Parakeet => "🦜", - crate::config::TranscriptionEngine::Whisper => "🗣️", - crate::config::TranscriptionEngine::Moonshine => "\u{1F319}", + crate::config::TranscriptionEngine::Parakeet => "\u{1F99C}", // 🦜 + crate::config::TranscriptionEngine::Whisper => "\u{1F5E3}\u{FE0F}", // 🗣️ + crate::config::TranscriptionEngine::Moonshine => "\u{1F319}", // 🌙 + crate::config::TranscriptionEngine::SenseVoice => "\u{1F442}", // 👂 + crate::config::TranscriptionEngine::Paraformer => "\u{1F4AC}", // 💬 + crate::config::TranscriptionEngine::Dolphin => "\u{1F42C}", // 🐬 + crate::config::TranscriptionEngine::Omnilingual => "\u{1F30D}", // 🌍 } } diff --git a/src/setup/gpu.rs b/src/setup/gpu.rs index 627d35dd..3870c151 100644 --- a/src/setup/gpu.rs +++ b/src/setup/gpu.rs @@ -4,7 +4,7 @@ //! 1. Tiered mode (DEB/RPM pre-built): Multiple CPU binaries (avx2, avx512) + vulkan in /usr/lib/voxtype/ //! 2. Simple mode (AUR source build): Native CPU binary at /usr/bin/voxtype + vulkan in /usr/lib/voxtype/ //! -//! Engine-aware: In Parakeet mode, switches between parakeet-cuda and parakeet-avx*. +//! Engine-aware: In ONNX mode, switches between onnx-cuda and onnx-avx*. //! In Whisper mode, switches between vulkan and avx*. //! //! GPU Selection: @@ -50,7 +50,7 @@ fn is_parakeet_binary_active() -> bool { if let Ok(resolved) = fs::canonicalize(active_bin) { if let Some(target_name) = resolved.file_name() { if let Some(name) = target_name.to_str() { - return name.contains("parakeet"); + return name.contains("onnx") || name.contains("parakeet"); } } } @@ -64,7 +64,7 @@ fn detect_active_parakeet_backend() -> Option { if let Ok(resolved) = fs::canonicalize(active_bin) { if let Some(target_name) = resolved.file_name() { if let Some(name) = target_name.to_str() { - if name.contains("parakeet") { + if name.contains("onnx") || name.contains("parakeet") { return Some(name.to_string()); } } @@ -520,11 +520,11 @@ pub fn show_status() { // Detect active Parakeet backend from symlink if let Some(target) = detect_active_parakeet_backend() { let display_name = match target.as_str() { - "voxtype-parakeet-avx2" => "Parakeet CPU (AVX2)", - "voxtype-parakeet-avx512" => "Parakeet CPU (AVX-512)", - "voxtype-parakeet-cuda" => "Parakeet GPU (CUDA)", - "voxtype-parakeet-rocm" => "Parakeet GPU (ROCm)", - _ => "Parakeet (unknown variant)", + "voxtype-onnx-avx2" | "voxtype-parakeet-avx2" => "ONNX CPU (AVX2)", + "voxtype-onnx-avx512" | "voxtype-parakeet-avx512" => "ONNX CPU (AVX-512)", + "voxtype-onnx-cuda" | "voxtype-parakeet-cuda" => "ONNX GPU (CUDA)", + "voxtype-onnx-rocm" | "voxtype-parakeet-rocm" => "ONNX GPU (ROCm)", + _ => "ONNX (unknown variant)", }; println!("Active backend: {}", display_name); println!( @@ -571,12 +571,12 @@ pub fn show_status() { let current = detect_current_backend(); if is_parakeet { - // Show Parakeet backends - let parakeet_backends = [ - ("voxtype-parakeet-avx2", "Parakeet CPU (AVX2)"), - ("voxtype-parakeet-avx512", "Parakeet CPU (AVX-512)"), - ("voxtype-parakeet-cuda", "Parakeet GPU (CUDA)"), - ("voxtype-parakeet-rocm", "Parakeet GPU (ROCm)"), + // Show ONNX backends (check both new and legacy names) + let onnx_backends = [ + ("voxtype-onnx-avx2", "voxtype-parakeet-avx2", "ONNX CPU (AVX2)"), + ("voxtype-onnx-avx512", "voxtype-parakeet-avx512", "ONNX CPU (AVX-512)"), + ("voxtype-onnx-cuda", "voxtype-parakeet-cuda", "ONNX GPU (CUDA)"), + ("voxtype-onnx-rocm", "voxtype-parakeet-rocm", "ONNX GPU (ROCm)"), ]; // Get current symlink target @@ -584,10 +584,12 @@ pub fn show_status() { .ok() .and_then(|p| p.file_name().map(|n| n.to_string_lossy().to_string())); - for (binary, display) in parakeet_backends { + for (binary, legacy_binary, display) in onnx_backends { let path = Path::new(VOXTYPE_LIB_DIR).join(binary); - let installed = path.exists(); - let active = current_target.as_deref() == Some(binary); + let legacy_path = Path::new(VOXTYPE_LIB_DIR).join(legacy_binary); + let installed = path.exists() || legacy_path.exists(); + let active = current_target.as_deref() == Some(binary) + || current_target.as_deref() == Some(legacy_binary); let status = if active { "active" @@ -693,41 +695,53 @@ pub fn show_status() { println!("To switch back to CPU:"); println!(" sudo voxtype setup gpu --disable"); } - } else { - if current != Some(Backend::Vulkan) && available.contains(&Backend::Vulkan) { - println!("To enable GPU acceleration:"); - println!(" sudo voxtype setup gpu --enable"); - } else if current == Some(Backend::Vulkan) { - println!("To switch back to CPU:"); - println!(" sudo voxtype setup gpu --disable"); - } + } else if current != Some(Backend::Vulkan) && available.contains(&Backend::Vulkan) { + println!("To enable GPU acceleration:"); + println!(" sudo voxtype setup gpu --enable"); + } else if current == Some(Backend::Vulkan) { + println!("To switch back to CPU:"); + println!(" sudo voxtype setup gpu --disable"); } } -/// Detect the best Parakeet GPU backend based on available hardware and installed binaries +/// Detect the best ONNX GPU backend based on available hardware and installed binaries fn detect_best_parakeet_gpu_backend() -> Option<(&'static str, &'static str)> { let gpus = detect_gpus(); + // Helper to find installed binary, preferring new name over legacy + let find_binary = + |new_name: &'static str, legacy_name: &'static str| -> Option<&'static str> { + if Path::new(VOXTYPE_LIB_DIR).join(new_name).exists() { + Some(new_name) + } else if Path::new(VOXTYPE_LIB_DIR).join(legacy_name).exists() { + Some(legacy_name) + } else { + None + } + }; + // Check for AMD GPU and ROCm binary let has_amd = gpus.iter().any(|g| g.vendor == GpuVendor::Amd); - let rocm_path = Path::new(VOXTYPE_LIB_DIR).join("voxtype-parakeet-rocm"); - if has_amd && rocm_path.exists() { - return Some(("voxtype-parakeet-rocm", "ROCm")); + if let Some(binary) = find_binary("voxtype-onnx-rocm", "voxtype-parakeet-rocm") { + if has_amd { + return Some((binary, "ROCm")); + } } // Check for NVIDIA GPU and CUDA binary let has_nvidia = gpus.iter().any(|g| g.vendor == GpuVendor::Nvidia); - let cuda_path = Path::new(VOXTYPE_LIB_DIR).join("voxtype-parakeet-cuda"); - if has_nvidia && cuda_path.exists() { - return Some(("voxtype-parakeet-cuda", "CUDA")); + if let Some(binary) = find_binary("voxtype-onnx-cuda", "voxtype-parakeet-cuda") { + if has_nvidia { + return Some((binary, "CUDA")); + } } // Fall back to whichever is installed (user may have external GPU) - if rocm_path.exists() { - return Some(("voxtype-parakeet-rocm", "ROCm")); + if let Some(binary) = find_binary("voxtype-onnx-rocm", "voxtype-parakeet-rocm") { + return Some((binary, "ROCm")); } - if cuda_path.exists() { - return Some(("voxtype-parakeet-cuda", "CUDA")); + if let Some(binary) = find_binary("voxtype-onnx-cuda", "voxtype-parakeet-cuda") { + return Some((binary, "CUDA")); } None @@ -746,16 +760,16 @@ pub fn enable() -> anyhow::Result<()> { let has_nvidia = gpus.iter().any(|g| g.vendor == GpuVendor::Nvidia); let hint = if has_amd { - "You have an AMD GPU. Install voxtype-parakeet-rocm for GPU acceleration." + "You have an AMD GPU. Install voxtype-onnx-rocm for GPU acceleration." } else if has_nvidia { - "You have an NVIDIA GPU. Install voxtype-parakeet-cuda for GPU acceleration." + "You have an NVIDIA GPU. Install voxtype-onnx-cuda for GPU acceleration." } else { - "No supported GPU detected. Parakeet GPU acceleration requires NVIDIA (CUDA) or AMD (ROCm)." + "No supported GPU detected. ONNX GPU acceleration requires NVIDIA (CUDA) or AMD (ROCm)." }; anyhow::anyhow!( - "No Parakeet GPU backend installed.\n\ - Neither voxtype-parakeet-cuda nor voxtype-parakeet-rocm found in {}\n\n\ + "No ONNX GPU backend installed.\n\ + Neither voxtype-onnx-cuda nor voxtype-onnx-rocm found in {}\n\n\ {}", VOXTYPE_LIB_DIR, hint @@ -767,12 +781,12 @@ pub fn enable() -> anyhow::Result<()> { // Regenerate systemd service if it exists if super::systemd::regenerate_service_file()? { println!( - "Updated systemd service to use Parakeet {} backend.", + "Updated systemd service to use ONNX {} backend.", backend_name ); } - println!("Switched to Parakeet ({}) backend.", backend_name); + println!("Switched to ONNX ({}) backend.", backend_name); println!(); println!("Restart voxtype to use GPU acceleration:"); println!(" systemctl --user restart voxtype"); @@ -823,24 +837,26 @@ pub fn disable() -> anyhow::Result<()> { let is_parakeet = is_parakeet_binary_active(); if is_parakeet { - // Parakeet mode: switch to best Parakeet CPU backend + // ONNX mode: switch to best ONNX CPU backend let best_backend = detect_best_parakeet_cpu_backend(); if let Some(backend_name) = best_backend { switch_backend_tiered_parakeet(backend_name)?; println!( - "Switched to Parakeet ({}) backend.", - backend_name.trim_start_matches("voxtype-parakeet-") + "Switched to ONNX ({}) backend.", + backend_name + .trim_start_matches("voxtype-onnx-") + .trim_start_matches("voxtype-parakeet-") ); } else { anyhow::bail!( - "No Parakeet CPU backend found.\n\ - Install voxtype-parakeet-avx2 or voxtype-parakeet-avx512." + "No ONNX CPU backend found.\n\ + Install voxtype-onnx-avx2 or voxtype-onnx-avx512." ); } // Regenerate systemd service if it exists if super::systemd::regenerate_service_file()? { - println!("Updated systemd service to use Parakeet CPU backend."); + println!("Updated systemd service to use ONNX CPU backend."); } println!(); @@ -886,36 +902,44 @@ fn detect_best_cpu_backend() -> Backend { Backend::Avx2 } -/// Detect the best Parakeet CPU backend for this system +/// Detect the best ONNX CPU backend for this system fn detect_best_parakeet_cpu_backend() -> Option<&'static str> { + // Helper to find installed binary, preferring new name over legacy + let find_binary = + |new_name: &'static str, legacy_name: &'static str| -> Option<&'static str> { + if Path::new(VOXTYPE_LIB_DIR).join(new_name).exists() { + Some(new_name) + } else if Path::new(VOXTYPE_LIB_DIR).join(legacy_name).exists() { + Some(legacy_name) + } else { + None + } + }; + // Check for AVX-512 support if let Ok(cpuinfo) = fs::read_to_string("/proc/cpuinfo") { if cpuinfo.contains("avx512f") { - let avx512_path = Path::new(VOXTYPE_LIB_DIR).join("voxtype-parakeet-avx512"); - if avx512_path.exists() { - return Some("voxtype-parakeet-avx512"); + if let Some(binary) = + find_binary("voxtype-onnx-avx512", "voxtype-parakeet-avx512") + { + return Some(binary); } } } // Fall back to AVX2 - let avx2_path = Path::new(VOXTYPE_LIB_DIR).join("voxtype-parakeet-avx2"); - if avx2_path.exists() { - return Some("voxtype-parakeet-avx2"); - } - - None + find_binary("voxtype-onnx-avx2", "voxtype-parakeet-avx2") } -/// Switch to a Parakeet backend binary (tiered mode) +/// Switch to an ONNX backend binary (tiered mode) fn switch_backend_tiered_parakeet(binary_name: &str) -> anyhow::Result<()> { let binary_path = Path::new(VOXTYPE_LIB_DIR).join(binary_name); let active_bin = get_active_binary_path(); if !binary_path.exists() { anyhow::bail!( - "Parakeet backend not found: {}\n\ - Install the appropriate voxtype-parakeet package.", + "ONNX backend not found: {}\n\ + Install the appropriate voxtype-onnx package.", binary_path.display() ); } diff --git a/src/setup/mod.rs b/src/setup/mod.rs index 2cef4ce6..a80165a9 100644 --- a/src/setup/mod.rs +++ b/src/setup/mod.rs @@ -460,16 +460,22 @@ pub async fn run_setup( let models_dir = Config::models_dir(); - // Check if model_override is a Parakeet model + // Check if model_override is a Parakeet or SenseVoice model let is_parakeet = model_override - .map(|name| model::is_parakeet_model(name)) + .map(model::is_parakeet_model) + .unwrap_or(false); + let is_sensevoice = model_override + .map(model::is_sensevoice_model) .unwrap_or(false); // Use model_override if provided, otherwise use config default (for Whisper) - let model_name: &str = match model_override { + let _model_name: &str = match model_override { Some(name) => { - // Validate the model name (check both Whisper and Parakeet) - if !model::is_valid_model(name) && !model::is_parakeet_model(name) { + // Validate the model name (check Whisper, Parakeet, and SenseVoice) + if !model::is_valid_model(name) + && !model::is_parakeet_model(name) + && !model::is_sensevoice_model(name) + { let valid = model::valid_model_names().join(", "); anyhow::bail!("Unknown model '{}'. Valid models are: {}", name, valid); } @@ -478,7 +484,56 @@ pub async fn run_setup( None => &config.whisper.model, }; - if is_parakeet { + if is_sensevoice { + // Handle SenseVoice model + #[allow(unused_variables)] + let model_name = model_override.unwrap(); // Safe: is_sensevoice implies Some + + if !quiet { + println!("\nSenseVoice model..."); + } + + #[cfg(not(feature = "sensevoice"))] + { + print_failure(&format!( + "SenseVoice model '{}' requires the 'sensevoice' feature", + model_name + )); + println!(" Rebuild with: cargo build --features sensevoice"); + anyhow::bail!("SenseVoice feature not enabled"); + } + + #[cfg(feature = "sensevoice")] + { + let dir_name = model::sensevoice_dir_name(model_name).unwrap(); + let model_path = models_dir.join(dir_name); + let model_valid = + model_path.exists() && model::validate_sensevoice_model(&model_path).is_ok(); + + if model_valid { + if !quiet { + let size = std::fs::read_dir(&model_path) + .map(|entries| { + entries + .flatten() + .filter_map(|e| e.metadata().ok()) + .map(|m| m.len() as f64 / 1024.0 / 1024.0) + .sum::() + }) + .unwrap_or(0.0); + print_success(&format!("Model ready: {} ({:.0} MB)", model_name, size)); + } + } else if download { + model::download_sensevoice_model(model_name)?; + } else if !quiet { + print_info(&format!("Model '{}' not downloaded yet", model_name)); + println!( + " Run: voxtype setup --download --model {}", + model_name + ); + } + } + } else if is_parakeet { // Handle Parakeet model #[allow(unused_variables)] let model_name = model_override.unwrap(); // Safe: is_parakeet implies Some diff --git a/src/setup/model.rs b/src/setup/model.rs index 84496449..fe8aabf7 100644 --- a/src/setup/model.rs +++ b/src/setup/model.rs @@ -286,6 +286,148 @@ const MOONSHINE_MODELS: &[MoonshineModelInfo] = &[ }, ]; +// ============================================================================= +// SenseVoice Model Definitions +// ============================================================================= + +/// SenseVoice model information for display and download +struct SenseVoiceModelInfo { + name: &'static str, + dir_name: &'static str, + size_mb: u32, + description: &'static str, + languages: &'static str, + files: &'static [(&'static str, &'static str)], // (repo_path, local_filename) + huggingface_repo: &'static str, +} + +const SENSEVOICE_MODELS: &[SenseVoiceModelInfo] = &[ + SenseVoiceModelInfo { + name: "small", + dir_name: "sensevoice-small", + size_mb: 239, + description: "Quantized int8 (recommended)", + languages: "zh/en/ja/ko/yue", + files: &[ + ("model.int8.onnx", "model.int8.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", + }, + SenseVoiceModelInfo { + name: "small-fp32", + dir_name: "sensevoice-small-fp32", + size_mb: 938, + description: "Full precision (larger, slightly better accuracy)", + languages: "zh/en/ja/ko/yue", + files: &[ + ("model.onnx", "model.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", + }, +]; + +// ============================================================================= +// Paraformer Model Definitions +// ============================================================================= + +/// Paraformer model info (same structure as SenseVoice: model.onnx + tokens.txt) +struct ParaformerModelInfo { + name: &'static str, + dir_name: &'static str, + size_mb: u32, + description: &'static str, + languages: &'static str, + files: &'static [(&'static str, &'static str)], + huggingface_repo: &'static str, +} + +const PARAFORMER_MODELS: &[ParaformerModelInfo] = &[ + ParaformerModelInfo { + name: "zh", + dir_name: "paraformer-zh", + size_mb: 487, + description: "Chinese + English offline (recommended)", + languages: "zh/en", + files: &[ + ("model.int8.onnx", "model.int8.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14", + }, + ParaformerModelInfo { + name: "en", + dir_name: "paraformer-en", + size_mb: 220, + description: "English offline", + languages: "en", + files: &[ + ("model.int8.onnx", "model.int8.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-paraformer-en-2024-03-09", + }, +]; + +// ============================================================================= +// Dolphin Model Definitions +// ============================================================================= + +struct DolphinModelInfo { + name: &'static str, + dir_name: &'static str, + size_mb: u32, + description: &'static str, + languages: &'static str, + files: &'static [(&'static str, &'static str)], + huggingface_repo: &'static str, +} + +const DOLPHIN_MODELS: &[DolphinModelInfo] = &[ + DolphinModelInfo { + name: "base", + dir_name: "dolphin-base", + size_mb: 198, + description: "Dictation-optimized (recommended)", + languages: "en/zh", + files: &[ + ("model.int8.onnx", "model.int8.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02", + }, +]; + +// ============================================================================= +// Omnilingual Model Definitions +// ============================================================================= + +struct OmnilingualModelInfo { + name: &'static str, + dir_name: &'static str, + size_mb: u32, + description: &'static str, + languages: &'static str, + files: &'static [(&'static str, &'static str)], + huggingface_repo: &'static str, +} + +const OMNILINGUAL_MODELS: &[OmnilingualModelInfo] = &[ + OmnilingualModelInfo { + name: "300m", + dir_name: "omnilingual-300m", + size_mb: 3900, + description: "1600+ languages, 300M params", + languages: "1600+ langs", + files: &[ + ("model.onnx", "model.onnx"), + ("tokens.txt", "tokens.txt"), + ], + huggingface_repo: "csukuangfj/sherpa-onnx-omnilingual-asr-1600-languages-300M-ctc-2025-11-12", + }, +]; + // ============================================================================= // Whisper Model Functions // ============================================================================= @@ -313,26 +455,39 @@ pub async fn interactive_select() -> anyhow::Result<()> { let is_whisper_engine = matches!(config.engine, TranscriptionEngine::Whisper); let is_parakeet_engine = matches!(config.engine, TranscriptionEngine::Parakeet); let is_moonshine_engine = matches!(config.engine, TranscriptionEngine::Moonshine); + let is_sensevoice_engine = matches!(config.engine, TranscriptionEngine::SenseVoice); + let is_paraformer_engine = matches!(config.engine, TranscriptionEngine::Paraformer); + let is_dolphin_engine = matches!(config.engine, TranscriptionEngine::Dolphin); + let is_omnilingual_engine = matches!(config.engine, TranscriptionEngine::Omnilingual); let current_whisper_model = &config.whisper.model; let current_parakeet_model = config.parakeet.as_ref().map(|p| p.model.as_str()); let current_moonshine_model = config.moonshine.as_ref().map(|m| m.model.as_str()); - + let current_sensevoice_model = config.sensevoice.as_ref().map(|s| s.model.as_str()); + let current_paraformer_model = config.paraformer.as_ref().map(|p| p.model.as_str()); + let current_dolphin_model = config.dolphin.as_ref().map(|d| d.model.as_str()); + let current_omnilingual_model = config.omnilingual.as_ref().map(|o| o.model.as_str()); let parakeet_available = cfg!(feature = "parakeet"); let moonshine_available = cfg!(feature = "moonshine"); + let sensevoice_available = cfg!(feature = "sensevoice"); + let paraformer_available = cfg!(feature = "paraformer"); + let dolphin_available = cfg!(feature = "dolphin"); + let omnilingual_available = cfg!(feature = "omnilingual"); let whisper_count = MODELS.len(); let parakeet_count = PARAKEET_MODELS.len(); let moonshine_count = MOONSHINE_MODELS.len(); + let sensevoice_count = SENSEVOICE_MODELS.len(); + let paraformer_count = PARAFORMER_MODELS.len(); + let dolphin_count = DOLPHIN_MODELS.len(); + let omnilingual_count = OMNILINGUAL_MODELS.len(); + + let available_count = |available: bool, count: usize| if available { count } else { 0 }; let total_count = whisper_count - + if parakeet_available { - parakeet_count - } else { - 0 - } - + if moonshine_available { - moonshine_count - } else { - 0 - }; + + available_count(parakeet_available, parakeet_count) + + available_count(moonshine_available, moonshine_count) + + available_count(sensevoice_available, sensevoice_count) + + available_count(paraformer_available, paraformer_count) + + available_count(dolphin_available, dolphin_count) + + available_count(omnilingual_available, omnilingual_count); // --- Whisper Section --- println!("--- Whisper (OpenAI, 99+ languages) ---\n"); @@ -441,6 +596,146 @@ pub async fn interactive_select() -> anyhow::Result<()> { println!(" \x1b[90m(not available - rebuild with --features moonshine)\x1b[0m"); } + // --- SenseVoice Section --- + let sensevoice_offset = moonshine_offset + + if moonshine_available { + moonshine_count + } else { + 0 + }; + println!("\n--- SenseVoice (Alibaba FunAudioLLM, CJK + English) ---\n"); + + if sensevoice_available { + for (i, model) in SENSEVOICE_MODELS.iter().enumerate() { + let model_path = models_dir.join(model.dir_name); + let installed = model_path.exists() && validate_sensevoice_model(&model_path).is_ok(); + + let is_current = is_sensevoice_engine && current_sensevoice_model == Some(model.name); + let star = if is_current { "*" } else { " " }; + + let status = if installed { + "\x1b[32m[installed]\x1b[0m" + } else { + "" + }; + + println!( + " {}[{:>2}] {:<20} ({:>4} MB) {} - {} {}", + star, + sensevoice_offset + i + 1, + model.dir_name, + model.size_mb, + model.languages, + model.description, + status + ); + } + } else { + println!(" \x1b[90m(not available - rebuild with --features sensevoice)\x1b[0m"); + } + + // --- Paraformer Section --- + let paraformer_offset = sensevoice_offset + + available_count(sensevoice_available, sensevoice_count); + println!("\n--- Paraformer (FunASR, Chinese + English) ---\n"); + + if paraformer_available { + for (i, model) in PARAFORMER_MODELS.iter().enumerate() { + let model_path = models_dir.join(model.dir_name); + let installed = model_path.exists() && validate_onnx_ctc_model(&model_path).is_ok(); + + let is_current = is_paraformer_engine && current_paraformer_model == Some(model.name); + let star = if is_current { "*" } else { " " }; + + let status = if installed { + "\x1b[32m[installed]\x1b[0m" + } else { + "" + }; + + println!( + " {}[{:>2}] {:<20} ({:>4} MB) {} - {} {}", + star, + paraformer_offset + i + 1, + model.dir_name, + model.size_mb, + model.languages, + model.description, + status + ); + } + } else { + println!(" \x1b[90m(not available - rebuild with --features paraformer)\x1b[0m"); + } + + // --- Dolphin Section --- + let dolphin_offset = paraformer_offset + + available_count(paraformer_available, paraformer_count); + println!("\n--- Dolphin (dictation-optimized CTC) ---\n"); + + if dolphin_available { + for (i, model) in DOLPHIN_MODELS.iter().enumerate() { + let model_path = models_dir.join(model.dir_name); + let installed = model_path.exists() && validate_onnx_ctc_model(&model_path).is_ok(); + + let is_current = is_dolphin_engine && current_dolphin_model == Some(model.name); + let star = if is_current { "*" } else { " " }; + + let status = if installed { + "\x1b[32m[installed]\x1b[0m" + } else { + "" + }; + + println!( + " {}[{:>2}] {:<20} ({:>4} MB) {} - {} {}", + star, + dolphin_offset + i + 1, + model.dir_name, + model.size_mb, + model.languages, + model.description, + status + ); + } + } else { + println!(" \x1b[90m(not available - rebuild with --features dolphin)\x1b[0m"); + } + + // --- Omnilingual Section --- + let omnilingual_offset = dolphin_offset + + available_count(dolphin_available, dolphin_count); + println!("\n--- Omnilingual (FunASR, 50+ languages) ---\n"); + + if omnilingual_available { + for (i, model) in OMNILINGUAL_MODELS.iter().enumerate() { + let model_path = models_dir.join(model.dir_name); + let installed = model_path.exists() && validate_onnx_ctc_model(&model_path).is_ok(); + + let is_current = is_omnilingual_engine && current_omnilingual_model == Some(model.name); + let star = if is_current { "*" } else { " " }; + + let status = if installed { + "\x1b[32m[installed]\x1b[0m" + } else { + "" + }; + + println!( + " {}[{:>2}] {:<20} ({:>4} MB) {} - {} {}", + star, + omnilingual_offset + i + 1, + model.dir_name, + model.size_mb, + model.languages, + model.description, + status + ); + } + } else { + println!(" \x1b[90m(not available - rebuild with --features omnilingual)\x1b[0m"); + } + println!("\n [ 0] Cancel\n"); // Get user selection @@ -459,16 +754,25 @@ pub async fn interactive_select() -> anyhow::Result<()> { // Route to appropriate handler based on selection if selection <= whisper_count { - // Whisper model selected handle_whisper_selection(selection).await } else if parakeet_available && selection <= whisper_count + parakeet_count { - // Parakeet model selected let parakeet_index = selection - whisper_count; handle_parakeet_selection(parakeet_index).await - } else if moonshine_available && selection <= total_count { - // Moonshine model selected + } else if moonshine_available && selection <= moonshine_offset + moonshine_count { let moonshine_index = selection - moonshine_offset; handle_moonshine_selection(moonshine_index).await + } else if sensevoice_available && selection <= sensevoice_offset + sensevoice_count { + let sensevoice_index = selection - sensevoice_offset; + handle_sensevoice_selection(sensevoice_index).await + } else if paraformer_available && selection <= paraformer_offset + paraformer_count { + let idx = selection - paraformer_offset; + handle_onnx_engine_selection("paraformer", PARAFORMER_MODELS.iter().map(|m| (m.name, m.dir_name, m.size_mb, m.files, m.huggingface_repo)).collect(), idx, validate_onnx_ctc_model).await + } else if dolphin_available && selection <= dolphin_offset + dolphin_count { + let idx = selection - dolphin_offset; + handle_onnx_engine_selection("dolphin", DOLPHIN_MODELS.iter().map(|m| (m.name, m.dir_name, m.size_mb, m.files, m.huggingface_repo)).collect(), idx, validate_onnx_ctc_model).await + } else if omnilingual_available && selection <= omnilingual_offset + omnilingual_count { + let idx = selection - omnilingual_offset; + handle_onnx_engine_selection("omnilingual", OMNILINGUAL_MODELS.iter().map(|m| (m.name, m.dir_name, m.size_mb, m.files, m.huggingface_repo)).collect(), idx, validate_onnx_ctc_model).await } else { println!("\nInvalid selection."); Ok(()) @@ -1381,6 +1685,58 @@ fn update_moonshine_in_config(config: &str, model_name: &str) -> String { result } +/// Handle SenseVoice model selection (download/config) +async fn handle_sensevoice_selection(selection: usize) -> anyhow::Result<()> { + let models_dir = Config::models_dir(); + + if selection == 0 || selection > SENSEVOICE_MODELS.len() { + println!("\nCancelled."); + return Ok(()); + } + + let model = &SENSEVOICE_MODELS[selection - 1]; + let model_path = models_dir.join(model.dir_name); + + // Check if already installed + if model_path.exists() && validate_sensevoice_model(&model_path).is_ok() { + println!("\nModel '{}' is already installed.\n", model.dir_name); + println!(" [1] Set as default model (update config)"); + println!(" [2] Re-download"); + println!(" [0] Cancel\n"); + + print!("Select option [1]: "); + io::stdout().flush()?; + + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + match choice { + "" | "1" => { + update_config_sensevoice(model.name)?; + restart_daemon_if_running().await; + return Ok(()); + } + "2" => { + // Continue to download below + } + _ => { + println!("Cancelled."); + return Ok(()); + } + } + } + + // Download the model + download_sensevoice_model_by_info(model)?; + + // Update config and restart daemon + update_config_sensevoice(model.name)?; + restart_daemon_if_running().await; + + Ok(()) +} + /// List installed Moonshine models pub fn list_installed_moonshine() { println!("\nInstalled Moonshine Models\n"); @@ -1429,6 +1785,510 @@ pub fn list_installed_moonshine() { } } +// ============================================================================= +// SenseVoice Model Functions +// ============================================================================= + +/// Check if a model name is a SenseVoice model +pub fn is_sensevoice_model(name: &str) -> bool { + SENSEVOICE_MODELS.iter().any(|m| m.name == name) +} + +/// Get the directory name for a SenseVoice model +pub fn sensevoice_dir_name(name: &str) -> Option<&'static str> { + SENSEVOICE_MODELS + .iter() + .find(|m| m.name == name) + .map(|m| m.dir_name) +} + +/// Get list of valid SenseVoice model names +pub fn valid_sensevoice_model_names() -> Vec<&'static str> { + SENSEVOICE_MODELS.iter().map(|m| m.name).collect() +} + +/// Validate that a SenseVoice model directory has the required files +pub fn validate_sensevoice_model(path: &Path) -> anyhow::Result<()> { + if !path.exists() { + anyhow::bail!("Model directory does not exist: {:?}", path); + } + + let has_model = + path.join("model.int8.onnx").exists() || path.join("model.onnx").exists(); + let has_tokens = path.join("tokens.txt").exists(); + + if has_model && has_tokens { + Ok(()) + } else { + let mut missing = Vec::new(); + if !has_model { + missing.push("model.int8.onnx or model.onnx"); + } + if !has_tokens { + missing.push("tokens.txt"); + } + anyhow::bail!( + "Incomplete SenseVoice model, missing: {}", + missing.join(", ") + ) + } +} + +/// Download a SenseVoice model by name (public API for run_setup) +pub fn download_sensevoice_model(model_name: &str) -> anyhow::Result<()> { + let model = SENSEVOICE_MODELS + .iter() + .find(|m| m.name == model_name) + .ok_or_else(|| anyhow::anyhow!("Unknown SenseVoice model: {}", model_name))?; + + download_sensevoice_model_by_info(model) +} + +/// Download a SenseVoice model using its info struct +fn download_sensevoice_model_by_info(model: &SenseVoiceModelInfo) -> anyhow::Result<()> { + let models_dir = Config::models_dir(); + let model_path = models_dir.join(model.dir_name); + + // Create model directory + std::fs::create_dir_all(&model_path)?; + + println!( + "\nDownloading {} ({} MB)...\n", + model.dir_name, model.size_mb + ); + + for (repo_path, local_filename) in model.files { + let file_path = model_path.join(local_filename); + + if file_path.exists() { + println!(" {} already exists, skipping", local_filename); + continue; + } + + let url = format!( + "https://huggingface.co/{}/resolve/main/{}", + model.huggingface_repo, repo_path + ); + + println!("Downloading {}...", local_filename); + + let status = Command::new("curl") + .args([ + "-L", + "--progress-bar", + "-o", + file_path.to_str().unwrap_or("file"), + &url, + ]) + .status(); + + match status { + Ok(exit_status) if exit_status.success() => { + // Success, continue + } + Ok(exit_status) => { + print_failure(&format!( + "Download failed: curl exited with code {}", + exit_status.code().unwrap_or(-1) + )); + let _ = std::fs::remove_file(&file_path); + anyhow::bail!("Download failed for {}", local_filename) + } + Err(e) => { + print_failure(&format!("Failed to run curl: {}", e)); + print_info("Please ensure curl is installed (e.g., 'sudo pacman -S curl')"); + anyhow::bail!("curl not available: {}", e) + } + } + } + + // Validate all files are present + validate_sensevoice_model(&model_path)?; + print_success(&format!( + "Model '{}' downloaded to {:?}", + model.dir_name, model_path + )); + + Ok(()) +} + +/// Update config to use SenseVoice engine and a specific model (with status messages) +fn update_config_sensevoice(model_name: &str) -> anyhow::Result<()> { + if let Some(config_path) = Config::default_path() { + if config_path.exists() { + let content = std::fs::read_to_string(&config_path)?; + let updated = update_sensevoice_in_config(&content, model_name); + std::fs::write(&config_path, updated)?; + print_success(&format!( + "Config updated: engine = \"sensevoice\", model = \"{}\"", + model_name + )); + Ok(()) + } else { + print_info("No config file found. Run 'voxtype setup' first."); + Ok(()) + } + } else { + anyhow::bail!("Could not determine config path") + } +} + +/// Update the config to use SenseVoice engine with a specific model +fn update_sensevoice_in_config(config: &str, model_name: &str) -> String { + let mut result = String::new(); + let mut has_engine_line = false; + let mut has_sensevoice_section = false; + let mut in_sensevoice_section = false; + let mut sensevoice_model_updated = false; + + for line in config.lines() { + let trimmed = line.trim(); + + // Track sections + if trimmed.starts_with('[') { + if in_sensevoice_section && !sensevoice_model_updated { + result.push_str(&format!("model = \"{}\"\n", model_name)); + sensevoice_model_updated = true; + } + in_sensevoice_section = trimmed == "[sensevoice]"; + if in_sensevoice_section { + has_sensevoice_section = true; + } + } + + // Update or add engine line at the top level + if trimmed.starts_with("engine") && !trimmed.starts_with('[') { + result.push_str("engine = \"sensevoice\"\n"); + has_engine_line = true; + } + // Update model line in sensevoice section + else if in_sensevoice_section && trimmed.starts_with("model") { + result.push_str(&format!("model = \"{}\"\n", model_name)); + sensevoice_model_updated = true; + } else { + result.push_str(line); + result.push('\n'); + } + } + + // If we were in sensevoice section at EOF and didn't update model, add it + if in_sensevoice_section && !sensevoice_model_updated { + result.push_str(&format!("model = \"{}\"\n", model_name)); + } + + // Add engine line if not present + if !has_engine_line { + let mut new_result = String::new(); + let mut engine_added = false; + for line in result.lines() { + let trimmed = line.trim(); + if !engine_added + && !trimmed.is_empty() + && !trimmed.starts_with('#') + && !trimmed.starts_with("engine") + { + new_result.push_str("engine = \"sensevoice\"\n\n"); + engine_added = true; + } + new_result.push_str(line); + new_result.push('\n'); + } + result = new_result; + } + + // Add [sensevoice] section if not present + if !has_sensevoice_section { + result.push_str(&format!("\n[sensevoice]\nmodel = \"{}\"\n", model_name)); + } + + // Remove trailing newline if original didn't have one + if !config.ends_with('\n') && result.ends_with('\n') { + result.pop(); + } + + result +} + +/// List installed SenseVoice models +pub fn list_installed_sensevoice() { + println!("\nInstalled SenseVoice Models\n"); + println!("===========================\n"); + + let models_dir = Config::models_dir(); + + if !models_dir.exists() { + println!("No models directory found: {:?}", models_dir); + return; + } + + let mut found = false; + + for model in SENSEVOICE_MODELS { + let model_path = models_dir.join(model.dir_name); + + if model_path.exists() && validate_sensevoice_model(&model_path).is_ok() { + let size = std::fs::read_dir(&model_path) + .map(|entries| { + entries + .flatten() + .filter_map(|e| e.metadata().ok()) + .map(|m| m.len() as f64 / 1024.0 / 1024.0) + .sum::() + }) + .unwrap_or(0.0); + + println!( + " {} ({:.0} MB) - {} ({})", + model.dir_name, size, model.description, model.languages + ); + found = true; + } + } + + if !found { + println!(" No SenseVoice models installed."); + println!("\n Run 'voxtype setup model' and select SenseVoice to download."); + } +} + +// ============================================================================= +// Generic ONNX Engine Functions (Paraformer, Dolphin, Omnilingual) +// ============================================================================= + +/// Validate a CTC-based ONNX model directory (model.int8.onnx or model.onnx + tokens.txt) +fn validate_onnx_ctc_model(path: &Path) -> anyhow::Result<()> { + if !path.exists() { + anyhow::bail!("Model directory does not exist: {:?}", path); + } + + let has_model = path.join("model.int8.onnx").exists() || path.join("model.onnx").exists(); + let has_tokens = path.join("tokens.txt").exists(); + + if has_model && has_tokens { + Ok(()) + } else { + let mut missing = Vec::new(); + if !has_model { + missing.push("model.int8.onnx or model.onnx"); + } + if !has_tokens { + missing.push("tokens.txt"); + } + anyhow::bail!("Incomplete model, missing: {}", missing.join(", ")) + } +} + +/// Generic handler for ONNX engine model selection (download/config/restart) +async fn handle_onnx_engine_selection( + engine_name: &str, + models: Vec<(&str, &str, u32, &[(&str, &str)], &str)>, + selection: usize, + validate_fn: fn(&Path) -> anyhow::Result<()>, +) -> anyhow::Result<()> { + let models_dir = Config::models_dir(); + + if selection == 0 || selection > models.len() { + println!("\nCancelled."); + return Ok(()); + } + + let (name, dir_name, size_mb, files, repo) = &models[selection - 1]; + let model_path = models_dir.join(dir_name); + + // Check if already installed + if model_path.exists() && validate_fn(&model_path).is_ok() { + println!("\nModel '{}' is already installed.\n", dir_name); + println!(" [1] Set as default model (update config)"); + println!(" [2] Re-download"); + println!(" [0] Cancel\n"); + + print!("Select option [1]: "); + io::stdout().flush()?; + + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let choice = choice.trim(); + + match choice { + "" | "1" => { + update_config_engine(engine_name, name)?; + restart_daemon_if_running().await; + return Ok(()); + } + "2" => { + // Continue to download below + } + _ => { + println!("Cancelled."); + return Ok(()); + } + } + } + + // Download the model + download_onnx_model(dir_name, *size_mb, files, repo)?; + + // Validate + validate_fn(&model_path)?; + print_success(&format!("Model '{}' downloaded to {:?}", dir_name, model_path)); + + // Update config and restart daemon + update_config_engine(engine_name, name)?; + restart_daemon_if_running().await; + + Ok(()) +} + +/// Download an ONNX model from HuggingFace +fn download_onnx_model( + dir_name: &str, + size_mb: u32, + files: &[(&str, &str)], + repo: &str, +) -> anyhow::Result<()> { + let models_dir = Config::models_dir(); + let model_path = models_dir.join(dir_name); + + std::fs::create_dir_all(&model_path)?; + + println!("\nDownloading {} ({} MB)...\n", dir_name, size_mb); + + for (repo_path, local_filename) in files { + let file_path = model_path.join(local_filename); + + if file_path.exists() { + println!(" {} already exists, skipping", local_filename); + continue; + } + + let url = format!( + "https://huggingface.co/{}/resolve/main/{}", + repo, repo_path + ); + + println!("Downloading {}...", local_filename); + + let status = Command::new("curl") + .args([ + "-L", + "--progress-bar", + "-o", + file_path.to_str().unwrap_or("file"), + &url, + ]) + .status(); + + match status { + Ok(exit_status) if exit_status.success() => {} + Ok(exit_status) => { + print_failure(&format!( + "Download failed: curl exited with code {}", + exit_status.code().unwrap_or(-1) + )); + let _ = std::fs::remove_file(&file_path); + anyhow::bail!("Download failed for {}", local_filename) + } + Err(e) => { + print_failure(&format!("Failed to run curl: {}", e)); + print_info("Please ensure curl is installed (e.g., 'sudo pacman -S curl')"); + anyhow::bail!("curl not available: {}", e) + } + } + } + + Ok(()) +} + +/// Update config to use a specific engine and model +fn update_config_engine(engine_name: &str, model_name: &str) -> anyhow::Result<()> { + if let Some(config_path) = Config::default_path() { + if config_path.exists() { + let content = std::fs::read_to_string(&config_path)?; + let updated = update_engine_in_config(&content, engine_name, model_name); + std::fs::write(&config_path, updated)?; + print_success(&format!( + "Config updated: engine = \"{}\", model = \"{}\"", + engine_name, model_name + )); + Ok(()) + } else { + print_info("No config file found. Run 'voxtype setup' first."); + Ok(()) + } + } else { + anyhow::bail!("Could not determine config path") + } +} + +/// Update a config string to use a specific engine and model +fn update_engine_in_config(config: &str, engine_name: &str, model_name: &str) -> String { + let section_name = format!("[{}]", engine_name); + let mut result = String::new(); + let mut has_engine_line = false; + let mut has_section = false; + let mut in_section = false; + let mut model_updated = false; + + for line in config.lines() { + let trimmed = line.trim(); + + if trimmed.starts_with('[') { + if in_section && !model_updated { + result.push_str(&format!("model = \"{}\"\n", model_name)); + model_updated = true; + } + in_section = trimmed == section_name; + if in_section { + has_section = true; + } + } + + if trimmed.starts_with("engine") && !trimmed.starts_with('[') { + result.push_str(&format!("engine = \"{}\"\n", engine_name)); + has_engine_line = true; + } else if in_section && trimmed.starts_with("model") { + result.push_str(&format!("model = \"{}\"\n", model_name)); + model_updated = true; + } else { + result.push_str(line); + result.push('\n'); + } + } + + if in_section && !model_updated { + result.push_str(&format!("model = \"{}\"\n", model_name)); + } + + if !has_engine_line { + let mut new_result = String::new(); + let mut engine_added = false; + for line in result.lines() { + let trimmed = line.trim(); + if !engine_added + && !trimmed.is_empty() + && !trimmed.starts_with('#') + && !trimmed.starts_with("engine") + { + new_result.push_str(&format!("engine = \"{}\"\n\n", engine_name)); + engine_added = true; + } + new_result.push_str(line); + new_result.push('\n'); + } + result = new_result; + } + + if !has_section { + result.push_str(&format!("\n[{}]\nmodel = \"{}\"\n", engine_name, model_name)); + } + + if !config.ends_with('\n') && result.ends_with('\n') { + result.pop(); + } + + result +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/setup/parakeet.rs b/src/setup/parakeet.rs index e9af0d59..ec05889c 100644 --- a/src/setup/parakeet.rs +++ b/src/setup/parakeet.rs @@ -25,21 +25,21 @@ pub enum ParakeetBackend { impl ParakeetBackend { fn binary_name(&self) -> &'static str { match self { - ParakeetBackend::Avx2 => "voxtype-parakeet-avx2", - ParakeetBackend::Avx512 => "voxtype-parakeet-avx512", - ParakeetBackend::Cuda => "voxtype-parakeet-cuda", - ParakeetBackend::Rocm => "voxtype-parakeet-rocm", - ParakeetBackend::Custom => "voxtype-parakeet", + ParakeetBackend::Avx2 => "voxtype-onnx-avx2", + ParakeetBackend::Avx512 => "voxtype-onnx-avx512", + ParakeetBackend::Cuda => "voxtype-onnx-cuda", + ParakeetBackend::Rocm => "voxtype-onnx-rocm", + ParakeetBackend::Custom => "voxtype-onnx", } } fn display_name(&self) -> &'static str { match self { - ParakeetBackend::Avx2 => "Parakeet (AVX2)", - ParakeetBackend::Avx512 => "Parakeet (AVX-512)", - ParakeetBackend::Cuda => "Parakeet (CUDA)", - ParakeetBackend::Rocm => "Parakeet (ROCm)", - ParakeetBackend::Custom => "Parakeet (Custom)", + ParakeetBackend::Avx2 => "ONNX (AVX2)", + ParakeetBackend::Avx512 => "ONNX (AVX-512)", + ParakeetBackend::Cuda => "ONNX (CUDA)", + ParakeetBackend::Rocm => "ONNX (ROCm)", + ParakeetBackend::Custom => "ONNX (Custom)", } } @@ -59,7 +59,7 @@ pub fn is_parakeet_active() -> bool { if let Ok(link_target) = fs::read_link(VOXTYPE_BIN) { if let Some(target_name) = link_target.file_name() { if let Some(name) = target_name.to_str() { - return name.contains("parakeet"); + return name.contains("onnx") || name.contains("parakeet"); } } } @@ -71,6 +71,13 @@ pub fn detect_current_parakeet_backend() -> Option { if let Ok(link_target) = fs::read_link(VOXTYPE_BIN) { let target_name = link_target.file_name()?.to_str()?; return match target_name { + // New ONNX names + "voxtype-onnx-avx2" => Some(ParakeetBackend::Avx2), + "voxtype-onnx-avx512" => Some(ParakeetBackend::Avx512), + "voxtype-onnx-cuda" => Some(ParakeetBackend::Cuda), + "voxtype-onnx-rocm" => Some(ParakeetBackend::Rocm), + "voxtype-onnx" => Some(ParakeetBackend::Custom), + // Legacy parakeet names (backward compat) "voxtype-parakeet-avx2" => Some(ParakeetBackend::Avx2), "voxtype-parakeet-avx512" => Some(ParakeetBackend::Avx512), "voxtype-parakeet-cuda" => Some(ParakeetBackend::Cuda), @@ -227,7 +234,7 @@ fn switch_binary(binary_name: &str) -> anyhow::Result<()> { fs::remove_file(VOXTYPE_BIN).map_err(|e| { anyhow::anyhow!( "Failed to remove existing symlink (need sudo?): {}\n\ - Try: sudo voxtype setup parakeet --enable", + Try: sudo voxtype setup onnx --enable", e ) })?; @@ -237,7 +244,7 @@ fn switch_binary(binary_name: &str) -> anyhow::Result<()> { symlink(&binary_path, VOXTYPE_BIN).map_err(|e| { anyhow::anyhow!( "Failed to create symlink (need sudo?): {}\n\ - Try: sudo voxtype setup parakeet --enable", + Try: sudo voxtype setup onnx --enable", e ) })?; @@ -250,7 +257,7 @@ fn switch_binary(binary_name: &str) -> anyhow::Result<()> { /// Show Parakeet backend status pub fn show_status() { - println!("=== Voxtype Parakeet Status ===\n"); + println!("=== Voxtype ONNX Engine Status ===\n"); // Current engine if is_parakeet_active() { @@ -274,14 +281,14 @@ pub fn show_status() { } } - // Available Parakeet backends - println!("\nAvailable Parakeet backends:"); + // Available ONNX backends + println!("\nAvailable ONNX backends:"); let available = detect_available_backends(); let current = detect_current_parakeet_backend(); if available.is_empty() { - println!(" No Parakeet binaries installed."); - println!("\n Install a Parakeet-enabled package to use this feature."); + println!(" No ONNX binaries installed."); + println!("\n Install an ONNX-enabled voxtype package to use this feature."); } else { for backend in [ ParakeetBackend::Avx2, @@ -322,11 +329,11 @@ pub fn show_status() { // Usage hints println!(); if !is_parakeet_active() && !available.is_empty() { - println!("To enable Parakeet:"); - println!(" sudo voxtype setup parakeet --enable"); + println!("To enable ONNX engines:"); + println!(" sudo voxtype setup onnx --enable"); } else if is_parakeet_active() { println!("To switch back to Whisper:"); - println!(" sudo voxtype setup parakeet --disable"); + println!(" sudo voxtype setup onnx --disable"); } } @@ -336,33 +343,33 @@ pub fn enable() -> anyhow::Result<()> { if available.is_empty() { anyhow::bail!( - "No Parakeet binaries installed.\n\ - Install a Parakeet-enabled voxtype package first." + "No ONNX binaries installed.\n\ + Install an ONNX-enabled voxtype package first." ); } if is_parakeet_active() { - println!("Parakeet is already enabled."); + println!("ONNX engine is already enabled."); if let Some(backend) = detect_current_parakeet_backend() { println!(" Current backend: {}", backend.display_name()); } return Ok(()); } - // Find best Parakeet backend + // Find best ONNX backend let backend = detect_best_parakeet_backend() - .ok_or_else(|| anyhow::anyhow!("No suitable Parakeet backend found"))?; + .ok_or_else(|| anyhow::anyhow!("No suitable ONNX backend found"))?; switch_binary(backend.binary_name())?; // Regenerate systemd service if it exists if super::systemd::regenerate_service_file()? { - println!("Updated systemd service to use Parakeet backend."); + println!("Updated systemd service to use ONNX backend."); } println!("Switched to {} backend.", backend.display_name()); println!(); - println!("Restart voxtype to use Parakeet:"); + println!("Restart voxtype to use ONNX engines:"); println!(" systemctl --user restart voxtype"); Ok(()) @@ -371,7 +378,7 @@ pub fn enable() -> anyhow::Result<()> { /// Disable Parakeet backend (switch back to Whisper) pub fn disable() -> anyhow::Result<()> { if !is_parakeet_active() { - println!("Parakeet is not currently enabled (already using Whisper)."); + println!("ONNX engine is not currently enabled (already using Whisper)."); return Ok(()); } @@ -439,23 +446,23 @@ mod tests { #[test] fn test_parakeet_backend_binary_names() { - assert_eq!(ParakeetBackend::Avx2.binary_name(), "voxtype-parakeet-avx2"); + assert_eq!(ParakeetBackend::Avx2.binary_name(), "voxtype-onnx-avx2"); assert_eq!( ParakeetBackend::Avx512.binary_name(), - "voxtype-parakeet-avx512" + "voxtype-onnx-avx512" ); - assert_eq!(ParakeetBackend::Cuda.binary_name(), "voxtype-parakeet-cuda"); - assert_eq!(ParakeetBackend::Rocm.binary_name(), "voxtype-parakeet-rocm"); - assert_eq!(ParakeetBackend::Custom.binary_name(), "voxtype-parakeet"); + assert_eq!(ParakeetBackend::Cuda.binary_name(), "voxtype-onnx-cuda"); + assert_eq!(ParakeetBackend::Rocm.binary_name(), "voxtype-onnx-rocm"); + assert_eq!(ParakeetBackend::Custom.binary_name(), "voxtype-onnx"); } #[test] fn test_parakeet_backend_display_names() { - assert_eq!(ParakeetBackend::Avx2.display_name(), "Parakeet (AVX2)"); - assert_eq!(ParakeetBackend::Avx512.display_name(), "Parakeet (AVX-512)"); - assert_eq!(ParakeetBackend::Cuda.display_name(), "Parakeet (CUDA)"); - assert_eq!(ParakeetBackend::Rocm.display_name(), "Parakeet (ROCm)"); - assert_eq!(ParakeetBackend::Custom.display_name(), "Parakeet (Custom)"); + assert_eq!(ParakeetBackend::Avx2.display_name(), "ONNX (AVX2)"); + assert_eq!(ParakeetBackend::Avx512.display_name(), "ONNX (AVX-512)"); + assert_eq!(ParakeetBackend::Cuda.display_name(), "ONNX (CUDA)"); + assert_eq!(ParakeetBackend::Rocm.display_name(), "ONNX (ROCm)"); + assert_eq!(ParakeetBackend::Custom.display_name(), "ONNX (Custom)"); } #[test] diff --git a/src/setup/vad.rs b/src/setup/vad.rs index 7152251a..a27fc86f 100644 --- a/src/setup/vad.rs +++ b/src/setup/vad.rs @@ -67,9 +67,7 @@ pub fn show_status() { println!("VAD Model Status\n"); if model_path.exists() { - let size = std::fs::metadata(&model_path) - .map(|m| m.len()) - .unwrap_or(0); + let size = std::fs::metadata(&model_path).map(|m| m.len()).unwrap_or(0); print_success(&format!( "Silero VAD model installed: {:?} ({:.1} MB)", model_path, diff --git a/src/transcribe/ctc.rs b/src/transcribe/ctc.rs new file mode 100644 index 00000000..16906f00 --- /dev/null +++ b/src/transcribe/ctc.rs @@ -0,0 +1,288 @@ +//! Shared CTC (Connectionist Temporal Classification) greedy decoding +//! +//! Used by SenseVoice, Dolphin, and Omnilingual backends. These models +//! all use CTC output and share the same decoding logic: argmax per frame, +//! collapse consecutive duplicates, remove blank tokens. +//! +//! SenseVoice additionally skips metadata tokens (language, emotion, event, ITN) +//! at the start of the sequence. + +use crate::error::TranscribeError; +use std::collections::HashMap; +use std::path::Path; + +/// Configuration for CTC greedy decoding +pub struct CtcConfig { + /// Token ID used for CTC blank (usually 0) + pub blank_id: u32, + /// Number of metadata tokens to skip at start of decoded sequence + /// (SenseVoice: 4 for language/emotion/event/ITN, others: 0) + pub num_metadata_tokens: usize, + /// Replace SentencePiece word boundary markers (U+2581) with spaces + pub sentencepiece_cleanup: bool, +} + +impl Default for CtcConfig { + fn default() -> Self { + Self { + blank_id: 0, + num_metadata_tokens: 0, + sentencepiece_cleanup: false, + } + } +} + +impl CtcConfig { + /// Config for SenseVoice: skip 4 metadata tokens, clean SentencePiece markers + pub fn sensevoice() -> Self { + Self { + blank_id: 0, + num_metadata_tokens: 4, + sentencepiece_cleanup: true, + } + } +} + +/// CTC greedy decoding: argmax per frame, collapse duplicates, remove blanks +/// +/// Input: raw logits of shape (time_steps, vocab_size) flattened to a 1D slice +/// Output: decoded text string +pub fn ctc_greedy_decode( + logits: &[f32], + time_steps: usize, + vocab_size: usize, + tokens: &HashMap, + config: &CtcConfig, +) -> String { + let mut token_ids: Vec = Vec::new(); + let mut prev_id: Option = None; + + for t in 0..time_steps { + let offset = t * vocab_size; + let frame_logits = &logits[offset..offset + vocab_size]; + + // Argmax + let best_id = frame_logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx as u32) + .unwrap_or(config.blank_id); + + // Collapse consecutive duplicates and skip blanks + if best_id != config.blank_id && Some(best_id) != prev_id { + token_ids.push(best_id); + } + prev_id = Some(best_id); + } + + tokens_to_string(&token_ids, tokens, config) +} + +/// Decode pre-argmaxed output where values are already token IDs (as f32) +/// +/// Some ONNX models output 2D logits where each value is already the best +/// token ID rather than a probability distribution over the vocabulary. +pub fn decode_pre_argmax( + token_ids_f32: &[f32], + tokens: &HashMap, + config: &CtcConfig, +) -> String { + let mut token_ids: Vec = Vec::new(); + let mut prev_id: Option = None; + + for &val in token_ids_f32 { + let id = val as u32; + if id != config.blank_id && Some(id) != prev_id { + token_ids.push(id); + } + prev_id = Some(id); + } + + tokens_to_string(&token_ids, tokens, config) +} + +/// Convert token IDs to string, applying metadata skipping and SentencePiece cleanup +fn tokens_to_string( + token_ids: &[u32], + tokens: &HashMap, + config: &CtcConfig, +) -> String { + let content_tokens = if token_ids.len() > config.num_metadata_tokens { + &token_ids[config.num_metadata_tokens..] + } else if config.num_metadata_tokens > 0 { + &[] + } else { + token_ids + }; + + let mut result = String::new(); + for &id in content_tokens { + if let Some(token_str) = tokens.get(&id) { + if config.sentencepiece_cleanup { + result.push_str(&token_str.replace('\u{2581}', " ")); + } else { + result.push_str(token_str); + } + } + } + + result.trim().to_string() +} + +/// Load tokens.txt into a HashMap +/// +/// Format: each line is "token_string token_id" (space-separated). +/// The token string may contain spaces, so we split from the right. +pub fn load_tokens(path: &Path) -> Result, TranscribeError> { + let content = std::fs::read_to_string(path).map_err(|e| { + TranscribeError::InitFailed(format!("Failed to read tokens.txt: {}", e)) + })?; + + let mut tokens = HashMap::new(); + for line in content.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Split from the right to handle tokens containing spaces + if let Some(last_space) = line.rfind(' ') { + let token_str = &line[..last_space]; + let id_str = &line[last_space + 1..]; + if let Ok(id) = id_str.parse::() { + tokens.insert(id, token_str.to_string()); + } + } + } + + if tokens.is_empty() { + return Err(TranscribeError::InitFailed( + "tokens.txt appears empty or malformed".to_string(), + )); + } + + Ok(tokens) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn test_load_tokens() { + let temp_dir = TempDir::new().unwrap(); + let tokens_path = temp_dir.path().join("tokens.txt"); + fs::write( + &tokens_path, + " 0\n 1\nhello 2\nworld 3\n", + ) + .unwrap(); + + let tokens = load_tokens(&tokens_path).unwrap(); + assert_eq!(tokens.get(&0), Some(&"".to_string())); + assert_eq!(tokens.get(&2), Some(&"hello".to_string())); + assert_eq!(tokens.get(&3), Some(&"world".to_string())); + } + + #[test] + fn test_load_tokens_empty() { + let temp_dir = TempDir::new().unwrap(); + let tokens_path = temp_dir.path().join("tokens.txt"); + fs::write(&tokens_path, "").unwrap(); + + let result = load_tokens(&tokens_path); + assert!(result.is_err()); + } + + #[test] + fn test_ctc_decode_basic() { + let mut tokens = HashMap::new(); + tokens.insert(1, "h".to_string()); + tokens.insert(2, "i".to_string()); + + // Simulate: blank, h, h, blank, i + let vocab_size = 3; + let time_steps = 5; + let mut logits = vec![0.0f32; time_steps * vocab_size]; + + let set_max = |logits: &mut Vec, t: usize, id: usize| { + logits[t * vocab_size + id] = 10.0; + }; + + set_max(&mut logits, 0, 0); // blank + set_max(&mut logits, 1, 1); // h + set_max(&mut logits, 2, 1); // h (duplicate) + set_max(&mut logits, 3, 0); // blank + set_max(&mut logits, 4, 2); // i + + let config = CtcConfig::default(); + let result = ctc_greedy_decode(&logits, time_steps, vocab_size, &tokens, &config); + assert_eq!(result, "hi"); + } + + #[test] + fn test_ctc_decode_with_metadata_skip() { + let mut tokens = HashMap::new(); + tokens.insert(1, "lang".to_string()); + tokens.insert(2, "emo".to_string()); + tokens.insert(3, "event".to_string()); + tokens.insert(4, "itn".to_string()); + tokens.insert(5, "h".to_string()); + tokens.insert(6, "i".to_string()); + + let vocab_size = 7; + let time_steps = 6; + let mut logits = vec![0.0f32; time_steps * vocab_size]; + + let set_max = |logits: &mut Vec, t: usize, id: usize| { + logits[t * vocab_size + id] = 10.0; + }; + + set_max(&mut logits, 0, 1); // lang (metadata) + set_max(&mut logits, 1, 2); // emo (metadata) + set_max(&mut logits, 2, 3); // event (metadata) + set_max(&mut logits, 3, 4); // itn (metadata) + set_max(&mut logits, 4, 5); // h + set_max(&mut logits, 5, 6); // i + + let config = CtcConfig::sensevoice(); + let result = ctc_greedy_decode(&logits, time_steps, vocab_size, &tokens, &config); + assert_eq!(result, "hi"); + } + + #[test] + fn test_ctc_decode_sentencepiece_cleanup() { + let mut tokens = HashMap::new(); + tokens.insert(1, "\u{2581}hello".to_string()); + tokens.insert(2, "\u{2581}world".to_string()); + + let vocab_size = 3; + let time_steps = 2; + let mut logits = vec![0.0f32; time_steps * vocab_size]; + + logits[0 * vocab_size + 1] = 10.0; // hello + logits[1 * vocab_size + 2] = 10.0; // world + + let config = CtcConfig { + sentencepiece_cleanup: true, + ..CtcConfig::default() + }; + let result = ctc_greedy_decode(&logits, time_steps, vocab_size, &tokens, &config); + assert_eq!(result, "hello world"); + } + + #[test] + fn test_decode_pre_argmax() { + let mut tokens = HashMap::new(); + tokens.insert(1, "a".to_string()); + tokens.insert(2, "b".to_string()); + + // Pre-argmaxed: blank, a, a, blank, b + let token_ids: Vec = vec![0.0, 1.0, 1.0, 0.0, 2.0]; + let config = CtcConfig::default(); + let result = decode_pre_argmax(&token_ids, &tokens, &config); + assert_eq!(result, "ab"); + } +} diff --git a/src/transcribe/dolphin.rs b/src/transcribe/dolphin.rs new file mode 100644 index 00000000..c6674cf9 --- /dev/null +++ b/src/transcribe/dolphin.rs @@ -0,0 +1,440 @@ +//! Dolphin-based speech-to-text transcription +//! +//! Uses DataoceanAI's Dolphin model via ONNX Runtime for local transcription. +//! Dolphin is a CTC-based E-Branchformer model optimized for Eastern languages +//! (40 languages + 22 Chinese dialects). No English support. +//! +//! The ONNX model expects 80-dim Fbank features as input, preprocessed with +//! the shared Fbank pipeline (same as SenseVoice/Paraformer) and normalized +//! with CMVN stats from model metadata. +//! +//! Pipeline: Audio (f32, 16kHz) -> Fbank (80-dim) -> CMVN -> ONNX model -> CTC decode +//! +//! Languages: zh, ja, ko, th, vi, id, ms, ar, hi, ur, bn, ta, and 28 more +//! Model files: model.int8.onnx (or model.onnx), tokens.txt + +use super::ctc; +use super::fbank::{self, FbankExtractor}; +use super::Transcriber; +use crate::config::DolphinConfig; +use crate::error::TranscribeError; +use ort::session::Session; +use ort::value::Tensor; +use std::collections::HashMap; +use std::path::PathBuf; + +/// Sample rate expected by Dolphin +const SAMPLE_RATE: usize = 16000; + +/// Dolphin-based transcriber using ONNX Runtime +pub struct DolphinTranscriber { + session: std::sync::Mutex, + tokens: HashMap, + neg_mean: Vec, + inv_stddev: Vec, + fbank_extractor: FbankExtractor, +} + +impl DolphinTranscriber { + pub fn new(config: &DolphinConfig) -> Result { + let model_dir = resolve_model_path(&config.model)?; + + tracing::info!("Loading Dolphin model from {:?}", model_dir); + let start = std::time::Instant::now(); + + let threads = config.threads.unwrap_or_else(|| num_cpus::get().min(4)); + + // Find model file (prefer int8 quantized) + let model_file = { + let int8 = model_dir.join("model.int8.onnx"); + let full = model_dir.join("model.onnx"); + if int8.exists() { + int8 + } else if full.exists() { + tracing::info!("Using full-precision model (model.int8.onnx not found)"); + full + } else { + return Err(TranscribeError::ModelNotFound(format!( + "Dolphin model not found in {:?}\n \ + Expected model.int8.onnx or model.onnx\n \ + Run: voxtype setup model", + model_dir + ))); + } + }; + + // Load tokens.txt + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(format!( + "Dolphin tokens.txt not found: {}\n \ + Ensure tokens.txt is in the model directory.", + tokens_path.display() + ))); + } + let tokens = ctc::load_tokens(&tokens_path)?; + tracing::debug!("Loaded {} tokens", tokens.len()); + + // Create ONNX session + let session = Session::builder() + .map_err(|e| { + TranscribeError::InitFailed(format!("ONNX session builder failed: {}", e)) + })? + .with_intra_threads(threads) + .map_err(|e| { + TranscribeError::InitFailed(format!("Failed to set threads: {}", e)) + })? + .commit_from_file(&model_file) + .map_err(|e| { + TranscribeError::InitFailed(format!( + "Failed to load Dolphin model from {:?}: {}", + model_file, e + )) + })?; + + // Read CMVN stats from model metadata + // Dolphin uses "mean"/"invstd" naming (mean is positive, needs negation) + let (neg_mean, inv_stddev) = read_cmvn_from_metadata(&session)?; + + let fbank_extractor = FbankExtractor::new_default(); + + tracing::info!( + "Dolphin model loaded in {:.2}s", + start.elapsed().as_secs_f32(), + ); + + Ok(Self { + session: std::sync::Mutex::new(session), + tokens, + neg_mean, + inv_stddev, + fbank_extractor, + }) + } +} + +impl Transcriber for DolphinTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + if samples.is_empty() { + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); + } + + let duration_secs = samples.len() as f32 / SAMPLE_RATE as f32; + tracing::debug!( + "Transcribing {:.2}s of audio ({} samples) with Dolphin", + duration_secs, + samples.len(), + ); + + let start = std::time::Instant::now(); + + // 1. Extract Fbank features (80-dim, same pipeline as SenseVoice) + let fbank_start = std::time::Instant::now(); + let fbank_features = self.fbank_extractor.extract(samples); + tracing::debug!( + "Fbank extraction: {:.2}s ({} frames x {})", + fbank_start.elapsed().as_secs_f32(), + fbank_features.nrows(), + fbank_features.ncols(), + ); + + if fbank_features.nrows() == 0 { + return Err(TranscribeError::AudioFormat( + "Audio too short for feature extraction".to_string(), + )); + } + + // 2. CMVN normalization (no LFR stacking - Dolphin takes 80-dim directly) + let mut features = fbank_features; + fbank::apply_cmvn(&mut features, &self.neg_mean, &self.inv_stddev); + + let num_frames = features.nrows(); + let feat_dim = features.ncols(); + + // x: shape [1, T, 80] + let (x_data, _offset) = features.into_raw_vec_and_offset(); + let x_tensor = + Tensor::::from_array(([1usize, num_frames, feat_dim], x_data)).map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create input tensor: {}", + e + )) + })?; + + // x_len: shape [1] (i64) + let x_len_tensor = Tensor::::from_array(([1usize], vec![num_frames as i64])) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create length tensor: {}", + e + )) + })?; + + // Run inference + let inference_start = std::time::Instant::now(); + let mut session = self.session.lock().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to lock session: {}", e)) + })?; + + let inputs: Vec<(std::borrow::Cow, ort::session::SessionInputValue)> = vec![ + (std::borrow::Cow::Borrowed("x"), x_tensor.into()), + ( + std::borrow::Cow::Borrowed("x_len"), + x_len_tensor.into(), + ), + ]; + + let outputs = session.run(inputs).map_err(|e| { + TranscribeError::InferenceFailed(format!("Dolphin inference failed: {}", e)) + })?; + + tracing::debug!( + "ONNX inference: {:.2}s", + inference_start.elapsed().as_secs_f32(), + ); + + // Extract CTC log-probs and decode + let logits_val = outputs + .get("lob_probs") + .or_else(|| outputs.get("logits")) + .or_else(|| outputs.get("output")) + .ok_or_else(|| { + TranscribeError::InferenceFailed( + "Dolphin output not found (expected 'lob_probs', 'logits', or 'output')" + .to_string(), + ) + })?; + + let (shape, logits_data) = logits_val.try_extract_tensor::().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to extract logits: {}", e)) + })?; + + let shape_dims: &[i64] = shape; + tracing::debug!("Dolphin output shape: {:?}", shape_dims); + + // Dolphin CTC output: [batch, time_steps, vocab_size] + let raw_text = if shape_dims.len() == 3 { + let time_steps = shape_dims[1] as usize; + let vocab_size = shape_dims[2] as usize; + let config = ctc::CtcConfig { + blank_id: 0, + num_metadata_tokens: 0, + sentencepiece_cleanup: true, + }; + ctc::ctc_greedy_decode(logits_data, time_steps, vocab_size, &self.tokens, &config) + } else if shape_dims.len() == 2 { + // Pre-argmaxed output + let time_steps = shape_dims[1] as usize; + let config = ctc::CtcConfig { + blank_id: 0, + num_metadata_tokens: 0, + sentencepiece_cleanup: true, + }; + ctc::decode_pre_argmax(&logits_data[..time_steps], &self.tokens, &config) + } else { + return Err(TranscribeError::InferenceFailed(format!( + "Unexpected Dolphin output shape: {:?}", + shape_dims + ))); + }; + + // Filter language/region tokens from output (e.g., , , , ) + let result = filter_language_tokens(&raw_text); + + tracing::info!( + "Dolphin transcription completed in {:.2}s: {:?}", + start.elapsed().as_secs_f32(), + if result.chars().count() > 50 { + format!("{}...", result.chars().take(50).collect::()) + } else { + result.clone() + } + ); + + Ok(result) + } +} + +/// Remove language and region tokens from CTC output +/// +/// Dolphin prepends tokens like , , , to its output. +/// These are useful for language identification but should not appear in +/// the final transcription text. +fn filter_language_tokens(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut chars = text.chars().peekable(); + + while let Some(&c) = chars.peek() { + if c == '<' { + // Consume everything up to and including '>' + let mut found_close = false; + for inner in chars.by_ref() { + if inner == '>' { + found_close = true; + break; + } + } + if !found_close { + // Malformed tag, just skip the '<' + result.push(c); + } + } else { + result.push(c); + chars.next(); + } + } + + result.trim().to_string() +} + +/// Read CMVN stats from ONNX model metadata +/// +/// Dolphin uses "mean"/"invstd" keys where mean is positive (needs negation). +/// Falls back to "neg_mean"/"inv_stddev" if those aren't found. +fn read_cmvn_from_metadata(session: &Session) -> Result<(Vec, Vec), TranscribeError> { + let metadata = session.metadata().map_err(|e| { + TranscribeError::InitFailed(format!("Failed to read model metadata: {}", e)) + })?; + + // Try Dolphin naming first: "mean" and "invstd" + // Despite the key name "mean", the values are already negated (same as SenseVoice's + // "neg_mean"), so we use them directly without negation. + let (neg_mean, inv_stddev) = if let Some(mean_str) = metadata.custom("mean") { + let invstd_str = metadata.custom("invstd").ok_or_else(|| { + TranscribeError::InitFailed("Model metadata has 'mean' but no 'invstd'".to_string()) + })?; + + let neg_mean: Vec = mean_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + let inv_stddev: Vec = invstd_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + + (neg_mean, inv_stddev) + } else if let Some(neg_mean_str) = metadata.custom("neg_mean") { + // SenseVoice-style naming (already negated) + let inv_stddev_str = metadata.custom("inv_stddev").ok_or_else(|| { + TranscribeError::InitFailed( + "Model metadata has 'neg_mean' but no 'inv_stddev'".to_string(), + ) + })?; + + let neg_mean: Vec = neg_mean_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + let inv_stddev: Vec = inv_stddev_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + (neg_mean, inv_stddev) + } else { + return Err(TranscribeError::InitFailed( + "Dolphin model metadata missing CMVN stats. \ + Expected 'mean'/'invstd' or 'neg_mean'/'inv_stddev' keys." + .to_string(), + )); + }; + + if neg_mean.is_empty() || inv_stddev.is_empty() { + return Err(TranscribeError::InitFailed(format!( + "CMVN stats malformed (neg_mean: {} values, inv_stddev: {} values)", + neg_mean.len(), + inv_stddev.len() + ))); + } + + tracing::debug!( + "Loaded CMVN stats: {} dimensions", + neg_mean.len() + ); + + Ok((neg_mean, inv_stddev)) +} + +/// Resolve model name to directory path +fn resolve_model_path(model: &str) -> Result { + let path = PathBuf::from(model); + if path.is_absolute() && path.exists() { + return Ok(path); + } + + let model_dir_name = if model.starts_with("dolphin-") { + model.to_string() + } else { + format!("dolphin-{}", model) + }; + + let models_dir = crate::config::Config::models_dir(); + let model_path = models_dir.join(&model_dir_name); + if model_path.exists() { + return Ok(model_path); + } + + let alt_path = models_dir.join(model); + if alt_path.exists() { + return Ok(alt_path); + } + + // Check sherpa-onnx naming convention + let sherpa_name = format!( + "sherpa-onnx-{}-ctc-multi-lang", + model_dir_name + ); + let sherpa_path = models_dir.join(&sherpa_name); + if sherpa_path.exists() { + return Ok(sherpa_path); + } + + Err(TranscribeError::ModelNotFound(format!( + "Dolphin model '{}' not found. Looked in:\n \ + - {}\n \ + - {}\n \ + - {}\n\n\ + Run: voxtype setup model", + model, + model_path.display(), + alt_path.display(), + sherpa_path.display(), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_filter_language_tokens() { + assert_eq!(filter_language_tokens("你好世界"), "你好世界"); + assert_eq!(filter_language_tokens("こんにちは"), "こんにちは"); + assert_eq!(filter_language_tokens("no tags here"), "no tags here"); + assert_eq!(filter_language_tokens("你好世界"), "你好世界"); + assert_eq!(filter_language_tokens(""), ""); + } + + #[test] + fn test_resolve_model_path_not_found() { + let result = resolve_model_path("/nonexistent/path"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + TranscribeError::ModelNotFound(_) + )); + } + + #[test] + fn test_resolve_model_path_absolute() { + let temp_dir = tempfile::TempDir::new().unwrap(); + let model_path = temp_dir.path().to_path_buf(); + std::fs::write(model_path.join("model.int8.onnx"), b"dummy").unwrap(); + + let resolved = resolve_model_path(model_path.to_str().unwrap()); + assert!(resolved.is_ok()); + assert_eq!(resolved.unwrap(), model_path); + } +} diff --git a/src/transcribe/fbank.rs b/src/transcribe/fbank.rs new file mode 100644 index 00000000..625226fc --- /dev/null +++ b/src/transcribe/fbank.rs @@ -0,0 +1,373 @@ +//! Shared Fbank (log-mel filterbank) feature extraction +//! +//! Used by SenseVoice, Paraformer, and FireRedASR backends. These models share +//! identical preprocessing: 80-dim Fbank features, LFR stacking (m=7, n=6), +//! and CMVN normalization with the same constants (16kHz, 25ms/10ms frames, +//! Hamming window, 0.97 pre-emphasis). +//! +//! Pipeline: Audio (f32, 16kHz) -> Fbank (80-dim) -> LFR (560-dim) -> CMVN + +use ndarray::Array2; +use rustfft::num_complex::Complex; +use rustfft::FftPlanner; + +/// Default sample rate for Fbank extraction +const DEFAULT_SAMPLE_RATE: usize = 16000; + +/// Default FFT size +const DEFAULT_FFT_SIZE: usize = 512; + +/// Default number of mel filterbank channels +const DEFAULT_NUM_MELS: usize = 80; + +/// Default frame length in samples (25ms at 16kHz) +const DEFAULT_FRAME_LENGTH: usize = 400; + +/// Default frame shift in samples (10ms at 16kHz) +const DEFAULT_FRAME_SHIFT: usize = 160; + +/// Default pre-emphasis coefficient +const DEFAULT_PREEMPH_COEFF: f32 = 0.97; + +/// Default LFR window size (stack 7 consecutive frames) +const DEFAULT_LFR_M: usize = 7; + +/// Default LFR stride (advance by 6 frames) +const DEFAULT_LFR_N: usize = 6; + +/// Configuration for Fbank feature extraction +pub struct FbankConfig { + pub sample_rate: usize, + pub fft_size: usize, + pub num_mels: usize, + pub frame_length: usize, + pub frame_shift: usize, + pub preemph_coeff: f32, +} + +impl Default for FbankConfig { + fn default() -> Self { + Self { + sample_rate: DEFAULT_SAMPLE_RATE, + fft_size: DEFAULT_FFT_SIZE, + num_mels: DEFAULT_NUM_MELS, + frame_length: DEFAULT_FRAME_LENGTH, + frame_shift: DEFAULT_FRAME_SHIFT, + preemph_coeff: DEFAULT_PREEMPH_COEFF, + } + } +} + +/// Configuration for LFR (Low Frame Rate) stacking +pub struct LfrConfig { + pub m: usize, + pub n: usize, +} + +impl Default for LfrConfig { + fn default() -> Self { + Self { + m: DEFAULT_LFR_M, + n: DEFAULT_LFR_N, + } + } +} + +/// Fbank feature extractor with pre-computed mel filterbank matrix +pub struct FbankExtractor { + config: FbankConfig, + mel_filterbank: Vec>, +} + +impl FbankExtractor { + /// Create a new FbankExtractor with the given configuration + pub fn new(config: FbankConfig) -> Self { + let mel_filterbank = + compute_mel_filterbank(config.num_mels, config.fft_size, config.sample_rate as f32); + Self { + config, + mel_filterbank, + } + } + + /// Create a new FbankExtractor with default SenseVoice/Paraformer settings + pub fn new_default() -> Self { + Self::new(FbankConfig::default()) + } + + /// Number of mel channels in the output + pub fn num_mels(&self) -> usize { + self.config.num_mels + } + + /// Extract 80-dim log-mel filterbank features from audio samples + /// + /// Input: f32 samples at the configured sample rate (default 16kHz) + /// Output: Array2 of shape (num_frames, num_mels) + pub fn extract(&self, samples: &[f32]) -> Array2 { + let num_mels = self.config.num_mels; + let frame_length = self.config.frame_length; + let frame_shift = self.config.frame_shift; + let fft_size = self.config.fft_size; + + // Scale to int16 range (kaldi convention) + let scaled: Vec = samples.iter().map(|&s| s * 32768.0).collect(); + + // Pre-emphasis + let mut emphasized = Vec::with_capacity(scaled.len()); + emphasized.push(scaled[0]); + for i in 1..scaled.len() { + emphasized.push(scaled[i] - self.config.preemph_coeff * scaled[i - 1]); + } + + // Compute number of frames + let num_frames = if emphasized.len() >= frame_length { + (emphasized.len() - frame_length) / frame_shift + 1 + } else { + 0 + }; + + if num_frames == 0 { + return Array2::zeros((0, num_mels)); + } + + // Pre-compute Hamming window + let hamming: Vec = (0..frame_length) + .map(|n| { + 0.54 - 0.46 + * (2.0 * std::f32::consts::PI * n as f32 / (frame_length as f32 - 1.0)).cos() + }) + .collect(); + + // Set up FFT + let mut planner = FftPlanner::::new(); + let fft = planner.plan_fft_forward(fft_size); + + let mut fbank = Array2::zeros((num_frames, num_mels)); + + for frame_idx in 0..num_frames { + let start = frame_idx * frame_shift; + + // Window the frame + let mut fft_input: Vec> = Vec::with_capacity(fft_size); + for i in 0..frame_length { + fft_input.push(Complex::new(emphasized[start + i] * hamming[i], 0.0)); + } + // Zero-pad to fft_size + fft_input.resize(fft_size, Complex::new(0.0, 0.0)); + + // FFT + fft.process(&mut fft_input); + + // Power spectrum (only need first fft_size/2 + 1 bins) + let num_bins = fft_size / 2 + 1; + let power: Vec = fft_input[..num_bins].iter().map(|c| c.norm_sqr()).collect(); + + // Apply mel filterbank and take log + for mel_idx in 0..num_mels { + let energy: f32 = self.mel_filterbank[mel_idx] + .iter() + .zip(power.iter()) + .map(|(&w, &p)| w * p) + .sum(); + fbank[[frame_idx, mel_idx]] = energy.max(1e-10).ln(); + } + } + + fbank + } +} + +/// Apply LFR (Low Frame Rate) stacking: concatenate m frames with stride n +/// +/// Left-pads with copies of the first frame. Output dimension is num_mels * m. +/// Default: m=7 consecutive frames, stride n=6, producing 560-dim features from 80-dim Fbank. +pub fn apply_lfr(fbank: &Array2, config: &LfrConfig) -> Array2 { + let num_mels = fbank.ncols(); + let num_frames = fbank.nrows(); + if num_frames == 0 { + return Array2::zeros((0, num_mels * config.m)); + } + + // Left-pad with copies of the first frame + let pad = (config.m - 1) / 2; + let padded_len = pad + num_frames; + let output_frames = padded_len.div_ceil(config.n); + + let mut output = Array2::zeros((output_frames, num_mels * config.m)); + + for out_idx in 0..output_frames { + let center = out_idx * config.n; + for j in 0..config.m { + let padded_idx = center + j; + let frame_idx = if padded_idx < pad { + 0 + } else { + (padded_idx - pad).min(num_frames - 1) + }; + + let col_start = j * num_mels; + for k in 0..num_mels { + output[[out_idx, col_start + k]] = fbank[[frame_idx, k]]; + } + } + } + + output +} + +/// Apply CMVN (Cepstral Mean and Variance Normalization) +/// +/// Formula: normalized = (features + neg_mean) * inv_stddev +/// Applied element-wise per feature dimension. +pub fn apply_cmvn(features: &mut Array2, neg_mean: &[f32], inv_stddev: &[f32]) { + let feat_dim = features.ncols(); + for row in features.rows_mut() { + for (j, val) in row.into_iter().enumerate() { + if j < feat_dim && j < neg_mean.len() { + *val = (*val + neg_mean[j]) * inv_stddev[j]; + } + } + } +} + +/// Compute mel filterbank matrix +/// +/// Returns num_mels triangular filters, each with fft_size/2+1 coefficients. +/// Uses the standard mel scale: mel = 1127 * ln(1 + f/700) +pub fn compute_mel_filterbank( + num_mels: usize, + fft_size: usize, + sample_rate: f32, +) -> Vec> { + let num_bins = fft_size / 2 + 1; + let max_freq = sample_rate / 2.0; + + let hz_to_mel = |f: f32| -> f32 { 1127.0 * (1.0 + f / 700.0).ln() }; + let mel_to_hz = |m: f32| -> f32 { 700.0 * ((m / 1127.0).exp() - 1.0) }; + + let mel_low = hz_to_mel(0.0); + let mel_high = hz_to_mel(max_freq); + + // Mel center frequencies (num_mels + 2 points for triangular filters) + let mel_points: Vec = (0..num_mels + 2) + .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (num_mels + 1) as f32) + .collect(); + + // Convert back to Hz and then to FFT bin indices + let bin_points: Vec = mel_points + .iter() + .map(|&m| mel_to_hz(m) * fft_size as f32 / sample_rate) + .collect(); + + // Build triangular filters + let mut filterbank = Vec::with_capacity(num_mels); + for i in 0..num_mels { + let mut filter = vec![0.0f32; num_bins]; + let left = bin_points[i]; + let center = bin_points[i + 1]; + let right = bin_points[i + 2]; + + for (j, val) in filter.iter_mut().enumerate() { + let freq = j as f32; + if freq >= left && freq < center && center > left { + *val = (freq - left) / (center - left); + } else if freq >= center && freq <= right && right > center { + *val = (right - freq) / (right - center); + } + } + filterbank.push(filter); + } + + filterbank +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mel_filterbank_shape() { + let fb = compute_mel_filterbank(80, 512, 16000.0); + assert_eq!(fb.len(), 80); + assert_eq!(fb[0].len(), 257); // FFT_SIZE/2 + 1 + } + + #[test] + fn test_mel_filterbank_triangular() { + let fb = compute_mel_filterbank(80, 512, 16000.0); + for filter in &fb { + for &val in filter { + assert!(val >= 0.0, "Filter values should be non-negative"); + } + } + for (i, filter) in fb.iter().enumerate() { + let sum: f32 = filter.iter().sum(); + assert!(sum > 0.0, "Filter {} should have non-zero area", i); + } + } + + #[test] + fn test_fbank_extractor_default() { + let extractor = FbankExtractor::new_default(); + assert_eq!(extractor.num_mels(), 80); + } + + #[test] + fn test_fbank_empty_audio() { + let extractor = FbankExtractor::new_default(); + // Audio shorter than one frame (400 samples at 16kHz = 25ms) + let short_audio = vec![0.0f32; 100]; + let result = extractor.extract(&short_audio); + assert_eq!(result.nrows(), 0); + } + + #[test] + fn test_fbank_one_second() { + let extractor = FbankExtractor::new_default(); + // 1 second of silence at 16kHz + let audio = vec![0.0f32; 16000]; + let result = extractor.extract(&audio); + // Expected frames: (16000 - 400) / 160 + 1 = 98 + assert_eq!(result.nrows(), 98); + assert_eq!(result.ncols(), 80); + } + + #[test] + fn test_lfr_default() { + let config = LfrConfig::default(); + assert_eq!(config.m, 7); + assert_eq!(config.n, 6); + } + + #[test] + fn test_lfr_stacking() { + let fbank = Array2::ones((100, 80)); + let config = LfrConfig::default(); + let result = apply_lfr(&fbank, &config); + // Output dim should be 80 * 7 = 560 + assert_eq!(result.ncols(), 560); + // Output frames: ceil((3 + 100) / 6) = ceil(103/6) = 18 + assert_eq!(result.nrows(), 18); + } + + #[test] + fn test_lfr_empty() { + let fbank = Array2::zeros((0, 80)); + let config = LfrConfig::default(); + let result = apply_lfr(&fbank, &config); + assert_eq!(result.nrows(), 0); + assert_eq!(result.ncols(), 560); + } + + #[test] + fn test_cmvn() { + let mut features = Array2::from_elem((2, 3), 1.0f32); + let neg_mean = vec![-1.0, -1.0, -1.0]; // (1.0 + (-1.0)) = 0.0 + let inv_stddev = vec![2.0, 2.0, 2.0]; // 0.0 * 2.0 = 0.0 + apply_cmvn(&mut features, &neg_mean, &inv_stddev); + for val in features.iter() { + assert!((val - 0.0).abs() < 1e-6); + } + } +} diff --git a/src/transcribe/mod.rs b/src/transcribe/mod.rs index 02977e72..445025fc 100644 --- a/src/transcribe/mod.rs +++ b/src/transcribe/mod.rs @@ -7,6 +7,10 @@ //! - Subprocess isolation for GPU memory release //! - Optionally NVIDIA Parakeet via ONNX Runtime (when `parakeet` feature is enabled) //! - Optionally Moonshine via ONNX Runtime (when `moonshine` feature is enabled) +//! - Optionally SenseVoice via ONNX Runtime (when `sensevoice` feature is enabled) +//! - Optionally Paraformer via ONNX Runtime (when `paraformer` feature is enabled) +//! - Optionally Dolphin via ONNX Runtime (when `dolphin` feature is enabled) +//! - Optionally Omnilingual via ONNX Runtime (when `omnilingual` feature is enabled) pub mod cli; pub mod remote; @@ -14,12 +18,42 @@ pub mod subprocess; pub mod whisper; pub mod worker; +/// Shared log-mel filterbank feature extraction for ONNX-based ASR engines +#[cfg(any( + feature = "sensevoice", + feature = "paraformer", + feature = "dolphin", + feature = "omnilingual", +))] +pub mod fbank; + +/// Shared CTC greedy decoder for CTC-based ASR engines +#[cfg(any( + feature = "sensevoice", + feature = "paraformer", + feature = "dolphin", + feature = "omnilingual", +))] +pub mod ctc; + #[cfg(feature = "parakeet")] pub mod parakeet; #[cfg(feature = "moonshine")] pub mod moonshine; +#[cfg(feature = "sensevoice")] +pub mod sensevoice; + +#[cfg(feature = "paraformer")] +pub mod paraformer; + +#[cfg(feature = "dolphin")] +pub mod dolphin; + +#[cfg(feature = "omnilingual")] +pub mod omnilingual; + use crate::config::{Config, TranscriptionEngine, WhisperConfig, WhisperMode}; use crate::error::TranscribeError; use crate::setup::gpu; @@ -80,6 +114,67 @@ pub fn create_transcriber(config: &Config) -> Result, Trans "Moonshine engine requested but voxtype was not compiled with --features moonshine" .to_string(), )), + #[cfg(feature = "sensevoice")] + TranscriptionEngine::SenseVoice => { + let sensevoice_config = config.sensevoice.as_ref().ok_or_else(|| { + TranscribeError::InitFailed( + "SenseVoice engine selected but [sensevoice] config section is missing" + .to_string(), + ) + })?; + Ok(Box::new(sensevoice::SenseVoiceTranscriber::new( + sensevoice_config, + )?)) + } + #[cfg(not(feature = "sensevoice"))] + TranscriptionEngine::SenseVoice => Err(TranscribeError::InitFailed( + "SenseVoice engine requested but voxtype was not compiled with --features sensevoice" + .to_string(), + )), + #[cfg(feature = "paraformer")] + TranscriptionEngine::Paraformer => { + let cfg = config.paraformer.as_ref().ok_or_else(|| { + TranscribeError::InitFailed( + "Paraformer engine selected but [paraformer] config section is missing" + .to_string(), + ) + })?; + Ok(Box::new(paraformer::ParaformerTranscriber::new(cfg)?)) + } + #[cfg(not(feature = "paraformer"))] + TranscriptionEngine::Paraformer => Err(TranscribeError::InitFailed( + "Paraformer engine requested but voxtype was not compiled with --features paraformer" + .to_string(), + )), + #[cfg(feature = "dolphin")] + TranscriptionEngine::Dolphin => { + let cfg = config.dolphin.as_ref().ok_or_else(|| { + TranscribeError::InitFailed( + "Dolphin engine selected but [dolphin] config section is missing".to_string(), + ) + })?; + Ok(Box::new(dolphin::DolphinTranscriber::new(cfg)?)) + } + #[cfg(not(feature = "dolphin"))] + TranscriptionEngine::Dolphin => Err(TranscribeError::InitFailed( + "Dolphin engine requested but voxtype was not compiled with --features dolphin" + .to_string(), + )), + #[cfg(feature = "omnilingual")] + TranscriptionEngine::Omnilingual => { + let cfg = config.omnilingual.as_ref().ok_or_else(|| { + TranscribeError::InitFailed( + "Omnilingual engine selected but [omnilingual] config section is missing" + .to_string(), + ) + })?; + Ok(Box::new(omnilingual::OmnilingualTranscriber::new(cfg)?)) + } + #[cfg(not(feature = "omnilingual"))] + TranscriptionEngine::Omnilingual => Err(TranscribeError::InitFailed( + "Omnilingual engine requested but voxtype was not compiled with --features omnilingual" + .to_string(), + )), } } diff --git a/src/transcribe/omnilingual.rs b/src/transcribe/omnilingual.rs new file mode 100644 index 00000000..142e4819 --- /dev/null +++ b/src/transcribe/omnilingual.rs @@ -0,0 +1,342 @@ +//! Omnilingual ASR transcription (Meta MMS wav2vec2) +//! +//! Uses Meta's Massively Multilingual Speech model via ONNX Runtime for local +//! transcription. Supports 1600+ languages with a single model. CTC-based, +//! character-level tokenizer with 9812 symbols. +//! +//! The model takes raw audio waveform as input (no Fbank preprocessing). +//! Audio is mean-variance normalized before inference. +//! +//! Pipeline: Audio (f32, 16kHz) -> Normalize -> ONNX model -> CTC decode +//! +//! Languages: 1600+ (language-agnostic, no language selection) +//! Model files: model.int8.onnx (or model.onnx), tokens.txt + +use super::ctc; +use super::Transcriber; +use crate::config::OmnilingualConfig; +use crate::error::TranscribeError; +use ort::session::Session; +use ort::value::Tensor; +use std::collections::HashMap; +use std::path::PathBuf; + +/// Sample rate expected by Omnilingual +const SAMPLE_RATE: usize = 16000; + +/// Omnilingual ASR transcriber using ONNX Runtime +pub struct OmnilingualTranscriber { + session: std::sync::Mutex, + tokens: HashMap, +} + +impl OmnilingualTranscriber { + pub fn new(config: &OmnilingualConfig) -> Result { + let model_dir = resolve_model_path(&config.model)?; + + tracing::info!("Loading Omnilingual model from {:?}", model_dir); + let start = std::time::Instant::now(); + + let threads = config.threads.unwrap_or_else(|| num_cpus::get().min(4)); + + // Find model file (prefer int8 quantized) + let model_file = { + let int8 = model_dir.join("model.int8.onnx"); + let full = model_dir.join("model.onnx"); + if int8.exists() { + int8 + } else if full.exists() { + tracing::info!("Using full-precision model (model.int8.onnx not found)"); + full + } else { + return Err(TranscribeError::ModelNotFound(format!( + "Omnilingual model not found in {:?}\n \ + Expected model.int8.onnx or model.onnx\n \ + Run: voxtype setup model", + model_dir + ))); + } + }; + + // Load tokens.txt + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(format!( + "Omnilingual tokens.txt not found: {}\n \ + Ensure tokens.txt is in the model directory.", + tokens_path.display() + ))); + } + let tokens = ctc::load_tokens(&tokens_path)?; + tracing::debug!("Loaded {} tokens", tokens.len()); + + // Create ONNX session + let session = Session::builder() + .map_err(|e| { + TranscribeError::InitFailed(format!("ONNX session builder failed: {}", e)) + })? + .with_intra_threads(threads) + .map_err(|e| { + TranscribeError::InitFailed(format!("Failed to set threads: {}", e)) + })? + .commit_from_file(&model_file) + .map_err(|e| { + TranscribeError::InitFailed(format!( + "Failed to load Omnilingual model from {:?}: {}", + model_file, e + )) + })?; + + tracing::info!( + "Omnilingual model loaded in {:.2}s", + start.elapsed().as_secs_f32(), + ); + + Ok(Self { + session: std::sync::Mutex::new(session), + tokens, + }) + } +} + +impl Transcriber for OmnilingualTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + if samples.is_empty() { + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); + } + + let duration_secs = samples.len() as f32 / SAMPLE_RATE as f32; + tracing::debug!( + "Transcribing {:.2}s of audio ({} samples) with Omnilingual", + duration_secs, + samples.len(), + ); + + let start = std::time::Instant::now(); + + // Apply mean-variance normalization (instance normalization) + let normalized = normalize_audio(samples); + + let num_samples = normalized.len(); + + // x: shape [1, num_samples] + let x_tensor = + Tensor::::from_array(([1usize, num_samples], normalized)).map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create input tensor: {}", + e + )) + })?; + + // Run inference + let inference_start = std::time::Instant::now(); + let mut session = self.session.lock().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to lock session: {}", e)) + })?; + + let inputs: Vec<(std::borrow::Cow, ort::session::SessionInputValue)> = + vec![(std::borrow::Cow::Borrowed("x"), x_tensor.into())]; + + let outputs = session.run(inputs).map_err(|e| { + TranscribeError::InferenceFailed(format!("Omnilingual inference failed: {}", e)) + })?; + + tracing::debug!( + "ONNX inference: {:.2}s", + inference_start.elapsed().as_secs_f32(), + ); + + // Extract CTC logits and decode + let logits_val = outputs + .get("logits") + .or_else(|| outputs.get("output")) + .ok_or_else(|| { + TranscribeError::InferenceFailed( + "Omnilingual output not found (expected 'logits' or 'output')".to_string(), + ) + })?; + + let (shape, logits_data) = logits_val.try_extract_tensor::().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to extract logits: {}", e)) + })?; + + let shape_dims: &[i64] = shape; + tracing::debug!("Omnilingual output shape: {:?}", shape_dims); + + // CTC output: [batch, time_steps, vocab_size] + let result = if shape_dims.len() == 3 { + let time_steps = shape_dims[1] as usize; + let vocab_size = shape_dims[2] as usize; + let config = ctc::CtcConfig { + blank_id: 0, + num_metadata_tokens: 0, + sentencepiece_cleanup: false, // character-level tokenizer + }; + ctc::ctc_greedy_decode(logits_data, time_steps, vocab_size, &self.tokens, &config) + } else if shape_dims.len() == 2 { + // Pre-argmaxed output + let time_steps = shape_dims[1] as usize; + let config = ctc::CtcConfig { + blank_id: 0, + num_metadata_tokens: 0, + sentencepiece_cleanup: false, + }; + ctc::decode_pre_argmax(&logits_data[..time_steps], &self.tokens, &config) + } else { + return Err(TranscribeError::InferenceFailed(format!( + "Unexpected Omnilingual output shape: {:?}", + shape_dims + ))); + }; + + tracing::info!( + "Omnilingual transcription completed in {:.2}s: {:?}", + start.elapsed().as_secs_f32(), + if result.chars().count() > 50 { + format!("{}...", result.chars().take(50).collect::()) + } else { + result.clone() + } + ); + + Ok(result) + } +} + +/// Apply mean-variance normalization (instance normalization) to audio samples +/// +/// The wav2vec2-based model expects normalized audio: +/// `normalized = (samples - mean) / sqrt(variance + epsilon)` +fn normalize_audio(samples: &[f32]) -> Vec { + let n = samples.len() as f32; + let mean: f32 = samples.iter().sum::() / n; + let variance: f32 = samples.iter().map(|&s| (s - mean) * (s - mean)).sum::() / n; + let inv_stddev = 1.0 / (variance + 1e-5_f32).sqrt(); + + samples.iter().map(|&s| (s - mean) * inv_stddev).collect() +} + +/// Resolve model name to directory path +fn resolve_model_path(model: &str) -> Result { + let path = PathBuf::from(model); + if path.is_absolute() && path.exists() { + return Ok(path); + } + + let model_dir_name = if model.starts_with("omnilingual-") { + model.to_string() + } else { + format!("omnilingual-{}", model) + }; + + let models_dir = crate::config::Config::models_dir(); + let model_path = models_dir.join(&model_dir_name); + if model_path.exists() { + return Ok(model_path); + } + + let alt_path = models_dir.join(model); + if alt_path.exists() { + return Ok(alt_path); + } + + // Check sherpa-onnx naming convention + let sherpa_name = format!( + "sherpa-onnx-omnilingual-asr-1600-languages-{}-ctc", + model.trim_start_matches("omnilingual-") + ); + let sherpa_path = models_dir.join(&sherpa_name); + if sherpa_path.exists() { + return Ok(sherpa_path); + } + + // Also check int8 variant naming + let sherpa_int8_name = format!("{}-int8", &sherpa_name); + let sherpa_int8_path = models_dir.join(&sherpa_int8_name); + if sherpa_int8_path.exists() { + return Ok(sherpa_int8_path); + } + + // Check with date suffix pattern + let models_dir_read = std::fs::read_dir(&models_dir); + if let Ok(entries) = models_dir_read { + for entry in entries.flatten() { + let name = entry.file_name().to_string_lossy().to_string(); + if name.starts_with("sherpa-onnx-omnilingual-asr") && entry.path().is_dir() { + let has_model = entry.path().join("model.int8.onnx").exists() + || entry.path().join("model.onnx").exists(); + let has_tokens = entry.path().join("tokens.txt").exists(); + if has_model && has_tokens { + return Ok(entry.path()); + } + } + } + } + + Err(TranscribeError::ModelNotFound(format!( + "Omnilingual model '{}' not found. Looked in:\n \ + - {}\n \ + - {}\n \ + - {}\n\n\ + Run: voxtype setup model", + model, + model_path.display(), + alt_path.display(), + sherpa_path.display(), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_audio() { + let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let normalized = normalize_audio(&samples); + + // Mean should be ~0 after normalization + let mean: f32 = normalized.iter().sum::() / normalized.len() as f32; + assert!(mean.abs() < 1e-5); + + // Standard deviation should be ~1 + let variance: f32 = normalized + .iter() + .map(|&s| (s - mean) * (s - mean)) + .sum::() + / normalized.len() as f32; + assert!((variance - 1.0).abs() < 0.01); + } + + #[test] + fn test_normalize_audio_constant() { + // Constant signal: all same value + let samples = vec![5.0; 100]; + let normalized = normalize_audio(&samples); + // Should not produce NaN/Inf thanks to epsilon + assert!(normalized.iter().all(|x| x.is_finite())); + } + + #[test] + fn test_resolve_model_path_not_found() { + let result = resolve_model_path("/nonexistent/path"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + TranscribeError::ModelNotFound(_) + )); + } + + #[test] + fn test_resolve_model_path_absolute() { + let temp_dir = tempfile::TempDir::new().unwrap(); + let model_path = temp_dir.path().to_path_buf(); + std::fs::write(model_path.join("model.int8.onnx"), b"dummy").unwrap(); + + let resolved = resolve_model_path(model_path.to_str().unwrap()); + assert!(resolved.is_ok()); + assert_eq!(resolved.unwrap(), model_path); + } +} diff --git a/src/transcribe/paraformer.rs b/src/transcribe/paraformer.rs new file mode 100644 index 00000000..24016192 --- /dev/null +++ b/src/transcribe/paraformer.rs @@ -0,0 +1,701 @@ +//! Paraformer-based speech-to-text transcription +//! +//! Uses Alibaba's Paraformer model via ONNX Runtime for local transcription. +//! Paraformer is a non-autoregressive encoder-predictor-decoder model that +//! generates all output tokens in a single pass (no autoregressive loop). +//! +//! Preprocessing reuses the shared Fbank pipeline (fbank.rs) with identical +//! parameters to SenseVoice: 80-dim Fbank, LFR m=7/n=6, CMVN normalization. +//! The key difference is CMVN stats come from an am.mvn file (Kaldi binary +//! matrix format) rather than ONNX model metadata. +//! +//! Pipeline: Audio (f32, 16kHz) -> Fbank (80-dim) -> LFR (560-dim) -> CMVN -> ONNX -> token decode +//! +//! Languages: zh+en (bilingual), zh+yue+en (trilingual) +//! Model files: model.int8.onnx (or model.onnx), tokens.txt, am.mvn + +use super::fbank::{self, FbankExtractor, LfrConfig}; +use super::ctc; +use super::Transcriber; +use crate::config::ParaformerConfig; +use crate::error::TranscribeError; +use ort::session::Session; +use ort::value::Tensor; +use std::collections::HashMap; +use std::path::PathBuf; + +/// Sample rate expected by Paraformer +const SAMPLE_RATE: usize = 16000; + +/// Paraformer-based transcriber using ONNX Runtime +pub struct ParaformerTranscriber { + session: std::sync::Mutex, + tokens: HashMap, + neg_mean: Vec, + inv_stddev: Vec, + fbank_extractor: FbankExtractor, +} + +impl ParaformerTranscriber { + pub fn new(config: &ParaformerConfig) -> Result { + let model_dir = resolve_model_path(&config.model)?; + + tracing::info!("Loading Paraformer model from {:?}", model_dir); + let start = std::time::Instant::now(); + + let threads = config.threads.unwrap_or_else(|| num_cpus::get().min(4)); + + // Find model file (prefer int8 quantized) + let model_file = { + let int8 = model_dir.join("model.int8.onnx"); + let full = model_dir.join("model.onnx"); + if int8.exists() { + int8 + } else if full.exists() { + tracing::info!("Using full-precision model (model.int8.onnx not found)"); + full + } else { + return Err(TranscribeError::ModelNotFound(format!( + "Paraformer model not found in {:?}\n \ + Expected model.int8.onnx or model.onnx\n \ + Run: voxtype setup model", + model_dir + ))); + } + }; + + // Load tokens.txt + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(format!( + "Paraformer tokens.txt not found: {}\n \ + Ensure tokens.txt is in the model directory.", + tokens_path.display() + ))); + } + let tokens = ctc::load_tokens(&tokens_path)?; + tracing::debug!("Loaded {} tokens", tokens.len()); + + // Create ONNX session + let session = Session::builder() + .map_err(|e| { + TranscribeError::InitFailed(format!("ONNX session builder failed: {}", e)) + })? + .with_intra_threads(threads) + .map_err(|e| { + TranscribeError::InitFailed(format!("Failed to set threads: {}", e)) + })? + .commit_from_file(&model_file) + .map_err(|e| { + TranscribeError::InitFailed(format!( + "Failed to load Paraformer model from {:?}: {}", + model_file, e + )) + })?; + + // Read CMVN stats from am.mvn (Kaldi binary matrix) + let mvn_path = model_dir.join("am.mvn"); + let (neg_mean, inv_stddev) = if mvn_path.exists() { + read_cmvn_from_kaldi_mvn(&mvn_path)? + } else { + // Fall back to ONNX model metadata (like SenseVoice) + tracing::info!("am.mvn not found, trying ONNX model metadata for CMVN"); + read_cmvn_from_metadata(&session)? + }; + + let fbank_extractor = FbankExtractor::new_default(); + + tracing::info!( + "Paraformer model loaded in {:.2}s", + start.elapsed().as_secs_f32(), + ); + + Ok(Self { + session: std::sync::Mutex::new(session), + tokens, + neg_mean, + inv_stddev, + fbank_extractor, + }) + } +} + +impl Transcriber for ParaformerTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + if samples.is_empty() { + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); + } + + let duration_secs = samples.len() as f32 / SAMPLE_RATE as f32; + tracing::debug!( + "Transcribing {:.2}s of audio ({} samples) with Paraformer", + duration_secs, + samples.len(), + ); + + let start = std::time::Instant::now(); + + // 1. Extract Fbank features (shared pipeline, identical to SenseVoice) + let fbank_start = std::time::Instant::now(); + let fbank_features = self.fbank_extractor.extract(samples); + tracing::debug!( + "Fbank extraction: {:.2}s ({} frames x {})", + fbank_start.elapsed().as_secs_f32(), + fbank_features.nrows(), + fbank_features.ncols(), + ); + + if fbank_features.nrows() == 0 { + return Err(TranscribeError::AudioFormat( + "Audio too short for feature extraction".to_string(), + )); + } + + // 2. LFR stacking (shared, same m=7/n=6 as SenseVoice) + let lfr = fbank::apply_lfr(&fbank_features, &LfrConfig::default()); + tracing::debug!("LFR output: {} frames x {}", lfr.nrows(), lfr.ncols()); + + // 3. CMVN normalization (shared, stats from am.mvn) + let mut features = lfr; + fbank::apply_cmvn(&mut features, &self.neg_mean, &self.inv_stddev); + + // 4. Build ONNX inputs + let num_frames = features.nrows(); + let feat_dim = features.ncols(); + + // speech: shape [1, T, 560] + let (x_data, _offset) = features.into_raw_vec_and_offset(); + let speech_tensor = Tensor::::from_array(([1usize, num_frames, feat_dim], x_data)) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create speech tensor: {}", + e + )) + })?; + + // speech_lengths: shape [1] + let lengths_tensor = Tensor::::from_array(([1usize], vec![num_frames as i32])) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create lengths tensor: {}", + e + )) + })?; + + // 5. Run inference + let inference_start = std::time::Instant::now(); + let mut session = self.session.lock().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to lock session: {}", e)) + })?; + + let inputs: Vec<(std::borrow::Cow, ort::session::SessionInputValue)> = vec![ + (std::borrow::Cow::Borrowed("speech"), speech_tensor.into()), + ( + std::borrow::Cow::Borrowed("speech_lengths"), + lengths_tensor.into(), + ), + ]; + + let outputs = session.run(inputs).map_err(|e| { + TranscribeError::InferenceFailed(format!("Paraformer inference failed: {}", e)) + })?; + + tracing::debug!( + "ONNX inference: {:.2}s", + inference_start.elapsed().as_secs_f32(), + ); + + // 6. Extract output and decode tokens + // Paraformer outputs token IDs directly (not CTC logits) + let result = decode_paraformer_output(&outputs, &self.tokens)?; + + tracing::info!( + "Paraformer transcription completed in {:.2}s: {:?}", + start.elapsed().as_secs_f32(), + if result.chars().count() > 50 { + format!("{}...", result.chars().take(50).collect::()) + } else { + result.clone() + } + ); + + Ok(result) + } +} + +/// Decode Paraformer ONNX output to text +/// +/// Paraformer outputs token IDs directly from its CIF+decoder pipeline. +/// The output may be named "logits" and shaped as either: +/// - [batch, seq_len] with i64 token IDs +/// - [batch, seq_len] with f32 token IDs (pre-argmaxed) +/// - [batch, seq_len, vocab_size] with f32 logits (needs argmax) +fn decode_paraformer_output( + outputs: &ort::session::SessionOutputs, + tokens: &HashMap, +) -> Result { + // Try to find the output tensor - name varies across model exports + let output_val = outputs + .get("logits") + .or_else(|| outputs.get("output")) + .ok_or_else(|| { + TranscribeError::InferenceFailed( + "Paraformer model output not found (expected 'logits' or 'output')".to_string(), + ) + })?; + + // Try extracting as i64 first (direct token IDs) + if let Ok((shape, data)) = output_val.try_extract_tensor::() { + tracing::debug!("Paraformer output shape (i64): {:?}", &*shape); + let token_ids: Vec = data.iter().map(|&id| id as u32).collect(); + return Ok(tokens_to_text(&token_ids, tokens)); + } + + // Try extracting as f32 + let (shape, data) = output_val.try_extract_tensor::().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to extract Paraformer output: {}", e)) + })?; + + let shape_dims: &[i64] = shape; + tracing::debug!("Paraformer output shape (f32): {:?}", shape_dims); + + if shape_dims.len() == 3 { + // [batch, seq_len, vocab_size] - needs argmax then BPE-aware decoding + let seq_len = shape_dims[1] as usize; + let vocab_size = shape_dims[2] as usize; + // Do CTC argmax + dedup + blank removal to get token IDs, + // then pass through tokens_to_text which handles @@ BPE markers + let token_ids = ctc_decode_to_ids(data, seq_len, vocab_size); + Ok(tokens_to_text(&token_ids, tokens)) + } else if shape_dims.len() == 2 { + // [batch, seq_len] - pre-argmaxed token IDs as f32 + let seq_len = shape_dims[1] as usize; + let token_ids: Vec = data[..seq_len] + .iter() + .map(|&v| v as u32) + .collect(); + Ok(tokens_to_text(&token_ids, tokens)) + } else { + Err(TranscribeError::InferenceFailed(format!( + "Unexpected Paraformer output shape: {:?}", + shape_dims + ))) + } +} + +/// CTC greedy decode to token IDs: argmax per frame, collapse duplicates, remove blanks +fn ctc_decode_to_ids(logits: &[f32], time_steps: usize, vocab_size: usize) -> Vec { + let blank_id: u32 = 0; + let mut token_ids: Vec = Vec::new(); + let mut prev_id: Option = None; + + for t in 0..time_steps { + let offset = t * vocab_size; + let frame = &logits[offset..offset + vocab_size]; + + let best_id = frame + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx as u32) + .unwrap_or(blank_id); + + if best_id != blank_id && Some(best_id) != prev_id { + token_ids.push(best_id); + } + prev_id = Some(best_id); + } + + token_ids +} + +/// Convert token IDs to text, handling BPE continuation markers +/// +/// Paraformer uses `@@` suffix for BPE continuation (e.g., "hel@@" + "lo" = "hello"). +/// Chinese characters appear as individual tokens without markers. +/// Special tokens (, , , ) are filtered out. +fn tokens_to_text(token_ids: &[u32], tokens: &HashMap) -> String { + let mut result = String::new(); + + for &id in token_ids { + if let Some(token_str) = tokens.get(&id) { + // Skip special tokens + if token_str.starts_with('<') && token_str.ends_with('>') { + continue; + } + + // Handle BPE continuation marker + if let Some(base) = token_str.strip_suffix("@@") { + result.push_str(base); + } else { + // SentencePiece marker cleanup (some models use this instead of @@) + result.push_str(&token_str.replace('\u{2581}', " ")); + } + } + } + + result.trim().to_string() +} + +/// Read CMVN stats from Kaldi am.mvn binary matrix file +/// +/// Format: binary header + 2-row float matrix where: +/// - Row 0: accumulated feature sums, last element = frame count +/// - Row 1: accumulated squared sums, last element = 0 +/// +/// Returns (neg_mean, inv_stddev) for use with apply_cmvn() +fn read_cmvn_from_kaldi_mvn( + path: &std::path::Path, +) -> Result<(Vec, Vec), TranscribeError> { + let data = std::fs::read(path).map_err(|e| { + TranscribeError::InitFailed(format!("Failed to read am.mvn: {}", e)) + })?; + + let mut pos = 0; + + // Skip binary header: "\0B" marker + if data.len() < 2 || data[0] != 0x00 || data[1] != b'B' { + return Err(TranscribeError::InitFailed( + "am.mvn: invalid Kaldi binary marker".to_string(), + )); + } + pos += 2; + + // Skip optional space after header + if pos < data.len() && data[pos] == b' ' { + pos += 1; + } + + // Read matrix type: "FM" for float, "DM" for double + let is_double = if pos + 2 <= data.len() { + let tag = &data[pos..pos + 2]; + pos += 2; + if tag == b"FM" { + false + } else if tag == b"DM" { + true + } else { + return Err(TranscribeError::InitFailed(format!( + "am.mvn: unexpected matrix type tag: {:?}", + tag + ))); + } + } else { + return Err(TranscribeError::InitFailed( + "am.mvn: truncated matrix header".to_string(), + )); + }; + + // Skip optional space + if pos < data.len() && data[pos] == b' ' { + pos += 1; + } + + // Read dimensions: \4 \4 + if pos >= data.len() || data[pos] != 4 { + return Err(TranscribeError::InitFailed( + "am.mvn: expected \\4 before rows".to_string(), + )); + } + pos += 1; + if pos + 4 > data.len() { + return Err(TranscribeError::InitFailed( + "am.mvn: truncated rows value".to_string(), + )); + } + let rows = i32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize; + pos += 4; + + if pos >= data.len() || data[pos] != 4 { + return Err(TranscribeError::InitFailed( + "am.mvn: expected \\4 before cols".to_string(), + )); + } + pos += 1; + if pos + 4 > data.len() { + return Err(TranscribeError::InitFailed( + "am.mvn: truncated cols value".to_string(), + )); + } + let cols = i32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize; + pos += 4; + + if rows != 2 { + return Err(TranscribeError::InitFailed(format!( + "am.mvn: expected 2 rows, got {}", + rows + ))); + } + + tracing::debug!("am.mvn: {} rows x {} cols, double={}", rows, cols, is_double); + + // Read matrix data + let feat_dim = cols - 1; // last column is the count + let matrix: Vec> = if is_double { + let elem_size = 8; + let total = rows * cols * elem_size; + if pos + total > data.len() { + return Err(TranscribeError::InitFailed( + "am.mvn: truncated matrix data".to_string(), + )); + } + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let offset = pos + (r * cols + c) * elem_size; + f64::from_le_bytes(data[offset..offset + elem_size].try_into().unwrap()) + }) + .collect() + }) + .collect() + } else { + let elem_size = 4; + let total = rows * cols * elem_size; + if pos + total > data.len() { + return Err(TranscribeError::InitFailed( + "am.mvn: truncated matrix data".to_string(), + )); + } + (0..rows) + .map(|r| { + (0..cols) + .map(|c| { + let offset = pos + (r * cols + c) * elem_size; + f32::from_le_bytes(data[offset..offset + elem_size].try_into().unwrap()) + as f64 + }) + .collect() + }) + .collect() + }; + + // Extract mean and variance from accumulated stats + let count = matrix[0][feat_dim]; // frame count is last element of row 0 + if count <= 0.0 { + return Err(TranscribeError::InitFailed( + "am.mvn: zero frame count".to_string(), + )); + } + + let mut neg_mean = Vec::with_capacity(feat_dim); + let mut inv_stddev = Vec::with_capacity(feat_dim); + + for i in 0..feat_dim { + let mean = matrix[0][i] / count; + let variance = (matrix[1][i] / count) - (mean * mean); + let stddev = variance.max(1e-20).sqrt(); + neg_mean.push(-mean as f32); + inv_stddev.push((1.0 / stddev) as f32); + } + + tracing::debug!( + "CMVN stats loaded from am.mvn: {} dimensions, {:.0} frames", + feat_dim, + count + ); + + Ok((neg_mean, inv_stddev)) +} + +/// Read CMVN stats from ONNX model metadata (fallback if no am.mvn) +fn read_cmvn_from_metadata(session: &Session) -> Result<(Vec, Vec), TranscribeError> { + let metadata = session.metadata().map_err(|e| { + TranscribeError::InitFailed(format!("Failed to read model metadata: {}", e)) + })?; + + let neg_mean_str = metadata.custom("neg_mean").ok_or_else(|| { + TranscribeError::InitFailed( + "Model has no am.mvn file and no CMVN metadata. \ + Ensure am.mvn is in the model directory." + .to_string(), + ) + })?; + + let inv_stddev_str = metadata.custom("inv_stddev").ok_or_else(|| { + TranscribeError::InitFailed( + "Model metadata missing 'inv_stddev' key".to_string(), + ) + })?; + + let neg_mean: Vec = neg_mean_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + + let inv_stddev: Vec = inv_stddev_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + + if neg_mean.is_empty() || inv_stddev.is_empty() { + return Err(TranscribeError::InitFailed(format!( + "CMVN stats malformed (neg_mean: {} values, inv_stddev: {} values)", + neg_mean.len(), + inv_stddev.len() + ))); + } + + Ok((neg_mean, inv_stddev)) +} + +/// Resolve model name to directory path +fn resolve_model_path(model: &str) -> Result { + let path = PathBuf::from(model); + if path.is_absolute() && path.exists() { + return Ok(path); + } + + // Map short names to directory names + let model_dir_name = if model.starts_with("paraformer-") { + model.to_string() + } else { + format!("paraformer-{}", model) + }; + + let models_dir = crate::config::Config::models_dir(); + let model_path = models_dir.join(&model_dir_name); + + if model_path.exists() { + return Ok(model_path); + } + + // Check without prefix + let alt_path = models_dir.join(model); + if alt_path.exists() { + return Ok(alt_path); + } + + // Check sherpa-onnx naming convention + let sherpa_name = format!("sherpa-onnx-{}", model_dir_name); + let sherpa_path = models_dir.join(&sherpa_name); + if sherpa_path.exists() { + return Ok(sherpa_path); + } + + Err(TranscribeError::ModelNotFound(format!( + "Paraformer model '{}' not found. Looked in:\n \ + - {}\n \ + - {}\n \ + - {}\n\n\ + Run: voxtype setup model", + model, + model_path.display(), + alt_path.display(), + sherpa_path.display(), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn test_resolve_model_path_not_found() { + let result = resolve_model_path("/nonexistent/path/to/model"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + TranscribeError::ModelNotFound(_) + )); + } + + #[test] + fn test_resolve_model_path_absolute() { + let temp_dir = TempDir::new().unwrap(); + let model_path = temp_dir.path().to_path_buf(); + fs::write(model_path.join("model.int8.onnx"), b"dummy").unwrap(); + + let resolved = resolve_model_path(model_path.to_str().unwrap()); + assert!(resolved.is_ok()); + assert_eq!(resolved.unwrap(), model_path); + } + + #[test] + fn test_tokens_to_text_chinese() { + let mut tokens = HashMap::new(); + tokens.insert(0, "".to_string()); + tokens.insert(1, "".to_string()); + tokens.insert(2, "".to_string()); + tokens.insert(10, "你".to_string()); + tokens.insert(11, "好".to_string()); + tokens.insert(12, "世".to_string()); + tokens.insert(13, "界".to_string()); + + let ids = vec![1, 10, 11, 12, 13, 2]; + let result = tokens_to_text(&ids, &tokens); + assert_eq!(result, "你好世界"); + } + + #[test] + fn test_tokens_to_text_bpe() { + let mut tokens = HashMap::new(); + tokens.insert(0, "".to_string()); + tokens.insert(1, "".to_string()); + tokens.insert(2, "".to_string()); + tokens.insert(10, "hel@@".to_string()); + tokens.insert(11, "lo".to_string()); + tokens.insert(12, "wor@@".to_string()); + tokens.insert(13, "ld".to_string()); + + let ids = vec![1, 10, 11, 12, 13, 2]; + let result = tokens_to_text(&ids, &tokens); + assert_eq!(result, "helloworld"); + } + + #[test] + fn test_tokens_to_text_sentencepiece() { + let mut tokens = HashMap::new(); + tokens.insert(10, "\u{2581}hello".to_string()); + tokens.insert(11, "\u{2581}world".to_string()); + + let ids = vec![10, 11]; + let result = tokens_to_text(&ids, &tokens); + assert_eq!(result, "hello world"); + } + + #[test] + fn test_read_cmvn_kaldi_float() { + let temp_dir = TempDir::new().unwrap(); + let mvn_path = temp_dir.path().join("am.mvn"); + + // Build a minimal Kaldi binary float matrix: 2 rows x 4 cols + // Row 0: [10.0, 20.0, 30.0, 5.0] (sums + count=5) + // Row 1: [30.0, 100.0, 200.0, 0.0] (sum of squares) + let mut data: Vec = Vec::new(); + data.push(0x00); // binary marker + data.push(b'B'); + data.push(b'F'); // float matrix + data.push(b'M'); + data.push(b' '); + data.push(4); // \4 before rows + data.extend_from_slice(&2i32.to_le_bytes()); + data.push(4); // \4 before cols + data.extend_from_slice(&4i32.to_le_bytes()); + // Row 0: sums + for v in &[10.0f32, 20.0, 30.0, 5.0] { + data.extend_from_slice(&v.to_le_bytes()); + } + // Row 1: sum of squares + for v in &[30.0f32, 100.0, 200.0, 0.0] { + data.extend_from_slice(&v.to_le_bytes()); + } + + fs::write(&mvn_path, &data).unwrap(); + + let (neg_mean, inv_stddev) = read_cmvn_from_kaldi_mvn(&mvn_path).unwrap(); + assert_eq!(neg_mean.len(), 3); + assert_eq!(inv_stddev.len(), 3); + + // mean[0] = 10/5 = 2.0, neg_mean[0] = -2.0 + assert!((neg_mean[0] - (-2.0)).abs() < 1e-5); + // mean[1] = 20/5 = 4.0, neg_mean[1] = -4.0 + assert!((neg_mean[1] - (-4.0)).abs() < 1e-5); + // variance[0] = 30/5 - 4 = 2.0, stddev = sqrt(2), inv = 1/sqrt(2) + assert!((inv_stddev[0] - (1.0 / 2.0f32.sqrt())).abs() < 1e-5); + } +} diff --git a/src/transcribe/sensevoice.rs b/src/transcribe/sensevoice.rs new file mode 100644 index 00000000..60c76e7a --- /dev/null +++ b/src/transcribe/sensevoice.rs @@ -0,0 +1,443 @@ +//! SenseVoice-based speech-to-text transcription +//! +//! Uses Alibaba's SenseVoice model via ONNX Runtime for local transcription. +//! SenseVoice is an encoder-only CTC model (no autoregressive decoder loop), +//! making inference a single forward pass. Preprocessing uses the shared Fbank +//! pipeline (fbank.rs) and CTC decoding uses the shared decoder (ctc.rs). +//! +//! Pipeline: Audio (f32, 16kHz) -> Fbank (80-dim) -> LFR (560-dim) -> CMVN -> ONNX -> CTC decode +//! +//! Supports languages: auto, zh, en, ja, ko, yue +//! Model files: model.int8.onnx (or model.onnx), tokens.txt + +use super::fbank::{self, FbankExtractor, LfrConfig}; +use super::ctc::{self, CtcConfig}; +use super::Transcriber; +use crate::config::SenseVoiceConfig; +use crate::error::TranscribeError; +use ort::session::Session; +use ort::value::Tensor; +use std::collections::HashMap; +use std::path::PathBuf; + +/// Sample rate expected by SenseVoice +const SAMPLE_RATE: usize = 16000; + +/// SenseVoice-based transcriber using ONNX Runtime +pub struct SenseVoiceTranscriber { + session: std::sync::Mutex, + tokens: HashMap, + neg_mean: Vec, + inv_stddev: Vec, + language_id: i32, + text_norm_id: i32, + fbank_extractor: FbankExtractor, + ctc_config: CtcConfig, +} + +impl SenseVoiceTranscriber { + pub fn new(config: &SenseVoiceConfig) -> Result { + let model_dir = resolve_model_path(&config.model)?; + + tracing::info!("Loading SenseVoice model from {:?}", model_dir); + let start = std::time::Instant::now(); + + let threads = config.threads.unwrap_or_else(|| num_cpus::get().min(4)); + + // Find model file (prefer int8 quantized) + let model_file = { + let int8 = model_dir.join("model.int8.onnx"); + let full = model_dir.join("model.onnx"); + if int8.exists() { + int8 + } else if full.exists() { + tracing::info!("Using full-precision model (model.int8.onnx not found)"); + full + } else { + return Err(TranscribeError::ModelNotFound(format!( + "SenseVoice model not found in {:?}\n \ + Expected model.int8.onnx or model.onnx\n \ + Download from: https://huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", + model_dir + ))); + } + }; + + // Load tokens.txt + let tokens_path = model_dir.join("tokens.txt"); + if !tokens_path.exists() { + return Err(TranscribeError::ModelNotFound(format!( + "SenseVoice tokens.txt not found: {}\n \ + Ensure tokens.txt is in the model directory.", + tokens_path.display() + ))); + } + let tokens = ctc::load_tokens(&tokens_path)?; + tracing::debug!("Loaded {} tokens", tokens.len()); + + // Create ONNX session + let session = Session::builder() + .map_err(|e| { + TranscribeError::InitFailed(format!("ONNX session builder failed: {}", e)) + })? + .with_intra_threads(threads) + .map_err(|e| { + TranscribeError::InitFailed(format!("Failed to set threads: {}", e)) + })? + .commit_from_file(&model_file) + .map_err(|e| { + TranscribeError::InitFailed(format!( + "Failed to load SenseVoice model from {:?}: {}", + model_file, e + )) + })?; + + // Read CMVN stats from model metadata + let (neg_mean, inv_stddev) = read_cmvn_from_metadata(&session)?; + + // Map language config to ID + let language_id = language_to_id(&config.language); + let text_norm_id = if config.use_itn { 14 } else { 15 }; + + // Create shared Fbank extractor with default settings + let fbank_extractor = FbankExtractor::new_default(); + + tracing::info!( + "SenseVoice model loaded in {:.2}s (language={}, use_itn={})", + start.elapsed().as_secs_f32(), + config.language, + config.use_itn, + ); + + Ok(Self { + session: std::sync::Mutex::new(session), + tokens, + neg_mean, + inv_stddev, + language_id, + text_norm_id, + fbank_extractor, + ctc_config: CtcConfig::sensevoice(), + }) + } +} + +impl Transcriber for SenseVoiceTranscriber { + fn transcribe(&self, samples: &[f32]) -> Result { + if samples.is_empty() { + return Err(TranscribeError::AudioFormat( + "Empty audio buffer".to_string(), + )); + } + + let duration_secs = samples.len() as f32 / SAMPLE_RATE as f32; + tracing::debug!( + "Transcribing {:.2}s of audio ({} samples) with SenseVoice", + duration_secs, + samples.len(), + ); + + let start = std::time::Instant::now(); + + // 1. Extract Fbank features (shared pipeline) + let fbank_start = std::time::Instant::now(); + let fbank_features = self.fbank_extractor.extract(samples); + tracing::debug!( + "Fbank extraction: {:.2}s ({} frames x {})", + fbank_start.elapsed().as_secs_f32(), + fbank_features.nrows(), + fbank_features.ncols(), + ); + + if fbank_features.nrows() == 0 { + return Err(TranscribeError::AudioFormat( + "Audio too short for feature extraction".to_string(), + )); + } + + // 2. LFR stacking (shared) + let lfr = fbank::apply_lfr(&fbank_features, &LfrConfig::default()); + tracing::debug!("LFR output: {} frames x {}", lfr.nrows(), lfr.ncols()); + + // 3. CMVN normalization (shared) + let mut features = lfr; + fbank::apply_cmvn(&mut features, &self.neg_mean, &self.inv_stddev); + + // 4. Build ONNX inputs + let num_frames = features.nrows(); + let feat_dim = features.ncols(); + + // x: shape [1, T, 560] + let (x_data, _offset) = features.into_raw_vec_and_offset(); + let x_tensor = Tensor::::from_array(([1usize, num_frames, feat_dim], x_data)) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create input tensor: {}", + e + )) + })?; + + // x_length: shape [1] + let x_length_tensor = Tensor::::from_array(([1usize], vec![num_frames as i32])) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create length tensor: {}", + e + )) + })?; + + // language: shape [1] + let language_tensor = Tensor::::from_array(([1usize], vec![self.language_id])) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create language tensor: {}", + e + )) + })?; + + // text_norm: shape [1] + let text_norm_tensor = Tensor::::from_array(([1usize], vec![self.text_norm_id])) + .map_err(|e| { + TranscribeError::InferenceFailed(format!( + "Failed to create text_norm tensor: {}", + e + )) + })?; + + // 5. Run inference + let inference_start = std::time::Instant::now(); + let mut session = self.session.lock().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to lock session: {}", e)) + })?; + + let inputs: Vec<(std::borrow::Cow, ort::session::SessionInputValue)> = vec![ + (std::borrow::Cow::Borrowed("x"), x_tensor.into()), + (std::borrow::Cow::Borrowed("x_length"), x_length_tensor.into()), + (std::borrow::Cow::Borrowed("language"), language_tensor.into()), + (std::borrow::Cow::Borrowed("text_norm"), text_norm_tensor.into()), + ]; + + let outputs = session.run(inputs).map_err(|e| { + TranscribeError::InferenceFailed(format!("SenseVoice inference failed: {}", e)) + })?; + + tracing::debug!( + "ONNX inference: {:.2}s", + inference_start.elapsed().as_secs_f32(), + ); + + // 6. Extract logits and decode + let logits_val = &outputs["logits"]; + let (shape, logits_data) = logits_val.try_extract_tensor::().map_err(|e| { + TranscribeError::InferenceFailed(format!("Failed to extract logits: {}", e)) + })?; + + let shape_dims: &[i64] = shape; + tracing::debug!("Logits shape: {:?}", shape_dims); + + // logits shape: [batch=1, time_steps] or [batch=1, time_steps, vocab_size] + let result = if shape_dims.len() == 3 { + let time_steps = shape_dims[1] as usize; + let vocab_size = shape_dims[2] as usize; + ctc::ctc_greedy_decode( + logits_data, + time_steps, + vocab_size, + &self.tokens, + &self.ctc_config, + ) + } else if shape_dims.len() == 2 { + // Pre-argmaxed output: each value is already a token ID + let time_steps = shape_dims[1] as usize; + ctc::decode_pre_argmax( + &logits_data[..time_steps], + &self.tokens, + &self.ctc_config, + ) + } else { + return Err(TranscribeError::InferenceFailed(format!( + "Unexpected logits shape: {:?}", + shape_dims + ))); + }; + + tracing::info!( + "SenseVoice transcription completed in {:.2}s: {:?}", + start.elapsed().as_secs_f32(), + if result.chars().count() > 50 { + format!("{}...", result.chars().take(50).collect::()) + } else { + result.clone() + } + ); + + Ok(result) + } +} + +/// Map language string to SenseVoice language ID +fn language_to_id(language: &str) -> i32 { + match language.to_lowercase().as_str() { + "auto" => 0, + "zh" | "chinese" => 3, + "en" | "english" => 4, + "yue" | "cantonese" => 7, + "ja" | "japanese" => 11, + "ko" | "korean" => 12, + _ => { + tracing::warn!( + "Unknown SenseVoice language '{}', falling back to auto-detect", + language + ); + 0 + } + } +} + +/// Read CMVN stats (neg_mean and inv_stddev) from ONNX model metadata +/// +/// The sherpa-onnx SenseVoice model stores these as comma-separated floats +/// in metadata keys "neg_mean" and "inv_stddev". This is SenseVoice-specific; +/// Paraformer reads CMVN from a separate am.mvn file. +fn read_cmvn_from_metadata(session: &Session) -> Result<(Vec, Vec), TranscribeError> { + let metadata = session.metadata().map_err(|e| { + TranscribeError::InitFailed(format!("Failed to read model metadata: {}", e)) + })?; + + let neg_mean_str = metadata.custom("neg_mean").ok_or_else(|| { + TranscribeError::InitFailed( + "Model metadata missing 'neg_mean' key. Is this a sherpa-onnx SenseVoice model?" + .to_string(), + ) + })?; + + let inv_stddev_str = metadata.custom("inv_stddev").ok_or_else(|| { + TranscribeError::InitFailed( + "Model metadata missing 'inv_stddev' key. Is this a sherpa-onnx SenseVoice model?" + .to_string(), + ) + })?; + + let neg_mean: Vec = neg_mean_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + + let inv_stddev: Vec = inv_stddev_str + .split(',') + .filter_map(|s: &str| s.trim().parse::().ok()) + .collect(); + + if neg_mean.is_empty() || inv_stddev.is_empty() { + return Err(TranscribeError::InitFailed(format!( + "CMVN stats appear malformed (neg_mean: {} values, inv_stddev: {} values)", + neg_mean.len(), + inv_stddev.len() + ))); + } + + tracing::debug!( + "CMVN stats loaded: neg_mean[{}], inv_stddev[{}]", + neg_mean.len(), + inv_stddev.len() + ); + + Ok((neg_mean, inv_stddev)) +} + +/// Resolve model name to directory path +fn resolve_model_path(model: &str) -> Result { + // If it's already an absolute path, use it directly + let path = PathBuf::from(model); + if path.is_absolute() && path.exists() { + return Ok(path); + } + + // Map short names to directory names + let model_dir_name = if model.starts_with("sensevoice-") { + model.to_string() + } else { + format!("sensevoice-{}", model) + }; + + // Check models directory + let models_dir = crate::config::Config::models_dir(); + let model_path = models_dir.join(&model_dir_name); + + if model_path.exists() { + return Ok(model_path); + } + + // Also check without prefix (user might pass "sensevoice-small" or just "small") + let alt_path = models_dir.join(model); + if alt_path.exists() { + return Ok(alt_path); + } + + // Check current directory + let cwd_path = PathBuf::from(&model_dir_name); + if cwd_path.exists() { + return Ok(cwd_path); + } + + // Check ./models/ + let local_models_path = PathBuf::from("models").join(&model_dir_name); + if local_models_path.exists() { + return Ok(local_models_path); + } + + Err(TranscribeError::ModelNotFound(format!( + "SenseVoice model '{}' not found. Looked in:\n \ + - {}\n \ + - {}\n \ + - {}\n \ + - {}\n\n\ + Manual download:\n \ + mkdir -p {}\n \ + cd {} && wget https://huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/resolve/main/model.int8.onnx\n \ + cd {} && wget https://huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/resolve/main/tokens.txt", + model, + model_path.display(), + alt_path.display(), + cwd_path.display(), + local_models_path.display(), + model_path.display(), + model_path.display(), + model_path.display(), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_language_to_id() { + assert_eq!(language_to_id("auto"), 0); + assert_eq!(language_to_id("zh"), 3); + assert_eq!(language_to_id("en"), 4); + assert_eq!(language_to_id("yue"), 7); + assert_eq!(language_to_id("ja"), 11); + assert_eq!(language_to_id("ko"), 12); + assert_eq!(language_to_id("unknown"), 0); // falls back to auto + } + + #[test] + fn test_resolve_model_path_not_found() { + let result = resolve_model_path("/nonexistent/path/to/model"); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, TranscribeError::ModelNotFound(_))); + } + + #[test] + fn test_resolve_model_path_absolute() { + let temp_dir = tempfile::TempDir::new().unwrap(); + let model_path = temp_dir.path().to_path_buf(); + std::fs::write(model_path.join("model.int8.onnx"), b"dummy").unwrap(); + + let resolved = resolve_model_path(model_path.to_str().unwrap()); + assert!(resolved.is_ok()); + assert_eq!(resolved.unwrap(), model_path); + } +} diff --git a/src/transcribe/subprocess.rs b/src/transcribe/subprocess.rs index ebf92d44..5edbd529 100644 --- a/src/transcribe/subprocess.rs +++ b/src/transcribe/subprocess.rs @@ -175,7 +175,7 @@ impl SubprocessTranscriber { let samples_bytes = unsafe { std::slice::from_raw_parts( samples.as_ptr() as *const u8, - samples.len() * std::mem::size_of::(), + std::mem::size_of_val(samples), ) }; stdin.write_all(samples_bytes).map_err(|e| { diff --git a/src/vad/mod.rs b/src/vad/mod.rs index d06be276..0e016bcc 100644 --- a/src/vad/mod.rs +++ b/src/vad/mod.rs @@ -57,7 +57,12 @@ pub fn create_vad(config: &Config) -> Result VadBackend::Whisper, - TranscriptionEngine::Parakeet | TranscriptionEngine::Moonshine => VadBackend::Energy, + TranscriptionEngine::Parakeet + | TranscriptionEngine::Moonshine + | TranscriptionEngine::SenseVoice + | TranscriptionEngine::Paraformer + | TranscriptionEngine::Dolphin + | TranscriptionEngine::Omnilingual => VadBackend::Energy, } } explicit => explicit, diff --git a/tests/fixtures/sensevoice/README.md b/tests/fixtures/sensevoice/README.md new file mode 100644 index 00000000..833ca3c1 --- /dev/null +++ b/tests/fixtures/sensevoice/README.md @@ -0,0 +1,29 @@ +# SenseVoice Test Audio Files + +Test WAV files for validating SenseVoice (and Sherpa) CJK transcription. + +All files are 16-bit PCM, mono, 16kHz. Source: +[sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17](https://huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tree/main/test_wavs) + +## Files + +| File | Language | Duration | Reference transcription | +|------|----------|----------|------------------------| +| zh.wav | Chinese (Mandarin) | 5.5s | 开放时间早上9点至下午5点。 | +| ja.wav | Japanese | 7.2s | (verify against model) | +| ko.wav | Korean | 4.6s | (verify against model) | +| yue.wav | Cantonese | 5.1s | (verify against model) | + +The Chinese reference comes from the sherpa-onnx documentation. Japanese, Korean, +and Cantonese references are not documented upstream; run through the model and +compare against sherpa-onnx output to establish baselines. + +## Usage + +```bash +# Test with SenseVoice +voxtype transcribe tests/fixtures/sensevoice/zh.wav --engine sensevoice + +# Compare against Whisper for the same file +voxtype transcribe tests/fixtures/sensevoice/zh.wav --engine whisper +``` diff --git a/tests/fixtures/sensevoice/ja.wav b/tests/fixtures/sensevoice/ja.wav new file mode 100644 index 00000000..2ca0e04b Binary files /dev/null and b/tests/fixtures/sensevoice/ja.wav differ diff --git a/tests/fixtures/sensevoice/ko.wav b/tests/fixtures/sensevoice/ko.wav new file mode 100644 index 00000000..fbdb3c03 Binary files /dev/null and b/tests/fixtures/sensevoice/ko.wav differ diff --git a/tests/fixtures/sensevoice/yue.wav b/tests/fixtures/sensevoice/yue.wav new file mode 100644 index 00000000..0c012226 Binary files /dev/null and b/tests/fixtures/sensevoice/yue.wav differ diff --git a/tests/fixtures/sensevoice/zh.wav b/tests/fixtures/sensevoice/zh.wav new file mode 100644 index 00000000..24a2bfc5 Binary files /dev/null and b/tests/fixtures/sensevoice/zh.wav differ diff --git a/tests/vad_integration.rs b/tests/vad_integration.rs index 33cd886c..8ee8bb0c 100644 --- a/tests/vad_integration.rs +++ b/tests/vad_integration.rs @@ -58,7 +58,10 @@ fn energy_vad_rejects_pure_silence() { let vad = energy_vad(); let result = vad.detect(&samples).unwrap(); - assert!(!result.has_speech, "Pure silence should not be detected as speech"); + assert!( + !result.has_speech, + "Pure silence should not be detected as speech" + ); assert_eq!(result.speech_duration_secs, 0.0); assert_eq!(result.speech_ratio, 0.0); // RMS may be slightly above 0 due to WAV quantization noise @@ -71,7 +74,10 @@ fn energy_vad_rejects_short_silence() { let vad = energy_vad(); let result = vad.detect(&samples).unwrap(); - assert!(!result.has_speech, "Short silence should not be detected as speech"); + assert!( + !result.has_speech, + "Short silence should not be detected as speech" + ); } #[test] @@ -80,7 +86,10 @@ fn energy_vad_rejects_low_noise() { let vad = energy_vad(); let result = vad.detect(&samples).unwrap(); - assert!(!result.has_speech, "Very low noise should not be detected as speech"); + assert!( + !result.has_speech, + "Very low noise should not be detected as speech" + ); assert!(result.rms_energy < 0.01, "RMS energy should be very low"); } @@ -95,8 +104,14 @@ fn energy_vad_accepts_tone() { let result = vad.detect(&samples).unwrap(); // Energy VAD detects any audio with sufficient energy, not just speech - assert!(result.has_speech, "Loud tone should be detected as 'audio present'"); - assert!(result.speech_ratio > 0.9, "Tone should fill most of the audio"); + assert!( + result.has_speech, + "Loud tone should be detected as 'audio present'" + ); + assert!( + result.speech_ratio > 0.9, + "Tone should fill most of the audio" + ); } #[test] @@ -107,8 +122,14 @@ fn energy_vad_accepts_white_noise() { // White noise has high energy, so Energy VAD will detect it assert!(result.has_speech, "Loud white noise should be detected"); - assert!(result.speech_ratio > 0.9, "White noise should fill most of the audio"); - assert!(result.rms_energy > 0.01, "White noise should have measurable energy"); + assert!( + result.speech_ratio > 0.9, + "White noise should fill most of the audio" + ); + assert!( + result.rms_energy > 0.01, + "White noise should have measurable energy" + ); } #[test] @@ -131,7 +152,10 @@ fn energy_vad_accepts_speech_hello() { let result = vad.detect(&samples).unwrap(); assert!(result.has_speech, "Speech should be detected"); - assert!(result.speech_ratio > 0.5, "Most of the clip should contain speech"); + assert!( + result.speech_ratio > 0.5, + "Most of the clip should contain speech" + ); } #[test] @@ -141,7 +165,10 @@ fn energy_vad_accepts_speech_long() { let result = vad.detect(&samples).unwrap(); assert!(result.has_speech, "Long speech should be detected"); - assert!(result.speech_duration_secs > 1.0, "Should detect significant speech duration"); + assert!( + result.speech_duration_secs > 1.0, + "Should detect significant speech duration" + ); } #[test] @@ -150,10 +177,19 @@ fn energy_vad_accepts_speech_padded() { let vad = energy_vad(); let result = vad.detect(&samples).unwrap(); - assert!(result.has_speech, "Speech with silence padding should still be detected"); + assert!( + result.has_speech, + "Speech with silence padding should still be detected" + ); // The speech ratio should be lower due to silence padding - assert!(result.speech_ratio < 0.8, "Speech ratio should reflect silence padding"); - assert!(result.speech_ratio > 0.1, "But should still detect the speech portion"); + assert!( + result.speech_ratio < 0.8, + "Speech ratio should reflect silence padding" + ); + assert!( + result.speech_ratio > 0.1, + "But should still detect the speech portion" + ); } #[test] @@ -276,7 +312,10 @@ fn energy_vad_min_speech_duration_filtering() { let vad = EnergyVad::new(&config); let result = vad.detect(&samples).unwrap(); - assert!(!result.has_speech, "Speech shorter than min_duration should be rejected"); + assert!( + !result.has_speech, + "Speech shorter than min_duration should be rejected" + ); // But speech_duration_secs should still report the actual detected duration assert!(result.speech_duration_secs > 0.0); } @@ -340,7 +379,10 @@ fn whisper_vad_rejects_tone() { let result = vad.detect(&samples).unwrap(); // Whisper VAD (Silero) is trained on speech, should reject pure tones - assert!(!result.has_speech, "Whisper VAD should reject non-speech tones"); + assert!( + !result.has_speech, + "Whisper VAD should reject non-speech tones" + ); } #[test] @@ -354,7 +396,10 @@ fn whisper_vad_rejects_white_noise() { let result = vad.detect(&samples).unwrap(); // Whisper VAD should reject white noise as non-speech - assert!(!result.has_speech, "Whisper VAD should reject white noise as non-speech"); + assert!( + !result.has_speech, + "Whisper VAD should reject white noise as non-speech" + ); } #[test] @@ -368,7 +413,10 @@ fn whisper_vad_accepts_speech() { let result = vad.detect(&samples).unwrap(); assert!(result.has_speech, "Whisper VAD should detect TTS speech"); - assert!(result.speech_ratio > 0.5, "Most of the speech clip should be detected"); + assert!( + result.speech_ratio > 0.5, + "Most of the speech clip should be detected" + ); } #[test] @@ -382,7 +430,10 @@ fn whisper_vad_accepts_long_speech() { let result = vad.detect(&samples).unwrap(); assert!(result.has_speech, "Whisper VAD should detect longer speech"); - assert!(result.speech_duration_secs > 1.0, "Should detect multiple seconds of speech"); + assert!( + result.speech_duration_secs > 1.0, + "Should detect multiple seconds of speech" + ); } #[test] @@ -395,9 +446,15 @@ fn whisper_vad_handles_padded_speech() { let samples = load_wav("speech_padded.wav"); let result = vad.detect(&samples).unwrap(); - assert!(result.has_speech, "Whisper VAD should detect speech even with silence padding"); + assert!( + result.has_speech, + "Whisper VAD should detect speech even with silence padding" + ); // Speech ratio should be lower due to silence padding - assert!(result.speech_ratio < 0.7, "Should account for silence padding"); + assert!( + result.speech_ratio < 0.7, + "Should account for silence padding" + ); } // ============================================================================ @@ -418,7 +475,10 @@ fn compare_vad_backends_on_tone() { // If Whisper VAD is available, it should NOT detect tone as speech if let Some(whisper) = try_create_whisper_vad() { let whisper_result = whisper.detect(&samples).unwrap(); - assert!(!whisper_result.has_speech, "Whisper VAD should not detect tone as speech"); + assert!( + !whisper_result.has_speech, + "Whisper VAD should not detect tone as speech" + ); } } @@ -433,6 +493,9 @@ fn compare_vad_backends_on_speech() { if let Some(whisper) = try_create_whisper_vad() { let whisper_result = whisper.detect(&samples).unwrap(); - assert!(whisper_result.has_speech, "Whisper VAD should detect speech"); + assert!( + whisper_result.has_speech, + "Whisper VAD should detect speech" + ); } } diff --git a/website/index.html b/website/index.html index ce606750..22aa99ed 100644 --- a/website/index.html +++ b/website/index.html @@ -717,7 +717,7 @@

Parakeet Models (English, experimental)

- .en models are English-only but faster and more accurate for English. Parakeet requires the parakeet binary variant. + .en models are English-only but faster and more accurate for English. Parakeet and Moonshine require the ONNX binary variant.