diff --git a/crates/cactus/src/stt/batch.rs b/crates/cactus/src/stt/batch.rs index 77f7340e39..37e1430481 100644 --- a/crates/cactus/src/stt/batch.rs +++ b/crates/cactus/src/stt/batch.rs @@ -35,9 +35,6 @@ unsafe extern "C" fn token_trampoline bool>( let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let chunk = unsafe { CStr::from_ptr(token) }.to_string_lossy(); - if chunk.starts_with("<|") && chunk.ends_with("|>") { - return; - } let on_token = unsafe { &mut *state.on_token.get() }; if !on_token(&chunk) { state.stopped.set(true); diff --git a/crates/listener2-core/src/batch.rs b/crates/listener2-core/src/batch.rs index 4f952286ed..674c7c6fe0 100644 --- a/crates/listener2-core/src/batch.rs +++ b/crates/listener2-core/src/batch.rs @@ -393,7 +393,7 @@ async fn spawn_batch_task( AdapterKind::Hyprnote => { spawn_batch_task_with_adapter::(args, myself).await } - AdapterKind::Cactus => spawn_batch_task_with_adapter::(args, myself).await, + AdapterKind::Cactus => spawn_cactus_batch_task(args, myself).await, } } @@ -509,6 +509,119 @@ async fn spawn_argmax_streaming_batch_task( Ok((rx_task, shutdown_tx)) } +async fn spawn_cactus_batch_task( + args: BatchArgs, + myself: ActorRef, +) -> Result< + ( + tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, + ), + ActorProcessingErr, +> { + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + let rx_task = tokio::spawn(async move { + tracing::info!("cactus batch task: starting direct HTTP batch"); + let start_notifier = args.start_notifier.clone(); + + let stream_result = CactusAdapter::transcribe_file_streaming( + &args.base_url, + &args.api_key, + &args.listen_params, + &args.file_path, + ) + .await; + + let mut stream = match stream_result { + Ok(s) => { + notify_start_result(&start_notifier, Ok(())); + s + } + Err(e) => { + let raw_error = format!("{:?}", e); + let error = format_user_friendly_error(&raw_error); + tracing::error!("cactus batch task: failed to start: {:?}", e); + notify_start_result(&start_notifier, Err(error.clone())); + let _ = myself.send_message(BatchMsg::StreamStartFailed(error)); + return; + } + }; + + let response_timeout = Duration::from_secs(BATCH_STREAM_TIMEOUT_SECS); + let mut response_count = 0; + let mut ended_cleanly = false; + + loop { + tokio::select! { + _ = &mut shutdown_rx => { + tracing::info!("cactus batch task: shutdown"); + ended_cleanly = true; + break; + } + result = tokio::time::timeout(response_timeout, StreamExt::next(&mut stream)) => { + match result { + Ok(Some(Ok(event))) => { + response_count += 1; + + let is_from_finalize = matches!( + &event.response, + StreamResponse::TranscriptResponse { from_finalize, .. } if *from_finalize + ); + + tracing::info!( + "cactus batch: response #{}{}", + response_count, + if is_from_finalize { " (from_finalize)" } else { "" } + ); + + if let Err(e) = myself.send_message(BatchMsg::StreamResponse { + response: Box::new(event.response), + percentage: event.percentage, + }) { + tracing::error!("failed to send cactus batch response: {:?}", e); + } + + if is_from_finalize { + ended_cleanly = true; + break; + } + } + Ok(Some(Err(e))) => { + let raw_error = format!("{:?}", e); + let error = format_user_friendly_error(&raw_error); + tracing::error!("cactus batch error: {:?}", e); + let _ = myself.send_message(BatchMsg::StreamError(error)); + break; + } + Ok(None) => { + tracing::info!("cactus batch completed (total: {})", response_count); + ended_cleanly = true; + break; + } + Err(elapsed) => { + tracing::warn!(timeout = ?elapsed, "cactus batch timeout"); + let _ = myself.send_message(BatchMsg::StreamError( + format_user_friendly_error("timeout waiting for response"), + )); + break; + } + } + } + } + } + + if ended_cleanly { + if let Err(e) = myself.send_message(BatchMsg::StreamEnded) { + tracing::error!("failed to send cactus batch ended message: {:?}", e); + } + } + tracing::info!("cactus batch task exited"); + }); + + Ok((rx_task, shutdown_tx)) +} + async fn spawn_batch_task_with_adapter( args: BatchArgs, myself: ActorRef, diff --git a/crates/owhisper-client/src/adapter/argmax/batch.rs b/crates/owhisper-client/src/adapter/argmax/batch.rs index 566ea1056d..f962317736 100644 --- a/crates/owhisper-client/src/adapter/argmax/batch.rs +++ b/crates/owhisper-client/src/adapter/argmax/batch.rs @@ -1,8 +1,7 @@ use std::path::{Path, PathBuf}; -use std::pin::Pin; use std::time::Duration; -use futures_util::{Stream, StreamExt}; +use futures_util::StreamExt; use hypr_audio_utils::{Source, f32_to_i16_bytes, resample_audio, source_from_path}; use owhisper_interface::batch::Response as BatchResponse; use owhisper_interface::stream::StreamResponse; @@ -11,7 +10,9 @@ use tokio_stream::StreamExt as TokioStreamExt; use crate::ListenClientBuilder; use crate::adapter::deepgram_compat::build_batch_url; -use crate::adapter::{BatchFuture, BatchSttAdapter, ClientWithMiddleware}; +use crate::adapter::{ + BatchFuture, BatchSttAdapter, ClientWithMiddleware, StreamingBatchEvent, StreamingBatchStream, +}; use crate::error::Error; use super::{ArgmaxAdapter, keywords::ArgmaxKeywordStrategy, language::ArgmaxLanguageStrategy}; @@ -151,15 +152,6 @@ impl StreamingBatchConfig { } } -#[derive(Debug, Clone)] -pub struct StreamingBatchEvent { - pub response: StreamResponse, - pub percentage: f64, -} - -pub type StreamingBatchStream = - Pin> + Send>>; - impl ArgmaxAdapter { pub async fn transcribe_file_streaming>( api_base: &str, diff --git a/crates/owhisper-client/src/adapter/argmax/mod.rs b/crates/owhisper-client/src/adapter/argmax/mod.rs index 4790a7622b..422a30a74b 100644 --- a/crates/owhisper-client/src/adapter/argmax/mod.rs +++ b/crates/owhisper-client/src/adapter/argmax/mod.rs @@ -5,7 +5,7 @@ pub(crate) mod language; mod live; #[cfg(feature = "argmax")] -pub use batch::{StreamingBatchConfig, StreamingBatchEvent, StreamingBatchStream}; +pub use batch::StreamingBatchConfig; pub use language::PARAKEET_V3_LANGS; diff --git a/crates/owhisper-client/src/adapter/cactus/batch.rs b/crates/owhisper-client/src/adapter/cactus/batch.rs new file mode 100644 index 0000000000..50c0c0ffce --- /dev/null +++ b/crates/owhisper-client/src/adapter/cactus/batch.rs @@ -0,0 +1,243 @@ +use std::path::Path; + +use futures_util::StreamExt; +use owhisper_interface::{InferencePhase, InferenceProgress, ListenParams, batch, stream}; +use tokio_stream::wrappers::UnboundedReceiverStream; + +use serde_json::Value; + +use super::CactusAdapter; +use crate::adapter::deepgram_compat::listen_endpoint_url; +use crate::adapter::{StreamingBatchEvent, StreamingBatchStream, is_local_host}; +use crate::error::Error; + +impl CactusAdapter { + pub async fn transcribe_file_streaming>( + api_base: &str, + _api_key: &str, + params: &ListenParams, + file_path: P, + ) -> Result { + let path = file_path.as_ref().to_path_buf(); + + let (audio_bytes, content_type) = tokio::task::spawn_blocking(move || { + let bytes = std::fs::read(&path).map_err(|e| Error::AudioProcessing(e.to_string()))?; + let ct = crate::adapter::http::mime_type_from_extension(&path).to_string(); + Ok::<_, Error>((bytes::Bytes::from(bytes), ct)) + }) + .await??; + + let url = build_http_url(api_base, params); + + let response = reqwest::Client::new() + .post(url) + .header("Content-Type", content_type) + .header("Accept", "text/event-stream") + .body(audio_bytes) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + + tokio::spawn(async move { + let mut byte_stream = response.bytes_stream(); + let mut buf = String::new(); + let mut current_event = String::new(); + let mut current_data = String::new(); + let mut accumulated = String::new(); + + loop { + match byte_stream.next().await { + None => break, + Some(Err(e)) => { + let _ = tx.send(Err(Error::AudioProcessing(e.to_string()))); + break; + } + Some(Ok(chunk)) => { + buf.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buf.find('\n') { + let line = buf[..pos].trim_end_matches('\r').to_string(); + buf = buf[pos + 1..].to_string(); + + if let Some(ev) = line.strip_prefix("event:") { + current_event = ev.trim().to_string(); + } else if let Some(data) = line.strip_prefix("data:") { + current_data = data.trim().to_string(); + } else if line.is_empty() && !current_event.is_empty() { + let event_type = std::mem::take(&mut current_event); + let data = std::mem::take(&mut current_data); + + match event_type.as_str() { + "progress" => { + let progress = parse_progress(&data); + if let Some(fragment) = progress.partial_text { + accumulated.push_str(&fragment); + } + let event = + in_progress_event(&accumulated, progress.percentage); + let _ = tx.send(Ok(event)); + } + "result" => { + match serde_json::from_str::(&data) { + Ok(r) => { + let _ = tx.send(Ok(batch_to_event(r))); + } + Err(e) => { + let _ = tx.send(Err(Error::AudioProcessing( + format!("result parse error: {e}"), + ))); + } + } + } + "error" => { + let _ = tx.send(Err(Error::AudioProcessing(data))); + } + _ => {} + } + } + } + } + } + } + }); + + Ok(Box::pin(UnboundedReceiverStream::new(rx))) + } +} + +fn build_http_url(api_base: &str, params: &ListenParams) -> url::Url { + let (mut url, existing_params) = listen_endpoint_url(api_base); + + let host = url.host_str().unwrap_or("localhost").to_string(); + let _ = url.set_scheme(if is_local_host(&host) { + "http" + } else { + "https" + }); + + { + let mut q = url.query_pairs_mut(); + for (k, v) in &existing_params { + q.append_pair(k, v); + } + q.append_pair("channels", ¶ms.channels.max(1).to_string()); + q.append_pair("sample_rate", ¶ms.sample_rate.to_string()); + if let Some(lang) = params.languages.first() { + q.append_pair("language", lang.iso639_code()); + } + for kw in ¶ms.keywords { + q.append_pair("keywords", kw); + } + } + + url +} + +fn parse_progress(data: &str) -> InferenceProgress { + if let Ok(p) = serde_json::from_str::(data) { + return p; + } + + // Backward-compat / best-effort parsing + if let Ok(v) = serde_json::from_str::(data) { + let percentage = v["percentage"].as_f64().unwrap_or(0.0); + let partial_text = v["partial_text"] + .as_str() + .or_else(|| v["token"].as_str()) + .map(|s| s.to_string()); + return InferenceProgress { + percentage, + partial_text, + phase: InferencePhase::Transcribing, + }; + } + + InferenceProgress { + percentage: 0.0, + partial_text: Some(data.to_string()), + phase: InferencePhase::Transcribing, + } +} + +fn in_progress_event(accumulated: &str, percentage: f64) -> StreamingBatchEvent { + StreamingBatchEvent { + response: stream::StreamResponse::TranscriptResponse { + start: 0.0, + duration: 0.0, + is_final: false, + speech_final: false, + from_finalize: false, + channel: stream::Channel { + alternatives: vec![stream::Alternatives { + transcript: accumulated.to_string(), + words: vec![], + confidence: 0.0, + languages: vec![], + }], + }, + metadata: stream::Metadata::default(), + channel_index: vec![0, 1], + }, + percentage, + } +} + +fn batch_to_event(response: batch::Response) -> StreamingBatchEvent { + let duration = response + .metadata + .get("duration") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + + let (transcript, words, confidence) = response + .results + .channels + .into_iter() + .next() + .and_then(|c| c.alternatives.into_iter().next()) + .map(|a| { + let words = a + .words + .iter() + .map(|w| stream::Word { + word: w.word.clone(), + start: w.start, + end: w.end, + confidence: w.confidence, + speaker: w.speaker.map(|s| s as i32), + punctuated_word: w.punctuated_word.clone(), + language: None, + }) + .collect::>(); + (a.transcript, words, a.confidence) + }) + .unwrap_or_default(); + + StreamingBatchEvent { + response: stream::StreamResponse::TranscriptResponse { + start: 0.0, + duration, + is_final: true, + speech_final: true, + from_finalize: true, + channel: stream::Channel { + alternatives: vec![stream::Alternatives { + transcript, + words, + confidence, + languages: vec![], + }], + }, + metadata: stream::Metadata::default(), + channel_index: vec![0, 1], + }, + percentage: 1.0, + } +} diff --git a/crates/owhisper-client/src/adapter/cactus/mod.rs b/crates/owhisper-client/src/adapter/cactus/mod.rs index ede6536e64..f701a33336 100644 --- a/crates/owhisper-client/src/adapter/cactus/mod.rs +++ b/crates/owhisper-client/src/adapter/cactus/mod.rs @@ -1,3 +1,4 @@ +mod batch; mod live; #[derive(Clone, Default)] diff --git a/crates/owhisper-client/src/adapter/mod.rs b/crates/owhisper-client/src/adapter/mod.rs index 3eaf1cc694..cffe155095 100644 --- a/crates/owhisper-client/src/adapter/mod.rs +++ b/crates/owhisper-client/src/adapter/mod.rs @@ -41,12 +41,23 @@ use owhisper_interface::ListenParams; use owhisper_interface::batch::Response as BatchResponse; use owhisper_interface::stream::StreamResponse; +use futures_util::Stream; + use crate::error::Error; pub use reqwest_middleware::ClientWithMiddleware; pub type BatchFuture<'a> = Pin> + Send + 'a>>; +#[derive(Debug, Clone)] +pub struct StreamingBatchEvent { + pub response: StreamResponse, + pub percentage: f64, +} + +pub type StreamingBatchStream = + Pin> + Send>>; + pub fn documented_language_codes_live() -> Vec { let mut set: BTreeSet<&'static str> = BTreeSet::new(); diff --git a/crates/owhisper-client/src/lib.rs b/crates/owhisper-client/src/lib.rs index 34b8e0fba3..88c38b9e6e 100644 --- a/crates/owhisper-client/src/lib.rs +++ b/crates/owhisper-client/src/lib.rs @@ -15,6 +15,8 @@ pub use providers::{Auth, Provider, is_meta_model}; use std::marker::PhantomData; +#[cfg(feature = "argmax")] +pub use adapter::StreamingBatchConfig; pub use adapter::deepgram::DeepgramModel; pub use adapter::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, BatchSttAdapter, CactusAdapter, CallbackResult, @@ -24,8 +26,7 @@ pub use adapter::{ documented_language_codes_batch, documented_language_codes_live, is_hyprnote_proxy, is_local_host, normalize_languages, }; -#[cfg(feature = "argmax")] -pub use adapter::{StreamingBatchConfig, StreamingBatchEvent, StreamingBatchStream}; +pub use adapter::{StreamingBatchEvent, StreamingBatchStream}; pub use batch::{BatchClient, BatchClientBuilder}; pub use error::Error; diff --git a/crates/owhisper-interface/src/lib.rs b/crates/owhisper-interface/src/lib.rs index e95a7adadf..19bffb9ccb 100644 --- a/crates/owhisper-interface/src/lib.rs +++ b/crates/owhisper-interface/src/lib.rs @@ -1,11 +1,14 @@ pub mod batch; #[cfg(feature = "openapi")] pub mod openapi; +pub mod progress; pub mod stream; #[cfg(feature = "openapi")] pub use openapi::openapi; +pub use progress::{InferencePhase, InferenceProgress}; + #[macro_export] macro_rules! common_derives { ($item:item) => { diff --git a/crates/owhisper-interface/src/progress.rs b/crates/owhisper-interface/src/progress.rs new file mode 100644 index 0000000000..2ea4ea51b8 --- /dev/null +++ b/crates/owhisper-interface/src/progress.rs @@ -0,0 +1,20 @@ +crate::common_derives! { + #[serde(rename_all = "snake_case")] + pub enum InferencePhase { + Transcribing, + Prefill, + Decoding, + } +} + +crate::common_derives! { + pub struct InferenceProgress { + /// Fraction of work completed, in range 0.0..=1.0. + pub percentage: f64, + + /// Optional text fragment produced so far. + pub partial_text: Option, + + pub phase: InferencePhase, + } +} diff --git a/crates/transcribe-cactus/src/service/batch/mod.rs b/crates/transcribe-cactus/src/service/batch/mod.rs index ef5eb90771..d9f82f0eae 100644 --- a/crates/transcribe-cactus/src/service/batch/mod.rs +++ b/crates/transcribe-cactus/src/service/batch/mod.rs @@ -2,19 +2,38 @@ mod audio; mod response; mod transcribe; +use std::convert::Infallible; use std::path::Path; use axum::{ Json, http::StatusCode, - response::{IntoResponse, Response}, + response::{IntoResponse, Response, sse::Event, sse::Sse}, }; use bytes::Bytes; +use futures_util::stream; use owhisper_interface::ListenParams; +use tokio::sync::mpsc; + +use owhisper_interface::InferenceProgress; use transcribe::transcribe_batch; pub async fn handle_batch( + body: Bytes, + content_type: &str, + accept: &str, + params: &ListenParams, + model_path: &Path, +) -> Response { + if accept.contains("text/event-stream") { + handle_batch_sse(body, content_type, params, model_path).await + } else { + handle_batch_json(body, content_type, params, model_path).await + } +} + +async fn handle_batch_json( body: Bytes, content_type: &str, params: &ListenParams, @@ -25,7 +44,7 @@ pub async fn handle_batch( let params = params.clone(); let result = tokio::task::spawn_blocking(move || { - transcribe_batch(&body, &content_type, ¶ms, &model_path) + transcribe_batch(&body, &content_type, ¶ms, &model_path, None) }) .await; @@ -48,3 +67,91 @@ pub async fn handle_batch( } } } + +async fn handle_batch_sse( + body: Bytes, + content_type: &str, + params: &ListenParams, + model_path: &Path, +) -> Response { + let model_path = model_path.to_path_buf(); + let content_type = content_type.to_string(); + let params = params.clone(); + + let (progress_tx, progress_rx) = mpsc::unbounded_channel::(); + let (result_tx, result_rx) = mpsc::unbounded_channel::>(); + + tokio::task::spawn_blocking(move || { + let outcome = transcribe_batch( + &body, + &content_type, + ¶ms, + &model_path, + Some(progress_tx), + ); + match outcome { + Ok(response) => match serde_json::to_string(&response) { + Ok(json) => { + let _ = result_tx.send(Ok(json)); + } + Err(e) => { + let _ = result_tx.send(Err(e.to_string())); + } + }, + Err(e) => { + let _ = result_tx.send(Err(e.to_string())); + } + } + }); + + let sse_stream = stream::unfold( + (Some(progress_rx), result_rx), + |(mut progress_rx, mut result_rx)| async move { + if let Some(ref mut prx) = progress_rx { + tokio::select! { + biased; + result = result_rx.recv() => { + match result { + Some(Ok(json)) => { + let event = Event::default().event("result").data(json); + Some((Ok::(event), (progress_rx, result_rx))) + } + Some(Err(e)) => { + let event = Event::default().event("error").data(e); + Some((Ok(event), (progress_rx, result_rx))) + } + None => None, + } + } + progress = prx.recv() => { + match progress { + Some(p) => { + let json = serde_json::to_string(&p).unwrap_or_else(|_| "{}".to_string()); + let event = Event::default().event("progress").data(json); + Some((Ok(event), (progress_rx, result_rx))) + } + None => { + // Progress channel closed — wait for the final result. + match result_rx.recv().await { + Some(Ok(json)) => { + let event = Event::default().event("result").data(json); + Some((Ok(event), (None, result_rx))) + } + Some(Err(e)) => { + let event = Event::default().event("error").data(e); + Some((Ok(event), (None, result_rx))) + } + None => None, + } + } + } + } + } + } else { + None + } + }, + ); + + Sse::new(sse_stream).into_response() +} diff --git a/crates/transcribe-cactus/src/service/batch/transcribe.rs b/crates/transcribe-cactus/src/service/batch/transcribe.rs index 97e609e872..dac2dc49cd 100644 --- a/crates/transcribe-cactus/src/service/batch/transcribe.rs +++ b/crates/transcribe-cactus/src/service/batch/transcribe.rs @@ -3,16 +3,23 @@ use std::path::Path; use owhisper_interface::ListenParams; use owhisper_interface::batch; -use owhisper_interface::stream::{Extra, Metadata, ModelInfo}; +use owhisper_interface::progress::{InferencePhase, InferenceProgress}; +use tokio::sync::mpsc::UnboundedSender; use super::audio::{audio_duration_secs, content_type_to_extension}; use super::response::build_batch_words; +fn parse_timestamp_token(token: &str) -> Option { + let inner = token.strip_prefix("<|")?.strip_suffix("|>")?; + inner.parse::().ok() +} + pub(super) fn transcribe_batch( audio_data: &[u8], content_type: &str, params: &ListenParams, model_path: &Path, + progress_tx: Option>, ) -> Result { let extension = content_type_to_extension(content_type); let mut temp_file = tempfile::Builder::new() @@ -39,25 +46,49 @@ pub(super) fn transcribe_batch( let total_duration = audio_duration_secs(temp_file.path()); - let cactus_response = model.transcribe_file(temp_file.path(), &options)?; + let cactus_response = match progress_tx { + Some(tx) => { + let mut last_audio_pos = 0.0f64; + let mut last_percentage = 0.0f64; + let total_duration = total_duration.max(0.0); + + model.transcribe_file_with_callback(temp_file.path(), &options, move |token| { + if let Some(ts) = parse_timestamp_token(token) { + last_audio_pos = last_audio_pos.max(ts); + if total_duration > 0.0 { + let pct = (last_audio_pos / total_duration).clamp(0.0, 0.95); + last_percentage = last_percentage.max(pct); + } + + let _ = tx.send(InferenceProgress { + percentage: last_percentage, + partial_text: None, + phase: InferencePhase::Transcribing, + }); + return true; + } + + // Ignore other special tokens (language tags, control markers, etc.) + if token.starts_with("<|") && token.ends_with("|>") { + return true; + } + + let _ = tx.send(InferenceProgress { + percentage: last_percentage, + partial_text: Some(token.to_string()), + phase: InferencePhase::Transcribing, + }); + true + })? + } + None => model.transcribe_file(temp_file.path(), &options)?, + }; + let transcript = cactus_response.text.trim().to_string(); let confidence = cactus_response.confidence as f64; let words = build_batch_words(&transcript, total_duration, confidence); - let model_name = model_path - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("cactus"); - - let meta = Metadata { - model_info: ModelInfo { - name: model_name.to_string(), - version: "1.0".to_string(), - arch: "cactus".to_string(), - }, - extra: Some(Extra::default().into()), - ..Default::default() - }; + let meta = crate::service::build_metadata(model_path); let mut metadata = serde_json::to_value(&meta).unwrap_or_default(); if let Some(obj) = metadata.as_object_mut() { @@ -108,7 +139,7 @@ mod tests { ..Default::default() }; - let response = transcribe_batch(&wav_bytes, "audio/wav", ¶ms, model_path) + let response = transcribe_batch(&wav_bytes, "audio/wav", ¶ms, model_path, None) .unwrap_or_else(|e| panic!("real-model batch transcription failed: {e}")); let Some(channel) = response.results.channels.first() else { diff --git a/crates/transcribe-cactus/src/service/mod.rs b/crates/transcribe-cactus/src/service/mod.rs index c924e10692..7d94282ce7 100644 --- a/crates/transcribe-cactus/src/service/mod.rs +++ b/crates/transcribe-cactus/src/service/mod.rs @@ -1,3 +1,26 @@ mod batch; mod streaming; + pub use streaming::*; + +use std::path::Path; + +use owhisper_interface::stream::{Extra, Metadata, ModelInfo}; + +pub(crate) fn build_metadata(model_path: &Path) -> Metadata { + let model_name = model_path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("cactus") + .to_string(); + + Metadata { + model_info: ModelInfo { + name: model_name, + version: "1.0".to_string(), + arch: "cactus".to_string(), + }, + extra: Some(Extra::default().into()), + ..Default::default() + } +} diff --git a/crates/transcribe-cactus/src/service/streaming/mod.rs b/crates/transcribe-cactus/src/service/streaming/mod.rs index e6131389de..ef154ae021 100644 --- a/crates/transcribe-cactus/src/service/streaming/mod.rs +++ b/crates/transcribe-cactus/src/service/streaming/mod.rs @@ -1,139 +1,6 @@ mod message; mod response; +mod service; mod session; -use std::{ - future::Future, - path::PathBuf, - pin::Pin, - task::{Context, Poll}, -}; - -use axum::{ - body::Body, - extract::{FromRequestParts, ws::WebSocketUpgrade}, - http::{Request, StatusCode}, - response::{IntoResponse, Response}, -}; -use tower::Service; - -use hypr_ws_utils::ConnectionManager; -use owhisper_interface::ListenParams; - -use super::batch; -use crate::CactusConfig; - -#[derive(Clone)] -pub struct TranscribeService { - model_path: PathBuf, - cactus_config: CactusConfig, - connection_manager: ConnectionManager, -} - -impl TranscribeService { - pub fn builder() -> TranscribeServiceBuilder { - TranscribeServiceBuilder::default() - } -} - -#[derive(Default)] -pub struct TranscribeServiceBuilder { - model_path: Option, - cactus_config: CactusConfig, - connection_manager: Option, -} - -impl TranscribeServiceBuilder { - pub fn model_path(mut self, model_path: PathBuf) -> Self { - self.model_path = Some(model_path); - self - } - - pub fn cactus_config(mut self, config: CactusConfig) -> Self { - self.cactus_config = config; - self - } - - pub fn build(self) -> TranscribeService { - TranscribeService { - model_path: self - .model_path - .expect("TranscribeServiceBuilder requires model_path"), - cactus_config: self.cactus_config, - connection_manager: self.connection_manager.unwrap_or_default(), - } - } -} - -impl Service> for TranscribeService { - type Response = Response; - type Error = String; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let model_path = self.model_path.clone(); - let cactus_config = self.cactus_config.clone(); - let connection_manager = self.connection_manager.clone(); - - Box::pin(async move { - let is_ws = req - .headers() - .get("upgrade") - .and_then(|v| v.to_str().ok()) - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false); - - let query_string = req.uri().query().unwrap_or("").to_string(); - let params: ListenParams = match serde_qs::from_str(&query_string) { - Ok(p) => p, - Err(e) => { - return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); - } - }; - - if is_ws { - let (mut parts, _body) = req.into_parts(); - let ws_upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await { - Ok(ws) => ws, - Err(e) => { - return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); - } - }; - - let guard = connection_manager.acquire_connection(); - - Ok(ws_upgrade - .on_upgrade(move |socket| async move { - session::handle_websocket(socket, params, model_path, cactus_config, guard) - .await; - }) - .into_response()) - } else { - let content_type = req - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("application/octet-stream") - .to_string(); - - let body_bytes = - match axum::body::to_bytes(req.into_body(), 100 * 1024 * 1024).await { - Ok(b) => b, - Err(e) => { - return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); - } - }; - - if body_bytes.is_empty() { - return Ok((StatusCode::BAD_REQUEST, "request body is empty").into_response()); - } - - Ok(batch::handle_batch(body_bytes, &content_type, ¶ms, &model_path).await) - } - }) - } -} +pub use service::{TranscribeService, TranscribeServiceBuilder}; diff --git a/crates/transcribe-cactus/src/service/streaming/response.rs b/crates/transcribe-cactus/src/service/streaming/response.rs index 56709270c1..b5aad114b3 100644 --- a/crates/transcribe-cactus/src/service/streaming/response.rs +++ b/crates/transcribe-cactus/src/service/streaming/response.rs @@ -2,9 +2,7 @@ use std::path::Path; use axum::extract::ws::{Message, WebSocket}; use futures_util::{SinkExt, stream::SplitSink}; -use owhisper_interface::stream::{ - Alternatives, Channel, Extra, Metadata, ModelInfo, StreamResponse, Word, -}; +use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; pub(super) type WsSender = SplitSink; @@ -33,21 +31,7 @@ pub(super) async fn send_ws_best_effort(sender: &mut WsSender, value: &StreamRes } pub(super) fn build_session_metadata(model_path: &Path) -> Metadata { - let model_name = model_path - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("cactus") - .to_string(); - - Metadata { - model_info: ModelInfo { - name: model_name, - version: "1.0".to_string(), - arch: "cactus".to_string(), - }, - extra: Some(Extra::default().into()), - ..Default::default() - } + crate::service::build_metadata(model_path) } pub(super) fn build_transcript_response( diff --git a/crates/transcribe-cactus/src/service/streaming/service.rs b/crates/transcribe-cactus/src/service/streaming/service.rs new file mode 100644 index 0000000000..72aeb8ef10 --- /dev/null +++ b/crates/transcribe-cactus/src/service/streaming/service.rs @@ -0,0 +1,146 @@ +use std::{ + future::Future, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +use axum::{ + body::Body, + extract::{FromRequestParts, ws::WebSocketUpgrade}, + http::{Request, StatusCode}, + response::{IntoResponse, Response}, +}; +use tower::Service; + +use hypr_ws_utils::ConnectionManager; +use owhisper_interface::ListenParams; + +use super::session; +use crate::CactusConfig; +use crate::service::batch; + +#[derive(Clone)] +pub struct TranscribeService { + pub(super) model_path: PathBuf, + pub(super) cactus_config: CactusConfig, + pub(super) connection_manager: ConnectionManager, +} + +impl TranscribeService { + pub fn builder() -> TranscribeServiceBuilder { + TranscribeServiceBuilder::default() + } +} + +#[derive(Default)] +pub struct TranscribeServiceBuilder { + model_path: Option, + cactus_config: CactusConfig, + connection_manager: Option, +} + +impl TranscribeServiceBuilder { + pub fn model_path(mut self, model_path: PathBuf) -> Self { + self.model_path = Some(model_path); + self + } + + pub fn cactus_config(mut self, config: CactusConfig) -> Self { + self.cactus_config = config; + self + } + + pub fn build(self) -> TranscribeService { + TranscribeService { + model_path: self + .model_path + .expect("TranscribeServiceBuilder requires model_path"), + cactus_config: self.cactus_config, + connection_manager: self.connection_manager.unwrap_or_default(), + } + } +} + +impl Service> for TranscribeService { + type Response = Response; + type Error = String; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let model_path = self.model_path.clone(); + let cactus_config = self.cactus_config.clone(); + let connection_manager = self.connection_manager.clone(); + + Box::pin(async move { + let is_ws = req + .headers() + .get("upgrade") + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + + let query_string = req.uri().query().unwrap_or("").to_string(); + let params: ListenParams = match serde_qs::from_str(&query_string) { + Ok(p) => p, + Err(e) => { + return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); + } + }; + + if is_ws { + let (mut parts, _body) = req.into_parts(); + let ws_upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await { + Ok(ws) => ws, + Err(e) => { + return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); + } + }; + + let guard = connection_manager.acquire_connection(); + + Ok(ws_upgrade + .on_upgrade(move |socket| async move { + session::handle_websocket(socket, params, model_path, cactus_config, guard) + .await; + }) + .into_response()) + } else { + let accept = req + .headers() + .get("accept") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + let content_type = req + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/octet-stream") + .to_string(); + + let body_bytes = + match axum::body::to_bytes(req.into_body(), 100 * 1024 * 1024).await { + Ok(b) => b, + Err(e) => { + return Ok((StatusCode::BAD_REQUEST, e.to_string()).into_response()); + } + }; + + if body_bytes.is_empty() { + return Ok((StatusCode::BAD_REQUEST, "request body is empty").into_response()); + } + + Ok( + batch::handle_batch(body_bytes, &content_type, &accept, ¶ms, &model_path) + .await, + ) + } + }) + } +} diff --git a/crates/transcribe-cactus/tests/common/mod.rs b/crates/transcribe-cactus/tests/common/mod.rs index 608b2594f3..b03bce08a1 100644 --- a/crates/transcribe-cactus/tests/common/mod.rs +++ b/crates/transcribe-cactus/tests/common/mod.rs @@ -1,8 +1,12 @@ use std::path::PathBuf; pub fn model_path() -> PathBuf { - let path = std::env::var("CACTUS_STT_MODEL") - .unwrap_or_else(|_| "/tmp/cactus-model/moonshine-base-cactus".to_string()); + let home = std::env::var("HOME").unwrap_or_default(); + let default = format!( + "{}/Library/Application Support/com.hyprnote.dev/models/cactus/whisper-small-int8-apple", + home + ); + let path = std::env::var("CACTUS_STT_MODEL").unwrap_or(default); let path = PathBuf::from(path); assert!(path.exists(), "model not found: {}", path.display()); path