Skip to content

Commit

Permalink
优化:默认不启用回复模板
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikaros-521 committed Sep 30, 2024
2 parents ea21bdc + c02512d commit 0fa3bfb
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 21 deletions.
11 changes: 10 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
"enable": false,
"copywriting": "{username}说:{comment}"
},
"reply_template": {
"enable": false,
"username_max_len": 10,
"copywriting": [
"回复{username}{data}",
"{username}{data}"
],
"username_max_le": 10
},
"comment_log_type": "回答",
"visual_body": "其他",
"xuniren": {
Expand Down Expand Up @@ -570,7 +579,7 @@
},
"local_qa": {
"periodic_trigger": {
"enable": true,
"enable": false,
"periodic_time_min": 10,
"periodic_time_max": 30,
"trigger_num_min": 0,
Expand Down
11 changes: 10 additions & 1 deletion config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
"enable": false,
"copywriting": "{username}说:{comment}"
},
"reply_template": {
"enable": false,
"username_max_len": 10,
"copywriting": [
"回复{username}{data}",
"{username}{data}"
],
"username_max_le": 10
},
"comment_log_type": "回答",
"visual_body": "其他",
"xuniren": {
Expand Down Expand Up @@ -570,7 +579,7 @@
},
"local_qa": {
"periodic_trigger": {
"enable": true,
"enable": false,
"periodic_time_min": 10,
"periodic_time_max": 30,
"trigger_num_min": 0,
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def send(msg: SendMessage):

try:
tmp_json = msg.dict()
logger.info(f"API收到数据{tmp_json}")
logger.info(f"内部HTTP API send接口收到数据{tmp_json}")
data_json = tmp_json["data"]
if "type" not in data_json:
data_json["type"] = tmp_json["type"]
Expand Down Expand Up @@ -228,7 +228,7 @@ async def callback(msg: CallbackMessage):

try:
data_json = msg.dict()
logger.info(f"API收到数据{data_json}")
logger.info(f"内部HTTP API callback接口收到数据{data_json}")

# 音频播放完成
if data_json["type"] in ["audio_playback_completed"]:
Expand Down
2 changes: 1 addition & 1 deletion utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ async def voice_change_and_put_to_queue(message, voice_tmp_path):

return False

logger.info(f"{message['tts_type']}合成成功,合成内容:【{message['content']}】,输出到={voice_tmp_path}")
logger.info(f"[{message['tts_type']}]合成成功,合成内容:【{message['content']}】,音频存储在 {voice_tmp_path}")

await voice_change_and_put_to_queue(message, voice_tmp_path)

Expand Down
17 changes: 17 additions & 0 deletions utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,23 @@ def get_random_str_in_list_and_format(self, ori_content: str = None, ori_list: l

return {"ret": 0, "content": content}

def get_list_random_or_default(self, strings: list, default_value):
"""
从列表中随机选择一个字符串,如果列表为空,则返回默认值。
参数:
strings (list of str): 字符串列表。
default_value (str): 默认值。
返回:
str: 随机选择的字符串或默认值。
"""
if not strings: # 如果列表是空的
return default_value
else:
return random.choice(strings)

"""
.@@@ @@@ @@^ =@@@@@@@@ /@@ /@@ =@@@@@*,@@\]]]] ,@@@@@@@@@@@@* .@@@ @@/.\]`@@@ =@@\]]]]]]] =@@..@@@@@@@@@ =@@\ /@@^
Expand Down
77 changes: 65 additions & 12 deletions utils/my_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ def audio_synthesis_handle(self, data_json):
reread_top_priority 最高优先级-复读
talk 聊天(语音输入)
comment 弹幕
local_qa_text 本地问答文本
local_qa_audio 本地问答音频
song 歌曲
reread 复读
Expand Down Expand Up @@ -699,16 +700,16 @@ def audio_synthesis_handle(self, data_json):
if data_json["type"] not in My_handle.config.get("assistant_anchor", "type"):
return

# 1、匹配本地问答库 触发后不执行后面的其他功能
if My_handle.config.get("assistant_anchor", "local_qa", "text", "enable") == True:
# 1、匹配助播本地问答库 触发后不执行后面的其他功能
if My_handle.config.get("assistant_anchor", "local_qa", "text", "enable"):
# 根据类型,执行不同的问答匹配算法
if My_handle.config.get("assistant_anchor", "local_qa", "text", "format") == "text":
tmp = self.find_answer(data_json["content"], My_handle.config.get("assistant_anchor", "local_qa", "text", "file_path"), My_handle.config.get("assistant_anchor", "local_qa", "text", "similarity"))
else:
tmp = self.find_similar_answer(data_json["content"], My_handle.config.get("assistant_anchor", "local_qa", "text", "file_path"), My_handle.config.get("assistant_anchor", "local_qa", "text", "similarity"))

if tmp != None:
logger.info(f'触发本地问答库-文本 [{My_handle.config.get("assistant_anchor", "username")}]: {data_json["content"]}')
if tmp is not None:
logger.info(f'触发助播 本地问答库-文本 [{My_handle.config.get("assistant_anchor", "username")}]: {data_json["content"]}')
# 将问答库中设定的参数替换为指定内容,开发者可以自定义替换内容
# 假设有多个未知变量,用户可以在此处定义动态变量
variables = {
Expand Down Expand Up @@ -747,9 +748,9 @@ def audio_synthesis_handle(self, data_json):
return True

# 如果开启了助播功能,则根据当前播放内容的文本信息,进行助播音频的播放
if My_handle.config.get("assistant_anchor", "enable") == True:
if My_handle.config.get("assistant_anchor", "enable"):
# 2、匹配本地问答音频库 触发后不执行后面的其他功能
if My_handle.config.get("assistant_anchor", "local_qa", "audio", "enable") == True:
if My_handle.config.get("assistant_anchor", "local_qa", "audio", "enable"):
# 输出当前用户发送的弹幕消息
# logger.info(f"[{username}]: {content}")
# 获取本地问答音频库文件夹内所有的音频文件名
Expand Down Expand Up @@ -907,14 +908,14 @@ def local_qa_handle(self, data):
username = username[:self.config.get("local_qa", "text", "username_max_len")]

# 1、匹配本地问答库 触发后不执行后面的其他功能
if My_handle.config.get("local_qa", "text", "enable") == True:
if My_handle.config.get("local_qa", "text", "enable"):
# 根据类型,执行不同的问答匹配算法
if My_handle.config.get("local_qa", "text", "type") == "text":
tmp = self.find_answer(content, My_handle.config.get("local_qa", "text", "file_path"), My_handle.config.get("local_qa", "text", "similarity"))
else:
tmp = self.find_similar_answer(content, My_handle.config.get("local_qa", "text", "file_path"), My_handle.config.get("local_qa", "text", "similarity"))

if tmp != None:
if tmp is not None:
logger.info(f"触发本地问答库-文本 [{username}]: {content}")
# 将问答库中设定的参数替换为指定内容,开发者可以自定义替换内容
# 假设有多个未知变量,用户可以在此处定义动态变量
Expand All @@ -929,14 +930,32 @@ def local_qa_handle(self, data):

# [1|2]括号语法随机获取一个值,返回取值完成后的字符串
tmp = My_handle.common.brackets_text_randomize(tmp)

logger.info(f"本地问答库-文本回答为: {tmp}")

"""
# 判断 回复模板 是否启用
if My_handle.config.get("reply_template", "enable"):
# 根据模板变量关系进行回复内容的替换
# 假设有多个未知变量,用户可以在此处定义动态变量
variables = {
'username': data["username"][:self.config.get("reply_template", "username_max_len")],
'data': tmp,
'cur_time': My_handle.common.get_bj_time(5),
}
reply_template_copywriting = My_handle.common.get_list_random_or_default(self.config.get("reply_template", "copywriting"), "{data}")
# 使用字典进行字符串替换
if any(var in reply_template_copywriting for var in variables):
tmp = reply_template_copywriting.format(**{var: value for var, value in variables.items() if var in reply_template_copywriting})
logger.debug(f"回复模板转换后: {tmp}")
"""

resp_content = tmp
# 将 AI 回复记录到日志文件中
self.write_to_comment_log(resp_content, {"username": username, "content": content})


message = {
"type": "comment",
"tts_type": My_handle.config.get("audio_synthesis_type"),
Expand All @@ -957,7 +976,7 @@ def local_qa_handle(self, data):
return True

# 2、匹配本地问答音频库 触发后不执行后面的其他功能
if My_handle.config.get("local_qa")["audio"]["enable"] == True:
if My_handle.config.get("local_qa")["audio"]["enable"]:
# 输出当前用户发送的弹幕消息
# logger.info(f"[{username}]: {content}")
# 获取本地问答音频库文件夹内所有的音频文件名
Expand Down Expand Up @@ -1546,6 +1565,22 @@ def llm_handle(self, chat_type, data, type="chat", webui_show=True):
# 替换 \n换行符 \n字符串为空
resp_content = re.sub(r'\\n|\n', '', resp_content)

# 判断 回复模板 是否启用
if My_handle.config.get("reply_template", "enable"):
# 根据模板变量关系进行回复内容的替换
# 假设有多个未知变量,用户可以在此处定义动态变量
variables = {
'username': data["username"][:self.config.get("reply_template", "username_max_len")],
'data': resp_content,
'cur_time': My_handle.common.get_bj_time(5),
}

reply_template_copywriting = My_handle.common.get_list_random_or_default(self.config.get("reply_template", "copywriting"), "{data}")
# 使用字典进行字符串替换
if any(var in reply_template_copywriting for var in variables):
resp_content = reply_template_copywriting.format(**{var: value for var, value in variables.items() if var in reply_template_copywriting})


logger.debug(f"resp_content={resp_content}")

# 返回为空,触发异常报警
Expand All @@ -1554,7 +1589,7 @@ def llm_handle(self, chat_type, data, type="chat", webui_show=True):
logger.warning("LLM没有正确返回数据,请排查配置、网络等是否正常。如果排查后都没有问题,可能是接口改动导致的兼容性问题,可以前往官方仓库提交issue,传送门:https://github.com/Ikaros-521/AI-Vtuber/issues")

# 是否启用webui回显
if webui_show:
if webui_show and resp_content:
self.webui_show_chat_log_callback(chat_type, data, resp_content)

return resp_content
Expand Down Expand Up @@ -1634,7 +1669,25 @@ def split_by_chinese_punctuation(s):
return {"ret": False, "content1": s, "content2": ""}

if resp is not None:
# 流式开始拼接文本内容时,初始的临时文本存储变量
tmp = ""

# 判断 回复模板 是否启用
if My_handle.config.get("reply_template", "enable"):
# 根据模板变量关系进行回复内容的替换
# 假设有多个未知变量,用户可以在此处定义动态变量
variables = {
'username': data["username"][:self.config.get("reply_template", "username_max_len")],
'data': "",
'cur_time': My_handle.common.get_bj_time(5),
}

reply_template_copywriting = My_handle.common.get_list_random_or_default(self.config.get("reply_template", "copywriting"), "")
# 使用字典进行字符串替换
if any(var in reply_template_copywriting for var in variables):
tmp = reply_template_copywriting.format(**{var: value for var, value in variables.items() if var in reply_template_copywriting})


# 已经切掉的字符长度,针对一些特殊llm的流式输出,需要去掉前面的字符
cut_len = 0
for chunk in resp:
Expand Down
18 changes: 14 additions & 4 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ async def send(msg: SendMessage):

try:
data_json = msg.dict()
logger.info(f'send接口 收到数据:{data_json}')
logger.info(f'WEBUI API send接口收到数据:{data_json}')

main_api_ip = "127.0.0.1" if config.get("api_ip") == "0.0.0.0" else config.get("api_ip")
resp_json = await common.send_async_request(f'http://{main_api_ip}:{config.get("api_port")}/send', "POST", data_json)
Expand Down Expand Up @@ -815,7 +815,7 @@ async def send(msg: SendMessage):
async def callback(request: Request):
try:
data_json = await request.json()
logger.info(f'callback接口 收到数据:{data_json}')
logger.info(f'WEBUI API callback接口收到数据:{data_json}')

data_handle_show_chat_log(data_json)

Expand Down Expand Up @@ -898,7 +898,7 @@ async def callback(request: Request):
async def tts(request: Request):
try:
data_json = await request.json()
logger.info(f'tts接口 收到数据:{data_json}')
logger.info(f'WEBUI API tts接口收到数据:{data_json}')

resp_json = await audio.tts_handle(data_json)

Expand Down Expand Up @@ -1604,6 +1604,9 @@ def common_textarea_handle(content):
config_data["after_prompt"] = input_after_prompt.value
config_data["comment_template"]["enable"] = switch_comment_template_enable.value
config_data["comment_template"]["copywriting"] = input_comment_template_copywriting.value
config_data["reply_template"]["enable"] = switch_reply_template_enable.value
config_data["reply_template"]["username_max_le"] = int(input_reply_template_username_max_len.value)
config_data["reply_template"]["copywriting"] = common_textarea_handle(textarea_reply_template_copywriting.value)
config_data["audio_synthesis_type"] = select_audio_synthesis_type.value

# 哔哩哔哩
Expand Down Expand Up @@ -3212,7 +3215,14 @@ def save_config():
input_after_prompt = ui.input(label='提示词后缀', placeholder='此配置会追加在弹幕后,再发送给LLM处理', value=config.get("after_prompt")).style("width:200px;").tooltip('此配置会追加在弹幕后,再发送给LLM处理')
switch_comment_template_enable = ui.switch('启用弹幕模板', value=config.get("comment_template", "enable")).style(switch_internal_css).tooltip('此配置会追加在弹幕后,再发送给LLM处理')
input_comment_template_copywriting = ui.input(label='弹幕模板', value=config.get("comment_template", "copywriting"), placeholder='此配置会对弹幕内容进行修改,{}内为变量,会被替换为指定内容,请勿随意删除变量').style("width:200px;").tooltip('此配置会对弹幕内容进行修改,{}内为变量,会被替换为指定内容,请勿随意删除变量')

switch_reply_template_enable = ui.switch('启用回复模板', value=config.get("reply_template", "enable")).style(switch_internal_css).tooltip('此配置会在LLM输出的答案中进行回复内容的重新构建')
input_reply_template_username_max_len = ui.input(label='回复用户名的最大长度', value=config.get("reply_template", "username_max_len"), placeholder='回复用户名的最大长度').style("width:200px;").tooltip('回复用户名的最大长度')
textarea_reply_template_copywriting = ui.textarea(
label='回复模板',
placeholder='此配置会对LLM回复内容进行修改,{}内为变量,会被替换为指定内容,请勿随意删除变量',
value=textarea_data_change(config.get("reply_template", "copywriting"))
).style("width:500px;").tooltip('此配置会对LLM回复内容进行修改,{}内为变量,会被替换为指定内容,请勿随意删除变量')

with ui.card().style(card_css):
ui.label('平台相关')
with ui.card().style(card_css):
Expand Down

0 comments on commit 0fa3bfb

Please sign in to comment.