diff --git a/Cargo.lock b/Cargo.lock index 3b7e0508fd..35f739bece 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2836,7 +2836,7 @@ dependencies = [ [[package]] name = "cactus-sys" version = "0.1.0" -source = "git+https://github.com/cactus-compute/cactus?rev=f8b714cc9782eafd4895d15fd66e9ffe141810bc#f8b714cc9782eafd4895d15fd66e9ffe141810bc" +source = "git+https://github.com/cactus-compute/cactus?rev=a5acad3#a5acad3238643da779dabcd4597f1552348543b5" dependencies = [ "bindgen 0.72.1", "cmake", @@ -20135,14 +20135,18 @@ dependencies = [ name = "transcribe-cactus" version = "0.1.0" dependencies = [ + "audio", "audio-utils", "axum 0.8.8", "bytes", "cactus", + "clap", + "colored", "data", "dirs 6.0.0", "futures-util", "language", + "owhisper-client", "owhisper-interface", "reqwest 0.13.2", "rodio", @@ -20160,6 +20164,27 @@ dependencies = [ "ws-utils", ] +[[package]] +name = "transcribe-cli" +version = "0.1.0" +dependencies = [ + "api-env", + "audio", + "audio-utils", + "axum 0.8.8", + "bytes", + "clap", + "colored", + "crossterm", + "futures-util", + "language", + "owhisper-client", + "owhisper-interface", + "tokio", + "transcribe-cactus", + "transcribe-proxy", +] + [[package]] name = "transcribe-proxy" version = "0.1.0" diff --git a/crates/cactus/Cargo.toml b/crates/cactus/Cargo.toml index 59d427e623..762080ffee 100644 --- a/crates/cactus/Cargo.toml +++ b/crates/cactus/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" license = "MIT" [dependencies] -cactus-sys = { git = "https://github.com/cactus-compute/cactus", package = "cactus-sys", rev = "f8b714cc9782eafd4895d15fd66e9ffe141810bc" } +cactus-sys = { git = "https://github.com/cactus-compute/cactus", package = "cactus-sys", rev = "a5acad3" } hypr-language = { workspace = true } hypr-llm-types = { workspace = true } diff --git a/crates/cactus/src/lib.rs b/crates/cactus/src/lib.rs index 36d28add4f..20061391e9 100644 --- a/crates/cactus/src/lib.rs +++ b/crates/cactus/src/lib.rs @@ -11,7 +11,7 @@ pub use hypr_language::Language; pub use llm::{CompleteOptions, CompletionResult, CompletionStream, Message, complete_stream}; pub use model::{Model, ModelBuilder, ModelKind}; pub use stt::{ - CloudConfig, StreamResult, TranscribeEvent, TranscribeOptions, Transcriber, + CloudConfig, StreamResult, StreamSegment, TranscribeEvent, TranscribeOptions, Transcriber, TranscriptionResult, TranscriptionSession, constrain_to, transcribe_stream, }; pub use vad::{VadOptions, VadResult, VadSegment}; diff --git a/crates/cactus/src/stt/mod.rs b/crates/cactus/src/stt/mod.rs index 08de5e6513..aa6a583dc2 100644 --- a/crates/cactus/src/stt/mod.rs +++ b/crates/cactus/src/stt/mod.rs @@ -6,7 +6,7 @@ mod whisper; pub use result::TranscriptionResult; pub use stream::{TranscribeEvent, TranscriptionSession, transcribe_stream}; -pub use transcriber::{CloudConfig, StreamResult, Transcriber}; +pub use transcriber::{CloudConfig, StreamResult, StreamSegment, Transcriber}; use hypr_language::Language; @@ -36,8 +36,6 @@ pub struct TranscribeOptions { #[serde(skip_serializing_if = "Option::is_none")] pub min_chunk_size: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub confirmation_threshold: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub custom_vocabulary: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub vocabulary_boost: Option, diff --git a/crates/cactus/src/stt/transcriber.rs b/crates/cactus/src/stt/transcriber.rs index 6632b84a24..14d7454094 100644 --- a/crates/cactus/src/stt/transcriber.rs +++ b/crates/cactus/src/stt/transcriber.rs @@ -55,6 +55,16 @@ pub struct Transcriber<'a> { // SAFETY: FFI calls are serialized through Model's inference_lock. unsafe impl Send for Transcriber<'_> {} +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct StreamSegment { + #[serde(default)] + pub start: f32, + #[serde(default)] + pub end: f32, + #[serde(default)] + pub text: String, +} + #[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] pub struct StreamResult { #[serde(default)] @@ -62,6 +72,8 @@ pub struct StreamResult { #[serde(default)] pub pending: String, #[serde(default)] + pub segments: Vec, + #[serde(default)] pub language: Option, #[serde(default)] pub cloud_handoff: bool, diff --git a/crates/cactus/tests/stt.rs b/crates/cactus/tests/stt.rs index 1fcaa1150f..51d980b4be 100644 --- a/crates/cactus/tests/stt.rs +++ b/crates/cactus/tests/stt.rs @@ -104,7 +104,7 @@ fn test_stream_transcriber() { let chunk_size = 32000; // 1 second at 16kHz 16-bit mono let mut had_confirmed = false; - for chunk in pcm.chunks(chunk_size).take(10) { + for chunk in pcm.chunks(chunk_size).take(30) { let r = transcriber.process(chunk).unwrap(); if !r.confirmed.is_empty() { had_confirmed = true; @@ -118,6 +118,89 @@ fn test_stream_transcriber() { assert!(had_confirmed, "expected at least one confirmed segment"); } +// cargo test -p cactus --test stt test_stream_transcriber_segments -- --ignored --nocapture +#[ignore] +#[test] +fn test_stream_transcriber_segments() { + let model = stt_model(); + let pcm = hypr_data::english_1::AUDIO; + let options = en_options(); + + let mut transcriber = Transcriber::new(&model, &options, CloudConfig::default()).unwrap(); + + let mut first_segmented = None; + let mut first_confirmed_with_segments = None; + + for chunk in pcm.chunks(32000).take(30) { + let r = transcriber.process(chunk).unwrap(); + + if first_segmented.is_none() && !r.segments.is_empty() { + first_segmented = Some(r.clone()); + } + + if first_confirmed_with_segments.is_none() + && !r.confirmed.trim().is_empty() + && !r.segments.is_empty() + { + first_confirmed_with_segments = Some(r.clone()); + break; + } + } + + let segmented = first_segmented.expect("expected at least one streaming result with segments"); + for segment in &segmented.segments { + assert!( + segment.end >= segment.start, + "segment end ({}) should be >= start ({})", + segment.end, + segment.start + ); + assert!(!segment.text.is_empty(), "segment text should not be empty"); + } + + let confirmed = + first_confirmed_with_segments.expect("expected a confirmed streaming result with segments"); + assert!( + !confirmed.confirmed.trim().is_empty(), + "confirmed text should not be empty" + ); + assert!( + confirmed.buffer_duration_ms > 0.0, + "buffer_duration_ms should be positive, got {}", + confirmed.buffer_duration_ms + ); + + let first_segment = confirmed.segments.first().unwrap(); + assert!( + first_segment.start >= 0.0, + "first segment start should be non-negative, got {}", + first_segment.start + ); + + let last_segment = confirmed.segments.last().unwrap(); + let last_segment_end = last_segment.end as f64; + assert!( + last_segment_end > 0.0, + "expected positive segment end timestamp, got {last_segment_end}" + ); + assert!( + (last_segment_end - confirmed.buffer_duration_ms / 1000.0).abs() < 0.25, + "last segment end ({last_segment_end}) should roughly match buffer duration ({}s)", + confirmed.buffer_duration_ms / 1000.0 + ); + + for pair in confirmed.segments.windows(2) { + assert!( + pair[1].start >= pair[0].start, + "segments should be ordered: {} came after {}", + pair[0].start, + pair[1].start + ); + } + + let _ = transcriber.stop().unwrap(); +} + // cargo test -p cactus --test stt test_stream_transcriber_drop -- --ignored --nocapture #[ignore] #[test] diff --git a/crates/transcribe-cactus/Cargo.toml b/crates/transcribe-cactus/Cargo.toml index 6253c58fca..baf0969cee 100644 --- a/crates/transcribe-cactus/Cargo.toml +++ b/crates/transcribe-cactus/Cargo.toml @@ -3,6 +3,9 @@ name = "transcribe-cactus" version = "0.1.0" edition = "2024" +[features] +live-example = [] + [dependencies] hypr-audio-utils = { workspace = true } hypr-cactus = { workspace = true } @@ -13,6 +16,7 @@ owhisper-interface = { workspace = true } rodio = { workspace = true } bytes = { workspace = true } +colored = "3" serde_html_form = { workspace = true } serde_json = { workspace = true } tempfile = { workspace = true } @@ -26,15 +30,27 @@ tower = { workspace = true } tracing = { workspace = true } [dev-dependencies] -dirs = { workspace = true } +hypr-audio = { workspace = true } hypr-audio-utils = { workspace = true } hypr-cactus = { workspace = true } hypr-data = { workspace = true } +owhisper-client = { workspace = true } +owhisper-interface = { workspace = true } + axum = { workspace = true, features = ["ws"] } futures-util = { workspace = true } reqwest = { workspace = true, features = ["json"] } -sequential-test = "0.2" -serde_json = { workspace = true } tokio = { workspace = true } tokio-tungstenite = { workspace = true } + +clap = { workspace = true, features = ["derive"] } +colored = "3" +dirs = { workspace = true } + +sequential-test = "0.2" +serde_json = { workspace = true } + +[[example]] +name = "live" +required-features = ["live-example"] diff --git a/crates/transcribe-cactus/examples/live.rs b/crates/transcribe-cactus/examples/live.rs new file mode 100644 index 0000000000..c818d289d2 --- /dev/null +++ b/crates/transcribe-cactus/examples/live.rs @@ -0,0 +1,238 @@ +use std::path::PathBuf; +use std::time::{Duration, Instant}; + +use axum::Router; +use axum::error_handling::HandleError; +use axum::http::StatusCode; +use clap::{Parser, ValueEnum}; +use colored::Colorize; +use futures_util::StreamExt; +use owhisper_client::{CactusAdapter, FinalizeHandle, ListenClient}; +use owhisper_interface::MixedMessage; +use owhisper_interface::stream::StreamResponse; + +use hypr_audio::AudioInput; +use hypr_audio_utils::{AudioFormatExt, chunk_size_for_stt}; +use transcribe_cactus::TranscribeService; + +#[derive(Clone, ValueEnum)] +enum AudioSource { + Input, + Output, +} + +#[derive(Parser)] +struct Args { + #[arg(long, default_value = "input")] + audio: AudioSource, + + #[arg(long)] + model: PathBuf, +} + +fn fmt_ts(secs: f64) -> String { + let m = (secs / 60.0) as u32; + let s = secs % 60.0; + format!("{:02}:{:04.1}", m, s) +} + +struct Segment { + time: f64, + text: String, +} + +struct Transcript { + segments: Vec, + partial: String, + t0: Instant, +} + +impl Transcript { + fn new(t0: Instant) -> Self { + Self { + segments: Vec::new(), + partial: String::new(), + t0, + } + } + + fn elapsed(&self) -> f64 { + self.t0.elapsed().as_secs_f64() + } + + fn set_partial(&mut self, text: &str) { + self.partial = text.to_string(); + self.render(); + } + + fn confirm(&mut self, text: &str) { + self.segments.push(Segment { + time: self.elapsed(), + text: text.to_string(), + }); + self.partial.clear(); + self.trim(); + self.render(); + } + + fn trim(&mut self) { + let total_len: usize = self.segments.iter().map(|s| s.text.len() + 1).sum(); + if total_len > 180 { + let drain_count = self.segments.len() * 2 / 3; + if drain_count > 0 { + self.segments.drain(..drain_count); + } + } + } + + fn render(&self) { + let confirmed: String = self + .segments + .iter() + .map(|s| s.text.as_str()) + .collect::>() + .join(" "); + + if confirmed.is_empty() && self.partial.is_empty() { + return; + } + + let from = self.segments.first().map(|s| s.time).unwrap_or(0.0); + let to = self.elapsed(); + let prefix = format!("[{} / {}]", fmt_ts(from), fmt_ts(to)).dimmed(); + + if self.partial.is_empty() { + eprintln!("{} {}", prefix, confirmed.bold().white()); + } else { + eprintln!( + "{} {} {}", + prefix, + confirmed.bold().white(), + self.partial.dimmed() + ); + } + } +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + assert!( + args.model.exists(), + "model not found: {}", + args.model.display() + ); + + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder().model_path(args.model).build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let mut audio_input = match args.audio { + AudioSource::Output => AudioInput::from_speaker(), + AudioSource::Input => AudioInput::from_mic(None).expect("failed to open mic"), + }; + + let sample_rate: u32 = 16000; + let chunk_size = chunk_size_for_stt(sample_rate); + let source_name = match args.audio { + AudioSource::Input => "input", + AudioSource::Output => "output", + }; + + eprintln!("source: {} ({})", source_name, audio_input.device_name()); + eprintln!( + "sample rate: {} Hz -> {} Hz, chunk size: {} samples", + audio_input.sample_rate(), + sample_rate, + chunk_size + ); + eprintln!("(set CACTUS_DEBUG=1 for raw engine output)"); + eprintln!(); + + let api_base = format!("http://{}/v1", addr); + let client = ListenClient::builder() + .adapter::() + .api_base(&api_base) + .params(owhisper_interface::ListenParams { + sample_rate, + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_single() + .await; + + let stream = audio_input.stream(); + let audio_stream = stream + .to_i16_le_chunks(sample_rate, chunk_size) + .map(|bytes| MixedMessage::Audio(bytes)); + + let (response_stream, handle) = client + .from_realtime_audio(Box::pin(audio_stream)) + .await + .expect("failed to connect"); + futures_util::pin_mut!(response_stream); + + let mut transcript = Transcript::new(Instant::now()); + let mut last_confirmed: Option = None; + + let read_loop = async { + while let Some(result) = response_stream.next().await { + match result { + Ok(StreamResponse::TranscriptResponse { + is_final, channel, .. + }) => { + let text = channel + .alternatives + .first() + .map(|a| a.transcript.as_str()) + .unwrap_or(""); + + if is_final { + if last_confirmed.as_deref() == Some(text) { + continue; + } + last_confirmed = Some(text.to_string()); + transcript.confirm(text); + } else { + transcript.set_partial(text); + } + } + Ok(StreamResponse::TerminalResponse { .. }) => break, + Ok(StreamResponse::ErrorResponse { error_message, .. }) => { + eprintln!("\nerror: {}", error_message); + break; + } + Ok(_) => {} + Err(e) => { + eprintln!("\nws error: {:?}", e); + break; + } + } + } + }; + + let _ = tokio::time::timeout(Duration::from_secs(600), read_loop).await; + handle.finalize().await; + + eprintln!(); + + let _ = shutdown_tx.send(()); +} diff --git a/crates/transcribe-cactus/src/service/streaming/debug.rs b/crates/transcribe-cactus/src/service/streaming/debug.rs new file mode 100644 index 0000000000..919234a475 --- /dev/null +++ b/crates/transcribe-cactus/src/service/streaming/debug.rs @@ -0,0 +1,88 @@ +use colored::Colorize; + +fn enabled() -> bool { + static ENABLED: std::sync::OnceLock = std::sync::OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var("CACTUS_DEBUG").as_deref() == Ok("1")) +} + +pub(super) enum Kind { + Partial, + Confirmed, + Cloud, +} + +struct Event<'a> { + ch: usize, + audio_offset: f64, + kind: Kind, + text: &'a str, + seg_start: f64, + seg_dur: f64, + confidence: f64, + decode_tps: f64, + buffer_duration_ms: f64, +} + +impl std::fmt::Display for Event<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let header = format!("[{:>7.2}s ch{}]", self.audio_offset, self.ch).dimmed(); + + let label = match self.kind { + Kind::Partial => " partial".yellow(), + Kind::Confirmed => " CONFIRM".green().bold(), + Kind::Cloud => " cloud".cyan(), + }; + + let text = truncate(self.text, 60); + + let timing = format!( + "seg:{:.2}\u{2192}{:.2}s conf:{:.2} dec:{:.0}tps buf:{:.0}ms", + self.seg_start, + self.seg_start + self.seg_dur, + self.confidence, + self.decode_tps, + self.buffer_duration_ms, + ) + .dimmed(); + + write!(f, "{header} {label} \"{text}\" {timing}") + } +} + +pub(super) fn log( + ch: usize, + audio_offset: f64, + kind: Kind, + text: &str, + seg_start: f64, + seg_dur: f64, + confidence: f64, + result: &hypr_cactus::StreamResult, +) { + if !enabled() { + return; + } + let event = Event { + ch, + audio_offset, + kind, + text, + seg_start, + seg_dur, + confidence, + decode_tps: result.decode_tps, + buffer_duration_ms: result.buffer_duration_ms, + }; + eprintln!("{event}"); +} + +fn truncate(text: &str, max_chars: usize) -> String { + if text.chars().count() <= max_chars { + return text.to_string(); + } + let end = text + .char_indices() + .nth(max_chars - 1) + .map_or(text.len(), |(i, _)| i); + format!("{}\u{2026}", &text[..end]) +} diff --git a/crates/transcribe-cactus/src/service/streaming/mod.rs b/crates/transcribe-cactus/src/service/streaming/mod.rs index 300432d3d9..61df7b5621 100644 --- a/crates/transcribe-cactus/src/service/streaming/mod.rs +++ b/crates/transcribe-cactus/src/service/streaming/mod.rs @@ -1,3 +1,4 @@ +mod debug; mod message; pub(crate) mod response; mod service; diff --git a/crates/transcribe-cactus/src/service/streaming/session.rs b/crates/transcribe-cactus/src/service/streaming/session.rs index 6f551850c6..809a225228 100644 --- a/crates/transcribe-cactus/src/service/streaming/session.rs +++ b/crates/transcribe-cactus/src/service/streaming/session.rs @@ -8,6 +8,7 @@ use owhisper_interface::{ControlMessage, ListenParams}; use hypr_ws_utils::ConnectionGuard; +use super::debug; use super::message::{AudioExtract, IncomingMessage, process_incoming_message}; use super::response::{ WsSender, build_transcript_response, format_timestamp_now, send_ws, send_ws_best_effort, @@ -15,6 +16,14 @@ use super::response::{ pub(super) const SAMPLE_RATE: u32 = 16_000; +macro_rules! try_send { + ($ws:expr, $msg:expr) => { + if !send_ws($ws, $msg).await { + return LoopAction::Break(SessionExit::TransportClosed); + } + }; +} + #[derive(Default)] struct ChannelState { last_confirmed_sent: String, @@ -180,205 +189,211 @@ async fn handle_transcribe_event( result, chunk_duration_secs, }) => { - let channel_index = vec![ch_idx as i32, total_channels as i32]; - let channel_u8 = vec![ch_idx as u8]; - let state = &mut channel_states[ch_idx]; + process_result( + ws_sender, + ch_idx, + result, + chunk_duration_secs, + channel_states, + total_channels, + metadata, + ) + .await + } + } +} - state.audio_offset += chunk_duration_secs; +async fn process_result( + ws_sender: &mut WsSender, + ch_idx: usize, + result: hypr_cactus::StreamResult, + chunk_duration_secs: f64, + channel_states: &mut [ChannelState], + total_channels: usize, + metadata: &Metadata, +) -> LoopAction { + let channel_index = vec![ch_idx as i32, total_channels as i32]; + let channel_u8 = vec![ch_idx as u8]; + let state = &mut channel_states[ch_idx]; - let seg_dur = if result.buffer_duration_ms > 0.0 { - result.buffer_duration_ms / 1000.0 - } else { - state.audio_offset - state.segment_start - }; - let seg_start = (state.audio_offset - seg_dur).max(state.segment_start); + state.audio_offset += chunk_duration_secs; - let confidence = result.confidence as f64; - let confirmed_text = result.confirmed.trim(); + let (seg_start, seg_dur) = + segment_timing_from_result(&result, state.audio_offset, state.segment_start); + let confidence = result.confidence as f64; + let confirmed_text = result.confirmed.trim(); + let metrics = stream_result_metrics(&result); - let metrics = stream_result_metrics(&result); + state.pending_text = result.pending.clone(); + state.pending_language = result.language.clone(); + state.pending_confidence = confidence; - state.pending_text = result.pending.clone(); - state.pending_language = result.language.clone(); - state.pending_confidence = confidence; + if result.cloud_handoff && result.cloud_job_id != 0 { + state.pending_cloud_job_id = result.cloud_job_id; + state.cloud_handoff_segment_start = state.segment_start; + } - if result.cloud_handoff && result.cloud_job_id != 0 { - state.pending_cloud_job_id = result.cloud_job_id; - state.cloud_handoff_segment_start = state.segment_start; - } + if result.cloud_result_job_id != 0 && !result.cloud_result.is_empty() { + let cloud_text = result.cloud_result.trim(); + let job_id = result.cloud_result_job_id; + let cloud_seg_start = state.cloud_handoff_segment_start; + let cloud_seg_dur = state.audio_offset - cloud_seg_start; + + debug::log( + ch_idx, + state.audio_offset, + debug::Kind::Cloud, + cloud_text, + cloud_seg_start, + cloud_seg_dur, + confidence, + &result, + ); - if result.cloud_result_job_id != 0 && !result.cloud_result.is_empty() { - let cloud_text = result.cloud_result.trim(); - let job_id = result.cloud_result_job_id; - let seg_start = state.cloud_handoff_segment_start; - let seg_duration = state.audio_offset - seg_start; - let mut keys = metrics.clone(); - keys.insert("cloud_corrected".to_string(), serde_json::Value::Bool(true)); - keys.insert( - "cloud_job_id".to_string(), - serde_json::Value::Number(job_id.into()), - ); - tracing::info!( - hyprnote.transcript.char_count = cloud_text.chars().count() as u64, - hyprnote.stt.job.id = job_id, - hyprnote.audio.channel_index = ch_idx, - "cactus_cloud_correction" - ); - if !send_ws( - ws_sender, - &build_transcript_response( - cloud_text, - seg_start, - seg_duration, - confidence, - result.language.as_deref(), - true, - true, - false, - metadata, - &channel_index, - Some(keys), - ), - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); - } - state.pending_cloud_job_id = 0; - } + let mut keys = metrics.clone(); + keys.insert("cloud_corrected".to_string(), serde_json::Value::Bool(true)); + keys.insert( + "cloud_job_id".to_string(), + serde_json::Value::Number(job_id.into()), + ); - if !confirmed_text.is_empty() && confirmed_text != state.last_confirmed_sent { - if !state.speech_started { - if !send_ws( - ws_sender, - &StreamResponse::SpeechStartedResponse { - channel: channel_u8.clone(), - timestamp: seg_start, - }, - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); - } - } + tracing::info!( + hyprnote.transcript.char_count = cloud_text.chars().count() as u64, + hyprnote.stt.job.id = job_id, + hyprnote.audio.channel_index = ch_idx, + "cactus_cloud_correction" + ); - let handoff_extra = { - let mut keys = metrics.clone(); - if result.cloud_handoff && result.cloud_job_id != 0 { - keys.insert("cloud_handoff".to_string(), serde_json::Value::Bool(true)); - keys.insert( - "cloud_job_id".to_string(), - serde_json::Value::Number(result.cloud_job_id.into()), - ); - } - Some(keys) - }; - - tracing::info!( - hyprnote.transcript.char_count = confirmed_text.chars().count() as u64, - hyprnote.audio.channel_index = ch_idx, - "cactus_confirmed_text" - ); - if !send_ws( - ws_sender, - &build_transcript_response( - confirmed_text, - seg_start, - seg_dur, - confidence, - result.language.as_deref(), - true, - true, - false, - metadata, - &channel_index, - handoff_extra, - ), - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); - } - if !send_ws( - ws_sender, - &StreamResponse::UtteranceEndResponse { - channel: channel_u8, - last_word_end: state.audio_offset, - }, - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); + try_send!( + ws_sender, + &build_transcript_response( + cloud_text, + cloud_seg_start, + cloud_seg_dur, + confidence, + result.language.as_deref(), + true, + true, + false, + metadata, + &channel_index, + Some(keys), + ) + ); + state.pending_cloud_job_id = 0; + } + + if !confirmed_text.is_empty() && confirmed_text != state.last_confirmed_sent { + debug::log( + ch_idx, + state.audio_offset, + debug::Kind::Confirmed, + confirmed_text, + seg_start, + seg_dur, + confidence, + &result, + ); + + if !state.speech_started { + try_send!( + ws_sender, + &StreamResponse::SpeechStartedResponse { + channel: channel_u8.clone(), + timestamp: seg_start, } + ); + } - state.last_confirmed_sent.clear(); - state.last_confirmed_sent.push_str(confirmed_text); - state.last_pending_sent.clear(); - state.segment_start = state.audio_offset; - state.speech_started = false; - return LoopAction::Continue; - } + tracing::info!( + hyprnote.transcript.char_count = confirmed_text.chars().count() as u64, + hyprnote.audio.channel_index = ch_idx, + "cactus_confirmed_text" + ); - let pending_text = result.pending.trim(); - if pending_text.is_empty() - || pending_text == state.last_pending_sent - || pending_text == state.last_confirmed_sent - { - return LoopAction::Continue; - } + try_send!( + ws_sender, + &build_transcript_response( + confirmed_text, + seg_start, + seg_dur, + confidence, + result.language.as_deref(), + true, + true, + false, + metadata, + &channel_index, + build_extra_keys(&metrics, &result), + ) + ); - if !state.speech_started { - state.speech_started = true; - if !send_ws( - ws_sender, - &StreamResponse::SpeechStartedResponse { - channel: channel_u8.clone(), - timestamp: seg_start, - }, - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); - } + try_send!( + ws_sender, + &StreamResponse::UtteranceEndResponse { + channel: channel_u8, + last_word_end: state.audio_offset, } + ); - let pending_handoff_extra = { - let mut keys = metrics; - if result.cloud_handoff && result.cloud_job_id != 0 { - keys.insert("cloud_handoff".to_string(), serde_json::Value::Bool(true)); - keys.insert( - "cloud_job_id".to_string(), - serde_json::Value::Number(result.cloud_job_id.into()), - ); - } - Some(keys) - }; + state.last_confirmed_sent.clear(); + state.last_confirmed_sent.push_str(confirmed_text); + state.last_pending_sent.clear(); + state.segment_start = state.audio_offset; + state.speech_started = false; + return LoopAction::Continue; + } - if !send_ws( - ws_sender, - &build_transcript_response( - pending_text, - seg_start, - seg_dur, - confidence, - result.language.as_deref(), - false, - false, - false, - metadata, - &channel_index, - pending_handoff_extra, - ), - ) - .await - { - return LoopAction::Break(SessionExit::TransportClosed); + let pending_text = result.pending.trim(); + if pending_text.is_empty() + || pending_text == state.last_pending_sent + || pending_text == state.last_confirmed_sent + { + return LoopAction::Continue; + } + + debug::log( + ch_idx, + state.audio_offset, + debug::Kind::Partial, + pending_text, + seg_start, + seg_dur, + confidence, + &result, + ); + + if !state.speech_started { + state.speech_started = true; + try_send!( + ws_sender, + &StreamResponse::SpeechStartedResponse { + channel: channel_u8, + timestamp: seg_start, } - state.last_pending_sent.clear(); - state.last_pending_sent.push_str(pending_text); - LoopAction::Continue - } + ); } + + try_send!( + ws_sender, + &build_transcript_response( + pending_text, + seg_start, + seg_dur, + confidence, + result.language.as_deref(), + false, + false, + false, + metadata, + &channel_index, + build_extra_keys(&metrics, &result), + ) + ); + + state.last_pending_sent.clear(); + state.last_pending_sent.push_str(pending_text); + LoopAction::Continue } async fn handle_ws_message( @@ -486,43 +501,58 @@ async fn handle_ws_message( LoopAction::Continue } +fn segment_timing_from_result( + result: &hypr_cactus::StreamResult, + audio_offset: f64, + segment_start: f64, +) -> (f64, f64) { + if let (Some(first), Some(last)) = (result.segments.first(), result.segments.last()) { + let start = first.start as f64; + let end = last.end as f64; + if end > start { + return (start, end - start); + } + } + (segment_start, audio_offset - segment_start) +} + fn stream_result_metrics( result: &hypr_cactus::StreamResult, ) -> std::collections::HashMap { - let mut m = std::collections::HashMap::new(); - m.insert( - "decode_tps".to_string(), - serde_json::json!(result.decode_tps), - ); - m.insert( - "prefill_tps".to_string(), - serde_json::json!(result.prefill_tps), - ); - m.insert( - "time_to_first_token_ms".to_string(), - serde_json::json!(result.time_to_first_token_ms), - ); - m.insert( - "total_time_ms".to_string(), - serde_json::json!(result.total_time_ms), - ); - m.insert( - "decode_tokens".to_string(), - serde_json::json!(result.decode_tokens), - ); - m.insert( - "prefill_tokens".to_string(), - serde_json::json!(result.prefill_tokens), - ); - m.insert( - "total_tokens".to_string(), - serde_json::json!(result.total_tokens), - ); - m.insert( - "buffer_duration_ms".to_string(), - serde_json::json!(result.buffer_duration_ms), - ); - m + [ + ("decode_tps", serde_json::json!(result.decode_tps)), + ("prefill_tps", serde_json::json!(result.prefill_tps)), + ( + "time_to_first_token_ms", + serde_json::json!(result.time_to_first_token_ms), + ), + ("total_time_ms", serde_json::json!(result.total_time_ms)), + ("decode_tokens", serde_json::json!(result.decode_tokens)), + ("prefill_tokens", serde_json::json!(result.prefill_tokens)), + ("total_tokens", serde_json::json!(result.total_tokens)), + ( + "buffer_duration_ms", + serde_json::json!(result.buffer_duration_ms), + ), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect() +} + +fn build_extra_keys( + metrics: &std::collections::HashMap, + result: &hypr_cactus::StreamResult, +) -> Option> { + let mut keys = metrics.clone(); + if result.cloud_handoff && result.cloud_job_id != 0 { + keys.insert("cloud_handoff".to_string(), serde_json::Value::Bool(true)); + keys.insert( + "cloud_job_id".to_string(), + serde_json::Value::Number(result.cloud_job_id.into()), + ); + } + Some(keys) } async fn handle_finalize( diff --git a/crates/transcribe-cli/Cargo.toml b/crates/transcribe-cli/Cargo.toml new file mode 100644 index 0000000000..2c0de78879 --- /dev/null +++ b/crates/transcribe-cli/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "transcribe-cli" +version = "0.1.0" +edition = "2024" + +[dependencies] +hypr-audio = { workspace = true } +hypr-audio-utils = { workspace = true } +hypr-language = { workspace = true } +owhisper-client = { workspace = true } +owhisper-interface = { workspace = true } + +axum = { workspace = true, features = ["ws"] } +bytes = { workspace = true } +colored = "3" +futures-util = { workspace = true } +tokio = { workspace = true } + +clap = { workspace = true, features = ["derive"] } +crossterm = { workspace = true } + +[dev-dependencies] +hypr-api-env = { workspace = true } +hypr-transcribe-cactus = { workspace = true } +hypr-transcribe-proxy = { workspace = true } diff --git a/crates/transcribe-cli/README.md b/crates/transcribe-cli/README.md new file mode 100644 index 0000000000..0b0c9f5484 --- /dev/null +++ b/crates/transcribe-cli/README.md @@ -0,0 +1,41 @@ +Direct Deepgram: + +```bash +DEEPGRAM_API_KEY=... cargo run -p transcribe-cli --example deepgram -- --audio input +``` + +```bash +DEEPGRAM_API_KEY=... cargo run -p transcribe-cli --example deepgram -- --audio aec-dual +``` + +Direct Soniox: + +```bash +SONIOX_API_KEY=... cargo run -p transcribe-cli --example soniox -- --audio input +``` + +```bash +SONIOX_API_KEY=... cargo run -p transcribe-cli --example soniox -- --audio raw-dual +``` + +Local Cactus: + +```bash +cargo run -p transcribe-cli --example cactus -- --model /path/to/model.bin --audio input +``` + +```bash +cargo run -p transcribe-cli --example cactus -- --model /path/to/model.bin --audio aec-dual +``` + +Proxy testing: + +```bash +DEEPGRAM_API_KEY=... cargo run -p transcribe-cli --example hyprnote -- --provider deepgram --audio input +``` + +```bash +DEEPGRAM_API_KEY=... SONIOX_API_KEY=... cargo run -p transcribe-cli --example hyprnote -- --provider hyprnote --audio input +``` + +Use `--audio output` to transcribe speaker output, `--audio raw-dual` for raw mic + speaker, and `--audio aec-dual` for AEC mic + speaker. diff --git a/crates/transcribe-cli/examples/cactus.rs b/crates/transcribe-cli/examples/cactus.rs new file mode 100644 index 0000000000..1403aee0ce --- /dev/null +++ b/crates/transcribe-cli/examples/cactus.rs @@ -0,0 +1,71 @@ +use std::path::PathBuf; + +use axum::Router; +use axum::error_handling::HandleError; +use axum::http::StatusCode; +use clap::Parser; +use hypr_transcribe_cactus::TranscribeService; +use transcribe_cli::{ + AudioArgs, DEFAULT_SAMPLE_RATE, DEFAULT_TIMEOUT_SECS, build_dual_client, build_single_client, + default_listen_params, run_dual_client, run_single_client, spawn_router, +}; + +#[derive(Parser)] +struct Args { + #[command(flatten)] + audio: AudioArgs, + + #[arg(long)] + model: PathBuf, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + assert!( + args.model.exists(), + "model not found: {}", + args.model.display() + ); + + let app = Router::new().route_service( + "/v1/listen", + HandleError::new( + TranscribeService::builder().model_path(args.model).build(), + |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ), + ); + let server = spawn_router(app).await; + if args.audio.audio.is_dual() { + let client = build_dual_client::( + server.api_base("/v1"), + None, + default_listen_params(), + ) + .await; + + run_dual_client( + args.audio.audio, + client, + DEFAULT_SAMPLE_RATE, + DEFAULT_TIMEOUT_SECS, + ) + .await; + } else { + let client = build_single_client::( + server.api_base("/v1"), + None, + default_listen_params(), + ) + .await; + + run_single_client( + args.audio.audio, + client, + DEFAULT_SAMPLE_RATE, + DEFAULT_TIMEOUT_SECS, + ) + .await; + } +} diff --git a/crates/transcribe-cli/examples/deepgram.rs b/crates/transcribe-cli/examples/deepgram.rs new file mode 100644 index 0000000000..1463184b1a --- /dev/null +++ b/crates/transcribe-cli/examples/deepgram.rs @@ -0,0 +1,10 @@ +transcribe_cli::simple_provider_example! { + adapter: owhisper_client::DeepgramAdapter, + api_base: "https://api.deepgram.com/v1", + api_key_env: "DEEPGRAM_API_KEY", + params: { + let mut params = transcribe_cli::default_listen_params(); + params.model = Some("nova-3".to_string()); + params + }, +} diff --git a/crates/transcribe-cli/examples/hyprnote.rs b/crates/transcribe-cli/examples/hyprnote.rs new file mode 100644 index 0000000000..8ef0f24381 --- /dev/null +++ b/crates/transcribe-cli/examples/hyprnote.rs @@ -0,0 +1,109 @@ +use clap::{Parser, ValueEnum}; +use hypr_transcribe_proxy::{HyprnoteRoutingConfig, SttProxyConfig}; +use transcribe_cli::{ + AudioArgs, DEFAULT_SAMPLE_RATE, DEFAULT_TIMEOUT_SECS, build_dual_client, build_single_client, + default_listen_params, run_dual_client, run_single_client, spawn_router, +}; + +#[derive(Clone, ValueEnum)] +enum ProviderArg { + Hyprnote, + Deepgram, + Soniox, +} + +#[derive(Parser)] +struct Args { + #[command(flatten)] + audio: AudioArgs, + + #[arg(long)] + provider: ProviderArg, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + let mut env = hypr_transcribe_proxy::Env::default(); + let provider_name = match args.provider { + ProviderArg::Hyprnote => { + env.stt.deepgram_api_key = std::env::var("DEEPGRAM_API_KEY").ok(); + env.stt.soniox_api_key = std::env::var("SONIOX_API_KEY").ok(); + "hyprnote" + } + ProviderArg::Deepgram => { + env.stt.deepgram_api_key = + Some(std::env::var("DEEPGRAM_API_KEY").expect("DEEPGRAM_API_KEY not set")); + "deepgram" + } + ProviderArg::Soniox => { + env.stt.soniox_api_key = + Some(std::env::var("SONIOX_API_KEY").expect("SONIOX_API_KEY not set")); + "soniox" + } + }; + + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_hyprnote_routing(HyprnoteRoutingConfig::default()); + let app = hypr_transcribe_proxy::router(config); + let server = spawn_router(app).await; + + eprintln!("proxy: {} -> {}", server.addr(), provider_name); + eprintln!(); + + match args.provider { + ProviderArg::Hyprnote => { + run_with_adapter::( + &args.audio.audio, + server.api_base(""), + ) + .await; + } + ProviderArg::Deepgram => { + run_with_adapter::( + &args.audio.audio, + server.api_base(""), + ) + .await; + } + ProviderArg::Soniox => { + run_with_adapter::( + &args.audio.audio, + server.api_base(""), + ) + .await; + } + } +} + +async fn run_with_adapter( + source: &transcribe_cli::AudioSource, + api_base: String, +) { + if source.is_dual() { + let client = build_dual_client::(api_base, None, default_listen_params()).await; + run_dual_client( + source.clone(), + client, + DEFAULT_SAMPLE_RATE, + DEFAULT_TIMEOUT_SECS, + ) + .await; + } else { + let client = build_single_client::(api_base, None, default_listen_params()).await; + run_single_client( + source.clone(), + client, + DEFAULT_SAMPLE_RATE, + DEFAULT_TIMEOUT_SECS, + ) + .await; + } +} diff --git a/crates/transcribe-cli/examples/soniox.rs b/crates/transcribe-cli/examples/soniox.rs new file mode 100644 index 0000000000..c5810f5b72 --- /dev/null +++ b/crates/transcribe-cli/examples/soniox.rs @@ -0,0 +1,6 @@ +transcribe_cli::simple_provider_example! { + adapter: owhisper_client::SonioxAdapter, + api_base: "https://api.soniox.com", + api_key_env: "SONIOX_API_KEY", + params: transcribe_cli::default_listen_params(), +} diff --git a/crates/transcribe-cli/src/lib.rs b/crates/transcribe-cli/src/lib.rs new file mode 100644 index 0000000000..8ab7e3bf2b --- /dev/null +++ b/crates/transcribe-cli/src/lib.rs @@ -0,0 +1,565 @@ +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +use axum::Router; +use clap::{Args, ValueEnum}; +use colored::Colorize; +use futures_util::StreamExt; +use owhisper_client::{FinalizeHandle, ListenClient, ListenClientDual, RealtimeSttAdapter}; +use owhisper_interface::MixedMessage; +use owhisper_interface::stream::StreamResponse; + +use hypr_audio::{AudioInput, CaptureConfig, CaptureFrame}; +use hypr_audio_utils::{AudioFormatExt, chunk_size_for_stt, f32_to_i16_bytes}; + +pub const DEFAULT_SAMPLE_RATE: u32 = 16_000; +pub const DEFAULT_TIMEOUT_SECS: u64 = 600; + +#[derive(Clone, Copy)] +pub enum ChannelKind { + Mic, + Speaker, +} + +pub enum DisplayMode { + Single(ChannelKind), + Dual, +} + +#[derive(Clone, ValueEnum)] +pub enum AudioSource { + Input, + Output, + RawDual, + AecDual, +} + +impl AudioSource { + pub fn is_dual(&self) -> bool { + matches!(self, Self::RawDual | Self::AecDual) + } + + fn uses_aec(&self) -> bool { + matches!(self, Self::AecDual) + } +} + +#[derive(Args)] +pub struct AudioArgs { + #[arg(long, default_value = "input")] + pub audio: AudioSource, +} + +pub fn open_audio(source: &AudioSource) -> AudioInput { + match source { + AudioSource::Output => AudioInput::from_speaker(), + AudioSource::Input => AudioInput::from_mic(None).expect("failed to open mic"), + AudioSource::RawDual | AudioSource::AecDual => { + panic!("dual audio modes use the realtime capture pipeline") + } + } +} + +pub fn create_audio_stream( + audio_input: &mut AudioInput, + sample_rate: u32, +) -> std::pin::Pin< + Box< + dyn futures_util::Stream< + Item = MixedMessage, + > + Send, + >, +> { + let chunk_size = chunk_size_for_stt(sample_rate); + let stream = audio_input.stream(); + Box::pin( + stream + .to_i16_le_chunks(sample_rate, chunk_size) + .map(MixedMessage::Audio), + ) +} + +pub fn create_dual_audio_stream( + source: &AudioSource, + sample_rate: u32, +) -> std::pin::Pin< + Box< + dyn futures_util::Stream< + Item = MixedMessage< + (bytes::Bytes, bytes::Bytes), + owhisper_interface::ControlMessage, + >, + > + Send, + >, +> { + let chunk_size = chunk_size_for_stt(sample_rate); + let capture_stream = AudioInput::from_mic_and_speaker(CaptureConfig { + sample_rate, + chunk_size, + mic_device: None, + enable_aec: source.uses_aec(), + }) + .expect("failed to open realtime capture"); + let source = source.clone(); + + Box::pin(capture_stream.map(move |result| { + let frame = result.unwrap_or_else(|error| panic!("capture failed: {error}")); + MixedMessage::Audio(capture_frame_to_bytes(&source, frame)) + })) +} + +pub fn print_audio_info(audio_input: &AudioInput, source: &AudioSource, sample_rate: u32) { + let source_name = match source { + AudioSource::Input => "input", + AudioSource::Output => "output", + AudioSource::RawDual | AudioSource::AecDual => unreachable!(), + }; + let chunk_size = chunk_size_for_stt(sample_rate); + + eprintln!("source: {} ({})", source_name, audio_input.device_name()); + eprintln!( + "sample rate: {} Hz -> {} Hz, chunk size: {} samples", + audio_input.sample_rate(), + sample_rate, + chunk_size + ); + eprintln!(); +} + +pub fn print_dual_audio_info(source: &AudioSource, sample_rate: u32) { + let chunk_size = chunk_size_for_stt(sample_rate); + let source_name = match source { + AudioSource::RawDual => "raw-dual", + AudioSource::AecDual => "aec-dual", + AudioSource::Input | AudioSource::Output => unreachable!(), + }; + + eprintln!( + "source: {} (input: {}, output: RealtimeSpeaker)", + source_name, + AudioInput::get_default_device_name() + ); + eprintln!( + "sample rate: {} Hz, chunk size: {} samples, AEC: {}", + sample_rate, + chunk_size, + if source.uses_aec() { + "enabled" + } else { + "disabled" + } + ); + eprintln!(); +} + +pub fn default_listen_params() -> owhisper_interface::ListenParams { + owhisper_interface::ListenParams { + sample_rate: DEFAULT_SAMPLE_RATE, + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + } +} + +pub async fn build_single_client( + api_base: impl Into, + api_key: Option, + params: owhisper_interface::ListenParams, +) -> ListenClient { + let mut builder = ListenClient::builder() + .adapter::() + .api_base(api_base.into()) + .params(params); + + if let Some(api_key) = api_key { + builder = builder.api_key(api_key); + } + + builder.build_single().await +} + +pub async fn build_dual_client( + api_base: impl Into, + api_key: Option, + params: owhisper_interface::ListenParams, +) -> ListenClientDual { + let mut builder = ListenClient::builder() + .adapter::() + .api_base(api_base.into()) + .params(params); + + if let Some(api_key) = api_key { + builder = builder.api_key(api_key); + } + + builder.build_dual().await +} + +pub async fn run_single_client( + source: AudioSource, + client: ListenClient, + sample_rate: u32, + timeout_secs: u64, +) { + let kind = match source { + AudioSource::Input => ChannelKind::Mic, + AudioSource::Output => ChannelKind::Speaker, + _ => unreachable!(), + }; + + let mut audio_input = open_audio(&source); + print_audio_info(&audio_input, &source, sample_rate); + + let audio_stream = create_audio_stream(&mut audio_input, sample_rate); + let (response_stream, handle) = client + .from_realtime_audio(audio_stream) + .await + .expect("failed to connect"); + + process_stream( + response_stream, + handle, + timeout_secs, + DisplayMode::Single(kind), + ) + .await; +} + +pub async fn run_dual_client( + source: AudioSource, + client: ListenClientDual, + sample_rate: u32, + timeout_secs: u64, +) { + print_dual_audio_info(&source, sample_rate); + + let audio_stream = create_dual_audio_stream(&source, sample_rate); + let (response_stream, handle) = client + .from_realtime_audio(audio_stream) + .await + .expect("failed to connect"); + + process_stream(response_stream, handle, timeout_secs, DisplayMode::Dual).await; +} + +pub struct LocalServer { + addr: SocketAddr, + shutdown_tx: Option>, +} + +impl LocalServer { + pub fn addr(&self) -> SocketAddr { + self.addr + } + + pub fn api_base(&self, suffix: &str) -> String { + format!("http://{}{}", self.addr, suffix) + } +} + +impl Drop for LocalServer { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + } +} + +pub async fn spawn_router(app: Router) -> LocalServer { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + axum::serve(listener, app) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + LocalServer { + addr, + shutdown_tx: Some(shutdown_tx), + } +} + +pub async fn process_stream( + response_stream: S, + handle: H, + timeout_secs: u64, + mode: DisplayMode, +) where + S: futures_util::Stream>, + H: FinalizeHandle, +{ + futures_util::pin_mut!(response_stream); + + let t0 = Instant::now(); + let mut channels: Vec<(Transcript, Option)> = match &mode { + DisplayMode::Single(kind) => vec![(Transcript::new(t0, *kind), None)], + DisplayMode::Dual => vec![ + (Transcript::new(t0, ChannelKind::Mic), None), + (Transcript::new(t0, ChannelKind::Speaker), None), + ], + }; + + let read_loop = async { + while let Some(result) = response_stream.next().await { + match result { + Ok(StreamResponse::TranscriptResponse { + is_final, + channel, + channel_index, + .. + }) => { + let text = channel + .alternatives + .first() + .map(|a| a.transcript.as_str()) + .unwrap_or(""); + + let ch = match &mode { + DisplayMode::Single(_) => 0, + DisplayMode::Dual => { + channel_index.first().copied().unwrap_or(0).clamp(0, 1) as usize + } + }; + + let (transcript, last_confirmed) = &mut channels[ch]; + if is_final { + if last_confirmed.as_deref() == Some(text) { + continue; + } + *last_confirmed = Some(text.to_string()); + transcript.confirm(text); + } else { + transcript.set_partial(text); + } + } + Ok(StreamResponse::TerminalResponse { .. }) => break, + Ok(StreamResponse::ErrorResponse { error_message, .. }) => { + eprintln!("\nerror: {}", error_message); + break; + } + Ok(_) => {} + Err(e) => { + eprintln!("\nws error: {:?}", e); + break; + } + } + } + }; + + let _ = tokio::time::timeout(Duration::from_secs(timeout_secs), read_loop).await; + handle.finalize().await; + eprintln!(); +} + +#[macro_export] +macro_rules! simple_provider_example { + ( + adapter: $adapter:path, + api_base: $api_base:expr, + api_key_env: $api_key_env:literal, + params: $params:expr $(,)? + ) => { + #[derive(::clap::Parser)] + struct Args { + #[command(flatten)] + audio: $crate::AudioArgs, + } + + #[::tokio::main] + async fn main() { + let args = ::parse(); + if args.audio.audio.is_dual() { + let client = $crate::build_dual_client::<$adapter>( + $api_base, + Some(::std::env::var($api_key_env).expect(concat!($api_key_env, " not set"))), + $params, + ) + .await; + + $crate::run_dual_client( + args.audio.audio, + client, + $crate::DEFAULT_SAMPLE_RATE, + $crate::DEFAULT_TIMEOUT_SECS, + ) + .await; + } else { + let client = $crate::build_single_client::<$adapter>( + $api_base, + Some(::std::env::var($api_key_env).expect(concat!($api_key_env, " not set"))), + $params, + ) + .await; + + $crate::run_single_client( + args.audio.audio, + client, + $crate::DEFAULT_SAMPLE_RATE, + $crate::DEFAULT_TIMEOUT_SECS, + ) + .await; + } + } + }; +} + +fn capture_frame_to_bytes( + source: &AudioSource, + frame: CaptureFrame, +) -> (bytes::Bytes, bytes::Bytes) { + let (mic, speaker) = match source { + AudioSource::RawDual => frame.raw_dual(), + AudioSource::AecDual => frame.aec_dual(), + AudioSource::Input | AudioSource::Output => unreachable!(), + }; + + ( + f32_to_i16_bytes(mic.iter().copied()), + f32_to_i16_bytes(speaker.iter().copied()), + ) +} + +fn fmt_ts(secs: f64) -> String { + let m = (secs / 60.0) as u32; + let s = secs % 60.0; + format!("{:02}:{:02}", m, s as u32) +} + +struct Segment { + time: f64, + text: String, +} + +struct Transcript { + segments: Vec, + partial: String, + t0: Instant, + kind: ChannelKind, +} + +impl Transcript { + fn new(t0: Instant, kind: ChannelKind) -> Self { + Self { + segments: Vec::new(), + partial: String::new(), + t0, + kind, + } + } + + fn elapsed(&self) -> f64 { + self.t0.elapsed().as_secs_f64() + } + + fn set_partial(&mut self, text: &str) { + self.partial = text.to_string(); + self.render(); + } + + fn confirm(&mut self, text: &str) { + self.segments.push(Segment { + time: self.elapsed(), + text: text.to_string(), + }); + self.partial.clear(); + self.trim(); + self.render(); + } + + fn trim(&mut self) { + const OVERHEAD: usize = 70; + let max_chars = crossterm::terminal::size() + .map(|(cols, _)| (cols as usize).saturating_sub(OVERHEAD)) + .unwrap_or(120); + + let partial_len = if self.partial.is_empty() { + 0 + } else { + self.partial.len() + 1 + }; + let total_len: usize = self + .segments + .iter() + .map(|s| s.text.len() + 1) + .sum::() + + partial_len; + if total_len > max_chars { + let drain_count = self.segments.len() * 2 / 3; + if drain_count > 0 { + self.segments.drain(..drain_count); + } + } + } + + fn render(&self) { + let confirmed: String = self + .segments + .iter() + .map(|s| s.text.as_str()) + .collect::>() + .join(" "); + + if confirmed.is_empty() && self.partial.is_empty() { + return; + } + + let to = self.elapsed(); + let from = self.segments.first().map(|s| fmt_ts(s.time)); + let prefix = format!("[{} / {}]", from.as_deref().unwrap_or("--:--"), fmt_ts(to)).dimmed(); + + let colored_confirmed = match self.kind { + ChannelKind::Mic => confirmed.truecolor(255, 190, 190).bold(), + ChannelKind::Speaker => confirmed.truecolor(190, 200, 255).bold(), + }; + + let colored_partial = if self.partial.is_empty() { + None + } else { + Some(match self.kind { + ChannelKind::Mic => self.partial.truecolor(128, 95, 95), + ChannelKind::Speaker => self.partial.truecolor(95, 100, 128), + }) + }; + + if let Some(partial) = colored_partial { + eprintln!("{} {} {}", prefix, colored_confirmed, partial); + } else { + eprintln!("{} {}", prefix, colored_confirmed); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn audio_source_reports_dual_modes() { + assert!(!AudioSource::Input.is_dual()); + assert!(!AudioSource::Output.is_dual()); + assert!(AudioSource::RawDual.is_dual()); + assert!(AudioSource::AecDual.is_dual()); + } + + #[test] + fn capture_frame_to_bytes_preserves_channel_order() { + let frame = CaptureFrame { + raw_mic: std::sync::Arc::from([0.25_f32, -0.25]), + raw_speaker: std::sync::Arc::from([0.75_f32, -0.75]), + aec_mic: Some(std::sync::Arc::from([0.1_f32, -0.1])), + }; + + let (raw_mic, raw_speaker) = capture_frame_to_bytes(&AudioSource::RawDual, frame.clone()); + assert_eq!(&raw_mic[..], &[0x00, 0x20, 0x00, 0xe0]); + assert_eq!(&raw_speaker[..], &[0x00, 0x60, 0x00, 0xa0]); + + let (aec_mic, aec_speaker) = capture_frame_to_bytes(&AudioSource::AecDual, frame); + assert_eq!(&aec_mic[..], &[0xcc, 0x0c, 0x34, 0xf3]); + assert_eq!(&aec_speaker[..], &[0x00, 0x60, 0x00, 0xa0]); + } +}