From 122df8c42043001d858ad07056b9a7190a61519d Mon Sep 17 00:00:00 2001 From: "lyuxiang.lx" Date: Fri, 6 Sep 2024 17:10:54 +0800 Subject: [PATCH] set onnx to false as last chunk rtf unstable --- .github/workflows/lint.yml | 1 + cosyvoice/cli/cosyvoice.py | 2 +- cosyvoice/cli/model.py | 93 ++++++++++++++++++-------------------- 3 files changed, 45 insertions(+), 51 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 58d7fb7..fff7290 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,6 +2,7 @@ name: Lint on: pull_request: + push: jobs: quick-checks: diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index af09c12..6e8d2d3 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -23,7 +23,7 @@ class CosyVoice: - def __init__(self, model_dir, load_jit=True, load_onnx=True): + def __init__(self, model_dir, load_jit=True, load_onnx=False): instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir if not os.path.exists(model_dir): diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index da88dd6..5efd30c 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -43,7 +43,6 @@ def __init__(self, self.stream_scale_factor = 1 assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf' self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() - self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() self.lock = threading.Lock() # dict used to store session related variable self.tts_speech_token_dict = {} @@ -93,32 +92,31 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui self.llm_end_dict[uuid] = True def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False): - with self.flow_hift_context: - tts_mel = self.flow.inference(token=token.to(self.device), - token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), - prompt_token=prompt_token.to(self.device), - prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), - prompt_feat=prompt_feat.to(self.device), - prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), - embedding=embedding.to(self.device)) - # mel overlap fade in out - if self.mel_overlap_dict[uuid] is not None: - tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) - # append hift cache - if self.hift_cache_dict[uuid] is not None: - hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) - else: - hift_cache_source = torch.zeros(1, 1, 0) - # keep overlap mel and hift cache - if finalize is False: - self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] - tts_mel = tts_mel[:, :, :-self.mel_overlap_len] - tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) - self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]} - tts_speech = tts_speech[:, :-self.source_cache_len] - else: - tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) + tts_mel = self.flow.inference(token=token.to(self.device), + token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device), + prompt_token=prompt_token.to(self.device), + prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device), + embedding=embedding.to(self.device)) + # mel overlap fade in out + if self.mel_overlap_dict[uuid] is not None: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) + # append hift cache + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] + tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = torch.zeros(1, 1, 0) + # keep overlap mel and hift cache + if finalize is False: + self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] + tts_mel = tts_mel[:, :, :-self.mel_overlap_len] + tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) + self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]} + tts_speech = tts_speech[:, :-self.source_cache_len] + else: + tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) return tts_speech def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), @@ -139,13 +137,12 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), time.sleep(0.1) if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1) - with self.flow_hift_context: - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=False) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=False) yield {'tts_speech': this_tts_speech.cpu()} with self.lock: self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] @@ -156,30 +153,26 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192), p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) - with self.flow_hift_context: - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: # deal with all tokens p.join() this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1) - with self.flow_hift_context: - this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - uuid=this_uuid, - finalize=True) + this_tts_speech = self.token2wav(token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True) yield {'tts_speech': this_tts_speech.cpu()} with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) - if torch.cuda.is_available(): - torch.cuda.synchronize()