diff --git a/config.json b/config.json index 9eb7e1cc..9c767fc6 100644 --- a/config.json +++ b/config.json @@ -569,6 +569,37 @@ "region": "japanwest", "voice_name": "zh-CN-XiaoyanNeural" }, + "fish_speech": { + "api_ip_port": "http://127.0.0.1:8000", + "model_name": "default", + "model_config": { + "device": "cuda", + "llama": { + "config_name": "text2semantic_finetune", + "checkpoint_path": "checkpoints/text2semantic-400m-v0.2-4k.pth", + "precision": "bfloat16", + "tokenizer": "fishaudio/speech-lm-v1", + "compile": true + }, + "vqgan": { + "config_name": "vqgan_pretrain", + "checkpoint_path": "checkpoints/vqgan-v1.pth" + } + }, + "tts_config": { + "prompt_text": "", + "prompt_tokens": "", + "max_new_tokens": 0, + "top_k": 3, + "top_p": 0.5, + "repetition_penalty": 1.5, + "temperature": 0.7, + "order": "zh,jp,en", + "use_g2p": true, + "seed": 1, + "speaker": "" + } + }, "choose_song": { "enable": true, "similarity": 0.5, @@ -1324,7 +1355,8 @@ "gradio_tts": true, "gpt_sovits": true, "clone_voice": true, - "azure_tts": true + "azure_tts": true, + "fish_speech": true }, "svc": { "ddsp_svc": true, diff --git a/config.json.bak b/config.json.bak index 9eb7e1cc..9c767fc6 100644 --- a/config.json.bak +++ b/config.json.bak @@ -569,6 +569,37 @@ "region": "japanwest", "voice_name": "zh-CN-XiaoyanNeural" }, + "fish_speech": { + "api_ip_port": "http://127.0.0.1:8000", + "model_name": "default", + "model_config": { + "device": "cuda", + "llama": { + "config_name": "text2semantic_finetune", + "checkpoint_path": "checkpoints/text2semantic-400m-v0.2-4k.pth", + "precision": "bfloat16", + "tokenizer": "fishaudio/speech-lm-v1", + "compile": true + }, + "vqgan": { + "config_name": "vqgan_pretrain", + "checkpoint_path": "checkpoints/vqgan-v1.pth" + } + }, + "tts_config": { + "prompt_text": "", + "prompt_tokens": "", + "max_new_tokens": 0, + "top_k": 3, + "top_p": 0.5, + "repetition_penalty": 1.5, + "temperature": 0.7, + "order": "zh,jp,en", + "use_g2p": true, + "seed": 1, + "speaker": "" + } + }, "choose_song": { "enable": true, "similarity": 0.5, @@ -1324,7 +1355,8 @@ "gradio_tts": true, "gpt_sovits": true, "clone_voice": true, - "azure_tts": true + "azure_tts": true, + "fish_speech": true }, "svc": { "ddsp_svc": true, diff --git a/tests/test_fish_speech/1.wav b/tests/test_fish_speech/1.wav new file mode 100644 index 00000000..ee4eeb0f Binary files /dev/null and b/tests/test_fish_speech/1.wav differ diff --git a/tests/test_fish_speech/api.py b/tests/test_fish_speech/api.py new file mode 100644 index 00000000..b166ca9c --- /dev/null +++ b/tests/test_fish_speech/api.py @@ -0,0 +1,106 @@ +import json, logging +import aiohttp, asyncio +from urllib.parse import urljoin + +async def fish_speech_load_model(data): + API_URL = urljoin(data["api_ip_port"], f'/v1/models/{data["model_name"]}') + + try: + async with aiohttp.ClientSession() as session: + async with session.put(API_URL, json=data["model_config"]) as response: + if response.status == 200: + ret = await response.json() + print(ret) + + if ret["name"] == data["model_name"]: + print(f'fish_speech模型加载成功: {ret["name"]}') + return ret + else: + return None + + except aiohttp.ClientError as e: + print(f'fish_speech请求失败: {e}') + except Exception as e: + print(f'fish_speech未知错误: {e}') + + return None + +async def fish_speech_api(data): + API_URL = urljoin(data["api_ip_port"], f'/v1/models/{data["model_name"]}/invoke') + + print(f"data={data}") + + def replace_empty_strings_with_none(input_dict): + for key, value in input_dict.items(): + if value == "": + input_dict[key] = None + return input_dict + + data["tts_config"] = replace_empty_strings_with_none(data["tts_config"]) + + print(f"data={data}") + + try: + async with aiohttp.ClientSession() as session: + async with session.post(API_URL, json=data["tts_config"]) as response: + if response.status == 200: + content = await response.read() + + # voice_tmp_path = os.path.join(self.audio_out_path, 'reecho_ai_' + self.common.get_bj_time(4) + '.wav') + # file_name = 'fish_speech_' + self.common.get_bj_time(4) + '.wav' + + # voice_tmp_path = self.common.get_new_audio_path(self.audio_out_path, file_name) + voice_tmp_path = "1.wav" + with open(voice_tmp_path, 'wb') as file: + file.write(content) + + return voice_tmp_path + else: + print(f'fish_speech下载音频失败: {response.status}') + return None + except aiohttp.ClientError as e: + print(f'fish_speech请求失败: {e}') + except Exception as e: + print(f'fish_speech未知错误: {e}') + + return None + + +data = { + "fish_speech": { + "api_ip_port": "http://127.0.0.1:8000", + "model_name": "default", + "model_config": { + "device": "cuda", + "llama": { + "config_name": "text2semantic_finetune", + "checkpoint_path": "checkpoints/text2semantic-400m-v0.2-4k.pth", + "precision": "bfloat16", + "tokenizer": "fishaudio/speech-lm-v1", + "compile": True + }, + "vqgan": { + "config_name": "vqgan_pretrain", + "checkpoint_path": "checkpoints/vqgan-v1.pth" + } + }, + "tts_config": { + "prompt_text": "", + "prompt_tokens": "", + "max_new_tokens": 0, + "top_k": 3, + "top_p": 0.5, + "repetition_penalty": 1.5, + "temperature": 0.7, + "order": "zh,jp,en", + "use_g2p": True, + "seed": 1, + "speaker": "" + } + } +} + +asyncio.run(fish_speech_load_model(data["fish_speech"])) + +data["fish_speech"]["tts_config"]["text"] = "你好" +asyncio.run(fish_speech_api(data["fish_speech"])) \ No newline at end of file diff --git a/utils/audio.py b/utils/audio.py index 26089565..c7fc47f2 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -863,6 +863,11 @@ async def voice_change_and_put_to_queue(message, voice_tmp_path): } voice_tmp_path = self.my_tts.azure_tts_api(data) + elif message["tts_type"] == "fish_speech": + data = message["data"] + data["tts_config"]["text"] = message["content"] + + voice_tmp_path = await self.my_tts.fish_speech_api(data) elif message["tts_type"] == "none": pass except Exception as e: @@ -1596,6 +1601,14 @@ async def audio_synthesis_use_local_config(self, content, audio_synthesis_type=" logging.debug(f"data={data}") voice_tmp_path = self.my_tts.azure_tts_api(data) + elif audio_synthesis_type == "fish_speech": + data = self.config.get("fish_speech") + data["tts_config"]["text"] = content + + logging.debug(f"data={data}") + + voice_tmp_path = await self.my_tts.fish_speech_api(data) + return voice_tmp_path diff --git a/utils/audio_handle/my_tts.py b/utils/audio_handle/my_tts.py index 0ba0f14a..5c88ba9c 100644 --- a/utils/audio_handle/my_tts.py +++ b/utils/audio_handle/my_tts.py @@ -746,4 +746,67 @@ def azure_tts_api(self, data): logging.error(traceback.format_exc()) logging.error(f'azure_tts未知错误: {e}') - return None \ No newline at end of file + return None + + + async def fish_speech_load_model(self, data): + API_URL = urljoin(data["api_ip_port"], f'/v1/models/{data["model_name"]}') + + try: + async with aiohttp.ClientSession() as session: + async with session.put(API_URL, json=data["model_config"]) as response: + if response.status == 200: + ret = await response.json() + logging.debug(ret) + + if ret["name"] == data["model_name"]: + logging.info(f'fish_speech模型加载成功: {ret["name"]}') + return ret + else: + return None + + except aiohttp.ClientError as e: + logging.error(f'fish_speech请求失败: {e}') + except Exception as e: + logging.error(f'fish_speech未知错误: {e}') + + return None + + async def fish_speech_api(self, data): + API_URL = urljoin(data["api_ip_port"], f'/v1/models/{data["model_name"]}/invoke') + + def replace_empty_strings_with_none(input_dict): + for key, value in input_dict.items(): + if value == "": + input_dict[key] = None + return input_dict + + data["tts_config"] = replace_empty_strings_with_none(data["tts_config"]) + + logging.debug(f"data={data}") + + try: + async with aiohttp.ClientSession() as session: + async with session.post(API_URL, json=data["tts_config"]) as response: + if response.status == 200: + content = await response.read() + + voice_tmp_path = os.path.join(self.audio_out_path, 'fish_speech_' + self.common.get_bj_time(4) + '.wav') + file_name = 'fish_speech_' + self.common.get_bj_time(4) + '.wav' + + voice_tmp_path = self.common.get_new_audio_path(self.audio_out_path, file_name) + + with open(voice_tmp_path, 'wb') as file: + file.write(content) + + return voice_tmp_path + else: + logging.error(f'fish_speech下载音频失败: {response.status}') + return None + except aiohttp.ClientError as e: + logging.error(f'fish_speech请求失败: {e}') + except Exception as e: + logging.error(f'fish_speech未知错误: {e}') + + return None + diff --git a/webui.py b/webui.py index 7b3dd81b..2d47a20f 100644 --- a/webui.py +++ b/webui.py @@ -1376,6 +1376,29 @@ def common_textarea_handle(content): config_data["azure_tts"]["region"] = input_azure_tts_region.value config_data["azure_tts"]["voice_name"] = input_azure_tts_voice_name.value + if config.get("webui", "show_card", "tts", "fish_speech"): + config_data["fish_speech"]["api_ip_port"] = input_fish_speech_api_ip_port.value + config_data["fish_speech"]["model_name"] = input_fish_speech_model_name.value + config_data["fish_speech"]["model_config"]["device"] = input_fish_speech_model_config_device.value + config_data["fish_speech"]["model_config"]["llama"]["config_name"] = input_fish_speech_model_config_llama_config_name.value + config_data["fish_speech"]["model_config"]["llama"]["checkpoint_path"] = input_fish_speech_model_config_llama_checkpoint_path.value + config_data["fish_speech"]["model_config"]["llama"]["precision"] = input_fish_speech_model_config_llama_precision.value + config_data["fish_speech"]["model_config"]["llama"]["tokenizer"] = input_fish_speech_model_config_llama_tokenizer.value + config_data["fish_speech"]["model_config"]["llama"]["compile"] = switch_fish_speech_model_config_llama_compile.value + config_data["fish_speech"]["model_config"]["vqgan"]["config_name"] = input_fish_speech_model_config_vqgan_config_name.value + config_data["fish_speech"]["model_config"]["vqgan"]["checkpoint_path"] = input_fish_speech_model_config_vqgan_checkpoint_path.value + config_data["fish_speech"]["tts_config"]["prompt_text"] = input_fish_speech_tts_config_prompt_text.value + config_data["fish_speech"]["tts_config"]["prompt_tokens"] = input_fish_speech_tts_config_prompt_tokens.value + config_data["fish_speech"]["tts_config"]["max_new_tokens"] = int(input_fish_speech_tts_config_max_new_tokens.value) + config_data["fish_speech"]["tts_config"]["top_k"] = int(input_fish_speech_tts_config_top_k.value) + config_data["fish_speech"]["tts_config"]["top_p"] = round(float(input_fish_speech_tts_config_top_p.value), 2) + config_data["fish_speech"]["tts_config"]["repetition_penalty"] = round(float(input_fish_speech_tts_config_repetition_penalty.value), 2) + config_data["fish_speech"]["tts_config"]["temperature"] = round(float(input_fish_speech_tts_config_temperature.value), 2) + config_data["fish_speech"]["tts_config"]["order"] = input_fish_speech_tts_config_order.value + config_data["fish_speech"]["tts_config"]["seed"] = int(input_fish_speech_tts_config_seed.value) + config_data["fish_speech"]["tts_config"]["speaker"] = input_fish_speech_tts_config_speaker.value + config_data["fish_speech"]["tts_config"]["use_g2p"] = switch_fish_speech_tts_config_use_g2p.value + """ SVC """ @@ -1657,6 +1680,7 @@ def common_textarea_handle(content): config_data["webui"]["show_card"]["tts"]["gpt_sovits"] = switch_webui_show_card_tts_gpt_sovits.value config_data["webui"]["show_card"]["tts"]["clone_voice"] = switch_webui_show_card_tts_clone_voice.value config_data["webui"]["show_card"]["tts"]["azure_tts"] = switch_webui_show_card_tts_azure_tts.value + config_data["webui"]["show_card"]["tts"]["fish_speech"] = switch_webui_show_card_tts_fish_speech.value config_data["webui"]["show_card"]["svc"]["ddsp_svc"] = switch_webui_show_card_svc_ddsp_svc.value config_data["webui"]["show_card"]["svc"]["so_vits_svc"] = switch_webui_show_card_svc_so_vits_svc.value @@ -1745,7 +1769,8 @@ def common_textarea_handle(content): 'gradio_tts': 'Gradio', 'gpt_sovits': 'GPT_SoVITS', 'clone_voice': 'clone-voice', - 'azure_tts': 'azure_tts' + 'azure_tts': 'azure_tts', + 'fish_speech': 'fish_speech' } # 聊天类型所有配置项 @@ -3191,7 +3216,43 @@ def clear_tts_common_audio_card(file_path): input_azure_tts_subscription_key = ui.input(label='密钥', value=config.get("azure_tts", "subscription_key"), placeholder='申请开通服务后,自然就看见了').style("width:200px;") input_azure_tts_region = ui.input(label='区域', value=config.get("azure_tts", "region"), placeholder='申请开通服务后,自然就看见了').style("width:200px;") input_azure_tts_voice_name = ui.input(label='说话人名', value=config.get("azure_tts", "voice_name"), placeholder='Speech Studio平台试听获取说话人名').style("width:200px;") - + + if config.get("webui", "show_card", "tts", "fish_speech"): + with ui.card().style(card_css): + ui.label("fish_speech") + with ui.row(): + input_fish_speech_api_ip_port = ui.input(label='API地址', value=config.get("fish_speech", "api_ip_port"), placeholder='程序启动后监听的地址').style("width:200px;") + input_fish_speech_model_name = ui.input(label='模型名', value=config.get("fish_speech", "model_name"), placeholder='需要加载的模型名').style("width:200px;") + + with ui.card().style(card_css): + ui.label("模型配置") + with ui.row(): + input_fish_speech_model_config_device = ui.input(label='device', value=config.get("fish_speech", "model_config", "device"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_model_config_llama_config_name = ui.input(label='config_name', value=config.get("fish_speech", "model_config", "llama", "config_name"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_model_config_llama_checkpoint_path = ui.input(label='checkpoint_path', value=config.get("fish_speech", "model_config", "llama", "checkpoint_path"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_model_config_llama_precision = ui.input(label='precision', value=config.get("fish_speech", "model_config", "llama", "precision"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_model_config_llama_tokenizer = ui.input(label='tokenizer', value=config.get("fish_speech", "model_config", "llama", "tokenizer"), placeholder='自行查阅').style("width:200px;") + switch_fish_speech_model_config_llama_compile = ui.switch('compile', value=config.get("fish_speech", "model_config", "llama", "compile")).style(switch_internal_css) + + input_fish_speech_model_config_vqgan_config_name = ui.input(label='config_name', value=config.get("fish_speech", "model_config", "vqgan", "config_name"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_model_config_vqgan_checkpoint_path = ui.input(label='checkpoint_path', value=config.get("fish_speech", "model_config", "vqgan", "checkpoint_path"), placeholder='自行查阅').style("width:200px;") + + with ui.card().style(card_css): + ui.label("TTS配置") + with ui.row(): + input_fish_speech_tts_config_prompt_text = ui.input(label='prompt_text', value=config.get("fish_speech", "tts_config", "prompt_text"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_prompt_tokens = ui.input(label='prompt_tokens', value=config.get("fish_speech", "tts_config", "prompt_tokens"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_max_new_tokens = ui.input(label='max_new_tokens', value=config.get("fish_speech", "tts_config", "max_new_tokens"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_top_k = ui.input(label='top_k', value=config.get("fish_speech", "tts_config", "top_k"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_top_p = ui.input(label='top_p', value=config.get("fish_speech", "tts_config", "top_p"), placeholder='自行查阅').style("width:200px;") + with ui.row(): + input_fish_speech_tts_config_repetition_penalty = ui.input(label='repetition_penalty', value=config.get("fish_speech", "tts_config", "repetition_penalty"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_temperature = ui.input(label='temperature', value=config.get("fish_speech", "tts_config", "temperature"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_order = ui.input(label='order', value=config.get("fish_speech", "tts_config", "order"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_seed = ui.input(label='seed', value=config.get("fish_speech", "tts_config", "seed"), placeholder='自行查阅').style("width:200px;") + input_fish_speech_tts_config_speaker = ui.input(label='speaker', value=config.get("fish_speech", "tts_config", "speaker"), placeholder='自行查阅').style("width:200px;") + switch_fish_speech_tts_config_use_g2p = ui.switch('use_g2p', value=config.get("fish_speech", "tts_config", "use_g2p")).style(switch_internal_css) + with ui.tab_panel(svc_page).style(tab_panel_css): if config.get("webui", "show_card", "svc", "ddsp_svc"): with ui.card().style(card_css): @@ -3791,7 +3852,7 @@ def update_echart_gift(): switch_webui_show_card_tts_gpt_sovits = ui.switch('gpt_sovits', value=config.get("webui", "show_card", "tts", "gpt_sovits")).style(switch_internal_css) switch_webui_show_card_tts_clone_voice = ui.switch('clone_voice', value=config.get("webui", "show_card", "tts", "clone_voice")).style(switch_internal_css) switch_webui_show_card_tts_azure_tts = ui.switch('azure_tts', value=config.get("webui", "show_card", "tts", "azure_tts")).style(switch_internal_css) - + switch_webui_show_card_tts_fish_speech = ui.switch('fish_speech', value=config.get("webui", "show_card", "tts", "fish_speech")).style(switch_internal_css) with ui.card().style(card_css): ui.label("变声") with ui.row():