Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions apps/whisper-api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ export RUST_LOG=debug

# Force CPU usage
export CANDLE_FORCE_CPU=1

# Disable VAD (Voice Activity Detection) - process entire audio directly
export DISABLE_VAD=true

# Set VAD threshold (0.0-1.0, default: 0.15)
export VAD_THRESHOLD=0.2
```

## Acknowledgements
Expand Down
17 changes: 14 additions & 3 deletions apps/whisper-api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod whisper;
struct AppState {
vad: Arc<Mutex<VADProcessor>>,
device: Device,
vad_enabled: bool,
// Use RwLock for read-heavy workload (checking cache)
whisper_models: Arc<RwLock<HashMap<String, Arc<Mutex<WhisperProcessor>>>>>,
}
Expand All @@ -49,20 +50,30 @@ impl AppState {

println!("🚀 Using device: {device:?}");

// Get VAD threshold from environment or use default
// Check if VAD is enabled
let vad_enabled = std::env::var("DISABLE_VAD")
.map(|s| s.to_lowercase() != "true" && s != "1")
.unwrap_or(true);
Comment on lines +54 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine if VAD is enabled is a bit hard to follow due to the negative condition (!=) combined with the DISABLE_VAD environment variable. It can be made more readable by checking for the "disabled" case first and then negating the result.

Suggested change
let vad_enabled = std::env::var("DISABLE_VAD")
.map(|s| s.to_lowercase() != "true" && s != "1")
.unwrap_or(true);
let vad_enabled = !std::env::var("DISABLE_VAD")
.map(|s| matches!(s.to_lowercase().as_str(), "true" | "1"))
.unwrap_or(false);


println!("🎯 VAD enabled: {vad_enabled}");

// Get VAD threshold from environment or use default (lowered for better detection)
let vad_threshold = std::env::var("VAD_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0.3);
.unwrap_or(0.15);

println!("🎯 VAD threshold: {vad_threshold}");
if vad_enabled {
println!("🎯 VAD threshold: {vad_threshold}");
}

// Initialize VAD processor (always use CPU for VAD)
let vad = VADProcessor::new(candle_core::Device::Cpu, vad_threshold)?;

Ok(Self {
vad: Arc::new(Mutex::new(vad)),
device,
vad_enabled,
whisper_models: Arc::new(RwLock::new(HashMap::new())),
})
}
Expand Down
183 changes: 127 additions & 56 deletions apps/whisper-api/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
collections::HashMap,
pin::Pin,
sync::Arc,
time::{Duration, Instant},
};
Expand All @@ -15,7 +16,7 @@ use axum::{
sse::{Event, KeepAlive, Sse},
},
};
use futures::stream::{self, Stream};
use futures::stream::{self, Stream, StreamExt};
use symphonia::{
core::{
audio::{AudioBufferRef, Signal},
Expand Down Expand Up @@ -266,16 +267,31 @@ async fn transcribe_audio_complete(
let whisper_processor = state.get_whisper_processor(&model_name).await?;
processing_stats.model_loading_duration = model_loading_start.elapsed();

// Check if VAD is enabled
if !state.vad_enabled {
println!("🔇 VAD disabled, processing entire audio directly");
let whisper_start = Instant::now();
let mut whisper = whisper_processor.lock().await;
let transcript = whisper.transcribe(&audio_data)?;
processing_stats.whisper_transcription_duration = whisper_start.elapsed();
processing_stats.vad_processing_duration = Duration::ZERO;

return Ok(transcript);
}

// Process audio through VAD and Whisper
let mut vad = state.vad.lock().await;
let mut whisper = whisper_processor.lock().await;
let mut audio_buffer = AudioBuffer::new(10000, 100, 500, sample_rate);
// Use more lenient parameters: max_duration=10s, min_speech=50ms, min_silence=300ms
let mut audio_buffer = AudioBuffer::new(10000, 50, 300, sample_rate);

let mut transcripts = Vec::new();
let mut frame_buffer = Vec::<f32>::new();

let vad_start = Instant::now();
let mut whisper_total_time = Duration::ZERO;
let mut speech_frame_count = 0;
let mut total_frame_count = 0;

// Process in chunks
for chunk in audio_data.chunks(1024) {
Expand All @@ -287,6 +303,11 @@ async fn transcribe_audio_complete(
let speech_prob = vad.process_chunk(&frame)?;
let is_speech = vad.is_speech(speech_prob);

total_frame_count += 1;
if is_speech {
speech_frame_count += 1;
}

if let Some(complete_audio) = audio_buffer.add_chunk(&frame, is_speech) {
// Measure Whisper transcription time
let whisper_start = Instant::now();
Expand All @@ -300,9 +321,24 @@ async fn transcribe_audio_complete(
}
}

println!("🎯 VAD Stats: {speech_frame_count}/{total_frame_count} frames detected as speech ({:.1}%)", speech_frame_count as f32 / total_frame_count as f32 * 100.0);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If audio_data is empty, total_frame_count will be 0, leading to a division by zero, which results in NaN for floating-point numbers. While this doesn't crash the program, it will print "NaN%", which is not ideal. It would be better to guard this print statement to avoid this.

Suggested change
println!("🎯 VAD Stats: {speech_frame_count}/{total_frame_count} frames detected as speech ({:.1}%)", speech_frame_count as f32 / total_frame_count as f32 * 100.0);
if total_frame_count > 0 {
println!("🎯 VAD Stats: {speech_frame_count}/{total_frame_count} frames detected as speech ({:.1}%)", speech_frame_count as f32 / total_frame_count as f32 * 100.0);
}


processing_stats.vad_processing_duration = vad_start.elapsed() - whisper_total_time;
processing_stats.whisper_transcription_duration = whisper_total_time;

// Fallback: If no segments were detected by VAD, process the entire audio
if transcripts.is_empty() && !audio_data.is_empty() {
println!("⚠️ No VAD segments detected, processing entire audio as fallback");
let whisper_start = Instant::now();
let transcript = whisper.transcribe(&audio_data)?;
let whisper_fallback_time = whisper_start.elapsed();
processing_stats.whisper_transcription_duration += whisper_fallback_time;

if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") {
transcripts.push(transcript.trim().to_string());
}
}
Comment on lines +329 to +340

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The streaming function create_transcription_stream is missing a similar fallback. If no speech segments are detected during streaming, the stream will just complete without sending any transcription, which can be confusing for the user. This is an inconsistency between the two transcription modes.

Please consider adding a similar fallback mechanism to the streaming implementation. For example, you could check at the end of the stream if any segments were transcribed, and if not, transcribe the whole audio and send it as a final event.


Ok(transcripts.join(" "))
}

Expand All @@ -312,7 +348,7 @@ async fn create_transcription_stream(
model_name: String, // Change to owned String
audio_data: Vec<f32>,
mut processing_stats: ProcessingStats,
) -> Result<impl Stream<Item = Result<Event, anyhow::Error>>, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Pin<Box<dyn Stream<Item = Result<Event, anyhow::Error>> + Send>>, (StatusCode, Json<ErrorResponse>)> {
let stream_start = Instant::now();

// Get the appropriate Whisper processor for this model with timing
Expand All @@ -337,68 +373,103 @@ async fn create_transcription_stream(

let sample_rate = 16000;

Ok(stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 100, 500, sample_rate), processing_stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move {
if processed >= audio_data.len() {
// Print final statistics for streaming
stats.total_duration = stream_start.elapsed();
stats.print_summary();
return None;
}
// If VAD is disabled, process entire audio directly and return as single stream event
if !state.vad_enabled {
println!("🔇 VAD disabled for streaming, processing entire audio directly");
let whisper_start = Instant::now();
let mut whisper = whisper_processor.lock().await;
let transcript = match whisper.transcribe(&audio_data) {
Ok(text) => text,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: ErrorDetail {
message: format!("Transcription failed: {e}"),
error_type: "server_error".to_string(),
param: None,
code: None,
},
}),
));
},
};
processing_stats.whisper_transcription_duration = whisper_start.elapsed();
processing_stats.vad_processing_duration = Duration::ZERO;
processing_stats.total_duration = stream_start.elapsed();
processing_stats.print_summary();

// Process audio in chunks suitable for VAD (512 samples at a time)
let chunk_size = 512.min(audio_data.len() - processed);
let chunk = &audio_data[processed..processed + chunk_size];
processed += chunk_size;
let event_data = StreamChunk { text: transcript, timestamp: Some(audio_data.len() as f64 / f64::from(sample_rate)) };
let event = Event::default().json_data(event_data).unwrap();

// Process through VAD and Whisper processors
let mut whisper_result = None;
return Ok(stream::once(async move { Ok(event) }).boxed());
}

// Process through VAD
let vad_chunk_start = Instant::now();
let mut vad = state.vad.lock().await;
if let Ok(speech_prob) = vad.process_chunk(chunk) {
let is_speech = vad.is_speech(speech_prob);
Ok(
stream::unfold((state, whisper_processor, audio_data, 0, AudioBuffer::new(10000, 50, 300, sample_rate), processing_stats, stream_start), move |(state, whisper_processor, audio_data, mut processed, mut audio_buffer, mut stats, stream_start)| async move {
if processed >= audio_data.len() {
// Print final statistics for streaming
stats.total_duration = stream_start.elapsed();
stats.print_summary();
return None;
}

// Add to audio buffer and check if we have complete audio
if let Some(complete_audio) = audio_buffer.add_chunk(chunk, is_speech) {
// Release VAD lock before acquiring Whisper lock
drop(vad);
let vad_chunk_time = vad_chunk_start.elapsed();
stats.vad_processing_duration += vad_chunk_time;

// Process complete audio through Whisper
let whisper_chunk_start = Instant::now();
let mut whisper = whisper_processor.lock().await;
if let Ok(transcript) = whisper.transcribe(&complete_audio) {
let whisper_chunk_time = whisper_chunk_start.elapsed();
stats.whisper_transcription_duration += whisper_chunk_time;

if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") {
whisper_result = Some(transcript.trim().to_string());
println!("🎯 Chunk transcribed in {:.2}ms: \"{}\"", whisper_chunk_time.as_secs_f64() * 1000.0, transcript.trim());
// Process audio in chunks suitable for VAD (512 samples at a time)
let chunk_size = 512.min(audio_data.len() - processed);
let chunk = &audio_data[processed..processed + chunk_size];
processed += chunk_size;

// Process through VAD and Whisper processors
let mut whisper_result = None;

// Process through VAD
let vad_chunk_start = Instant::now();
let mut vad = state.vad.lock().await;
if let Ok(speech_prob) = vad.process_chunk(chunk) {
let is_speech = vad.is_speech(speech_prob);

// Add to audio buffer and check if we have complete audio
if let Some(complete_audio) = audio_buffer.add_chunk(chunk, is_speech) {
// Release VAD lock before acquiring Whisper lock
drop(vad);
let vad_chunk_time = vad_chunk_start.elapsed();
stats.vad_processing_duration += vad_chunk_time;

// Process complete audio through Whisper
let whisper_chunk_start = Instant::now();
let mut whisper = whisper_processor.lock().await;
if let Ok(transcript) = whisper.transcribe(&complete_audio) {
let whisper_chunk_time = whisper_chunk_start.elapsed();
stats.whisper_transcription_duration += whisper_chunk_time;

if !transcript.trim().is_empty() && !transcript.contains("[BLANK_AUDIO]") {
whisper_result = Some(transcript.trim().to_string());
println!("🎯 Chunk transcribed in {:.2}ms: \"{}\"", whisper_chunk_time.as_secs_f64() * 1000.0, transcript.trim());
}
}
}
} else {
stats.vad_processing_duration += vad_chunk_start.elapsed();
}
} else {
stats.vad_processing_duration += vad_chunk_start.elapsed();
}

// Create event with actual transcription or progress update
#[allow(clippy::option_if_let_else)]
let event_data = if let Some(transcript) = whisper_result {
#[allow(clippy::cast_precision_loss)]
StreamChunk { text: transcript, timestamp: Some(processed as f64 / f64::from(sample_rate)) }
} else {
StreamChunk {
#[allow(clippy::cast_precision_loss)]
text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0),
// Create event with actual transcription or progress update
#[allow(clippy::option_if_let_else)]
let event_data = if let Some(transcript) = whisper_result {
#[allow(clippy::cast_precision_loss)]
timestamp: Some(processed as f64 / f64::from(sample_rate)),
}
};
StreamChunk { text: transcript, timestamp: Some(processed as f64 / f64::from(sample_rate)) }
} else {
StreamChunk {
#[allow(clippy::cast_precision_loss)]
text: format!("Processing... ({:.1}%)", (processed as f64 / audio_data.len() as f64) * 100.0),
#[allow(clippy::cast_precision_loss)]
timestamp: Some(processed as f64 / f64::from(sample_rate)),
}
};

let event = Event::default().json_data(event_data).unwrap();
let event = Event::default().json_data(event_data).unwrap();

Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer, stats, stream_start)))
}))
Some((Ok(event), (state.clone(), whisper_processor.clone(), audio_data, processed, audio_buffer, stats, stream_start)))
})
.boxed(),
)
}
Loading