Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ pip install -r requirements.txt
from zipvoice.luxvoice import LuxTTS

# load model on GPU
lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda')
# float32 (default)
lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda', threads=2)

# float16 (~2x faster, recommended for GPU)
lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda', dtype='float16')

# load model on CPU
# lux_tts = LuxTTS('YatharthS/LuxTTS', device='cpu', threads=2)
Expand Down Expand Up @@ -111,6 +115,7 @@ if display is not None:
- Please use at minimum a 3 second audio file for voice cloning.
- You can use return_smooth = True if you hear metallic sounds.
- Lower t_shift for less possible pronunciation errors but worse quality and vice versa.
- Use `dtype='float16'` on GPU for ~2x faster inference with no perceptible quality difference.

## Community
- [Lux-TTS-Gradio](https://github.com/NidAll/LuxTTS-Gradio): A gradio app to use LuxTTS.
Expand All @@ -127,15 +132,19 @@ A: LuxTTS uses the same architecture but distilled to 4 steps with an improved s

Q: Can it be even faster?

A: Yes, currently it uses float32. Float16 should be significantly faster(almost 2x).
A: Yes, pass `dtype='float16'` when loading the model. This runs inference in half precision and is approximately 2x faster on GPU with no perceptible quality loss.

Q: Does float16 work on CPU?

A: No, PyTorch does not support float16 on CPU. LuxTTS will automatically fall back to float32 with a printed warning if you try.

## Roadmap

- [x] Release model and code
- [x] Huggingface spaces demo
- [x] Release MPS support (thanks to @builtbybasit)
- [ ] Release LuxTTS v1.5
- [ ] Release code for float16 inference
- [x] Release code for float16 inference

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ jieba
piper_phonemize
pypinyin
setuptools<81
pytest>=7.0
Empty file added tests/__init__.py
Empty file.
Binary file added tests/assets/reference.wav
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
from zipvoice.luxvoice import LuxTTS


@pytest.fixture(scope="session")
def lux32():
return LuxTTS("YatharthS/LuxTTS", device="cuda", dtype="float32")


@pytest.fixture(scope="session")
def lux16():
return LuxTTS("YatharthS/LuxTTS", device="cuda", dtype="float16")


@pytest.fixture(scope="session")
def reference_audio():
return "tests/assets/reference.wav"


@pytest.fixture(scope="session")
def test_text():
return "The quick brown fox jumps over the lazy dog."
66 changes: 66 additions & 0 deletions tests/test_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
import torch
import numpy as np
import soundfile as sf
from zipvoice.luxvoice import LuxTTS

def test_fp16_loads_without_error(lux16):
assert lux16 is not None

def test_fp16_model_dtype(lux16):
if lux16.device == "cpu":
assert lux16.dtype == torch.float32
else:
param = next(lux16.model.parameters())
assert param.dtype == torch.float16, f"Expected float16, got {param.dtype}"

def test_fp32_model_dtype(lux32):
if lux32.device == "cpu":
assert lux32.dtype == torch.float32
else:
param = next(lux32.model.parameters())
assert param.dtype == torch.float32, f"Expected float32, got {param.dtype}"

def test_fp16_output_is_float32(lux16, reference_audio, test_text):
enc = lux16.encode_prompt(reference_audio, rms=0.01)
wav = lux16.generate_speech(test_text, enc, num_steps=4)
assert wav.dtype == torch.float32

def test_fp16_no_nan_in_output(lux16, reference_audio, test_text):
enc = lux16.encode_prompt(reference_audio, rms=0.01)
wav = lux16.generate_speech(test_text, enc, num_steps=4)
assert not wav.isnan().any(), "NaN detected in fp16 output"

def test_fp16_no_inf_in_output(lux16, reference_audio, test_text):
enc = lux16.encode_prompt(reference_audio, rms=0.01)
wav = lux16.generate_speech(test_text, enc, num_steps=4)
assert not wav.isinf().any(), "Inf detected in fp16 output"

def test_fp16_output_in_valid_range(lux16, reference_audio, test_text):
enc = lux16.encode_prompt(reference_audio, rms=0.01)
wav = lux16.generate_speech(test_text, enc, num_steps=4).numpy()
assert wav.max() <= 1.0 and wav.min() >= -1.0, "Waveform outside [-1, 1]"

def test_fp16_output_is_not_silent(lux16, reference_audio, test_text):
enc = lux16.encode_prompt(reference_audio, rms=0.01)
wav = lux16.generate_speech(test_text, enc, num_steps=4).numpy()
assert np.abs(wav).mean() > 1e-4, "Output waveform is silent"

def test_default_dtype_is_float32(lux32):
assert lux32.dtype == torch.float32

def test_fp32_no_nan_in_output(lux32, reference_audio, test_text):
enc = lux32.encode_prompt(reference_audio, rms=0.01)
wav = lux32.generate_speech(test_text, enc, num_steps=4)
assert not wav.isnan().any()

def test_fp32_output_is_float32(lux32, reference_audio, test_text):
enc = lux32.encode_prompt(reference_audio, rms=0.01)
wav = lux32.generate_speech(test_text, enc, num_steps=4)
assert wav.dtype == torch.float32

def test_fp16_falls_back_on_cpu(capsys):
lux = LuxTTS("YatharthS/LuxTTS", device="cpu", dtype="float16")
assert lux.dtype == torch.float32
captured = capsys.readouterr()
assert "float32" in captured.out
33 changes: 26 additions & 7 deletions zipvoice/luxvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ class LuxTTS:
LuxTTS class for encoding prompt and generating speech on cpu/cuda/mps.
"""

def __init__(self, model_path='YatharthS/LuxTTS', device='cuda', threads=4):
if model_path == 'YatharthS/LuxTTS':
def __init__(self, model_path="YatharthS/LuxTTS", device="cuda", threads=4, dtype="float32"):
if model_path == "YatharthS/LuxTTS":
model_path = None

# Auto-detect better device if cuda is requested but not available
Expand All @@ -27,6 +27,17 @@ def __init__(self, model_path='YatharthS/LuxTTS', device='cuda', threads=4):
model, feature_extractor, vocos, tokenizer, transcriber = load_models_gpu(model_path, device=device)
print("Loading model on GPU")

if dtype == "float16" or dtype == torch.float16:
if device == "cpu":
print(
"Warning: float16 is not supported on CPU, falling back to float32"
)
self.dtype = torch.float32
else:
self.dtype = torch.float16
else:
self.dtype = torch.float32

self.model = model
self.feature_extractor = feature_extractor
self.vocos = vocos
Expand All @@ -35,12 +46,19 @@ def __init__(self, model_path='YatharthS/LuxTTS', device='cuda', threads=4):
self.device = device
self.vocos.freq_range = 12000


if self.dtype == torch.float16:
self.model = self.model.to(dtype=torch.float16)
self.vocos = self.vocos.to(dtype=torch.float16)

def encode_prompt(self, prompt_audio, duration=5, rms=0.001):
"""encodes audio prompt according to duration and rms(volume control)"""
prompt_tokens, prompt_features_lens, prompt_features, prompt_rms = process_audio(prompt_audio, self.transcriber, self.tokenizer, self.feature_extractor, self.device, target_rms=rms, duration=duration)
encode_dict = {"prompt_tokens": prompt_tokens, 'prompt_features_lens': prompt_features_lens, 'prompt_features': prompt_features, 'prompt_rms': prompt_rms}

if self.dtype == torch.float16 and self.device != "cpu":
if prompt_features is not None:
prompt_features = prompt_features.to(dtype=torch.float16)

encode_dict = {"prompt_tokens": prompt_tokens, "prompt_features_lens": prompt_features_lens, "prompt_features": prompt_features, "prompt_rms": prompt_rms}

return encode_dict

Expand All @@ -54,9 +72,10 @@ def generate_speech(self, text, encode_dict, num_steps=4, guidance_scale=3.0, t_
else:
self.vocos.return_48k = True

if self.device == 'cpu':
if self.device == "cpu":
final_wav = generate_cpu(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed)
else:
final_wav = generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed)
with torch.autocast(device_type=self.device, dtype=self.dtype, enabled=(self.dtype == torch.float16)):
final_wav = generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, text, self.model, self.vocos, self.tokenizer, num_step=num_steps, guidance_scale=guidance_scale, t_shift=t_shift, speed=speed)

return final_wav.cpu()
return final_wav.cpu().float()
1 change: 1 addition & 0 deletions zipvoice/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def generate(prompt_tokens, prompt_features_lens, prompt_features, prompt_rms, t

# Convert to waveform
pred_features = pred_features.permute(0, 2, 1) / 0.1
pred_features = pred_features.float() # cast to float32 for vocoder stability
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)

# Volume matching
Expand Down