diff --git a/crates/transcribe-proxy/src/query_params.rs b/crates/transcribe-proxy/src/query_params.rs index 76b472a675..085c2aa0a2 100644 --- a/crates/transcribe-proxy/src/query_params.rs +++ b/crates/transcribe-proxy/src/query_params.rs @@ -77,6 +77,18 @@ impl QueryParams { }) .unwrap_or_default() } + + pub fn parse_keywords(&self) -> Vec { + self.get("keyword") + .or_else(|| self.get("keywords")) + .map(|v| { + v.iter() + .flat_map(|s| s.split(',')) + .map(|k| k.trim().to_string()) + .collect() + }) + .unwrap_or_default() + } } impl Deref for QueryParams { diff --git a/crates/transcribe-proxy/src/routes/batch.rs b/crates/transcribe-proxy/src/routes/batch.rs index 5dc2306b8c..4bb91cd591 100644 --- a/crates/transcribe-proxy/src/routes/batch.rs +++ b/crates/transcribe-proxy/src/routes/batch.rs @@ -19,7 +19,7 @@ use owhisper_interface::batch::Response as BatchResponse; use crate::hyprnote_routing::{RetryConfig, is_retryable_error, should_use_hyprnote_routing}; use crate::provider_selector::SelectedProvider; -use crate::query_params::{QueryParams, QueryValue}; +use crate::query_params::QueryParams; use super::AppState; @@ -204,22 +204,10 @@ async fn transcribe_with_retry( } fn build_listen_params(params: &QueryParams) -> ListenParams { - let model = params.get_first("model").map(|s| s.to_string()); - let languages = params.get_languages(); - - let keywords: Vec = params - .get("keyword") - .or_else(|| params.get("keywords")) - .map(|v| match v { - QueryValue::Single(s) => s.split(',').map(|k| k.trim().to_string()).collect(), - QueryValue::Multi(vec) => vec.iter().map(|k| k.trim().to_string()).collect(), - }) - .unwrap_or_default(); - ListenParams { - model, - languages, - keywords, + model: params.get_first("model").map(|s| s.to_string()), + languages: params.get_languages(), + keywords: params.parse_keywords(), ..Default::default() } } diff --git a/crates/transcribe-proxy/src/routes/callback.rs b/crates/transcribe-proxy/src/routes/callback.rs index b795491d73..199090bb9d 100644 --- a/crates/transcribe-proxy/src/routes/callback.rs +++ b/crates/transcribe-proxy/src/routes/callback.rs @@ -8,7 +8,7 @@ use serde::Deserialize; use hypr_supabase_storage::SupabaseStorage; use super::{AppState, RouteError, parse_async_provider}; -use crate::supabase::SupabaseClient; +use crate::supabase::{PipelineStatus, SupabaseClient}; #[derive(Deserialize)] pub(crate) struct CallbackQuery { @@ -69,12 +69,12 @@ pub async fn handler( let update = match &outcome { CallbackResult::Done(raw_result) => serde_json::json!({ - "status": "done", + "status": PipelineStatus::Done, "raw_result": raw_result, "updated_at": chrono::Utc::now().to_rfc3339(), }), CallbackResult::ProviderError(message) => serde_json::json!({ - "status": "error", + "status": PipelineStatus::Error, "error": message, "updated_at": chrono::Utc::now().to_rfc3339(), }), diff --git a/crates/transcribe-proxy/src/routes/start.rs b/crates/transcribe-proxy/src/routes/start.rs index 914d2edad6..c79c334957 100644 --- a/crates/transcribe-proxy/src/routes/start.rs +++ b/crates/transcribe-proxy/src/routes/start.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use hypr_supabase_storage::SupabaseStorage; use super::{AppState, RouteError, parse_async_provider}; -use crate::supabase::{SupabaseClient, TranscriptionJob}; +use crate::supabase::{PipelineStatus, SupabaseClient, TranscriptionJob}; #[derive(Deserialize, utoipa::ToSchema)] #[serde(rename_all = "camelCase")] @@ -110,7 +110,7 @@ pub async fn handler( user_id, file_id: body.file_id, provider: provider_str.to_string(), - status: "processing".to_string(), + status: PipelineStatus::Processing, provider_request_id: Some(provider_request_id), raw_result: None, error: None, diff --git a/crates/transcribe-proxy/src/routes/status.rs b/crates/transcribe-proxy/src/routes/status.rs index 36d8cfadd3..664b36dcb8 100644 --- a/crates/transcribe-proxy/src/routes/status.rs +++ b/crates/transcribe-proxy/src/routes/status.rs @@ -2,12 +2,12 @@ use axum::{Json, extract::Path}; use serde::Serialize; use super::RouteError; -use crate::supabase::SupabaseClient; +use crate::supabase::{PipelineStatus, SupabaseClient}; #[derive(Debug, Clone, Serialize, utoipa::ToSchema)] #[serde(rename_all = "camelCase")] pub struct SttStatusResponse { - pub status: String, + pub status: PipelineStatus, #[serde(skip_serializing_if = "Option::is_none")] pub provider: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs index a199878384..7bfb70043b 100644 --- a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs +++ b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs @@ -9,7 +9,7 @@ use owhisper_interface::ListenParams; use crate::config::SttProxyConfig; use crate::provider_selector::SelectedProvider; -use crate::query_params::{QueryParams, QueryValue}; +use crate::query_params::QueryParams; use crate::relay::WebSocketProxy; use crate::routes::AppState; @@ -18,26 +18,12 @@ use super::common::{ProxyBuildError, build_proxy_with_url, finalize_proxy_builde use super::session::init_session; fn build_listen_params(params: &QueryParams) -> ListenParams { - let model = params.get_first("model").map(|s| s.to_string()); - let languages = params.get_languages(); - let sample_rate: u32 = parse_param(params, "sample_rate", 16000); - let channels: u8 = parse_param(params, "channels", 1); - - let keywords: Vec = params - .get("keyword") - .or_else(|| params.get("keywords")) - .map(|v| match v { - QueryValue::Single(s) => s.split(',').map(|k| k.trim().to_string()).collect(), - QueryValue::Multi(vec) => vec.iter().map(|k| k.trim().to_string()).collect(), - }) - .unwrap_or_default(); - ListenParams { - model, - languages, - sample_rate, - channels, - keywords, + model: params.get_first("model").map(|s| s.to_string()), + languages: params.get_languages(), + sample_rate: parse_param(params, "sample_rate", 16000), + channels: parse_param(params, "channels", 1), + keywords: params.parse_keywords(), ..Default::default() } } diff --git a/crates/transcribe-proxy/src/supabase.rs b/crates/transcribe-proxy/src/supabase.rs index 433899de51..d848ff0afb 100644 --- a/crates/transcribe-proxy/src/supabase.rs +++ b/crates/transcribe-proxy/src/supabase.rs @@ -2,13 +2,21 @@ use serde::{Deserialize, Serialize}; type BoxError = Box; +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum PipelineStatus { + Processing, + Done, + Error, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TranscriptionJob { pub id: String, pub user_id: String, pub file_id: String, pub provider: String, - pub status: String, + pub status: PipelineStatus, #[serde(skip_serializing_if = "Option::is_none")] pub provider_request_id: Option, #[serde(skip_serializing_if = "Option::is_none")]