|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Extract voice data for on-device CosyVoice 3 using ONNX Runtime (CPU, no GPU). |
| 4 | +
|
| 5 | +Extracts from a reference audio clip: |
| 6 | + - speaker_embedding.bin (192 x float32) |
| 7 | + - prompt_tokens.bin (N x int64) |
| 8 | + - prompt_mel.bin (int32 header + frames*80 x float32) |
| 9 | + - prompt_text.txt (UTF-8 transcript) |
| 10 | +
|
| 11 | +Usage: |
| 12 | + python extract_voice_onnx.py \ |
| 13 | + --audio prompts/en_female_nova_greeting.wav \ |
| 14 | + --text "Hello, my name is Sarah." \ |
| 15 | + --voice-id test_female \ |
| 16 | + --models-dir models \ |
| 17 | + --output-dir ../App/Resources/voices/ |
| 18 | +""" |
| 19 | + |
| 20 | +import argparse |
| 21 | +import struct |
| 22 | +import sys |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | +import numpy as np |
| 26 | +import librosa |
| 27 | +import onnxruntime as ort |
| 28 | + |
| 29 | + |
| 30 | +def extract_speaker_embedding(campplus_session: ort.InferenceSession, audio_path: str) -> np.ndarray: |
| 31 | + """Extract 192-dim speaker embedding using CAMPPlus ONNX model.""" |
| 32 | + # Load audio at 16kHz |
| 33 | + audio, _ = librosa.load(audio_path, sr=16000) |
| 34 | + audio = audio.astype(np.float32) |
| 35 | + |
| 36 | + # Kaldi-style fbank features (80 mels) |
| 37 | + mel = librosa.feature.melspectrogram( |
| 38 | + y=audio, sr=16000, n_fft=400, hop_length=160, |
| 39 | + n_mels=80, fmin=20, fmax=7600 |
| 40 | + ) |
| 41 | + log_mel = np.log(np.maximum(mel, 1e-10)) |
| 42 | + log_mel = log_mel.T # [frames, 80] |
| 43 | + |
| 44 | + # Mean normalization (Kaldi-style) |
| 45 | + log_mel = log_mel - log_mel.mean(axis=0, keepdims=True) |
| 46 | + |
| 47 | + # [1, frames, 80] |
| 48 | + feat = log_mel[np.newaxis, :, :].astype(np.float32) |
| 49 | + |
| 50 | + input_name = campplus_session.get_inputs()[0].name |
| 51 | + embedding = campplus_session.run(None, {input_name: feat})[0] |
| 52 | + |
| 53 | + # Flatten to [192] |
| 54 | + return embedding.flatten().astype(np.float32) |
| 55 | + |
| 56 | + |
| 57 | +def extract_speech_tokens(tokenizer_session: ort.InferenceSession, audio_path: str) -> np.ndarray: |
| 58 | + """Extract speech tokens using Speech Tokenizer v3 ONNX model.""" |
| 59 | + # Load audio at 16kHz |
| 60 | + audio, _ = librosa.load(audio_path, sr=16000) |
| 61 | + audio = audio.astype(np.float32) |
| 62 | + |
| 63 | + # Whisper-style mel (128 mels) |
| 64 | + mel = librosa.feature.melspectrogram( |
| 65 | + y=audio, sr=16000, n_fft=400, hop_length=160, |
| 66 | + n_mels=128, fmin=0, fmax=8000 |
| 67 | + ) |
| 68 | + log_mel = np.log10(np.maximum(mel, 1e-10)) |
| 69 | + log_mel = np.maximum(log_mel, log_mel.max() - 8.0) |
| 70 | + log_mel = (log_mel + 4.0) / 4.0 |
| 71 | + |
| 72 | + # [1, 128, frames] |
| 73 | + feat = log_mel[np.newaxis, :, :].astype(np.float32) |
| 74 | + feat_len = np.array([feat.shape[2]], dtype=np.int32) |
| 75 | + |
| 76 | + input_names = [inp.name for inp in tokenizer_session.get_inputs()] |
| 77 | + tokens = tokenizer_session.run(None, { |
| 78 | + input_names[0]: feat, |
| 79 | + input_names[1]: feat_len, |
| 80 | + })[0] |
| 81 | + |
| 82 | + # Flatten to [N] and convert to int64 |
| 83 | + return tokens.flatten().astype(np.int64) |
| 84 | + |
| 85 | + |
| 86 | +def extract_prompt_mel(audio_path: str) -> tuple[np.ndarray, int]: |
| 87 | + """Extract mel spectrogram for flow conditioning. Returns (mel_floats, frame_count).""" |
| 88 | + # Load at 24kHz (CosyVoice native rate) |
| 89 | + audio, _ = librosa.load(audio_path, sr=24000) |
| 90 | + audio = audio.astype(np.float32) |
| 91 | + |
| 92 | + # CosyVoice flow mel params |
| 93 | + mel = librosa.feature.melspectrogram( |
| 94 | + y=audio, sr=24000, n_fft=1024, hop_length=256, |
| 95 | + n_mels=80, fmin=0, fmax=12000 |
| 96 | + ) |
| 97 | + log_mel = np.log(np.maximum(mel, 1e-10)) |
| 98 | + |
| 99 | + # [frames, 80] row-major |
| 100 | + mel_feat = log_mel.T.astype(np.float32) |
| 101 | + n_frames = mel_feat.shape[0] |
| 102 | + |
| 103 | + return mel_feat.flatten(), n_frames |
| 104 | + |
| 105 | + |
| 106 | +def save_voice_data( |
| 107 | + output_dir: Path, |
| 108 | + speaker_embedding: np.ndarray, |
| 109 | + prompt_tokens: np.ndarray, |
| 110 | + prompt_mel_flat: np.ndarray, |
| 111 | + prompt_mel_frames: int, |
| 112 | + prompt_text: str, |
| 113 | +): |
| 114 | + """Save extracted voice data in the binary format expected by OnDeviceTTSEngine.""" |
| 115 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 116 | + |
| 117 | + # speaker_embedding.bin: raw float32, 192 elements |
| 118 | + emb_path = output_dir / "speaker_embedding.bin" |
| 119 | + speaker_embedding.tofile(str(emb_path)) |
| 120 | + print(f" speaker_embedding.bin: shape=({speaker_embedding.shape[0]},) size={emb_path.stat().st_size} bytes") |
| 121 | + |
| 122 | + # prompt_tokens.bin: raw int64 |
| 123 | + tok_path = output_dir / "prompt_tokens.bin" |
| 124 | + prompt_tokens.tofile(str(tok_path)) |
| 125 | + print(f" prompt_tokens.bin: shape=({prompt_tokens.shape[0]},) size={tok_path.stat().st_size} bytes") |
| 126 | + |
| 127 | + # prompt_mel.bin: int32 frame count header + float32 data |
| 128 | + mel_path = output_dir / "prompt_mel.bin" |
| 129 | + with open(mel_path, "wb") as f: |
| 130 | + f.write(struct.pack("<i", prompt_mel_frames)) |
| 131 | + prompt_mel_flat.tofile(f) |
| 132 | + print(f" prompt_mel.bin: frames={prompt_mel_frames} size={mel_path.stat().st_size} bytes") |
| 133 | + |
| 134 | + # prompt_text.txt: UTF-8 |
| 135 | + text_path = output_dir / "prompt_text.txt" |
| 136 | + text_path.write_text(prompt_text, encoding="utf-8") |
| 137 | + print(f" prompt_text.txt: {len(prompt_text)} chars") |
| 138 | + |
| 139 | + |
| 140 | +def main(): |
| 141 | + parser = argparse.ArgumentParser(description="Extract voice data using ONNX Runtime (CPU)") |
| 142 | + parser.add_argument("--audio", required=True, help="Path to prompt audio WAV file") |
| 143 | + parser.add_argument("--text", required=True, help="Transcript of the prompt audio") |
| 144 | + parser.add_argument("--voice-id", required=True, help="Voice ID (e.g., longjiaxin_v3)") |
| 145 | + parser.add_argument("--models-dir", default="models", help="Directory with campplus.onnx and speech_tokenizer_v3.onnx") |
| 146 | + parser.add_argument("--output-dir", default="../App/Resources/voices", help="Output base directory") |
| 147 | + args = parser.parse_args() |
| 148 | + |
| 149 | + models_dir = Path(args.models_dir) |
| 150 | + campplus_path = models_dir / "campplus.onnx" |
| 151 | + tokenizer_path = models_dir / "speech_tokenizer_v3.onnx" |
| 152 | + |
| 153 | + if not campplus_path.exists(): |
| 154 | + print(f"ERROR: {campplus_path} not found. Download from HuggingFace first.") |
| 155 | + sys.exit(1) |
| 156 | + if not tokenizer_path.exists(): |
| 157 | + print(f"ERROR: {tokenizer_path} not found. Download from HuggingFace first.") |
| 158 | + sys.exit(1) |
| 159 | + |
| 160 | + print(f"Loading ONNX models...") |
| 161 | + campplus = ort.InferenceSession(str(campplus_path), providers=["CPUExecutionProvider"]) |
| 162 | + tokenizer = ort.InferenceSession(str(tokenizer_path), providers=["CPUExecutionProvider"]) |
| 163 | + |
| 164 | + print(f"\nExtracting voice data for '{args.voice_id}' from: {args.audio}") |
| 165 | + |
| 166 | + print("\n1. Extracting speaker embedding (CAMPPlus)...") |
| 167 | + embedding = extract_speaker_embedding(campplus, args.audio) |
| 168 | + print(f" Shape: ({embedding.shape[0]},) | Min: {embedding.min():.4f} Max: {embedding.max():.4f}") |
| 169 | + |
| 170 | + print("\n2. Extracting speech tokens (Speech Tokenizer v3)...") |
| 171 | + tokens = extract_speech_tokens(tokenizer, args.audio) |
| 172 | + print(f" Shape: ({tokens.shape[0]},) | Token range: [{tokens.min()}, {tokens.max()}]") |
| 173 | + |
| 174 | + print("\n3. Extracting prompt mel spectrogram...") |
| 175 | + mel_flat, mel_frames = extract_prompt_mel(args.audio) |
| 176 | + print(f" Frames: {mel_frames} | Total floats: {mel_flat.shape[0]}") |
| 177 | + |
| 178 | + output_dir = Path(args.output_dir) / args.voice_id |
| 179 | + print(f"\nSaving to: {output_dir}") |
| 180 | + save_voice_data(output_dir, embedding, tokens, mel_flat, mel_frames, args.text) |
| 181 | + |
| 182 | + print(f"\nDone! Voice data for '{args.voice_id}' saved.") |
| 183 | + |
| 184 | + |
| 185 | +if __name__ == "__main__": |
| 186 | + main() |
0 commit comments