From b01b4efc2e20b9f4485c77a8519b40f6a2e9ea21 Mon Sep 17 00:00:00 2001 From: Rohan Marwaha <123789373+rohan-uiuc@users.noreply.github.com> Date: Sat, 6 Apr 2024 22:00:36 +0530 Subject: [PATCH] Add `doc_groups` filtering support in vector retrieval (#239) * Add support for filtering documents by doc_groups in retrieval service * Fix filter bug and handle doc_groups as JSON string * Increase Qdrant timeout from defualt 5s to 20s. Getting timeout err w/ doc groups. * Fix: Don't use mutable datastructures (lists) as default arguments --------- Co-authored-by: Kastan Day --- ai_ta_backend/database/vector.py | 23 +++++++++++++++------- ai_ta_backend/main.py | 9 ++++++++- ai_ta_backend/service/retrieval_service.py | 19 ++++++++++++++---- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index d22fc6ca..70a212e1 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -1,4 +1,5 @@ import os +from typing import List from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings @@ -22,19 +23,27 @@ def __init__(self): self.qdrant_client = QdrantClient( url=os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_API_KEY'], + timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups. ) - self.vectorstore = Qdrant(client=self.qdrant_client, - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) + self.vectorstore = Qdrant( + client=self.qdrant_client, + collection_name=os.environ['QDRANT_COLLECTION_NAME'], + embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE), + ) - def vector_search(self, search_query, course_name, user_query_embedding, top_n): + def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): """ Search the vector database for a given query. """ - myfilter = models.Filter(must=[ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)), - ]) + # print(f"Searching for: {search_query} with doc_groups: {doc_groups}") + must_conditions: list[models.Condition] = [ + models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) + ] + if doc_groups and doc_groups != []: + must_conditions.append(models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups))) + myfilter = models.Filter(must=must_conditions) + print(f"Filter: {myfilter}") search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], query_filter=myfilter, diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 44cb964e..6058557a 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,3 +1,4 @@ +import json import os import time from typing import List @@ -105,6 +106,7 @@ def getTopContexts(service: RetrievalService) -> Response: search_query: str = request.args.get('search_query', default='', type=str) course_name: str = request.args.get('course_name', default='', type=str) token_limit: int = request.args.get('token_limit', default=3000, type=int) + doc_groups_str: str = request.args.get('doc_groups', default='[]', type=str) if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -113,7 +115,12 @@ def getTopContexts(service: RetrievalService) -> Response: f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" ) - found_documents = service.getTopContexts(search_query, course_name, token_limit) + doc_groups: List[str] = [] + + if doc_groups_str != '[]': + doc_groups = json.loads(doc_groups_str) + + found_documents = service.getTopContexts(search_query, course_name, token_limit, doc_groups) response = jsonify(found_documents) response.headers.add('Access-Control-Allow-Origin', '*') diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index af425218..158c7ce6 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -53,7 +53,11 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_type=os.environ['OPENAI_API_TYPE'], ) - def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContexts(self, + search_query: str, + course_name: str, + token_limit: int = 4_000, + doc_groups: List[str] | None = None) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -64,10 +68,14 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = or String: An error message with traceback. """ + if doc_groups is None: + doc_groups = [] try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name) + found_docs: list[Document] = self.vector_search(search_query=search_query, + course_name=course_name, + doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -339,7 +347,9 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) self.sentry.capture_exception(e) - def vector_search(self, search_query, course_name): + def vector_search(self, search_query, course_name, doc_groups: List[str] | None = None): + if doc_groups is None: + doc_groups = [] top_n = 80 # EMBED openai_start_time = time.monotonic() @@ -352,10 +362,11 @@ def vector_search(self, search_query, course_name): properties={ "user_query": search_query, "course_name": course_name, + "doc_groups": doc_groups, }, ) qdrant_start_time = time.monotonic() - search_results = self.vdb.vector_search(search_query, course_name, user_query_embedding, top_n) + search_results = self.vdb.vector_search(search_query, course_name, doc_groups, user_query_embedding, top_n) found_docs: list[Document] = [] for d in search_results: