From cf4ae80039cb48b8f371ca7ae91c6eb6a35d3c6c Mon Sep 17 00:00:00 2001 From: yym68686 Date: Fri, 15 Dec 2023 15:35:42 +0800 Subject: [PATCH] Add JSON validation and error handling in agent.py and chatgpt2api.py --- test/test_json.py | 17 +++++++++++++++++ utils/agent.py | 17 +++++++++++++++++ utils/chatgpt2api.py | 7 ++++++- 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/test_json.py diff --git a/test/test_json.py b/test/test_json.py new file mode 100644 index 00000000..719285c4 --- /dev/null +++ b/test/test_json.py @@ -0,0 +1,17 @@ +import json + +json_data = '{"prompt":"\n\n' +def check_json(json_data): + while True: + try: + json.loads(json_data) + break + except json.decoder.JSONDecodeError as e: + print("JSON error:", e) + print("JSON body", repr(json_data)) + if "Invalid control character" in str(e): + json_data = json_data.replace("\n", "\\n") + if "Unterminated string starting" in str(e): + json_data += '"}' + return json_data +print(json.loads(check_json(json_data))) diff --git a/utils/agent.py b/utils/agent.py index 6dda88aa..c12fdb18 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -1,5 +1,6 @@ import os import re +import json import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -305,6 +306,8 @@ def getddgsearchurl(result, numresults=3): try: search = DuckDuckGoSearchResults(num_results=numresults) webresult = search.run(result) + if webresult == None: + return [] urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE) except Exception as e: print('\033[31m') @@ -500,6 +503,20 @@ def get_search_results(prompt: str, context_max_tokens: int): return useful_source_text +def check_json(json_data): + while True: + try: + json.loads(json_data) + break + except json.decoder.JSONDecodeError as e: + print("JSON error:", e) + print("JSON body", repr(json_data)) + if "Invalid control character" in str(e): + json_data = json_data.replace("\n", "\\n") + if "Unterminated string starting" in str(e): + json_data += '"}' + return json_data + if __name__ == "__main__": os.system("clear") diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index defcc1d1..ff47b162 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -12,7 +12,7 @@ from typing import Set import config -from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len +from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len, check_json from utils.function_call import function_call_list def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: @@ -539,6 +539,7 @@ def ask_stream( self.reset(convo_id=convo_id, system_prompt=self.system_prompt) self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name) json_post, message_token = self.truncate_conversation(prompt, role, convo_id, model, pass_history, **kwargs) + print(json_post) print(self.conversation[convo_id]) if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106": @@ -592,7 +593,11 @@ def ask_stream( if "name" in delta["function_call"]: function_call_name = delta["function_call"]["name"] full_response += function_call_content + if full_response.count("\\n") > 2: + break if need_function_call: + full_response = check_json(full_response) + print("full_response", full_response) if not self.function_calls_counter.get(function_call_name): self.function_calls_counter[function_call_name] = 1 else: