Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 357 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version = "0.0.1"
authors = ["Project AIRI"]
edition = "2024"
rust-version = "1.80"
rust-version = "1.85"
readme = "README.md"
homepage = "https://github.com/proj-airi/candle-examples"
repository = "https://github.com/proj-airi/candle-examples"
Expand All @@ -17,6 +17,7 @@ members = [
"apps/silero-vad-realtime",
"apps/silero-vad-realtime-minimum",
"apps/silero-vad-whisper-realtime",
"apps/silero-vad-whisper-realtime-api",
"apps/whisper-realtime",
]

Expand Down
63 changes: 32 additions & 31 deletions apps/orpheus-tts/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
const STOP_TOKEN_ID: u32 = 128258;
const STOP_TOKEN_ID: u32 = 128_258;

#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Voice {
Expand All @@ -39,21 +39,22 @@ enum Voice {
}

impl Voice {
fn as_str(&self) -> &'static str {
const fn as_str(self) -> &'static str {
match self {
Voice::Tara => "tara",
Voice::Leah => "leah",
Voice::Jess => "jess",
Voice::Leo => "leo",
Voice::Dan => "dan",
Voice::Mia => "mia",
Voice::Zac => "zac",
Voice::Zoe => "zoe",
Self::Tara => "tara",
Self::Leah => "leah",
Self::Jess => "jess",
Self::Leo => "leo",
Self::Dan => "dan",
Self::Mia => "mia",
Self::Zac => "zac",
Self::Zoe => "zoe",
}
}
}

