diff --git a/clipped_audio.wav b/clipped_audio.wav new file mode 100644 index 00000000..bdfa7cc4 Binary files /dev/null and b/clipped_audio.wav differ diff --git a/test_sig.npy b/test_sig.npy new file mode 100644 index 00000000..798d6f96 Binary files /dev/null and b/test_sig.npy differ diff --git a/tortoise/api.py b/tortoise/api.py index a5b95dda..0612c8a9 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +import numpy as np import progressbar import torchaudio @@ -21,9 +22,11 @@ from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel, TacotronSTFT from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.misc_helpers import Timer from tortoise.utils.wav2vec_alignment import Wav2VecAlignment from contextlib import contextmanager from huggingface_hub import hf_hub_download + pbar = None DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') @@ -39,13 +42,19 @@ 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', } + def get_model_path(model_name, models_dir=MODELS_DIR): """ Get path to given model, download it if it doesn't exist. """ if model_name not in MODELS: raise ValueError(f'Model {model_name} not found in available models.') - model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + try: + model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + except: + # CVVP not found in Manmay tortoise-tts. + model_path = hf_hub_download(repo_id="jbetker/tortoise-tts-v2", subfolder=".models", filename=model_name, cache_dir=models_dir) + return model_path @@ -56,17 +65,20 @@ def pad_or_truncate(t, length): if t.shape[-1] == length: return t elif t.shape[-1] < length: - return F.pad(t, (0, length-t.shape[-1])) + return F.pad(t, (0, length - t.shape[-1])) else: return t[..., :length] -def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, + cond_free_k=1): """ Helper function to load a GaussianDiffusion instance configured for use as a vocoder. """ - return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', - model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', + betas=get_named_beta_schedule('linear', trained_diffusion_steps), conditioning_free=cond_free, conditioning_free_k=cond_free_k) @@ -119,15 +131,17 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_la Uses the specified diffusion model to convert discrete codes into a spectrogram. """ with torch.no_grad(): - output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_seq_len = latents.shape[ + 1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (latents.shape[0], 100, output_seq_len) - precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False) + precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, + False) noise = torch.randn(output_shape, device=latents.device) * temperature mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, - model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, progress=verbose) - return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] def classify_audio_clip(clip): @@ -171,12 +185,13 @@ def pick_best_batch_size_for_gpu(): return 4 return 1 + class TextToSpeech: """ Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, tokenizer_vocab_file=None, tokenizer_basic=False): @@ -194,7 +209,7 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, self.models_dir = models_dir self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size self.enable_redaction = enable_redaction - self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.backends.mps.is_available(): self.device = torch.device('mps') if self.enable_redaction: @@ -210,15 +225,19 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') else: - self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, - model_dim=1024, - heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, - train_solo_embeddings=False).cpu().eval() - self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False) + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, + layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, + checkpointing=False, + train_solo_embeddings=False).cpu().eval() + self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), + strict=False) self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half) - + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, + num_heads=16, layer_drop=0, unconditioned_percentage=0).cpu().eval() self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir))) @@ -227,31 +246,37 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, use_xformers=True).cpu().eval() self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir))) - self.cvvp = None # CVVP model is only loaded if used. + self.cvvp = None # CVVP model is only loaded if used. self.vocoder = UnivNetGenerator().cpu() - self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) + self.vocoder.load_state_dict( + torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) self.vocoder.eval(inference=True) - - self.stft = None # TacotronSTFT is only loaded if used. + + self.stft = None # TacotronSTFT is only loaded if used. # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None self.rlg_diffusion = None + @contextmanager def temporary_cuda(self, model): m = model.to(self.device) yield m m = model.cpu() - def load_cvvp(self): """Load CVVP model.""" - self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, + self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, + cond_mask_percentage=0, speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() + #self.cvvp.to(self.device).eval() + + self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) - def get_conditioning_latents(self, voice_samples, return_mels=False): + + def get_conditioning_latents(self, voice_samples, return_mels=False, return_average=True): """ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic @@ -268,7 +293,7 @@ def get_conditioning_latents(self, voice_samples, return_mels=False): auto_conds.append(format_conditioning(vs, device=self.device)) auto_conds = torch.stack(auto_conds, dim=1) self.autoregressive = self.autoregressive.to(self.device) - auto_latent = self.autoregressive.get_conditioning(auto_conds) + auto_latent = self.autoregressive.get_conditioning(auto_conds, return_average=return_average) self.autoregressive = self.autoregressive.cpu() if self.stft is None: @@ -283,10 +308,11 @@ def get_conditioning_latents(self, voice_samples, return_mels=False): cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device, stft=self.stft) diffusion_conds.append(cond_mel) - diffusion_conds = torch.stack(diffusion_conds, dim=1) + + diffusion_conds = torch.stack(diffusion_conds, dim=1) self.diffusion = self.diffusion.to(self.device) - diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) + diffusion_latent = self.diffusion.get_conditioning(diffusion_conds, return_average=return_average) self.diffusion = self.diffusion.cpu() if return_mels: @@ -298,9 +324,11 @@ def get_random_conditioning_latents(self): # Lazy-load the RLG models. if self.rlg_auto is None: self.rlg_auto = RandomLatentConverter(1024).eval() - self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) + self.rlg_auto.load_state_dict( + torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) self.rlg_diffusion = RandomLatentConverter(2048).eval() - self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu'))) + self.rlg_diffusion.load_state_dict( + torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu'))) with torch.no_grad(): return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) @@ -324,17 +352,20 @@ def tts_with_preset(self, text, preset='fast', **kwargs): 'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, } settings.update(presets[preset]) - settings.update(kwargs) # allow overriding of preset settings with kwargs + settings.update(kwargs) # allow overriding of preset settings with kwargs return self.tts(text, **settings) def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, return_deterministic_state=False, # autoregressive generation parameters follow - num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, + max_mel_tokens=500, # CVVP parameters follow cvvp_amount=.0, # diffusion generation parameters follow diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + use_averaged_latents=True, + auto_conds=None, **hf_generate_kwargs): """ Produces an audio clip of the given text being spoken with the given reference voice. @@ -381,22 +412,38 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ + + if not use_averaged_latents: + assert k==1, "Non-averaged latents currently only support single sample generation" + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. - assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' - auto_conds = None + assert text_tokens.shape[ + -1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' if voice_samples is not None: - auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True) + auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, + return_mels=True, + return_average=use_averaged_latents) elif conditioning_latents is not None: auto_conditioning, diffusion_conditioning = conditioning_latents + if use_averaged_latents: + # Average across second axis + if auto_conditioning.dim() > 2: + auto_conditioning = torch.mean(auto_conditioning,axis=1) + if diffusion_conditioning.dim() > 2: + diffusion_conditioning = torch.mean(diffusion_conditioning,axis=1) else: auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents() + + + auto_conditioning = auto_conditioning.to(self.device) diffusion_conditioning = diffusion_conditioning.to(self.device) - diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, + cond_free_k=cond_free_k) with torch.no_grad(): samples = [] @@ -407,41 +454,94 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= print("Generating autoregressive samples..") if not torch.backends.mps.is_available(): with self.temporary_cuda(self.autoregressive - ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): + ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, + enabled=self.half): + # Store the latent indices for alignment with the diffusion conditions + batched_latent_indices = [] for b in tqdm(range(num_batches), disable=not verbose): - codes = autoregressive.inference_speech(auto_conditioning, text_tokens, - do_sample=True, - top_p=top_p, - temperature=temperature, - num_return_sequences=self.autoregressive_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens, - **hf_generate_kwargs) + # Case where we're returning non-averaged latents + if auto_conditioning.dim() > 2: + # If the number of candidate speech conditioning latents are not equal to (greater) + # num_return_sequences (batch), randomly select an equal number of candidate latents. + if auto_conditioning.shape[1] >= self.autoregressive_batch_size: + latent_indices = torch.randperm(auto_conditioning.shape[1])[ + :self.autoregressive_batch_size] + batched_latent_indices.append(latent_indices) + else: + # If there are less candidate speech conditioning latents, replicate the + # latents to meet the autoregressive batch size + replications = np.ceil(self.autoregressive_batch_size / + auto_conditioning.shape[1]).astype(int) + latent_indices = (torch.arange(0, auto_conditioning.shape[1], dtype=torch.int32). + repeat(replications))[:self.autoregressive_batch_size] + batched_latent_indices.append(latent_indices) + auto_conditioning_ = auto_conditioning[0, latent_indices].unsqueeze(0) + + else: + auto_conditioning_ = auto_conditioning + + codes = autoregressive.inference_speech(auto_conditioning_, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) else: with self.temporary_cuda(self.autoregressive) as autoregressive: + batched_latent_indices = [] for b in tqdm(range(num_batches), disable=not verbose): - codes = autoregressive.inference_speech(auto_conditioning, text_tokens, - do_sample=True, - top_p=top_p, - temperature=temperature, - num_return_sequences=self.autoregressive_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - max_generate_length=max_mel_tokens, - **hf_generate_kwargs) + + # Case where we're returning non-averaged latents + if auto_conditioning.dim() > 2: + # If the number of candidate speech conditioning latents are not equal to (greater) + # num_return_sequences (batch), randomly select an equal number of candidate latents. + if auto_conditioning.shape[1] >= self.autoregressive_batch_size: + latent_indices = torch.randperm(auto_conditioning.shape[1])[ + :self.autoregressive_batch_size] + batched_latent_indices.append(latent_indices) + else: + # If there are less candidate speech conditioning latents, replicate the + # latents to meet the autoregressive batch size + replications = np.ceil(self.autoregressive_batch_size / + auto_conditioning.shape[1]).astype(int) + latent_indices = (torch.arange(0, auto_conditioning.shape[1], dtype=torch.int32). + repeat(replications))[:self.autoregressive_batch_size] + batched_latent_indices.append(latent_indices) + auto_conditioning_ = auto_conditioning[0, latent_indices].unsqueeze(0) + else: + auto_conditioning_ = auto_conditioning + + + codes = autoregressive.inference_speech(auto_conditioning_, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) + + # Flatten batched latents (if not using averaged latents) + if len(batched_latent_indices): + batched_latent_indices_flattened = torch.cat(batched_latent_indices, dim=0) clip_results = [] - + cvvp_results = [] + clvp_results = [] if not torch.backends.mps.is_available(): with self.temporary_cuda(self.clvp) as clvp, torch.autocast( - device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, + enabled=self.half ): if cvvp_amount > 0: if self.cvvp is None: @@ -451,7 +551,8 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if self.cvvp is None: print("Computing best candidates using CLVP") else: - print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + print( + f"Computing best candidates using CLVP {((1 - cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) @@ -460,17 +561,44 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if auto_conds is not None and cvvp_amount > 0: cvvp_accumulator = 0 for cl in range(auto_conds.shape[1]): - cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp_accumulator = cvvp_accumulator + self.cvvp( + auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) cvvp = cvvp_accumulator / auto_conds.shape[1] if cvvp_amount == 1: + # Voice only based selection (CVVP - how well do the VQ mels align with the voice prompt(s)?) clip_results.append(cvvp) else: - clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + # Hybrid Voice-Text based selection + # We append to clvp and cvvp lists separately such that norm equalization occurs + # on total set. + clvp_results.append(clvp_out) + cvvp_results.append(cvvp) else: + # Text based selection (CLVP - how well do the VQ mels align with the text prompt?) clip_results.append(clvp_out) - clip_results = torch.cat(clip_results, dim=0) + + if len(clvp_results): + # cvvp and clvp_out have dramatically different scales. Equalize the norms such that + # weighting value has more intuitive, linear meaning. + clvp_results = torch.cat(clvp_results,dim=0) + cvvp_results = torch.cat(cvvp_results,dim=0) + norm_clvp = torch.linalg.norm(clvp_results) + norm_cvvp = torch.linalg.norm(cvvp_results) + norm_scale_cvvp = norm_clvp/norm_cvvp + cvvp_results *= norm_scale_cvvp + + # Calculate weighted clip results + clip_results = cvvp * cvvp_amount + clvp_out * (1 - cvvp_amount) + + + else: + clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=k).indices] + top_k_ = torch.topk(clip_results, k=k).indices + if len(batched_latent_indices): + # map top_k_ back to samples to reference proper diffusion conditions + mapped_top_k = batched_latent_indices_flattened[top_k_.cpu()] + best_results = samples[top_k_] else: with self.temporary_cuda(self.clvp) as clvp: if cvvp_amount > 0: @@ -481,7 +609,8 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if self.cvvp is None: print("Computing best candidates using CLVP") else: - print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + print( + f"Computing best candidates using CLVP {((1 - cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) @@ -490,13 +619,17 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if auto_conds is not None and cvvp_amount > 0: cvvp_accumulator = 0 for cl in range(auto_conds.shape[1]): - cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp_accumulator = cvvp_accumulator + self.cvvp( + auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) cvvp = cvvp_accumulator / auto_conds.shape[1] if cvvp_amount == 1: + # Voice only based selection (CVVP - how well do the VQ mels align with the voice prompt(s)?) clip_results.append(cvvp) else: - clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + # Hybrid Voice-Text based selection + clip_results.append(cvvp * cvvp_amount + clvp_out * (1 - cvvp_amount)) else: + # Text based selection (CLVP - how well do the VQ mels align with the text prompt?) clip_results.append(clvp_out) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) @@ -510,23 +643,44 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= # results, but will increase memory usage. if not torch.backends.mps.is_available(): with self.temporary_cuda( - self.autoregressive + self.autoregressive ) as autoregressive, torch.autocast( - device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, + enabled=self.half ): - best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), - return_latent=True, clip_inputs=False) + # Allow for non-averaged latents + if auto_conditioning.dim() == 2: + auto_conditioning = auto_conditioning.repeat(k, 1) + else: + # Select the best condition + auto_conditioning = (auto_conditioning[0,mapped_top_k]).repeat(k,1) + + best_latents = autoregressive(auto_conditioning, text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor([best_results.shape[ + -1] * self.autoregressive.mel_length_compression], + device=text_tokens.device), + return_latent=True, clip_inputs=False) del auto_conditioning else: with self.temporary_cuda( - self.autoregressive + self.autoregressive ) as autoregressive: - best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), - torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), - return_latent=True, clip_inputs=False) + # Allow for non-averaged latents + if auto_conditioning.dim() == 2: + auto_conditioning = auto_conditioning.repeat(k, 1) + else: + # Select the best condition + auto_conditioning = (auto_conditioning[0,mapped_top_k]).repeat(k,1) + + best_latents = autoregressive(auto_conditioning, text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor([best_results.shape[ + -1] * self.autoregressive.mel_length_compression], + device=text_tokens.device), + return_latent=True, clip_inputs=False) del auto_conditioning if verbose: @@ -534,7 +688,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= wav_candidates = [] if not torch.backends.mps.is_available(): with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( - self.vocoder + self.vocoder ) as vocoder: for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) @@ -550,8 +704,15 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. latents = latents[:, :k] break - mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, - verbose=verbose) + + # Get top selection for diffusion conditioning + if diffusion_conditioning.dim() > 2: + diffusion_conditioning_ = diffusion_conditioning[0,mapped_top_k] + else: + diffusion_conditioning_ = diffusion_conditioning + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning_, + temperature=diffusion_temperature, + verbose=verbose) wav = vocoder.inference(mel) wav_candidates.append(wav.cpu()) else: @@ -571,15 +732,27 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. latents = latents[:, :k] break - mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, - verbose=verbose) + + # Get top selection for diffusion conditioning + if diffusion_conditioning.dim() > 2: + diffusion_conditioning_ = diffusion_conditioning[0, mapped_top_k] + else: + diffusion_conditioning_ = diffusion_conditioning + + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning_, + temperature=diffusion_temperature, + verbose=verbose) wav = vocoder.inference(mel) wav_candidates.append(wav.cpu()) + if verbose: + print('Finished getting wav candidates') + def potentially_redact(clip, text): if self.enable_redaction: return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) return clip + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] if len(wav_candidates) > 1: @@ -587,10 +760,14 @@ def potentially_redact(clip, text): else: res = wav_candidates[0] + if verbose: + print("Returning result") + if return_deterministic_state: return res, (deterministic_seed, text, voice_samples, conditioning_latents) else: return res + def deterministic_state(self, seed=None): """ Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py index fcd1a94f..2937b557 100644 --- a/tortoise/models/autoregressive.py +++ b/tortoise/models/autoregressive.py @@ -129,8 +129,10 @@ def forward( return_dict if return_dict is not None else self.config.use_return_dict ) + # Create embedding mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: text_inputs = input_ids[:, mel_len:] text_emb = self.embeddings(text_inputs) @@ -147,6 +149,8 @@ def forward( emb = emb + self.text_pos_embedding.get_fixed_embedding( attention_mask.shape[1] - mel_len, attention_mask.device ) + + transformer_outputs = self.transformer( inputs_embeds=emb, past_key_values=past_key_values, @@ -159,8 +163,9 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=return_dict ) + hidden_states = transformer_outputs[0] # Set device for model parallelism @@ -441,16 +446,18 @@ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, secon else: return first_logits - def get_conditioning(self, speech_conditioning_input): + def get_conditioning(self, speech_conditioning_input, return_average=True): speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds = torch.stack(conds, dim=1) - conds = conds.mean(dim=1) + if return_average: + conds = conds.mean(dim=1) return conds + def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ @@ -485,7 +492,11 @@ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_cod text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) - conds = speech_conditioning_latent.unsqueeze(1) + if speech_conditioning_latent.dim() == 2: + conds = speech_conditioning_latent.unsqueeze(1) + else: + conds = speech_conditioning_latent + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) @@ -532,18 +543,38 @@ def compute_embeddings( ) gpt_inputs[:, -1] = self.start_mel_token return gpt_inputs - def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, + def inference_speech(self, speech_conditioning_latents, text_inputs, input_tokens=None, num_return_sequences=1, max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - conds = speech_conditioning_latent.unsqueeze(1) - emb = torch.cat([conds, text_emb], dim=1) + # Optionally expand the speech conditioning latent for concatenation to text embedding. + # Allow for different speech conditioning latents to be passed per sample. + if speech_conditioning_latents.dim() == 2: + conds = speech_conditioning_latents.unsqueeze(1) + emb = torch.cat([conds, text_emb], dim=1) + + else: + assert speech_conditioning_latents.shape[1] == num_return_sequences, \ + ("If the number of speech conditioning latents passed is > 1, they must be equal to the " + "autoregressive batch size") + conds = speech_conditioning_latents + + # Here, we have num_return_sequences unique VQ mels (different conditional VQ mel for each) + emb = torch.cat([torch.swapaxes(conds,0,1), + text_emb.repeat(num_return_sequences,1,1)], dim=1) + + self.inference_model.store_mel_emb(emb) - fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long, + + # TODO: Resolve below adjustment. When might emb.shape != 1? + # fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long, + # device=text_inputs.device) + + fake_inputs = torch.full((1, 1 + emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:, -1] = self.start_mel_token trunc_index = fake_inputs.shape[1] @@ -554,12 +585,19 @@ def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens fake_inputs = fake_inputs.repeat(num_return_sequences, 1) input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) inputs = torch.cat([fake_inputs, input_tokens], dim=1) - logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + + # Pre-expansion of inputs - this simply expands inputs into the number of input sequences + inputs = self.inference_model._expand_inputs_for_generation(expand_size=num_return_sequences, + input_ids=inputs, + **hf_generate_kwargs)[0] + + #print(f'GENERATE KWARGS: {hf_generate_kwargs}') gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, max_length=max_length, logits_processor=logits_processor, - num_return_sequences=num_return_sequences, **hf_generate_kwargs) + num_return_sequences=1, **hf_generate_kwargs) + return gen[:, trunc_index:] def get_generator(self, fake_inputs, **hf_generate_kwargs): diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py index 544ca47b..64058720 100644 --- a/tortoise/models/cvvp.py +++ b/tortoise/models/cvvp.py @@ -108,6 +108,9 @@ def forward( mel_input, return_loss=False ): + + #print(f'MEL COND SHAPE:{mel_cond.shape}') + #print(f'MEL IN SHAPE: {mel_input.shape}') cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) enc_cond = self.conditioning_transformer(cond_emb) cond_latents = self.to_conditioning_latent(enc_cond) diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index e969129c..5ebbc7d7 100644 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -219,14 +219,21 @@ def get_grad_norm_parameter_groups(self): } return groups - def get_conditioning(self, conditioning_input): + def get_conditioning(self, conditioning_input, return_average=True): speech_conditioning_input = conditioning_input.unsqueeze(1) if len( conditioning_input.shape) == 3 else conditioning_input conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) - conds = torch.cat(conds, dim=-1) - conds = conds.mean(dim=-1) + diff_context = self.contextual_embedder(speech_conditioning_input[:, j]) + if not return_average: + # We must still average across the last dim per sample (we don't average cross all samples) + diff_context = diff_context.mean(dim=-1) + conds.append(diff_context) + if return_average: + conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) + else: + conds = torch.stack(conds,dim=1) return conds def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred): diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py index 98783eff..582f5c1c 100644 --- a/tortoise/utils/audio.py +++ b/tortoise/utils/audio.py @@ -8,7 +8,7 @@ from scipy.io.wavfile import read from tortoise.utils.stft import STFT - +from tortoise.utils.misc_helpers import Timer BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../voices')