Skip to content

Commit

Permalink
Enable mps support
Browse files Browse the repository at this point in the history
  • Loading branch information
SSardorf committed Dec 18, 2024
1 parent 8a2563e commit bf79761
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
29 changes: 18 additions & 11 deletions tortoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
self.rlg_auto = None
self.rlg_diffusion = None
@contextmanager
def temporary_cuda(self, model):
def temporary_device(self, model):
m = model.to(self.device)
yield m
m = model.cpu()
Expand Down Expand Up @@ -410,8 +410,9 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
if verbose:
print("Generating autoregressive samples..")
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.autoregressive
) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half):
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
Expand All @@ -426,7 +427,9 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
else:
with self.temporary_cuda(self.autoregressive) as autoregressive:
with self.temporary_device(self.autoregressive) as autoregressive, torch.autocast(
device_type="mps", dtype=torch.float16, enabled=self.half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(auto_conditioning, text_tokens,
do_sample=True,
Expand All @@ -444,8 +447,10 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
clip_results = []

if not torch.backends.mps.is_available():
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
with self.temporary_device(self.clvp) as clvp, torch.autocast(
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
if cvvp_amount > 0:
if self.cvvp is None:
Expand Down Expand Up @@ -476,7 +481,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
else:
with self.temporary_cuda(self.clvp) as clvp:
with self.temporary_device(self.clvp) as clvp:
if cvvp_amount > 0:
if self.cvvp is None:
self.load_cvvp()
Expand Down Expand Up @@ -513,18 +518,20 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
if not torch.backends.mps.is_available():
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half
device_type=self.device.type,
dtype=torch.float16,
enabled=self.half
):
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False)
del auto_conditioning
else:
with self.temporary_cuda(
with self.temporary_device(
self.autoregressive
) as autoregressive:
best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
Expand All @@ -537,7 +544,7 @@ def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
if not torch.backends.mps.is_available():
with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda(
with self.temporary_device(self.diffusion) as diffusion, self.temporary_device(
self.vocoder
) as vocoder:
for b in range(best_results.shape[0]):
Expand Down
8 changes: 4 additions & 4 deletions tortoise/api_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, v
if verbose:
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps" , dtype=torch.float16, enabled=self.half
):
fake_inputs = self.autoregressive.compute_embeddings(
auto_conditioning,
Expand Down Expand Up @@ -400,7 +400,7 @@ def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, v
while not is_end:
try:
with torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes, latent = next(gpt_generator)
all_latents += [latent]
Expand Down Expand Up @@ -477,9 +477,9 @@ def tts(self, text, voice_samples=None, k=1, verbose=True, use_deterministic_see
with torch.no_grad():
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
if verbose:
print("Generating autoregressive samples..")
print("Generating autoregressive samples..")
with torch.autocast(
device_type="cuda" , dtype=torch.float16, enabled=self.half
device_type="cuda" if not torch.backends.mps.is_available() else "mps", dtype=torch.float16, enabled=self.half
):
codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens,
top_k=50,
Expand Down

0 comments on commit bf79761

Please sign in to comment.