-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathquery_router.py
More file actions
437 lines (362 loc) · 17 KB
/
query_router.py
File metadata and controls
437 lines (362 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
"""
Query Router module for Max Discord Bot.
This module classifies incoming queries to determine whether they need web research
or can be answered with standard model knowledge.
"""
from typing import Dict, Any, Tuple, List, Optional, Callable
import json
from logger import router_logger
class QueryRouter:
"""
Routes queries to the appropriate LLM based on content analysis.
Uses Perplexity for web research queries, Gemini for standard knowledge.
"""
# Common greetings list used across methods
COMMON_GREETINGS = [
"hey",
"hello",
"hi",
"sup",
"yo",
"greetings",
"hiya",
"howdy",
"hey max",
"hello max",
"hi max",
"yo max",
"howdy max",
"hey!",
"hello!",
"hi!",
"sup!",
"yo!",
"howdy!",
"good morning",
"good afternoon",
"good evening",
"morning",
"afternoon",
"evening",
"what's up",
"whats up",
"what up",
"hey there",
"hello there",
"hi there",
"heya",
"heyy",
"hiii",
"hiiii",
"heyyy",
"hellooo",
"wassup",
"what is up",
"what's happening",
"whats happening",
]
def __init__(self,
perplexity_model="sonar-pro",
gemini_model="gemini-2.0-flash-001",
classifier_model="claude-3-7-sonnet-20250219"):
"""
Initialize the QueryRouter with model preferences.
Args:
perplexity_model: The Perplexity model to use for web research
gemini_model: The Gemini model to use for standard responses
classifier_model: The model to use for query classification
"""
self.perplexity_model = perplexity_model
self.gemini_model = gemini_model
self.classifier_model = classifier_model
router_logger.debug(f"QueryRouter initialized with models: perplexity={perplexity_model}, gemini={gemini_model}, classifier={classifier_model}")
def _is_greeting(self, query: str) -> bool:
"""
Check if a query is a common greeting.
Args:
query: The user's message
Returns:
Boolean indicating if the query is a greeting
"""
query_lower = query.lower().strip()
# Check if the query exactly matches a greeting
if query_lower in self.COMMON_GREETINGS:
return True
# Check if the query starts with a greeting
for greeting in self.COMMON_GREETINGS:
if query_lower.startswith(greeting + " "):
return True
# Check for short queries that might be conversational starters
if len(query_lower.split()) <= 3:
# Additional check for variations of greetings with emojis or punctuation
for greeting in ["hey", "hello", "hi", "sup", "yo", "heya"]:
if greeting in query_lower:
return True
return False
async def _process_llm_response(
self, response, expected_keys: List[str], default_values: Dict[str, Any]
) -> Dict[str, Any]:
"""
Process LLM response and handle JSON parsing with fallback logic.
Args:
response: The response from the LLM
expected_keys: List of keys expected in the JSON response
default_values: Dictionary of default values for keys
Returns:
Dictionary with parsed values or defaults
"""
try:
response_text = response.content
try:
result = json.loads(response_text)
return {
key: result.get(key, default_values[key]) for key in expected_keys
}
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON
router_logger.warning(
"Failed to parse JSON response from LLM, using fallback"
)
result = {}
for key in expected_keys:
# Check if the key and "true" are both in the response text
result[key] = (
"true" in response_text.lower() and key in response_text.lower()
)
# If the default is True but the response mentions "false" with the key, override to False
if (
default_values[key] is True
and "false" in response_text.lower()
and key in response_text.lower()
):
result[key] = False
return result
except Exception as e:
router_logger.error(f"Error processing LLM response: {e}", exc_info=True)
return default_values
async def classify_with_llm(self, query: str, client) -> Dict:
"""
Use an LLM to determine if a query requires web research and if it's AI-related.
Args:
query: The user's question or message
client: The LLM client to use for classification
Returns:
Dictionary with classification results:
- needs_web_research: Boolean indicating if web research is needed
- is_ai_related: Boolean indicating if the query is related to AI/technology
"""
router_logger.debug(f"Classifying query: '{query[:50]}...'")
# Check for common greetings first - these don't need web research
if self._is_greeting(query):
router_logger.debug("Query is a common greeting, no web research needed")
return {"needs_web_research": False, "is_ai_related": True}
prompt = f"""
You are Max, an AI assistant for the Maxpool Discord server focused on generative AI topics. You have been created by the Maxpool community.
Analyze the following user query to determine:
1. If it requires up-to-date information from the web
2. If it is related to LLMs, AI, machine learning, coding, or technology topics
Query: "{query}"
For determining if web research is needed:
- Does it request specific data, statistics, or factual information that changes frequently?
- Does it ask about comparisons, prices, or reviews of AI tools/models that need current data?
- Does it request technical information about recent AI software, products, or services?
- Does it ask about recent research papers, model releases, or AI developments?
- Does it involve recent product releases, updates, or version comparisons of AI tools?
For determining if it's LLM/AI/technology related, it should match any of the following criteria:
- Is it about LLMs, AI models, tools, techniques, or concepts?
- Is it about specialized AI components like embeddings, retrievers, rerankers, RAG systems, vector databases?
- Is it about model training, fine-tuning, inference, or optimization techniques?
- Is it about programming, coding, or software development?
Respond with a JSON object containing:
1. "needs_web_research": boolean value (true/false)
2. "is_ai_related": boolean value (true/false)
"""
try:
router_logger.debug("Invoking LLM for query classification")
response = client.invoke(prompt)
expected_keys = ["needs_web_research", "is_ai_related"]
default_values = {"needs_web_research": False, "is_ai_related": True}
return await self._process_llm_response(
response, expected_keys, default_values
)
except Exception as e:
router_logger.error(f"Error classifying query: {e}", exc_info=True)
return {
"needs_web_research": False,
"is_ai_related": True # Default to True in case of errors
}
async def route_query(self, query: str, llm_client) -> Tuple[str, str, Dict[str, Any], bool]:
"""
Determine which LLM should handle the given query.
Args:
query: The user's question or message
llm_client: The client to use for query classification
Returns:
Tuple containing:
- provider name ("perplexity" or "google")
- model name
- additional parameters for the LLM
- boolean indicating if the query is AI/technology related
"""
router_logger.debug(f"Routing query: '{query[:50]}...'")
classification = await self.classify_with_llm(query, llm_client)
# Check if query is AI/technology related
is_ai_related = classification.get("is_ai_related", True)
# If non-AI query, return Google with default parameters and is_ai_related = False
if not is_ai_related:
router_logger.info("Query classified as non-AI related, routing to Google")
return "google", self.gemini_model, {"temperature": 0.2}, False
# For AI-related queries, route based on web research needs
if classification.get("needs_web_research", False):
router_logger.info("Query needs web research, routing to Perplexity")
return "perplexity", self.perplexity_model, {"temperature": 0.0}, True
else:
router_logger.info("Query does not need web research, routing to Google")
return "google", self.gemini_model, {"temperature": 0.2}, True
async def should_ask_clarification(self, query: str, llm_client) -> bool:
"""
Uses LLM to determine if the query is too vague and requires clarification.
Args:
query: The user's question or message
llm_client: The client to use for classification
Returns:
Boolean indicating if clarification is needed
"""
router_logger.debug(f"Checking if query needs clarification: '{query[:50]}...'")
# Check for common greetings first - never ask for clarification for these
if self._is_greeting(query):
router_logger.debug("Query is a greeting, no clarification needed")
return False
# Check for short queries that might be conversational starters
if len(query.strip().split()) <= 3:
router_logger.debug(
"Short query, potential greeting, no clarification needed"
)
return False
prompt = f"""
Determine if the following user query is too vague and requires clarification before providing a helpful response.
Query: "{query}"
Consider:
- Is it extremely short (less than 3 words)?
- Is it ambiguous with multiple possible interpretations?
- Does it lack necessary context or specificity?
- Is it overly broad or general?
- IMPORTANT: Is it just a casual greeting like "hey", "hello", "hi", etc.? If so, do NOT ask for clarification.
- IMPORTANT: If it's a simple greeting or conversation starter, do NOT ask for clarification.
IMPORTANT: Respond with ONLY a JSON object and NOTHING ELSE. Your JSON response MUST be in exactly this format:
{{
"needs_clarification": true/false
}}
"""
try:
router_logger.debug("Invoking LLM to determine if clarification is needed")
response = llm_client.invoke(prompt)
expected_keys = ["needs_clarification"]
default_values = {"needs_clarification": False}
result = await self._process_llm_response(
response, expected_keys, default_values
)
return result.get("needs_clarification", False)
except Exception as e:
router_logger.error(f"Error checking for clarification need: {e}", exc_info=True)
return False
async def detect_coreference(self, message: str, llm_client) -> bool:
"""
Use an LLM to determine if a message contains coreferences (e.g., "this paper", "this question")
that would benefit from additional context.
Args:
message: The user's message
llm_client: The LLM client to use for detection
Returns:
Boolean indicating if the message contains coreferences that need context
"""
router_logger.debug(f"Checking for coreferences in message: '{message[:50]}...'")
# If the message is very short (just mentioning the bot), it's likely a reference
if len(message.strip()) < 5:
router_logger.debug("Message too short, treating as coreference")
return True
prompt = f"""
Determine if the following message contains coreferences that would benefit from additional context.
Message: "{message}"
Consider the following:
- Does it include demonstrative pronouns like "this", "that", "these", "those" without clear antecedents?
- Does it refer to "this paper", "this article", "this topic", "this code", etc. without specifying which one?
- Does it contain phrases like "thoughts?", "your take?", "what do you think?" without sufficient context?
- Is it asking for an opinion, analysis, or evaluation of something that isn't fully specified?
- Is the message very short but seems to expect knowledge of a previous context?
- Does it use pronouns (it, they, them) without clear references?
IMPORTANT: Respond with ONLY a JSON object and NOTHING ELSE. Your JSON response MUST be in exactly this format:
{{
"is_coreference": true/false
}}
"""
try:
router_logger.debug("Invoking LLM for coreference detection")
response = llm_client.invoke(prompt)
expected_keys = ["is_coreference"]
default_values = {"is_coreference": False}
result = await self._process_llm_response(
response, expected_keys, default_values
)
return result.get("is_coreference", False)
except Exception as e:
router_logger.error(f"Error detecting coreference: {e}", exc_info=True)
return False
async def should_reply_in_thread(
self, message: str, context_messages: list, llm_client
) -> bool:
"""
Determine if a message in a thread is addressed to the bot and requires a response.
Args:
message: The user's message
context_messages: Recent messages in the thread for context
llm_client: The LLM client to use for classification
Returns:
Boolean indicating if the bot should respond to this message
"""
router_logger.debug(f"Checking if bot should reply to: '{message[:50]}...'")
# Always reply to direct questions and commands
if message.lower().startswith("@max"):
router_logger.debug(
"Message is a question or direct address to Max, should reply"
)
return True
prompt = f"""
Analyze this message in a thread to determine if it requires a response from Max (an AI assistant).
Message: "{message}"
Recent thread context (newest last):
{json.dumps([f"{msg['author']}: {msg['content'][:100]}..." for msg in context_messages[-5:] if msg])}
Determine if:
- The message is asking Max a follow-up question
- The message is directly addressing Max
- The message contains a command or request for Max
- The message is expecting a response from Max
DO NOT respond if:
- The message is clearly addressed to another person
- The message is a thank you or acknowledgment of Max's previous response
- The message is users talking to each other (not to Max)
- The message appears to be part of a conversation between users
- The message is just sharing information without asking anything
IMPORTANT: Respond with ONLY a JSON object and NOTHING ELSE:
{{
"should_reply": true/false
}}
"""
try:
router_logger.debug("Invoking LLM to check if message requires a response")
response = llm_client.invoke(prompt)
expected_keys = ["should_reply"]
default_values = {"should_reply": True} # Default to replying if unsure
result = await self._process_llm_response(
response, expected_keys, default_values
)
router_logger.debug(
f"Should reply determination: {result.get('should_reply', True)}"
)
return result.get("should_reply", True)
except Exception as e:
router_logger.error(
f"Error determining if bot should reply: {e}", exc_info=True
)
# Default to replying if there's an error
return True