Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update webui & fastapi for cosy2.0 #765

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions runtime/python/fastapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def main():
}
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
response = requests.request("GET", url, data=payload, files=files, stream=True)
else:
elif args.mode == 'instruct':
payload = {
'tts_text': args.tts_text,
'spk_id': args.spk_id,
'instruct_text': args.instruct_text
}
response = requests.request("GET", url, data=payload, stream=True)
else:
# instruct2
url = url + "_v2"
payload = {
'tts_text': args.tts_text,
'instruct_text': args.instruct_text,
'format': 'pcm' # option
}
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
response = requests.request("GET", url, data=payload, files=files, stream=True)
tts_audio = b''
for r in response.iter_content(chunk_size=16000):
tts_audio += r
Expand All @@ -66,7 +76,7 @@ def main():
default='50000')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct', 'instruct2'],
help='request mode')
parser.add_argument('--tts_text',
type=str,
Expand Down
82 changes: 60 additions & 22 deletions runtime/python/fastapi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,27 @@
# limitations under the License.
import os
import sys
import io
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from fastapi import FastAPI, UploadFile, Form, File
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, Response
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice
import torch
import torchaudio
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = f'{CURR_DIR}/../../..'
sys.path.append(f'{ROOT_DIR}')
sys.path.append(f'{ROOT_DIR}/third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav

model_dir = f"{ROOT_DIR}/pretrained_models/CosyVoice2-0.5B"
cosyvoice = CosyVoice2(model_dir) if 'CosyVoice2' in model_dir else CosyVoice(model_dir)

app = FastAPI()
# set cross region allowance
app.add_middleware(
Expand All @@ -37,47 +44,78 @@
allow_headers=["*"])


# 非流式wav数据
def build_data(model_output):
tts_speeches = []
for i in model_output:
tts_speeches.append(i['tts_speech'])
output = torch.concat(tts_speeches, dim=1)

buffer = io.BytesIO()
torchaudio.save(buffer, output, 22050, format="wav")
buffer.seek(0)
return buffer.read(-1)


# 流式pcm数据
def generate_data(model_output):
for i in model_output:
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
yield tts_audio


@app.get("/inference_sft")
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
model_output = cosyvoice.inference_sft(tts_text, spk_id)
return StreamingResponse(generate_data(model_output))
async def inference_sft(tts_text: str = Form(), spk_id: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
model_output = cosyvoice.inference_sft(tts_text, spk_id, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")


@app.get("/inference_zero_shot")
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
return StreamingResponse(generate_data(model_output))
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")


@app.get("/inference_cross_lingual")
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
return StreamingResponse(generate_data(model_output))
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")


@app.get("/inference_instruct")
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
return StreamingResponse(generate_data(model_output))
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")


@app.get("/inference_instruct_v2")
async def inference_instruct_v2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File(), stream: bool = Form(default=False), format: str = Form(default="pcm")):
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream)
if format == "pcm":
return StreamingResponse(generate_data(model_output))
else:
return Response(build_data(model_output), media_type="audio/wav")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
type=int,
default=50000)
parser.add_argument('--model_dir',
type=str,
default='iic/CosyVoice-300M',
help='local path or modelscope repo id')
args = parser.parse_args()
cosyvoice = CosyVoice(args.model_dir)
uvicorn.run(app, host="0.0.0.0", port=args.port)
14 changes: 10 additions & 4 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
stream_mode_list = [('否', False), ('是', True)]
max_val = 0.8
v2 = True


def generate_seed():
Expand Down Expand Up @@ -128,8 +129,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
else:
logging.info('get instruct inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
if v2:
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
else:
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())


def main():
Expand Down Expand Up @@ -159,7 +165,7 @@ def main():

generate_button = gr.Button("生成音频")

audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=False)

seed_button.click(generate_seed, inputs=[], outputs=seed)
generate_button.click(generate_audio,
Expand All @@ -181,7 +187,7 @@ def main():
default='pretrained_models/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir)
cosyvoice, v2 = (CosyVoice2(args.model_dir),True) if 'CosyVoice2' in args.model_dir else (CosyVoice(args.model_dir),False)
sft_spk = cosyvoice.list_avaliable_spks()
prompt_sr = 16000
default_data = np.zeros(cosyvoice.sample_rate)
Expand Down