diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index d02b2760..edf745b8 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -269,13 +269,13 @@ def __init__( self.system_prompt: str = system_prompt self.max_tokens: int = max_tokens or ( 4096 - if "gpt-4-1106-preview" in engine + if "gpt-4-1106-preview" in engine or "gpt-3.5-turbo-1106" in engine else 31000 if "gpt-4-32k" in engine else 7000 if "gpt-4" in engine else 16385 - if "gpt-3.5-turbo-1106" in engine or "gpt-3.5-turbo-16k" in engine + if "gpt-3.5-turbo-16k" in engine else 99000 if "claude-2-web" in engine or "claude-2" in engine else 4000 @@ -392,13 +392,13 @@ def truncate_conversation( while True: json_post = self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs) url = config.bot_api_url.chat_url - if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106" or self.engine == "claude-2": + if self.engine == "gpt-4-1106-preview" or self.engine == "claude-2": message_token = { "total": self.get_token_count(convo_id), } else: message_token = self.get_message_token(url, json_post) - print("message_token", message_token, self.truncate_limit) + print("message_token", message_token, "truncate_limit", self.truncate_limit) if ( message_token["total"] > self.truncate_limit and len(self.conversation[convo_id]) > 1 @@ -454,7 +454,7 @@ def get_message_token(self, url, json_post): json=json_post, timeout=None, ) - # print(response.text) + # print("response.text", response.text) if response.status_code != 200: json_response = json.loads(response.text) string = json_response["error"]["message"] @@ -541,8 +541,10 @@ def ask_stream( print(json.dumps(json_post, indent=4, ensure_ascii=False)) # print(self.conversation[convo_id]) - if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106": + if self.engine == "gpt-4-1106-preview": model_max_tokens = kwargs.get("max_tokens", self.max_tokens) + elif self.engine == "gpt-3.5-turbo-1106": + model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), 16385 - message_token["total"]) else: model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), self.max_tokens - message_token["total"]) print("model_max_tokens", model_max_tokens)