Skip to content

Commit

Permalink
1. Add the balance of dalle3 when it is insufficient and display an e…
Browse files Browse the repository at this point in the history
…rror message.

2. Fixed bug: potential timeout

3. Fixed bug: when switching to the claude2 model.

4. Fixed bug: where claude2 may receive empty messages for string concatenation.

5. Fixed bug: truncate function call in search.
  • Loading branch information
yym68686 committed Dec 16, 2023
1 parent cf4ae80 commit ebdc93a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
9 changes: 7 additions & 2 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ async def image(update, context):
result += "当前 prompt 未能成功生成图片,可能因为版权,政治,色情,暴力,种族歧视等违反 OpenAI 的内容政策😣,换句话试试吧~"
elif "server is busy" in str(e):
result += "服务器繁忙,请稍后再试~"
elif "billing_hard_limit_reached" in str(e):
result += "当前账号余额不足~"
else:
result += f"`{e}`"
await context.bot.edit_message_text(chat_id=chatid, message_id=start_messageid, text=result, parse_mode='MarkdownV2', disable_web_page_preview=True)
Expand Down Expand Up @@ -225,7 +227,7 @@ async def delete_message(update, context, messageid, delay=10):
# ],
[
InlineKeyboardButton("claude-2", callback_data="claude-2"),
InlineKeyboardButton("claude-2-web", callback_data="claude-2-web"),
# InlineKeyboardButton("claude-2-web", callback_data="claude-2-web"),
],
[
InlineKeyboardButton("返回上一级", callback_data="返回上一级"),
Expand Down Expand Up @@ -571,4 +573,7 @@ async def post_init(application: Application) -> None:
print("WEB_HOOK:", WEB_HOOK)
application.run_webhook("127.0.0.1", PORT, webhook_url=WEB_HOOK)
else:
application.run_polling()
# application.run_polling()
time_out = 600
application.run_polling(read_timeout=time_out, write_timeout=time_out)
# application.run_polling(read_timeout=time_out, write_timeout=time_out, pool_timeout=time_out, connect_timeout=time_out, timeout=time_out)
3 changes: 1 addition & 2 deletions test/test_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def get_token_count(self, convo_id: str = "default") -> int:
raise NotImplementedError(
f"Engine {self.engine} is not supported. Select from {ENGINES}",
)
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["claude-2-web"] = "cl100k_base"
tiktoken.get_encoding("cl100k_base")
tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base"

encoding = tiktoken.encoding_for_model(self.engine)
Expand Down
4 changes: 4 additions & 0 deletions test/test_tikitoken.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import tiktoken
# tiktoken.get_encoding("cl100k_base")
tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
encoding = tiktoken.encoding_for_model("claude-2.1")
17 changes: 7 additions & 10 deletions utils/chatgpt2api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ def get_token_count(self, convo_id: str = "default") -> int:
raise NotImplementedError(
f"Engine {self.engine} is not supported. Select from {ENGINES}",
)
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["claude-2-web"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base"

encoding = tiktoken.encoding_for_model(self.engine)

num_tokens = 0
Expand Down Expand Up @@ -195,8 +192,9 @@ def ask_stream(
# print(line)
resp: dict = json.loads(line)
content = resp.get("completion")
full_response += content
yield content
if content:
full_response += content
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
# print(repr(self.conversation.Conversation(convo_id)))
# print("total tokens:", self.get_token_count(convo_id))
Expand Down Expand Up @@ -395,7 +393,7 @@ def truncate_conversation(
while True:
json_post = self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
url = config.bot_api_url.chat_url
if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106":
if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106" or self.engine == "claude-2":
message_token = {
"total": self.get_token_count(convo_id),
}
Expand Down Expand Up @@ -430,10 +428,9 @@ def get_token_count(self, convo_id: str = "default") -> int:
raise NotImplementedError(
f"Engine {self.engine} is not supported. Select from {ENGINES}",
)
# tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
# tiktoken.model.MODEL_TO_ENCODING["claude-2-web"] = "cl100k_base"
# tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base"
tiktoken.get_encoding("cl100k_base")
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base"

encoding = tiktoken.encoding_for_model(self.engine)

Expand Down Expand Up @@ -593,7 +590,7 @@ def ask_stream(
if "name" in delta["function_call"]:
function_call_name = delta["function_call"]["name"]
full_response += function_call_content
if full_response.count("\\n") > 2:
if full_response.count("\\n") > 2 or "}" in full_response:
break
if need_function_call:
full_response = check_json(full_response)
Expand Down

0 comments on commit ebdc93a

Please sign in to comment.