diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index b45c78d6..e7f2d5e8 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -33,12 +33,14 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q """ Search the vector database for a given query. """ - must_conditions: List[models.Condition] = [ + # print(f"Searching for: {search_query} with doc_groups: {doc_groups}") + must_conditions: list[models.Condition] = [ models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) ] - if doc_groups: + if doc_groups and doc_groups != []: must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) myfilter = models.Filter(must=must_conditions) + print(f"Filter: {myfilter}") search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 28ab8f5d..c11049ca 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,3 +1,4 @@ +import json import os import time from typing import List @@ -103,7 +104,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) - doc_groups: List[str] = request.args.get('doc_groups', default=[], type=List[str]) + doc_groups_str: str = request.args.get('doc_groups', default='[]', type=str) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -112,6 +113,11 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) + doc_groups: List[str] = [] + + if doc_groups_str != '[]': + doc_groups = json.loads(doc_groups_str) + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents)