diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index d22fc6ca..70a212e1 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -1,4 +1,5 @@ import os +from typing import List from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings @@ -22,19 +23,27 @@ def __init__(self): self.qdrant_client = QdrantClient( url=os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_API_KEY'], + timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups. ) - self.vectorstore = Qdrant(client=self.qdrant_client, - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) + self.vectorstore = Qdrant( + client=self.qdrant_client, + collection_name=os.environ['QDRANT_COLLECTION_NAME'], + embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE), + ) - def vector_search(self, search_query, course_name, user_query_embedding, top_n): + def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): """ Search the vector database for a given query. """ - myfilter = models.Filter(must=[ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)), - ]) + # print(f"Searching for: {search_query} with doc_groups: {doc_groups}") + must_conditions: list[models.Condition] = [ + models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) + ] + if doc_groups and doc_groups != []: + must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) + myfilter = models.Filter(must=must_conditions) + print(f"Filter: {myfilter}") search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 77bfeea5..c11049ca 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,3 +1,4 @@ +import json import os import time from typing import List @@ -103,6 +104,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) + doc_groups_str: str = request.args.get('doc_groups', default='[]', type=str) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -111,7 +113,12 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) - found_documents = service.getTopContexts(search_query, course_name, token_limit) + doc_groups: List[str] = [] + + if doc_groups_str != '[]': + doc_groups = json.loads(doc_groups_str) + + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents) response.headers.add('Access-Control-Allow-Origin', '*') diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index af425218..158c7ce6 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -53,7 +53,11 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_type=os.environ['OPENAI_API_TYPE'], ) - def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContexts(self, + search_query: str, + course_name: str, + token_limit: int = 4_000, + doc_groups: List[str] | None = None) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -64,10 +68,14 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = or String: An error message with traceback. """ + if doc_groups is None: + doc_groups = [] try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name) + found_docs: list[Document] = self.vector_search(search_query=search_query, + course_name=course_name, + doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -339,7 +347,9 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) self.sentry.capture_exception(e) - def vector_search(self, search_query, course_name): + def vector_search(self, search_query, course_name, doc_groups: List[str] | None = None): + if doc_groups is None: + doc_groups = [] top_n = 80 # EMBED openai_start_time = time.monotonic() @@ -352,10 +362,11 @@ def vector_search(self, search_query, course_name): properties={ "user_query": search_query, "course_name": course_name, + "doc_groups": doc_groups, }, ) qdrant_start_time = time.monotonic() - search_results = self.vdb.vector_search(search_query, course_name, user_query_embedding, top_n) + search_results = self.vdb.vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n) found_docs: list[Document] = [] for d in search_results: