Skip to content

Commit 29c1d9e

Browse files
authored
Merge pull request neonbjb#97 from jnordberg/cpu-support
CPU support
2 parents a9e64e2 + de7c5dd commit 29c1d9e

File tree

4 files changed

+44
-33
lines changed

4 files changed

+44
-33
lines changed

scripts/tortoise_tts.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@
7979
help='Normally text enclosed in brackets are automatically redacted from the spoken output '
8080
'(but are still rendered by the model), this can be used for prompt engineering. '
8181
'Set this to disable this behavior.')
82+
advanced_group.add_argument(
83+
'--device', type=str, default=None,
84+
help='Device to use for inference.')
85+
advanced_group.add_argument(
86+
'--batch-size', type=int, default=None,
87+
help='Batch size to use for inference. If omitted, the batch size is set based on available GPU memory.')
8288

8389
tuning_group = parser.add_argument_group('tuning options (overrides preset settings)')
8490
tuning_group.add_argument(
@@ -200,10 +206,11 @@
200206
seed = int(time.time()) if args.seed is None else args.seed
201207
if not args.quiet:
202208
print('Loading tts...')
203-
tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction)
209+
tts = TextToSpeech(models_dir=args.models_dir, enable_redaction=not args.disable_redaction,
210+
device=args.device, autoregressive_batch_size=args.batch_size)
204211
gen_settings = {
205212
'use_deterministic_seed': seed,
206-
'varbose': not args.quiet,
213+
'verbose': not args.quiet,
207214
'k': args.candidates,
208215
'preset': args.preset,
209216
}

tortoise/api.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusi
101101
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
102102

103103

104-
def format_conditioning(clip, cond_length=132300):
104+
def format_conditioning(clip, cond_length=132300, device='cuda'):
105105
"""
106106
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
107107
"""
@@ -112,7 +112,7 @@ def format_conditioning(clip, cond_length=132300):
112112
rand_start = random.randint(0, gap)
113113
clip = clip[:, rand_start:rand_start + cond_length]
114114
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
115-
return mel_clip.unsqueeze(0).cuda()
115+
return mel_clip.unsqueeze(0).to(device)
116116

117117

118118
def fix_autoregressive_output(codes, stop_token, complain=True):
@@ -181,14 +181,15 @@ def pick_best_batch_size_for_gpu():
181181
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
182182
you a good shot.
183183
"""
184-
free, available = torch.cuda.mem_get_info()
185-
availableGb = available / (1024 ** 3)
186-
if availableGb > 14:
187-
return 16
188-
elif availableGb > 10:
189-
return 8
190-
elif availableGb > 7:
191-
return 4
184+
if torch.cuda.is_available():
185+
_, available = torch.cuda.mem_get_info()
186+
availableGb = available / (1024 ** 3)
187+
if availableGb > 14:
188+
return 16
189+
elif availableGb > 10:
190+
return 8
191+
elif availableGb > 7:
192+
return 4
192193
return 1
193194

194195

@@ -197,7 +198,7 @@ class TextToSpeech:
197198
Main entry point into Tortoise.
198199
"""
199200

200-
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
201+
def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None):
201202
"""
202203
Constructor
203204
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@@ -207,10 +208,12 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable
207208
:param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
208209
(but are still rendered by the model). This can be used for prompt engineering.
209210
Default is true.
211+
:param device: Device to use when running the model. If omitted, the device will be automatically chosen.
210212
"""
211213
self.models_dir = models_dir
212214
self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
213215
self.enable_redaction = enable_redaction
216+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
214217
if self.enable_redaction:
215218
self.aligner = Wav2VecAlignment()
216219

@@ -240,7 +243,7 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable
240243
self.cvvp = None # CVVP model is only loaded if used.
241244

242245
self.vocoder = UnivNetGenerator().cpu()
243-
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir))['model_g'])
246+
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
244247
self.vocoder.eval(inference=True)
245248

246249
# Random latent generators (RLGs) are loaded lazily.
@@ -261,15 +264,15 @@ def get_conditioning_latents(self, voice_samples, return_mels=False):
261264
:param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data.
262265
"""
263266
with torch.no_grad():
264-
voice_samples = [v.to('cuda') for v in voice_samples]
267+
voice_samples = [v.to(self.device) for v in voice_samples]
265268

266269
auto_conds = []
267270
if not isinstance(voice_samples, list):
268271
voice_samples = [voice_samples]
269272
for vs in voice_samples:
270-
auto_conds.append(format_conditioning(vs))
273+
auto_conds.append(format_conditioning(vs, device=self.device))
271274
auto_conds = torch.stack(auto_conds, dim=1)
272-
self.autoregressive = self.autoregressive.cuda()
275+
self.autoregressive = self.autoregressive.to(self.device)
273276
auto_latent = self.autoregressive.get_conditioning(auto_conds)
274277
self.autoregressive = self.autoregressive.cpu()
275278

@@ -278,11 +281,11 @@ def get_conditioning_latents(self, voice_samples, return_mels=False):
278281
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
279282
sample = torchaudio.functional.resample(sample, 22050, 24000)
280283
sample = pad_or_truncate(sample, 102400)
281-
cond_mel = wav_to_univnet_mel(sample.to('cuda'), do_normalization=False)
284+
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
282285
diffusion_conds.append(cond_mel)
283286
diffusion_conds = torch.stack(diffusion_conds, dim=1)
284287

285-
self.diffusion = self.diffusion.cuda()
288+
self.diffusion = self.diffusion.to(self.device)
286289
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
287290
self.diffusion = self.diffusion.cpu()
288291

@@ -380,7 +383,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
380383
"""
381384
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
382385

383-
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
386+
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
384387
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
385388
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
386389

@@ -391,8 +394,8 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
391394
auto_conditioning, diffusion_conditioning = conditioning_latents
392395
else:
393396
auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents()
394-
auto_conditioning = auto_conditioning.cuda()
395-
diffusion_conditioning = diffusion_conditioning.cuda()
397+
auto_conditioning = auto_conditioning.to(self.device)
398+
diffusion_conditioning = diffusion_conditioning.to(self.device)
396399

397400
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
398401

@@ -401,7 +404,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
401404
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
402405
stop_mel_token = self.autoregressive.stop_mel_token
403406
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
404-
self.autoregressive = self.autoregressive.cuda()
407+
self.autoregressive = self.autoregressive.to(self.device)
405408
if verbose:
406409
print("Generating autoregressive samples..")
407410
for b in tqdm(range(num_batches), disable=not verbose):
@@ -420,11 +423,11 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
420423
self.autoregressive = self.autoregressive.cpu()
421424

422425
clip_results = []
423-
self.clvp = self.clvp.cuda()
426+
self.clvp = self.clvp.to(self.device)
424427
if cvvp_amount > 0:
425428
if self.cvvp is None:
426429
self.load_cvvp()
427-
self.cvvp = self.cvvp.cuda()
430+
self.cvvp = self.cvvp.to(self.device)
428431
if verbose:
429432
if self.cvvp is None:
430433
print("Computing best candidates using CLVP")
@@ -457,7 +460,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
457460
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
458461
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
459462
# results, but will increase memory usage.
460-
self.autoregressive = self.autoregressive.cuda()
463+
self.autoregressive = self.autoregressive.to(self.device)
461464
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
462465
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
463466
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
@@ -468,8 +471,8 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
468471
if verbose:
469472
print("Transforming autoregressive outputs into audio..")
470473
wav_candidates = []
471-
self.diffusion = self.diffusion.cuda()
472-
self.vocoder = self.vocoder.cuda()
474+
self.diffusion = self.diffusion.to(self.device)
475+
self.vocoder = self.vocoder.to(self.device)
473476
for b in range(best_results.shape[0]):
474477
codes = best_results[b].unsqueeze(0)
475478
latents = best_latents[b].unsqueeze(0)

tortoise/utils/audio.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ def mel_spectrogram(self, y):
180180
return mel_output
181181

182182

183-
def wav_to_univnet_mel(wav, do_normalization=False):
183+
def wav_to_univnet_mel(wav, do_normalization=False, device='cuda'):
184184
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
185-
stft = stft.cuda()
185+
stft = stft.to(device)
186186
mel = stft.mel_spectrogram(wav)
187187
if do_normalization:
188188
mel = normalize_tacotron_mel(mel)

tortoise/utils/wav2vec_alignment.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,18 @@ class Wav2VecAlignment:
4949
"""
5050
Uses wav2vec2 to perform audio<->text alignment.
5151
"""
52-
def __init__(self):
52+
def __init__(self, device='cuda'):
5353
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
5454
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
5555
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
56+
self.device = device
5657

5758
def align(self, audio, expected_text, audio_sample_rate=24000):
5859
orig_len = audio.shape[-1]
5960

6061
with torch.no_grad():
61-
self.model = self.model.cuda()
62-
audio = audio.to('cuda')
62+
self.model = self.model.to(self.device)
63+
audio = audio.to(self.device)
6364
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
6465
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
6566
logits = self.model(clip_norm).logits

0 commit comments

Comments
 (0)