Skip to content

Commit

Permalink
created a separate endpoint for multi-query-retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Nov 29, 2023
1 parent 5d19aa9 commit 9aa030c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 14 deletions.
25 changes: 24 additions & 1 deletion ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def getTopContexts() -> Response:
Exception
Testing how exceptions are handled.
"""
print("In getRopContexts in Main()")
print("In getTopContexts in Main()")
search_query: str = request.args.get('search_query', default='', type=str)
course_name: str = request.args.get('course_name', default='', type=str)
token_limit: int = request.args.get('token_limit', default=3000, type=int)
Expand All @@ -150,6 +150,29 @@ def getTopContexts() -> Response:
response.headers.add('Access-Control-Allow-Origin', '*')
return response

@app.route('/getTopContextsWithMQR', methods=['GET'])
def getTopContextsWithMQR() -> Response:
"""
Get relevant contexts for a given search query, using Multi-query retrieval + filtering method.
"""
search_query: str = request.args.get('search_query', default='', type=str)
course_name: str = request.args.get('course_name', default='', type=str)
token_limit: int = request.args.get('token_limit', default=3000, type=int)
if search_query == '' or course_name == '':
# proper web error "400 Bad request"
abort(
400,
description=
f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`"
)

ingester = Ingest()
found_documents = ingester.getTopContextsWithMQR(search_query, course_name, token_limit)
del ingester

response = jsonify(found_documents)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/get_stuffed_prompt', methods=['GET'])
Expand Down
73 changes: 60 additions & 13 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,16 @@ def __init__(self):
supabase_url=os.environ['SUPABASE_URL'],
supabase_key=os.environ['SUPABASE_API_KEY'])

# self.llm = AzureChatOpenAI(
# temperature=0,
# deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore
# openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore
# openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore
# openai_api_version=os.getenv('AZURE_OPENAI_API_VERSION'), #type:ignore
# )
self.llm = AzureChatOpenAI(
temperature=0,
deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore
openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore
openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore
#openai_api_version=os.getenv('AZURE_OPENAI_API_VERSION'), #type:ignore
openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore
)
# self.llm = OpenAI(temperature=0, openai_api_base='https://api.kastan.ai/v1')
self.llm = ChatOpenAI(temperature=0, model='gpt-3.5-turbo')
#self.llm = ChatOpenAI(temperature=0, model='gpt-3.5-turbo')

def bulk_ingest(self, s3_paths: Union[List[str], str], course_name: str, **kwargs) -> Dict[str, List[str]]:
def _ingest_single(ingest_method: Callable, s3_path, *args, **kwargs):
Expand Down Expand Up @@ -1283,9 +1284,11 @@ def reciprocal_rank_fusion(self, results: list[list], k=60):
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
return reranked_results

def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]:
"""Here's a summary of the work.
"""
The original info-retrieval pipeline that uses vector search.
Here's a summary of the work.
/GET arguments
course name (optional) str: A json response with TBD fields.
Expand All @@ -1299,6 +1302,51 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
top_n = 80 # HARD CODE TO ENSURE WE HIT THE MAX TOKENS
start_time_overall = time.monotonic()

found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name)

pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n"
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore

valid_docs = []
num_tokens = 0
for doc in found_docs:
doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore

print(f"token_counter: {token_counter}, num_tokens: {num_tokens}, max_tokens: {token_limit}")
if token_counter + num_tokens <= token_limit:
token_counter += num_tokens
valid_docs.append(doc)
else:
# filled our token size, time to return
break

print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}")
print(f"Course: {course_name} ||| search_query: {search_query}")
print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds")
if len(valid_docs) == 0:
return []
return self.format_for_json(valid_docs)
except Exception as e:
# return full traceback to front end
err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
print(err)
return err


def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]:
"""
New info-retrieval pipeline that uses multi-query retrieval + reciprocal rank fusion + context padding.
1. Generate multiple queries based on the input search query.
2. Retrieve relevant docs for each query.
3. Rank the docs based on the relevance score.
4. Pad the top 5 docs with context from the original document.
"""
try:
top_n = 80 # HARD CODE TO ENSURE WE HIT THE MAX TOKENS
start_time_overall = time.monotonic()

# Vector search with ONLY original query
# found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name)
mq_start_time = time.monotonic()
Expand All @@ -1325,7 +1373,6 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
print(f"⏰ Multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds")

# 'context padding' // 'parent document retriever'
# TODO maybe only do context padding for top 5 docs? Otherwise it's wasteful imo.
final_docs = context_processing(found_docs, search_query, course_name)
print(f"Number of final docs after context padding: {len(final_docs)}")

Expand Down Expand Up @@ -1359,13 +1406,13 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =

print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}")
print(f"Course: {course_name} ||| search_query: {search_query}")
print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds")
print(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds")
if len(valid_docs) == 0:
return []
return self.format_for_json(valid_docs)
except Exception as e:
# return full traceback to front end
err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.format_exc()}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
err: str = f"ERROR: In /getTopContextsWithMQR. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.format_exc()}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
print(err)
return err

Expand Down

0 comments on commit 9aa030c

Please sign in to comment.