From b177a30f93bb7c55ee978aba973d2a4fd78b0031 Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Fri, 8 Nov 2024 18:54:56 +0100 Subject: [PATCH] Faster `getTopContexts()` (#326) * 120 -> 60, remove token counting cuz slow and unnecessary * Remove some of the logs, cleanup code, should be a tad faster * Set PostHog to sync_mode=False, should help quite a bit shaving off 0.1 sec here and there * Clean up prints --- ai_ta_backend/main.py | 35 +++--- ai_ta_backend/service/posthog_service.py | 2 +- ai_ta_backend/service/retrieval_service.py | 122 +++++++++------------ 3 files changed, 70 insertions(+), 89 deletions(-) diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index a11e9c91..d9f0765d 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -106,12 +106,12 @@ def getTopContexts(service: RetrievalService) -> Response: Exception Testing how exceptions are handled. """ + start_time = time.monotonic() data = request.get_json() search_query: str = data.get('search_query', '') course_name: str = data.get('course_name', '') token_limit: int = data.get('token_limit', 3000) doc_groups: List[str] = data.get('doc_groups', []) - start_time = time.monotonic() if search_query == '' or course_name == '': # proper web error "400 Bad request" @@ -122,9 +122,9 @@ def getTopContexts(service: RetrievalService) -> Response: ) found_documents = asyncio.run(service.getTopContexts(search_query, course_name, token_limit, doc_groups)) - print(f"⏰ Runtime of getTopContexts in main.py: {(time.monotonic() - start_time):.2f} seconds") response = jsonify(found_documents) response.headers.add('Access-Control-Allow-Origin', '*') + print(f"⏰ Runtime of getTopContexts in main.py: {(time.monotonic() - start_time):.2f} seconds") return response @@ -519,31 +519,34 @@ def switch_workflow(service: WorkflowService) -> Response: else: abort(400, description=f"Bad request: {e}") + @app.route('/getConversationStats', methods=['GET']) def get_conversation_stats(service: RetrievalService) -> Response: - course_name = request.args.get('course_name', default='', type=str) + course_name = request.args.get('course_name', default='', type=str) - if course_name == '': - abort(400, description="Missing required parameter: 'course_name' must be provided.") + if course_name == '': + abort(400, description="Missing required parameter: 'course_name' must be provided.") - conversation_stats = service.getConversationStats(course_name) + conversation_stats = service.getConversationStats(course_name) + + response = jsonify(conversation_stats) + response.headers.add('Access-Control-Allow-Origin', '*') + return response - response = jsonify(conversation_stats) - response.headers.add('Access-Control-Allow-Origin', '*') - return response @app.route('/getConversationHeatmapByHour', methods=['GET']) def get_questions_heatmap_by_hour(service: RetrievalService) -> Response: - course_name = request.args.get('course_name', default='', type=str) + course_name = request.args.get('course_name', default='', type=str) - if not course_name: - abort(400, description="Missing required parameter: 'course_name' must be provided.") + if not course_name: + abort(400, description="Missing required parameter: 'course_name' must be provided.") - heatmap_data = service.getConversationHeatmapByHour(course_name) + heatmap_data = service.getConversationHeatmapByHour(course_name) + + response = jsonify(heatmap_data) + response.headers.add('Access-Control-Allow-Origin', '*') + return response - response = jsonify(heatmap_data) - response.headers.add('Access-Control-Allow-Origin', '*') - return response @app.route('/run_flow', methods=['POST']) def run_flow(service: WorkflowService) -> Response: diff --git a/ai_ta_backend/service/posthog_service.py b/ai_ta_backend/service/posthog_service.py index a6b3aaba..97c7dc1d 100644 --- a/ai_ta_backend/service/posthog_service.py +++ b/ai_ta_backend/service/posthog_service.py @@ -9,7 +9,7 @@ class PosthogService: @inject def __init__(self): self.posthog = Posthog( - sync_mode=True, + sync_mode=False, project_api_key=os.environ["POSTHOG_API_KEY"], host="https://app.posthog.com", ) diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 506d4af1..87032c48 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -3,12 +3,12 @@ import os import time import traceback -import pytz -from typing import Dict, List, Union -from dateutil import parser from collections import defaultdict +from typing import Dict, List, Union import openai +import pytz +from dateutil import parser from injector import inject from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings @@ -21,7 +21,6 @@ from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService -from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost class RetrievalService: @@ -59,11 +58,12 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos # openai_api_type=os.environ['OPENAI_API_TYPE'], # ) - async def getTopContexts(self, - search_query: str, - course_name: str, - token_limit: int = 4_000, - doc_groups: List[str] | None = None) -> Union[List[Dict], str]: + async def getTopContexts( + self, + search_query: str, + course_name: str, + token_limit: int = 4_000, # Deprecated + doc_groups: List[str] | None = None) -> Union[List[Dict], str]: """Here's a summary of the work. /GET arguments @@ -111,8 +111,8 @@ async def getTopContexts(self, public_doc_groups = [doc_group['doc_groups'] for doc_group in public_doc_groups_response.data] time_for_parallel_operations = time.monotonic() - start_time_overall - start_time_vector_search = time.monotonic() + # Perform vector search found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name, @@ -122,36 +122,15 @@ async def getTopContexts(self, public_doc_groups=public_doc_groups) time_to_retrieve_docs = time.monotonic() - start_time_vector_search - start_time_count_tokens = time.monotonic() - - pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" - # count tokens at start and end, then also count each context. - token_counter, _ = count_tokens_and_cost(pre_prompt + "\n\nNow please respond to my query: " + # type: ignore - search_query) valid_docs = [] - num_tokens = 0 for doc in found_docs: - doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n" - num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore - - print( - f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}" - ) - if token_counter + num_tokens <= token_limit: - token_counter += num_tokens - valid_docs.append(doc) - else: - # filled our token size, time to return - break - - time_to_count_tokens = time.monotonic() - start_time_count_tokens - - print(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") - print(f"Course: {course_name} ||| search_query: {search_query}") - print( - f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds, time to count tokens: {time_to_count_tokens:.2f} seconds, time for parallel operations: {time_for_parallel_operations:.2f} seconds, time to retrieve docs: {time_to_retrieve_docs:.2f} seconds" - ) + valid_docs.append(doc) + + print(f"Course: {course_name} ||| search_query: {search_query}\n" + f"⏰ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds\n" + f"Runtime for parallel operations: {time_for_parallel_operations:.2f} seconds, " + f"Runtime to complete vector_search: {time_to_retrieve_docs:.2f} seconds") if len(valid_docs) == 0: return [] @@ -161,7 +140,7 @@ async def getTopContexts(self, "user_query": search_query, "course_name": course_name, "token_limit": token_limit, - "total_tokens_used": token_counter, + # "total_tokens_used": token_counter, "total_contexts_used": len(valid_docs), "total_unique_docs_retrieved": len(found_docs), "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, @@ -403,8 +382,6 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q """ Search the vector database for a given query, course name, and document groups. """ - start_time_overall = time.monotonic() - if doc_groups is None: doc_groups = [] @@ -415,7 +392,7 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q public_doc_groups = [] # Max number of search results to return - top_n = 120 + top_n = 60 # Capture the search invoked event to PostHog self._capture_search_invoked_event(search_query, course_name, doc_groups) @@ -436,10 +413,10 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q self._capture_search_succeeded_event(search_query, course_name, search_results) time_for_capture_search_succeeded_event = time.monotonic() - start_time_capture_search_succeeded_event - print( - f"time for vector search: {time_for_vector_search:.2f} seconds, time for process search results: {time_for_process_search_results:.2f} seconds, time for capture search succeeded event: {time_for_capture_search_succeeded_event:.2f} seconds" - ) - print(f"time for embedding query: {self.openai_embedding_latency:.2f} seconds") + print(f"Runtime for embedding query: {self.openai_embedding_latency:.2f} seconds\n" + f"Runtime for vector search: {time_for_vector_search:.2f} seconds\n" + f"Runtime for process search results: {time_for_process_search_results:.2f} seconds\n" + f"Runtime for capture search succeeded event: {time_for_capture_search_succeeded_event:.2f} seconds") return found_docs def _perform_vector_search(self, search_query, course_name, doc_groups, user_query_embedding, top_n, @@ -485,7 +462,8 @@ def _process_search_results(self, search_results, course_name): def _capture_search_succeeded_event(self, search_query, course_name, search_results): vector_score_calc_latency_sec = time.monotonic() - max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results) + # Removed because it takes 0.15 seconds to _calculate_vector_scores... not worth it rn. + # max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results) self.posthog.capture( event_name="vector_search_succeeded", properties={ @@ -493,9 +471,9 @@ def _capture_search_succeeded_event(self, search_query, course_name, search_resu "course_name": course_name, "qdrant_latency_sec": self.qdrant_latency_sec, "openai_embedding_latency_sec": self.openai_embedding_latency, - "max_vector_score": max_vector_score, - "min_vector_score": min_vector_score, - "avg_vector_score": avg_vector_score, + # "max_vector_score": max_vector_score, + # "min_vector_score": min_vector_score, + # "avg_vector_score": avg_vector_score, "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, }, ) @@ -569,30 +547,30 @@ def getConversationStats(self, course_name: str): } if response and hasattr(response, 'data') and response.data: - for record in response.data: - created_at = record['created_at'] + for record in response.data: + created_at = record['created_at'] - parsed_date = parser.parse(created_at) + parsed_date = parser.parse(created_at) - central_time = parsed_date.astimezone(central_tz) + central_time = parsed_date.astimezone(central_tz) - day = central_time.date() - hour = central_time.hour - day_of_week = central_time.strftime('%A') + day = central_time.date() + hour = central_time.hour + day_of_week = central_time.strftime('%A') - grouped_data['per_day'][str(day)] += 1 - grouped_data['per_hour'][hour] += 1 - grouped_data['per_weekday'][day_of_week] += 1 + grouped_data['per_day'][str(day)] += 1 + grouped_data['per_hour'][hour] += 1 + grouped_data['per_weekday'][day_of_week] += 1 else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} + print("No valid response data. Check if the query is correct or if the response is empty.") + return {} return { 'per_day': dict(grouped_data['per_day']), 'per_hour': dict(grouped_data['per_hour']), 'per_weekday': dict(grouped_data['per_weekday']), } - + def getConversationHeatmapByHour(self, course_name: str): """ Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time). @@ -609,18 +587,18 @@ def getConversationHeatmapByHour(self, course_name: str): heatmap_data = defaultdict(lambda: defaultdict(int)) if response and hasattr(response, 'data') and response.data: - for record in response.data: - created_at = record['created_at'] + for record in response.data: + created_at = record['created_at'] - parsed_date = parser.parse(created_at) - central_time = parsed_date.astimezone(central_tz) + parsed_date = parser.parse(created_at) + central_time = parsed_date.astimezone(central_tz) - day_of_week = central_time.strftime('%A') - hour = central_time.hour + day_of_week = central_time.strftime('%A') + hour = central_time.hour - heatmap_data[day_of_week][hour] += 1 + heatmap_data[day_of_week][hour] += 1 else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} + print("No valid response data. Check if the query is correct or if the response is empty.") + return {} - return dict(heatmap_data) \ No newline at end of file + return dict(heatmap_data)