From 3a78e5d5a5b25af0260074a1886aa23b2565cee5 Mon Sep 17 00:00:00 2001 From: Simon Sardorf Date: Wed, 18 Dec 2024 15:44:50 +0100 Subject: [PATCH] Enable mps support --- tortoise/api.py | 29 ++++++++++++++++++----------- tortoise/api_fast.py | 7 ++++--- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 8a010c2a..0e61ffbb 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -243,7 +243,7 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, self.rlg_auto = None self.rlg_diffusion = None @contextmanager - def temporary_cuda(self, model): + def temporary_device(self, model): m = model.to(self.device) yield m m = model.cpu() @@ -410,8 +410,9 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= if verbose: print("Generating autoregressive samples..") if not torch.backends.mps.is_available(): - with self.temporary_cuda(self.autoregressive - ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): + with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=self.half + ): for b in tqdm(range(num_batches), disable=not verbose): codes = autoregressive.inference_speech(auto_conditioning, text_tokens, do_sample=True, @@ -426,7 +427,9 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) else: - with self.temporary_cuda(self.autoregressive) as autoregressive: + with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast( + device_type="mps", dtype=torch.float16, enabled=self.half + ): for b in tqdm(range(num_batches), disable=not verbose): codes = autoregressive.inference_speech(auto_conditioning, text_tokens, do_sample=True, @@ -444,8 +447,10 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= clip_results = [] if not torch.backends.mps.is_available(): - with self.temporary_cuda(self.clvp) as clvp, torch.autocast( - device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + with self.temporary_device(self.clvp) as clvp, torch.autocast( + device_type=self.device.type, + dtype=torch.float16, + enabled=self.half ): if cvvp_amount > 0: if self.cvvp is None: @@ -476,7 +481,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices] else: - with self.temporary_cuda(self.clvp) as clvp: + with self.temporary_device(self.clvp) as clvp: if cvvp_amount > 0: if self.cvvp is None: self.load_cvvp() @@ -513,10 +518,12 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. if not torch.backends.mps.is_available(): - with self.temporary_cuda( + with self.temporary_device( self.autoregressive ) as autoregressive, torch.autocast( - device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + device_type=self.device.type, + dtype=torch.float16, + enabled=self.half ): best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, @@ -524,7 +531,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= return_latent=True, clip_inputs=False) del auto_conditioning else: - with self.temporary_cuda( + with self.temporary_device( self.autoregressive ) as autoregressive: best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), @@ -537,7 +544,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose= print("Transforming autoregressive outputs into audio..") wav_candidates = [] if not torch.backends.mps.is_available(): - with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( + with self.temporary_device(self.diffusion) as diffusion, self.temporary_device( self.vocoder ) as vocoder: for b in range(best_results.shape[0]): diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py index fd7c5904..7b936bfb 100644 --- a/tortoise/api_fast.py +++ b/tortoise/api_fast.py @@ -371,7 +371,7 @@ def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, v if verbose: print("Generating autoregressive samples..") with torch.autocast( - device_type="cuda" , dtype=torch.float16, enabled=self.half + device_type="cuda" if not torch.backends.mps.is_available() else "mps" , dtype=torch.float16, enabled=self.half ): fake_inputs = self.autoregressive.compute_embeddings( auto_conditioning, @@ -400,7 +400,7 @@ def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, v while not is_end: try: with torch.autocast( - device_type="cuda", dtype=torch.float16, enabled=self.half + device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half ): codes, latent = next(gpt_generator) all_latents += [latent] @@ -478,8 +478,9 @@ def tts(self, text, voice_samples=None, k=1, verbose=True, use_deterministic_see calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" if verbose: print("Generating autoregressive samples..") + print("Using device MPS: ", torch.backends.mps.is_available()) with torch.autocast( - device_type="cuda" , dtype=torch.float16, enabled=self.half + device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half ): codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, top_k=50,