From 9570a84042bb965efb00793f0371a7294054d964 Mon Sep 17 00:00:00 2001 From: rohanmarwaha Date: Wed, 27 Mar 2024 18:44:35 -0500 Subject: [PATCH 1/4] Add support for filtering documents by doc_groups in retrieval service --- ai_ta_backend/database/vector.py | 12 ++++++++---- ai_ta_backend/main.py | 3 ++- ai_ta_backend/service/retrieval_service.py | 15 +++++++++++---- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index d22fc6ca..b45c78d6 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -1,4 +1,5 @@ import os +from typing import List from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings @@ -28,13 +29,16 @@ def __init__(self): collection_name=os.environ['QDRANT_COLLECTION_NAME'], embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) - def vector_search(self, search_query, course_name, user_query_embedding, top_n): + def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): """ Search the vector database for a given query. """ - myfilter = models.Filter(must=[ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)), - ]) + must_conditions: List[models.Condition] = [ + models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) + ] + if doc_groups: + must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) + myfilter = models.Filter(must=must_conditions) search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 77bfeea5..28ab8f5d 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -103,6 +103,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) + doc_groups: List[str] = request.args.get('doc_groups', default=[], type=List[str]) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -111,7 +112,7 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) - found_documents = service.getTopContexts(search_query, course_name, token_limit) + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents) response.headers.add('Access-Control-Allow-Origin', '*') diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index af425218..eb9fb2ba 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -53,7 +53,11 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_type=os.environ['OPENAI_API_TYPE'], ) - def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContexts(self, + search_query: str, + course_name: str, + token_limit: int = 4_000, + doc_groups: List[str] = []) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -67,7 +71,9 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name) + found_docs: list[Document] = self.vector_search(search_query=search_query, + course_name=course_name, + doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -339,7 +345,7 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) self.sentry.capture_exception(e) - def vector_search(self, search_query, course_name): + def vector_search(self, search_query, course_name, doc_groups: List[str] = []): top_n = 80 # EMBED openai_start_time = time.monotonic() @@ -352,10 +358,11 @@ def vector_search(self, search_query, course_name): properties={ "user_query": search_query, "course_name": course_name, + "doc_groups": doc_groups, }, ) qdrant_start_time = time.monotonic() - search_results = self.vdb.vector_search(search_query, course_name, user_query_embedding, top_n) + search_results = self.vdb.vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n) found_docs: list[Document] = [] for d in search_results: From 5da717cceec1ad5ab098bd261bdb618cd56b921a Mon Sep 17 00:00:00 2001 From: rohanmarwaha Date: Thu, 28 Mar 2024 16:00:00 -0500 Subject: [PATCH 2/4] Fix filter bug and handle doc_groups as JSON string --- ai_ta_backend/database/vector.py | 6 ++++-- ai_ta_backend/main.py | 8 +++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index b45c78d6..e7f2d5e8 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -33,12 +33,14 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q """ Search the vector database for a given query. """ - must_conditions: List[models.Condition] = [ + # print(f"Searching for: {search_query} with doc_groups: {doc_groups}") + must_conditions: list[models.Condition] = [ models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) ] - if doc_groups: + if doc_groups and doc_groups != []: must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) myfilter = models.Filter(must=must_conditions) + print(f"Filter: {myfilter}") search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 28ab8f5d..c11049ca 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,3 +1,4 @@ +import json import os import time from typing import List @@ -103,7 +104,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) - doc_groups: List[str] = request.args.get('doc_groups', default=[], type=List[str]) + doc_groups_str: str = request.args.get('doc_groups', default='[]', type=str) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -112,6 +113,11 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) + doc_groups: List[str] = [] + + if doc_groups_str != '[]': + doc_groups = json.loads(doc_groups_str) + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents) From 65b642e2ef968aa545d78256b068ea0e71bf18a4 Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Mon, 1 Apr 2024 13:54:18 -0700 Subject: [PATCH 3/4] Increase Qdrant timeout from defualt 5s to 20s. Getting timeout err w/ doc groups. --- ai_ta_backend/database/vector.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index e7f2d5e8..70a212e1 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -23,11 +23,14 @@ def __init__(self): self.qdrant_client = QdrantClient( url=os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_API_KEY'], + timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups. ) - self.vectorstore = Qdrant(client=self.qdrant_client, - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) + self.vectorstore = Qdrant( + client=self.qdrant_client, + collection_name=os.environ['QDRANT_COLLECTION_NAME'], + embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE), + ) def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): """ From 217157d5e55837099ff2e6f9c6f49418f9f95aba Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Mon, 1 Apr 2024 13:57:23 -0700 Subject: [PATCH 4/4] Fix: Don't use mutable datastructures (lists) as default arguments --- ai_ta_backend/service/retrieval_service.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index eb9fb2ba..158c7ce6 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -57,7 +57,7 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000, - doc_groups: List[str] = []) -> Union[List[Dict], str]: + doc_groups: List[str] | None = None) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -68,6 +68,8 @@ def getTopContexts(self, or String: An error message with traceback. """ + if doc_groups is None: + doc_groups = [] try: start_time_overall = time.monotonic() @@ -345,7 +347,9 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) self.sentry.capture_exception(e) - def vector_search(self, search_query, course_name, doc_groups: List[str] = []): + def vector_search(self, search_query, course_name, doc_groups: List[str] | None = None): + if doc_groups is None: + doc_groups = [] top_n = 80 # EMBED openai_start_time = time.monotonic()