Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/cactus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2024"
license = "MIT"

[dependencies]
cactus-sys = { git = "https://github.com/cactus-compute/cactus", package = "cactus-sys", rev = "f8b714cc9782eafd4895d15fd66e9ffe141810bc" }
cactus-sys = { git = "https://github.com/cactus-compute/cactus", package = "cactus-sys", rev = "a5acad3" }

hypr-language = { workspace = true }
hypr-llm-types = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/cactus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use hypr_language::Language;
pub use llm::{CompleteOptions, CompletionResult, CompletionStream, Message, complete_stream};
pub use model::{Model, ModelBuilder, ModelKind};
pub use stt::{
CloudConfig, StreamResult, TranscribeEvent, TranscribeOptions, Transcriber,
CloudConfig, StreamResult, StreamSegment, TranscribeEvent, TranscribeOptions, Transcriber,
TranscriptionResult, TranscriptionSession, constrain_to, transcribe_stream,
};
pub use vad::{VadOptions, VadResult, VadSegment};
Expand Down
4 changes: 1 addition & 3 deletions crates/cactus/src/stt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod whisper;

pub use result::TranscriptionResult;
pub use stream::{TranscribeEvent, TranscriptionSession, transcribe_stream};
pub use transcriber::{CloudConfig, StreamResult, Transcriber};
pub use transcriber::{CloudConfig, StreamResult, StreamSegment, Transcriber};

use hypr_language::Language;

Expand Down Expand Up @@ -36,8 +36,6 @@ pub struct TranscribeOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub min_chunk_size: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub confirmation_threshold: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_vocabulary: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vocabulary_boost: Option<f32>,
Expand Down
12 changes: 12 additions & 0 deletions crates/cactus/src/stt/transcriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,25 @@ pub struct Transcriber<'a> {
// SAFETY: FFI calls are serialized through Model's inference_lock.
unsafe impl Send for Transcriber<'_> {}

#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct StreamSegment {
#[serde(default)]
pub start: f32,
#[serde(default)]
pub end: f32,
#[serde(default)]
pub text: String,
}

#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct StreamResult {
#[serde(default)]
pub confirmed: String,
#[serde(default)]
pub pending: String,
#[serde(default)]
pub segments: Vec<StreamSegment>,
#[serde(default)]
pub language: Option<String>,
#[serde(default)]
pub cloud_handoff: bool,
Expand Down
85 changes: 84 additions & 1 deletion crates/cactus/tests/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fn test_stream_transcriber() {
let chunk_size = 32000; // 1 second at 16kHz 16-bit mono
let mut had_confirmed = false;

for chunk in pcm.chunks(chunk_size).take(10) {
for chunk in pcm.chunks(chunk_size).take(30) {
let r = transcriber.process(chunk).unwrap();
if !r.confirmed.is_empty() {
had_confirmed = true;
Expand All @@ -118,6 +118,89 @@ fn test_stream_transcriber() {
assert!(had_confirmed, "expected at least one confirmed segment");
}

// cargo test -p cactus --test stt test_stream_transcriber_segments -- --ignored --nocapture
#[ignore]
#[test]
fn test_stream_transcriber_segments() {
let model = stt_model();
let pcm = hypr_data::english_1::AUDIO;
let options = en_options();

let mut transcriber = Transcriber::new(&model, &options, CloudConfig::default()).unwrap();

let mut first_segmented = None;
let mut first_confirmed_with_segments = None;

for chunk in pcm.chunks(32000).take(30) {
let r = transcriber.process(chunk).unwrap();

if first_segmented.is_none() && !r.segments.is_empty() {
first_segmented = Some(r.clone());
}

if first_confirmed_with_segments.is_none()
&& !r.confirmed.trim().is_empty()
&& !r.segments.is_empty()
{
first_confirmed_with_segments = Some(r.clone());
break;
}
}

let segmented = first_segmented.expect("expected at least one streaming result with segments");
for segment in &segmented.segments {
assert!(
segment.end >= segment.start,
"segment end ({}) should be >= start ({})",
segment.end,
segment.start
);
assert!(!segment.text.is_empty(), "segment text should not be empty");
}

let confirmed =
first_confirmed_with_segments.expect("expected a confirmed streaming result with segments");
assert!(
!confirmed.confirmed.trim().is_empty(),
"confirmed text should not be empty"
);
assert!(
confirmed.buffer_duration_ms > 0.0,
"buffer_duration_ms should be positive, got {}",
confirmed.buffer_duration_ms
);

let first_segment = confirmed.segments.first().unwrap();
assert!(
first_segment.start >= 0.0,
"first segment start should be non-negative, got {}",
first_segment.start
);

let last_segment = confirmed.segments.last().unwrap();
let last_segment_end = last_segment.end as f64;
assert!(
last_segment_end > 0.0,
"expected positive segment end timestamp, got {last_segment_end}"
);
assert!(
(last_segment_end - confirmed.buffer_duration_ms / 1000.0).abs() < 0.25,
"last segment end ({last_segment_end}) should roughly match buffer duration ({}s)",
confirmed.buffer_duration_ms / 1000.0
);

for pair in confirmed.segments.windows(2) {
assert!(
pair[1].start >= pair[0].start,
"segments should be ordered: {} came after {}",
pair[0].start,
pair[1].start
);
}

let _ = transcriber.stop().unwrap();
}

// cargo test -p cactus --test stt test_stream_transcriber_drop -- --ignored --nocapture
#[ignore]
#[test]
Expand Down
22 changes: 19 additions & 3 deletions crates/transcribe-cactus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "transcribe-cactus"
version = "0.1.0"
edition = "2024"

[features]
live-example = []

[dependencies]
hypr-audio-utils = { workspace = true }
hypr-cactus = { workspace = true }
Expand All @@ -13,6 +16,7 @@ owhisper-interface = { workspace = true }
rodio = { workspace = true }

bytes = { workspace = true }
colored = "3"
serde_html_form = { workspace = true }
serde_json = { workspace = true }
tempfile = { workspace = true }
Expand All @@ -26,15 +30,27 @@ tower = { workspace = true }
tracing = { workspace = true }

[dev-dependencies]
dirs = { workspace = true }
hypr-audio = { workspace = true }
hypr-audio-utils = { workspace = true }
hypr-cactus = { workspace = true }
hypr-data = { workspace = true }

owhisper-client = { workspace = true }
owhisper-interface = { workspace = true }

axum = { workspace = true, features = ["ws"] }
futures-util = { workspace = true }
reqwest = { workspace = true, features = ["json"] }
sequential-test = "0.2"
serde_json = { workspace = true }
tokio = { workspace = true }
tokio-tungstenite = { workspace = true }

clap = { workspace = true, features = ["derive"] }
colored = "3"
dirs = { workspace = true }

sequential-test = "0.2"
serde_json = { workspace = true }

[[example]]
name = "live"
required-features = ["live-example"]
Loading
Loading