Skip to content

Commit

Permalink
Add JSON validation and error handling in agent.py and chatgpt2api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Dec 15, 2023
1 parent b62e628 commit cf4ae80
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
17 changes: 17 additions & 0 deletions test/test_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json

json_data = '{"prompt":"\n\n'
def check_json(json_data):
while True:
try:
json.loads(json_data)
break
except json.decoder.JSONDecodeError as e:
print("JSON error:", e)
print("JSON body", repr(json_data))
if "Invalid control character" in str(e):
json_data = json_data.replace("\n", "\\n")
if "Unterminated string starting" in str(e):
json_data += '"}'
return json_data
print(json.loads(check_json(json_data)))
17 changes: 17 additions & 0 deletions utils/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import json

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
Expand Down Expand Up @@ -305,6 +306,8 @@ def getddgsearchurl(result, numresults=3):
try:
search = DuckDuckGoSearchResults(num_results=numresults)
webresult = search.run(result)
if webresult == None:
return []
urls = re.findall(r"(https?://\S+)\]", webresult, re.MULTILINE)
except Exception as e:
print('\033[31m')
Expand Down Expand Up @@ -500,6 +503,20 @@ def get_search_results(prompt: str, context_max_tokens: int):

return useful_source_text

def check_json(json_data):
while True:
try:
json.loads(json_data)
break
except json.decoder.JSONDecodeError as e:
print("JSON error:", e)
print("JSON body", repr(json_data))
if "Invalid control character" in str(e):
json_data = json_data.replace("\n", "\\n")
if "Unterminated string starting" in str(e):
json_data += '"}'
return json_data

if __name__ == "__main__":
os.system("clear")

Expand Down
7 changes: 6 additions & 1 deletion utils/chatgpt2api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Set

import config
from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len
from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len, check_json
from utils.function_call import function_call_list

def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
Expand Down Expand Up @@ -539,6 +539,7 @@ def ask_stream(
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name)
json_post, message_token = self.truncate_conversation(prompt, role, convo_id, model, pass_history, **kwargs)
print(json_post)
print(self.conversation[convo_id])

if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106":
Expand Down Expand Up @@ -592,7 +593,11 @@ def ask_stream(
if "name" in delta["function_call"]:
function_call_name = delta["function_call"]["name"]
full_response += function_call_content
if full_response.count("\\n") > 2:
break
if need_function_call:
full_response = check_json(full_response)
print("full_response", full_response)
if not self.function_calls_counter.get(function_call_name):
self.function_calls_counter[function_call_name] = 1
else:
Expand Down

0 comments on commit cf4ae80

Please sign in to comment.