Skip to content

Commit

Permalink
Merge pull request #682 from Ikaros-521/owner
Browse files Browse the repository at this point in the history
修改qwen-alice为qwen,适配qwen官方库openai_api接口
  • Loading branch information
Ikaros-521 committed Mar 3, 2024
2 parents c81093d + 196531b commit d84870e
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 96 deletions.
7 changes: 3 additions & 4 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@
"api": "https://api.openai.com/v1",
"api_key": [
"替换为你的api-key"
],
"model": "gpt-3.5-turbo"
]
},
"chatgpt": {
"model": "gpt-3.5-turbo-0613",
Expand Down Expand Up @@ -186,7 +185,7 @@
"history_enable": true,
"history_max_len": 500
},
"alice": {
"qwen": {
"api_ip_port": "http://localhost:8000/v1/chat/completions",
"max_length": 4096,
"top_p": 0.5,
Expand Down Expand Up @@ -1290,7 +1289,7 @@
"claude": true,
"claude2": true,
"chatglm": true,
"alice": true,
"qwen": true,
"zhipu": true,
"chat_with_file": true,
"langchain_chatglm": true,
Expand Down
7 changes: 3 additions & 4 deletions config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@
"api": "https://api.openai.com/v1",
"api_key": [
"替换为你的api-key"
],
"model": "gpt-3.5-turbo"
]
},
"chatgpt": {
"model": "gpt-3.5-turbo-0613",
Expand Down Expand Up @@ -186,7 +185,7 @@
"history_enable": true,
"history_max_len": 500
},
"alice": {
"qwen": {
"api_ip_port": "http://localhost:8000/v1/chat/completions",
"max_length": 4096,
"top_p": 0.5,
Expand Down Expand Up @@ -1290,7 +1289,7 @@
"claude": true,
"claude2": true,
"chatglm": true,
"alice": true,
"qwen": true,
"zhipu": true,
"chat_with_file": true,
"langchain_chatglm": true,
Expand Down
3 changes: 3 additions & 0 deletions utils/gpt_model/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def chat(self, msg, sessionid):
# 调用 ChatGPT 接口生成回复消息
message = self.chat_with_gpt(session['msg'])

if message is None:
return None

# 如果返回的消息包含最大上下文长度限制,则删除超长上下文并重试
if message.__contains__("This model's maximum context length is 409"):
del session['msg'][0:3]
Expand Down
4 changes: 2 additions & 2 deletions utils/gpt_model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging

from utils.gpt_model.chatglm import Chatglm
from utils.gpt_model.alice import Qwen_alice
from utils.gpt_model.qwen import Qwen
from utils.gpt_model.chatgpt import Chatgpt
from utils.gpt_model.claude import Claude
from utils.gpt_model.claude2 import Claude2
Expand All @@ -37,7 +37,7 @@ def set_model_config(self, model_name, config):
"claude": Claude,
"claude2": Claude2,
"chatglm": Chatglm,
"alice": Qwen_alice,
"qwen": Qwen,
"text_generation_webui": TEXT_GENERATION_WEBUI,
"sparkdesk": SPARKDESK,
"langchain_chatglm": Langchain_ChatGLM,
Expand Down
64 changes: 9 additions & 55 deletions utils/gpt_model/alice.py → utils/gpt_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,6 @@
from utils.common import Common
from utils.logger import Configure_logger

tools = [
{
'name_for_human': '光之剑',
'name_for_model': 'sword_of_light',
'description_for_model': '光之剑是爱丽丝的武器,可以用光之剑发射电磁炮攻击敌人,当遭遇敌人时使用。',
'parameters': [
{
'name': 'target',
'description': '攻击目标的名字',
'required': True,
'schema': {'type': 'string'},
}
],
},
{
'name_for_human': '移动到其他地点',
'name_for_model': 'move',
'description_for_model': '离开当前场景,去往其他地点。',
'parameters': [
{
'name': 'to',
'description': '接下来要前往的场景或地点的名称',
'required': False,
'schema': {'type': 'string'},
}
],
},
]


def remove_emotion(message: str) -> str:
"""
Expand Down Expand Up @@ -71,7 +42,7 @@ def remove_action(line: str) -> str:
return line


class Qwen_alice:
class Qwen:

def __init__(self, data):
self.common = Common()
Expand All @@ -85,40 +56,24 @@ def __init__(self, data):
self.temperature = data["temperature"]
self.history_enable = data["history_enable"]
self.history_max_len = data["history_max_len"]
self.functions = tools
self.preset = data["preset"]
self.history = []


def construct_query(self, user_name, prompt: str, **kwargs) -> Dict:
"""构造请求体
"""
embedding = ""
for key, value in kwargs.items():
if key == "embedding":
embedding = value
if user_name == "悪魔sama":
user_name = "老师"
else:
user_name = "观众“" + user_name + "”"
messages = self.history + [{"role": "user", "content": f"({user_name}说){prompt}"}]

messages = self.history + [{"role": "user", "content": f"{prompt}"}]
query = {
"functions": self.functions,
"model": "gpt-3.5-turbo",
"messages": messages,
"embeddings": embedding,
"temperature": self.temperature,
"top_p": self.top_p,
"stream": False, # 不启用流式API
}
# 查找提示信息的位置,不加入历史
tip_p = prompt.rfind("\n(当前时间:")
if tip_p >= 0:
raw_prompt = prompt[:tip_p]
else:
raw_prompt = prompt

self.history = self.history + [{"role": "user", "content": raw_prompt}]

self.history = self.history + [{"role": "user", "content": prompt}]
return query


Expand Down Expand Up @@ -146,7 +101,7 @@ def construct_observation(self, prompt: str, **kwargs) -> Dict:
# 调用chatglm接口,获取返回内容
def get_resp(self, user_name, prompt):
# construct query
query = self.construct_query(user_name, prompt, embedding=f"{self.preset}\n爱丽丝的状态栏:职业:勇者;经验值:0/100;生命值:1000;攻击力:100;持有的财富:100点信用积分;装备:“光之剑”(电磁炮);持有的道具:['光之剑']。")
query = self.construct_query(user_name, prompt)

try:
response = requests.post(url=self.api_ip_port, json=query)
Expand All @@ -160,9 +115,8 @@ def get_resp(self, user_name, prompt):
finish_reason = ret['choices'][0]['finish_reason']
if finish_reason != "":
predictions = ret['choices'][0]['message']['content'].strip()
thought = ret['choices'][0]['thought'].strip()
self.history = self.history + [
{"role": "assistant", "content": f"Thought: {thought}\nFinal Answer: {predictions}"}]
{"role": "assistant", "content": f"{predictions}"}]

# 启用历史就给我记住!
if self.history_enable:
Expand All @@ -180,7 +134,7 @@ def get_resp(self, user_name, prompt):


if __name__ == "__main__":
llm = Qwen_alice
llm = Qwen
llm.__init__(llm,
{"api_ip_port": "http://localhost:8000/v1/chat/completions",
"max_length": 4096,
Expand All @@ -189,5 +143,5 @@ def get_resp(self, user_name, prompt):
"max_new_tokens": 250,
"history_enable": True,
"history_max_len": 20})
resp = llm.get_resp(self=llm, prompt="(老师说)邦邦咔邦")
resp = llm.get_resp(self=llm, prompt="你好")
print(resp)
6 changes: 3 additions & 3 deletions utils/my_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, config_path):
self.claude = None
self.claude2 = None
self.chatglm = None
self.alice = None
self.qwen = None
self.chat_with_file = None
self.text_generation_webui = None
self.sparkdesk = None
Expand All @@ -133,7 +133,7 @@ def __init__(self, config_path):
self.qanything = None
self.koboldcpp = None

self.chat_type_list = ["chatgpt", "claude", "claude2", "chatglm", "chat_with_file", "text_generation_webui", \
self.chat_type_list = ["chatgpt", "claude", "claude2", "chatglm", "qwen", "chat_with_file", "text_generation_webui", \
"sparkdesk", "langchain_chatglm", "langchain_chatchat", "zhipu", "bard", "yiyan", "tongyi", \
"tongyixingchen", "my_qianfan", "my_wenxinworkshop", "gemini", "qanything", "koboldcpp"]

Expand Down Expand Up @@ -1130,7 +1130,7 @@ def llm_handle(self, chat_type, data):
"claude2": lambda: self.claude2.get_resp(data["content"]),
"chatterbot": lambda: self.bot.get_response(data["content"]).text,
"chatglm": lambda: self.chatglm.get_resp(data["content"]),
"alice": lambda: self.alice.get_resp(data["user_name"], data["content"]),
"qwen": lambda: self.qwen.get_resp(data["user_name"], data["content"]),
"chat_with_file": lambda: self.chat_with_file.get_model_resp(data["content"]),
"text_generation_webui": lambda: self.text_generation_webui.get_resp(data["content"]),
"sparkdesk": lambda: self.sparkdesk.get_resp(data["content"]),
Expand Down
56 changes: 28 additions & 28 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,14 +1056,14 @@ def common_textarea_handle(content):
config_data["chatglm"]["history_enable"] = switch_chatglm_history_enable.value
config_data["chatglm"]["history_max_len"] = int(input_chatglm_history_max_len.value)

if config.get("webui", "show_card", "llm", "alice"):
config_data["alice"]["api_ip_port"] = input_alice_api_ip_port.value
config_data["alice"]["max_length"] = int(input_alice_max_length.value)
config_data["alice"]["top_p"] = round(float(input_alice_top_p.value), 1)
config_data["alice"]["temperature"] = round(float(input_alice_temperature.value), 2)
config_data["alice"]["history_enable"] = switch_alice_history_enable.value
config_data["alice"]["history_max_len"] = int(input_alice_history_max_len.value)
config_data["alice"]["preset"] = input_alice_preset.value
if config.get("webui", "show_card", "llm", "qwen"):
config_data["qwen"]["api_ip_port"] = input_qwen_api_ip_port.value
config_data["qwen"]["max_length"] = int(input_qwen_max_length.value)
config_data["qwen"]["top_p"] = round(float(input_qwen_top_p.value), 1)
config_data["qwen"]["temperature"] = round(float(input_qwen_temperature.value), 2)
config_data["qwen"]["history_enable"] = switch_qwen_history_enable.value
config_data["qwen"]["history_max_len"] = int(input_qwen_history_max_len.value)
config_data["qwen"]["preset"] = input_qwen_preset.value

if config.get("webui", "show_card", "llm", "chat_with_file"):
config_data["chat_with_file"]["chat_mode"] = select_chat_with_file_chat_mode.value
Expand Down Expand Up @@ -1622,7 +1622,7 @@ def common_textarea_handle(content):
config_data["webui"]["show_card"]["llm"]["claude"] = switch_webui_show_card_llm_claude.value
config_data["webui"]["show_card"]["llm"]["claude2"] = switch_webui_show_card_llm_claude2.value
config_data["webui"]["show_card"]["llm"]["chatglm"] = switch_webui_show_card_llm_chatglm.value
config_data["webui"]["show_card"]["llm"]["alice"] = switch_webui_show_card_llm_alice.value
config_data["webui"]["show_card"]["llm"]["qwen"] = switch_webui_show_card_llm_qwen.value
config_data["webui"]["show_card"]["llm"]["zhipu"] = switch_webui_show_card_llm_zhipu.value
config_data["webui"]["show_card"]["llm"]["chat_with_file"] = switch_webui_show_card_llm_chat_with_file.value
config_data["webui"]["show_card"]["llm"]["langchain_chatglm"] = switch_webui_show_card_llm_langchain_chatglm.value
Expand Down Expand Up @@ -1752,7 +1752,7 @@ def common_textarea_handle(content):
'claude': 'Claude',
'claude2': 'Claude2',
'chatglm': 'ChatGLM',
'alice': 'Qwen-Alice',
'qwen': 'Qwen',
'chat_with_file': 'chat_with_file',
'chatterbot': 'Chatterbot',
'text_generation_webui': 'text_generation_webui',
Expand Down Expand Up @@ -2323,23 +2323,23 @@ def common_textarea_handle(content):
input_chatglm_history_max_len = ui.input(label='最大记忆长度', placeholder='最大记忆的上下文字符数量,不建议设置过大,容易爆显存,自行根据情况配置', value=config.get("chatglm", "history_max_len"))
input_chatglm_history_max_len.style("width:200px")

if config.get("webui", "show_card", "llm", "alice"):
with ui.card().style(card_css):
ui.label("Qwen-Alice")
with ui.row():
input_alice_api_ip_port = ui.input(label='API地址', placeholder='ChatGLM的API版本运行后的服务链接(需要完整的URL)', value=config.get("alice", "api_ip_port"))
input_alice_api_ip_port.style("width:400px")
input_alice_max_length = ui.input(label='最大长度限制', placeholder='生成回答的最大长度限制,以令牌数或字符数为单位。', value=config.get("alice", "max_length"))
input_alice_max_length.style("width:200px")
input_alice_top_p = ui.input(label='前p个选择', placeholder='也称为 Nucleus采样。控制模型生成时选择概率的阈值范围。', value=config.get("alice", "top_p"))
input_alice_top_p.style("width:200px")
input_alice_temperature = ui.input(label='温度', placeholder='温度参数,控制生成文本的随机性。较高的温度值会产生更多的随机性和多样性。', value=config.get("alice", "temperature"))
input_alice_temperature.style("width:200px")
with ui.row():
switch_alice_history_enable = ui.switch('上下文记忆', value=config.get("alice", "history_enable")).style(switch_internal_css)
input_alice_history_max_len = ui.input(label='最大记忆轮数', placeholder='最大记忆的上下文轮次数量,不建议设置过大,容易爆显存,自行根据情况配置', value=config.get("alice", "history_max_len"))
input_alice_history_max_len.style("width:200px")
input_alice_preset = ui.input(label='预设',
if config.get("webui", "show_card", "llm", "qwen"):
with ui.card().style(card_css):
ui.label("Qwen")
with ui.row():
input_qwen_api_ip_port = ui.input(label='API地址', placeholder='ChatGLM的API版本运行后的服务链接(需要完整的URL)', value=config.get("qwen", "api_ip_port"))
input_qwen_api_ip_port.style("width:400px")
input_qwen_max_length = ui.input(label='最大长度限制', placeholder='生成回答的最大长度限制,以令牌数或字符数为单位。', value=config.get("qwen", "max_length"))
input_qwen_max_length.style("width:200px")
input_qwen_top_p = ui.input(label='前p个选择', placeholder='也称为 Nucleus采样。控制模型生成时选择概率的阈值范围。', value=config.get("qwen", "top_p"))
input_qwen_top_p.style("width:200px")
input_qwen_temperature = ui.input(label='温度', placeholder='温度参数,控制生成文本的随机性。较高的温度值会产生更多的随机性和多样性。', value=config.get("qwen", "temperature"))
input_qwen_temperature.style("width:200px")
with ui.row():
switch_qwen_history_enable = ui.switch('上下文记忆', value=config.get("qwen", "history_enable")).style(switch_internal_css)
input_qwen_history_max_len = ui.input(label='最大记忆轮数', placeholder='最大记忆的上下文轮次数量,不建议设置过大,容易爆显存,自行根据情况配置', value=config.get("qwen", "history_max_len"))
input_qwen_history_max_len.style("width:200px")
input_qwen_preset = ui.input(label='预设',
placeholder='用于指定一组预定义的设置,以便模型更好地适应特定的对话场景。',
value=config.get("chatgpt", "preset")).style("width:500px")

Expand Down Expand Up @@ -3739,7 +3739,7 @@ def update_echart_gift():
switch_webui_show_card_llm_claude = ui.switch('claude', value=config.get("webui", "show_card", "llm", "claude")).style(switch_internal_css)
switch_webui_show_card_llm_claude2 = ui.switch('claude2', value=config.get("webui", "show_card", "llm", "claude2")).style(switch_internal_css)
switch_webui_show_card_llm_chatglm = ui.switch('chatglm', value=config.get("webui", "show_card", "llm", "chatglm")).style(switch_internal_css)
switch_webui_show_card_llm_alice = ui.switch('Qwen-Alice', value=config.get("webui", "show_card", "llm", "alice")).style(switch_internal_css)
switch_webui_show_card_llm_qwen = ui.switch('Qwen', value=config.get("webui", "show_card", "llm", "qwen")).style(switch_internal_css)
switch_webui_show_card_llm_zhipu = ui.switch('智谱AI', value=config.get("webui", "show_card", "llm", "zhipu")).style(switch_internal_css)
switch_webui_show_card_llm_chat_with_file = ui.switch('chat_with_file', value=config.get("webui", "show_card", "llm", "chat_with_file")).style(switch_internal_css)
switch_webui_show_card_llm_langchain_chatglm = ui.switch('langchain_chatglm', value=config.get("webui", "show_card", "llm", "langchain_chatglm")).style(switch_internal_css)
Expand Down

0 comments on commit d84870e

Please sign in to comment.