Skip to content

Commit

Permalink
Faster getTopContexts() (#326)
Browse files Browse the repository at this point in the history
* 120 -> 60, remove token counting cuz slow and unnecessary

* Remove some of the logs, cleanup code, should be a tad faster

* Set PostHog to sync_mode=False, should help quite a bit shaving off 0.1 sec here and there

* Clean up prints
  • Loading branch information
KastanDay authored Nov 8, 2024
1 parent 60b1e7b commit b177a30
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 89 deletions.
35 changes: 19 additions & 16 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def getTopContexts(service: RetrievalService) -> Response:
Exception
Testing how exceptions are handled.
"""
start_time = time.monotonic()
data = request.get_json()
search_query: str = data.get('search_query', '')
course_name: str = data.get('course_name', '')
token_limit: int = data.get('token_limit', 3000)
doc_groups: List[str] = data.get('doc_groups', [])
start_time = time.monotonic()

if search_query == '' or course_name == '':
# proper web error "400 Bad request"
Expand All @@ -122,9 +122,9 @@ def getTopContexts(service: RetrievalService) -> Response:
)

found_documents = asyncio.run(service.getTopContexts(search_query, course_name, token_limit, doc_groups))
print(f"⏰ Runtime of getTopContexts in main.py: {(time.monotonic() - start_time):.2f} seconds")
response = jsonify(found_documents)
response.headers.add('Access-Control-Allow-Origin', '*')
print(f"⏰ Runtime of getTopContexts in main.py: {(time.monotonic() - start_time):.2f} seconds")
return response


Expand Down Expand Up @@ -519,31 +519,34 @@ def switch_workflow(service: WorkflowService) -> Response:
else:
abort(400, description=f"Bad request: {e}")


@app.route('/getConversationStats', methods=['GET'])
def get_conversation_stats(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)
course_name = request.args.get('course_name', default='', type=str)

if course_name == '':
abort(400, description="Missing required parameter: 'course_name' must be provided.")
if course_name == '':
abort(400, description="Missing required parameter: 'course_name' must be provided.")

conversation_stats = service.getConversationStats(course_name)
conversation_stats = service.getConversationStats(course_name)

response = jsonify(conversation_stats)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

response = jsonify(conversation_stats)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

@app.route('/getConversationHeatmapByHour', methods=['GET'])
def get_questions_heatmap_by_hour(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)
course_name = request.args.get('course_name', default='', type=str)

if not course_name:
abort(400, description="Missing required parameter: 'course_name' must be provided.")
if not course_name:
abort(400, description="Missing required parameter: 'course_name' must be provided.")

heatmap_data = service.getConversationHeatmapByHour(course_name)
heatmap_data = service.getConversationHeatmapByHour(course_name)

response = jsonify(heatmap_data)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

response = jsonify(heatmap_data)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

@app.route('/run_flow', methods=['POST'])
def run_flow(service: WorkflowService) -> Response:
Expand Down
2 changes: 1 addition & 1 deletion ai_ta_backend/service/posthog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class PosthogService:
@inject
def __init__(self):
self.posthog = Posthog(
sync_mode=True,
sync_mode=False,
project_api_key=os.environ["POSTHOG_API_KEY"],
host="https://app.posthog.com",
)
Expand Down
122 changes: 50 additions & 72 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import os
import time
import traceback
import pytz
from typing import Dict, List, Union
from dateutil import parser
from collections import defaultdict
from typing import Dict, List, Union

import openai
import pytz
from dateutil import parser
from injector import inject
from langchain.chat_models import AzureChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
Expand All @@ -21,7 +21,6 @@
from ai_ta_backend.service.nomic_service import NomicService
from ai_ta_backend.service.posthog_service import PosthogService
from ai_ta_backend.service.sentry_service import SentryService
from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost


class RetrievalService:
Expand Down Expand Up @@ -59,11 +58,12 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos
# openai_api_type=os.environ['OPENAI_API_TYPE'],
# )

async def getTopContexts(self,
search_query: str,
course_name: str,
token_limit: int = 4_000,
doc_groups: List[str] | None = None) -> Union[List[Dict], str]:
async def getTopContexts(
self,
search_query: str,
course_name: str,
token_limit: int = 4_000, # Deprecated
doc_groups: List[str] | None = None) -> Union[List[Dict], str]:
"""Here's a summary of the work.
/GET arguments
Expand Down Expand Up @@ -111,8 +111,8 @@ async def getTopContexts(self,
public_doc_groups = [doc_group['doc_groups'] for doc_group in public_doc_groups_response.data]

time_for_parallel_operations = time.monotonic() - start_time_overall

start_time_vector_search = time.monotonic()

# Perform vector search
found_docs: list[Document] = self.vector_search(search_query=search_query,
course_name=course_name,
Expand All @@ -122,36 +122,15 @@ async def getTopContexts(self,
public_doc_groups=public_doc_groups)

time_to_retrieve_docs = time.monotonic() - start_time_vector_search
start_time_count_tokens = time.monotonic()

pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n"
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + "\n\nNow please respond to my query: " + # type: ignore
search_query)

valid_docs = []
num_tokens = 0
for doc in found_docs:
doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore

print(
f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}"
)
if token_counter + num_tokens <= token_limit:
token_counter += num_tokens
valid_docs.append(doc)
else:
# filled our token size, time to return
break

time_to_count_tokens = time.monotonic() - start_time_count_tokens

print(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved")
print(f"Course: {course_name} ||| search_query: {search_query}")
print(
f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds, time to count tokens: {time_to_count_tokens:.2f} seconds, time for parallel operations: {time_for_parallel_operations:.2f} seconds, time to retrieve docs: {time_to_retrieve_docs:.2f} seconds"
)
valid_docs.append(doc)

print(f"Course: {course_name} ||| search_query: {search_query}\n"
f"⏰ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds\n"
f"Runtime for parallel operations: {time_for_parallel_operations:.2f} seconds, "
f"Runtime to complete vector_search: {time_to_retrieve_docs:.2f} seconds")
if len(valid_docs) == 0:
return []

Expand All @@ -161,7 +140,7 @@ async def getTopContexts(self,
"user_query": search_query,
"course_name": course_name,
"token_limit": token_limit,
"total_tokens_used": token_counter,
# "total_tokens_used": token_counter,
"total_contexts_used": len(valid_docs),
"total_unique_docs_retrieved": len(found_docs),
"getTopContext_total_latency_sec": time.monotonic() - start_time_overall,
Expand Down Expand Up @@ -403,8 +382,6 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q
"""
Search the vector database for a given query, course name, and document groups.
"""
start_time_overall = time.monotonic()

if doc_groups is None:
doc_groups = []

Expand All @@ -415,7 +392,7 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q
public_doc_groups = []

# Max number of search results to return
top_n = 120
top_n = 60

# Capture the search invoked event to PostHog
self._capture_search_invoked_event(search_query, course_name, doc_groups)
Expand All @@ -436,10 +413,10 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q
self._capture_search_succeeded_event(search_query, course_name, search_results)
time_for_capture_search_succeeded_event = time.monotonic() - start_time_capture_search_succeeded_event

print(
f"time for vector search: {time_for_vector_search:.2f} seconds, time for process search results: {time_for_process_search_results:.2f} seconds, time for capture search succeeded event: {time_for_capture_search_succeeded_event:.2f} seconds"
)
print(f"time for embedding query: {self.openai_embedding_latency:.2f} seconds")
print(f"Runtime for embedding query: {self.openai_embedding_latency:.2f} seconds\n"
f"Runtime for vector search: {time_for_vector_search:.2f} seconds\n"
f"Runtime for process search results: {time_for_process_search_results:.2f} seconds\n"
f"Runtime for capture search succeeded event: {time_for_capture_search_succeeded_event:.2f} seconds")
return found_docs

def _perform_vector_search(self, search_query, course_name, doc_groups, user_query_embedding, top_n,
Expand Down Expand Up @@ -485,17 +462,18 @@ def _process_search_results(self, search_results, course_name):

def _capture_search_succeeded_event(self, search_query, course_name, search_results):
vector_score_calc_latency_sec = time.monotonic()
max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results)
# Removed because it takes 0.15 seconds to _calculate_vector_scores... not worth it rn.
# max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results)
self.posthog.capture(
event_name="vector_search_succeeded",
properties={
"user_query": search_query,
"course_name": course_name,
"qdrant_latency_sec": self.qdrant_latency_sec,
"openai_embedding_latency_sec": self.openai_embedding_latency,
"max_vector_score": max_vector_score,
"min_vector_score": min_vector_score,
"avg_vector_score": avg_vector_score,
# "max_vector_score": max_vector_score,
# "min_vector_score": min_vector_score,
# "avg_vector_score": avg_vector_score,
"vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec,
},
)
Expand Down Expand Up @@ -569,30 +547,30 @@ def getConversationStats(self, course_name: str):
}

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)
parsed_date = parser.parse(created_at)

central_time = parsed_date.astimezone(central_tz)
central_time = parsed_date.astimezone(central_tz)

day = central_time.date()
hour = central_time.hour
day_of_week = central_time.strftime('%A')
day = central_time.date()
hour = central_time.hour
day_of_week = central_time.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][hour] += 1
grouped_data['per_weekday'][day_of_week] += 1
grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][hour] += 1
grouped_data['per_weekday'][day_of_week] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': dict(grouped_data['per_hour']),
'per_weekday': dict(grouped_data['per_weekday']),
}

def getConversationHeatmapByHour(self, course_name: str):
"""
Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time).
Expand All @@ -609,18 +587,18 @@ def getConversationHeatmapByHour(self, course_name: str):
heatmap_data = defaultdict(lambda: defaultdict(int))

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)
central_time = parsed_date.astimezone(central_tz)
parsed_date = parser.parse(created_at)
central_time = parsed_date.astimezone(central_tz)

day_of_week = central_time.strftime('%A')
hour = central_time.hour
day_of_week = central_time.strftime('%A')
hour = central_time.hour

heatmap_data[day_of_week][hour] += 1
heatmap_data[day_of_week][hour] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return dict(heatmap_data)
return dict(heatmap_data)

0 comments on commit b177a30

Please sign in to comment.