forked from addy999/onequery
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
232 lines (194 loc) · 7.64 KB
/
agent.py
File metadata and controls
232 lines (194 loc) · 7.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import re
import json
import ast
from pydantic import BaseModel
import litellm
from litellm import completion
import os
from functools import lru_cache
from datetime import datetime
# Configure LiteLLM with OpenAI-compatible API
LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL")
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY")
if LITELLM_BASE_URL:
litellm.api_base = LITELLM_BASE_URL
if LITELLM_API_KEY:
litellm.api_key = LITELLM_API_KEY
base_path = os.path.dirname(os.path.abspath(__file__))
RAW_SYSTEM_PROMPT = open(os.path.join(base_path, "agent_prompt.txt")).read()
# litellm.set_verbose = True
litellm.modify_params = True
def parse_text(text):
next_action_pattern = r"<next_action-1>\n(.*?)\n</next_action-1>"
next_action2_pattern = r"<next_action-2>\n(.*?)\n</next_action-2>"
explanation_pattern = r"<explanation>\n(.*?)\n</explanation>"
next_task_pattern = r"<next_task>\n(.*?)\n</next_task>"
next_action_match = re.search(next_action_pattern, text, re.DOTALL)
next_action2_match = re.search(next_action2_pattern, text, re.DOTALL)
explanation_match = re.search(explanation_pattern, text, re.DOTALL)
next_task_match = re.search(next_task_pattern, text, re.DOTALL)
result = {
"next_action": next_action_match.group(1) if next_action_match else None,
"next_action_2": (next_action2_match.group(1) if next_action2_match else None),
"explanation": explanation_match.group(1) if explanation_match else None,
"next_task": next_task_match.group(1) if next_task_match else None,
}
return result
def is_valid_json(string: str) -> bool:
try:
json.loads(string)
return True
except json.JSONDecodeError:
return False
def clean_up_json(string: str) -> str:
def extract_json_from_string(string):
start_index = string.find("{")
end_index = string.rfind("}")
if start_index != -1 and end_index != -1:
return string[start_index : end_index + 1]
return ""
cleaned = (
extract_json_from_string(string)
.strip()
.replace("\n", "")
.replace('\\"', '"')
.replace("```", "")
.replace("json", "")
)
# Check if there's a missing "}" at the end and add it
if cleaned.count("{") > cleaned.count("}"):
cleaned += "}"
if not is_valid_json(cleaned):
try:
cleaned = json.dumps(ast.literal_eval(cleaned))
except (ValueError, SyntaxError):
raise ValueError("String not valid", cleaned)
return cleaned
def get_reply(state, model: str = "gemini/gemini-2.5-pro") -> str:
"""
Get a reply from the LLM using OpenAI-compatible API.
Args:
state: The conversation state/history
model: The model name (will be prefixed with "openai/")
"""
today_date = datetime.now().strftime("%Y-%m-%d")
SYSTEM_PROMPT = f"{RAW_SYSTEM_PROMPT}\n\nToday's date: {today_date}"
# Prefix model with "openai/" for LiteLLM
full_model_name = f"openai/{model}"
reply = (
completion(
model=full_model_name,
# max_tokens=int(256 * 1.5),
messages=[{"role": "system", "content": SYSTEM_PROMPT}] + state,
temperature=0.5,
)
.choices[0]
.message.content
)
return parse_text(reply)
def summarize_text(
prompt: str, documents: list, schema: str, model: str = "gpt-4"
) -> str:
"""
Summarize documents using OpenAI-compatible API.
Args:
prompt: The prompt for summarization
documents: List of documents to summarize
schema: The JSON schema to use for output
model: The model name (will be prefixed with "openai/")
"""
full_model_name = f"openai/{model}"
return json.loads(
clean_up_json(
(
"{"
+ completion(
model=full_model_name,
max_tokens=1000,
temperature=0.3,
messages=[
{
"role": "system",
"content": f"""Summarize the following documents for this prompt in JSON format.
Prompt: {prompt}
Return using this schema: {schema}""",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": text,
}
for text in documents
],
},
{"role": "assistant", "content": "{"},
],
)
.choices[0]
.message.content
)
)
)
def fetch_query_for_rag(task: str, model: str = "gemini/gemini-2.5-flash") -> str:
"""
Generate a RAG query using OpenAI-compatible API.
Args:
task: The task description
model: The model name (will be prefixed with "openai/")
"""
full_model_name = f"openai/{model}"
response = clean_up_json(
"{"
+ completion(
model=full_model_name,
max_tokens=256,
temperature=0.3,
messages=[
{
"role": "user",
"content": "Generate a simple keyword/phrase query for a RAG system based on the following task. Return the query as JSON with 'query' key. The query should help fetch documents relevant to the task: "
+ task,
},
{"role": "assistant", "content": "{"},
],
)
.choices[0]
.message.content
)
return json.loads(response)["query"]
@lru_cache(maxsize=128, typed=True)
def find_schema_for_query(query: str, model: str = "gemini/gemini-2.5-flash") -> str:
"""
Find a JSON schema for a given query using OpenAI-compatible API.
Args:
query: The query to generate schema for
model: The model name (will be prefixed with "openai/")
"""
full_model_name = f"openai/{model}"
return clean_up_json(
completion(
model=full_model_name,
temperature=0.5,
max_tokens=512,
messages=[
{
"role": "system",
"content": """You're an expert in data science. You're helping a colleague form JSON schemas for their data. You're given a query and asked to find the schema for it.
Example:
Query: Find 2 recent issues from PyTorch repository.
Schema: {'properties': {'date': {'title': 'Date', 'type': 'string'}, 'title': {'title': 'Title', 'type': 'string'}, 'author': {'title': 'Author', 'type': 'string'}, 'description': {'title': 'Description', 'type': 'string'}}, 'required': ['date', 'title', 'author', 'description'], 'title': 'IssueModel', 'type': 'object'}
Example:
Query: Find 5 events happening in Bangalore this week.
Schema: {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'date': {'title': 'Date', 'type': 'string'}, 'location': {'title': 'Location', 'type': 'string'}}, 'required': ['name', 'date', 'location'], 'title': 'EventsModel', 'type': 'object'}""",
},
{
"role": "user",
"content": f"""Find the schema for the following query: {query}.""",
},
],
)
.choices[0]
.message.content
)