#[derive(Parser)]
#[allow(clippy::struct_excessive_bools)]
struct Args {
#[arg(long)]
cpu: bool,
Expand Down Expand Up @@ -154,12 +155,14 @@ pub trait Sample {
}

impl Sample for f32 {
#[allow(clippy::cast_possible_truncation)]
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}

impl Sample for f64 {
#[allow(clippy::cast_possible_truncation)]
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
Expand All @@ -171,16 +174,17 @@ impl Sample for i16 {
}
}

#[allow(clippy::missing_panics_doc)]
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let len = len + u32::try_from(samples.len()).unwrap() * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
let bytes_per_second = sample_rate * 2 * u32::from(n_channels);
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
Expand All @@ -197,21 +201,21 @@ pub fn write_pcm_as_wav<W: Write, S: Sample>(

// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
w.write_all(&(u32::try_from(samples.len()).unwrap() * 2).to_le_bytes())?;
for sample in samples {
w.write_all(&sample.to_i16().to_le_bytes())?;
}
Ok(())
}

struct Model {
model: LlamaModel,
llama: LlamaModel,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
cache: Cache,
device: Device,
verbose_prompt: bool,
snac_model: SnacModel,
snac: SnacModel,
out_file: String,
voice: Voice,
}
Expand All @@ -226,15 +230,14 @@ impl Model {
.build()?;

let model_id = match args.orpheus_model_id {
Some(model_id) => model_id.to_string(),
Some(model_id) => model_id,
None => match args.which_orpheus_model {
WhichOrpheusModel::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let revision = args
.revision
.map_or_else(|| "main".to_string(), |r| r);
let repo = api.repo(hf_hub::Repo::with_revision(model_id, hf_hub::RepoType::Model, revision));
let model_file = match args.orpheus_model_file {
Some(m) => vec![m.into()],
Expand Down Expand Up @@ -292,16 +295,16 @@ impl Model {
println!("load the model in {:?}", start_time.elapsed());

let cache = Cache::new(true, dtype, &config, &device)?;
let snac_model = load_snac_model(hf_token.clone().as_str(), &device)?;
let snac_model = load_snac_model(hf_token.as_str(), &device)?;

Ok(Self {
model,
llama: model,
tokenizer,
logits_processor,
cache,
device,
verbose_prompt: args.verbose_prompt,
snac_model,
snac: snac_model,
out_file: args.out_file,
voice: args.voice,
})
Expand All @@ -319,7 +322,7 @@ impl Model {
.map_err(E::msg)?;

// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
let mut tokens = [&[128259], tokens.get_ids(), &[128009, 128260, 128261, 128257]].concat();
let mut tokens = [&[128_259], tokens.get_ids(), &[128_009, 128_260, 128_261, 128_257]].concat();
if self.verbose_prompt {
println!("prompt tokens: {tokens:?}");
}
Expand All @@ -342,7 +345,7 @@ impl Model {
let context = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(context, device)?.unsqueeze(0)?;
let logits = self
.model
.llama
.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += context.len();
Expand All @@ -354,7 +357,7 @@ impl Model {
Some(tok) => {
let tok = tok.parse::<u32>()?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
let tok = tok - 10 - ((u32::try_from(audio_tokens.len()).unwrap() % 7) * 4096);
audio_tokens.push(tok);
},
None => {
Expand Down Expand Up @@ -393,9 +396,7 @@ impl Model {
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
let pcm = self
.snac_model
.decode(&[&codes0, &codes1, &codes2])?;
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;

println!("decoded to pcm {pcm:?}");
let mut output = std::fs::File::create(&self.out_file)?;
Expand Down
39 changes: 39 additions & 0 deletions apps/silero-vad-whisper-realtime-api/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[package]
name = "silero-vad-whisper-realtime-api"
version = "0.1.0"
edition = "2024"

[dependencies]
anyhow = "1.0.98"
byteorder = "1.5.0"
candle-core = { version = "0.9.1" }
candle-nn = { version = "0.9.1" }
candle-transformers = { version = "0.9.1" }
candle-onnx = { version = "0.9.1" }
clap = { version = "4.5.38", features = ["derive"] }
cpal = "0.15.3"
hf-hub = "0.4.2"
rand = "0.9.1"
rubato = "0.16.2"
serde_json = "1.0.140"
symphonia = "0.5.4"
tokenizers = "0.21.1"
tracing-chrome = "0.7.2"
tracing-subscriber = "0.3.19"
tracing = "0.1.41"
tokio = "1.45.1"
crossbeam-channel = "0.5.15"
axum = { version = "0.8.4", features = ["multipart"] }
serde = { version = "1.0.219", features = ["derive"] }

# Server-sent events and HTTP utilities
futures = "0.3.31"
tokio-stream = "0.1.17"
axum-extra = { version = "0.9.5", features = ["typed-header"] }
tower = "0.5.1"
tower-http = { version = "0.6.2", features = ["fs", "cors"] }

[features]
default = []
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
100 changes: 100 additions & 0 deletions apps/silero-vad-whisper-realtime-api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# silero-vad-whisper-realtime-api

> An OpenAI-compatible speech transcription API service with real-time streaming response (SSE), integrated with Silero VAD and Whisper models.
>
> This service provides a `/v1/audio/transcriptions` endpoint that is fully compatible with OpenAI's API format, supporting both standard JSON responses and streaming Server-Sent Events.

## Getting started

```
git clone https://github.com/proj-airi/candle-examples.git
cd apps/silero-vad-whisper-realtime-api
```

## Build

```
cargo fetch --locked
cargo clean
```

### NVIDIA CUDA

```
cargo build --package silero-vad-whisper-realtime-api --features cuda
```

### macOS Metal

```
cargo build --package silero-vad-whisper-realtime-api --features metal
```

### CPU Only

```
cargo build --package silero-vad-whisper-realtime-api
```

## Run

### Any platforms

```shell
cargo run --package silero-vad-whisper-realtime-api --release
```

The server will start at `http://localhost:3000`.

## Usage

### Basic transcription

```bash
curl -X POST http://localhost:3000/v1/audio/transcriptions \
-F "file=@your_audio.wav" \
-F "model=whisper-1"
```

### Streaming transcription

```bash
curl -X POST "http://localhost:3000/v1/audio/transcriptions?stream=true" \
-F "file=@your_audio.wav" \
-F "model=whisper-1" \
--no-buffer
```

## API Parameters

| Parameter | Type | Required | Description |
|-----------|------|----------|-------------|
| `file` | File | ✅ | Audio file to transcribe |
| `model` | String | ❌ | Model name (default: "whisper-1") |
| `language` | String | ❌ | Audio language |
| `prompt` | String | ❌ | Prompt text |
| `response_format` | String | ❌ | Response format (default: "json") |
| `temperature` | Float | ❌ | Sampling temperature (default: 0.0) |
| `stream` | Boolean | ❌ | Enable streaming response (query parameter) |

## Supported Audio Formats

- WAV, MP3, FLAC, M4A
- Any format supported by Symphonia

## Environment Variables

```bash
# Set log level
export RUST_LOG=debug

# Force CPU usage
export CANDLE_FORCE_CPU=1
```

## Acknowledgements

- [candle](https://github.com/huggingface/candle) - High-performance ML framework
- [axum](https://github.com/tokio-rs/axum) - Modern web framework
- [OpenAI](https://openai.com/) - API design reference
- [Silero VAD](https://github.com/snakers4/silero-vad) - Voice activity detection model
Binary file not shown.
Binary file not shown.
27 changes: 27 additions & 0 deletions apps/silero-vad-whisper-realtime-api/src/api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use serde::Serialize;

#[derive(Debug, Serialize)]
pub struct TranscriptionResponse {
pub text: String,
}

#[derive(Debug, Serialize)]
pub struct StreamChunk {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<f64>,
}

#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}

#[derive(Debug, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub param: Option<String>,
pub code: Option<String>,
}
Loading