Skip to content

Commit

Permalink
Fix filter bug and handle doc_groups as JSON string
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-uiuc committed Mar 28, 2024
1 parent 9570a84 commit 5da717c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 4 additions & 2 deletions ai_ta_backend/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q
"""
Search the vector database for a given query.
"""
must_conditions: List[models.Condition] = [
# print(f"Searching for: {search_query} with doc_groups: {doc_groups}")
must_conditions: list[models.Condition] = [
models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name))
]
if doc_groups:
if doc_groups and doc_groups != []:
must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups)))
myfilter = models.Filter(must=must_conditions)
print(f"Filter: {myfilter}")
search_results = self.qdrant_client.search(
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
query_filter=myfilter,
Expand Down
8 changes: 7 additions & 1 deletion ai_ta_backend/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import time
from typing import List
Expand Down Expand Up @@ -103,7 +104,7 @@ def getTopContexts(service: RetrievalService) -> Response:
search_query: str = request.args.get('search_query', default='', type=str)
course_name: str = request.args.get('course_name', default='', type=str)
token_limit: int = request.args.get('token_limit', default=3000, type=int)
doc_groups: List[str] = request.args.get('doc_groups', default=[], type=List[str])
doc_groups_str: str = request.args.get('doc_groups', default='[]', type=str)
if search_query == '' or course_name == '':
# proper web error "400 Bad request"
abort(
Expand All @@ -112,6 +113,11 @@ def getTopContexts(service: RetrievalService) -> Response:
f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`"
)

doc_groups: List[str] = []

if doc_groups_str != '[]':
doc_groups = json.loads(doc_groups_str)

found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups)

response = jsonify(found_documents)
Expand Down

0 comments on commit 5da717c

Please sign in to comment.