diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index d22fc6ca..b45c78d6 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -1,4 +1,5 @@ import os +from typing import List from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings @@ -28,13 +29,16 @@ def __init__(self): collection_name=os.environ['QDRANT_COLLECTION_NAME'], embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) - def vector_search(self, search_query, course_name, user_query_embedding, top_n): + def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): """ Search the vector database for a given query. """ - myfilter = models.Filter(must=[ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)), - ]) + must_conditions: List[models.Condition] = [ + models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) + ] + if doc_groups: + must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) + myfilter = models.Filter(must=must_conditions) search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 77bfeea5..28ab8f5d 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -103,6 +103,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) + doc_groups: List[str] = request.args.get('doc_groups', default=[], type=List[str]) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -111,7 +112,7 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) - found_documents = service.getTopContexts(search_query, course_name, token_limit) + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents) response.headers.add('Access-Control-Allow-Origin', '*') diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index af425218..eb9fb2ba 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -53,7 +53,11 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_type=os.environ['OPENAI_API_TYPE'], ) - def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContexts(self, + search_query: str, + course_name: str, + token_limit: int = 4_000, + doc_groups: List[str] = []) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -67,7 +71,9 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name) + found_docs: list[Document] = self.vector_search(search_query=search_query, + course_name=course_name, + doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -339,7 +345,7 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) self.sentry.capture_exception(e) - def vector_search(self, search_query, course_name): + def vector_search(self, search_query, course_name, doc_groups: List[str] = []): top_n = 80 # EMBED openai_start_time = time.monotonic() @@ -352,10 +358,11 @@ def vector_search(self, search_query, course_name): properties={ "user_query": search_query, "course_name": course_name, + "doc_groups": doc_groups, }, ) qdrant_start_time = time.monotonic() - search_results = self.vdb.vector_search(search_query, course_name, user_query_embedding, top_n) + search_results = self.vdb.vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n) found_docs: list[Document] = [] for d in search_results: