diff --git a/tortoise/api.py b/tortoise/api.py index 57d01594..39d029a9 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -208,7 +208,7 @@ class TextToSpeech: def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, - tokenizer_vocab_file=None, tokenizer_basic=False): + tokenizer_vocab_file=None, tokenizer_basic=False, device_only=False): """ Constructor @@ -263,16 +263,27 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) self.vocoder.eval(inference=True) + self.device_only = device_only + + if self.device_only: + self.autoregressive = self.autoregressive.to(self.device) + self.diffusion = self.diffusion.to(self.device) + self.clvp = self.clvp.to(self.device) + self.vocoder = self.vocoder.to(self.device) + # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None self.rlg_diffusion = None + @contextmanager def temporary_cuda(self, model): - m = model.to(self.device) - yield m - m = model.cpu() + if not self.device_only: + m = model.to(self.device) + yield m + m = model.cpu() + else: + yield model - def load_cvvp(self): """Load CVVP model.""" self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, @@ -295,22 +306,31 @@ def get_conditioning_latents(self, voice_samples, return_mels=False): for vs in voice_samples: auto_conds.append(format_conditioning(vs, device=self.device)) auto_conds = torch.stack(auto_conds, dim=1) - self.autoregressive = self.autoregressive.to(self.device) - auto_latent = self.autoregressive.get_conditioning(auto_conds) - self.autoregressive = self.autoregressive.cpu() + + if not self.device_only: + self.autoregressive = self.autoregressive.to(self.device) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = self.autoregressive.cpu() + else: + auto_latent = self.autoregressive.get_conditioning(auto_conds) diffusion_conds = [] + for sample in voice_samples: # The diffuser operates at a sample rate of 24000 (except for the latent inputs) sample = torchaudio.functional.resample(sample, 22050, 24000) sample = pad_or_truncate(sample, 102400) cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device) diffusion_conds.append(cond_mel) + diffusion_conds = torch.stack(diffusion_conds, dim=1) - self.diffusion = self.diffusion.to(self.device) - diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) - self.diffusion = self.diffusion.cpu() + if not self.device_only: + self.diffusion = self.diffusion.to(self.device) + diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) + self.diffusion = self.diffusion.cpu() + else: + diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) if return_mels: return auto_latent, diffusion_latent, auto_conds, diffusion_conds