From 1d97a9896b99157ac4a86f5a530f50b2e4286a8e Mon Sep 17 00:00:00 2001 From: Vira Kasprova Date: Fri, 1 Nov 2024 14:19:53 -0500 Subject: [PATCH 1/4] implemented pagination --- ai_ta_backend/database/sql.py | 37 ++++++++++- ai_ta_backend/main.py | 13 ---- ai_ta_backend/service/retrieval_service.py | 77 +++++++--------------- 3 files changed, 61 insertions(+), 66 deletions(-) diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 5224e3b8..358ea848 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -158,4 +158,39 @@ def getPreAssignedAPIKeys(self, email: str): return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute() def getConversationsCreatedAtByCourse(self, course_name: str): - return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute() \ No newline at end of file + count_response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at", count="exact")\ + .eq("course_name", course_name)\ + .execute() + + total_count = count_response.count + # print(f"Total entries available: {total_count}") + + if total_count <= 0: + return [], 0 + + all_data = [] + batch_size = 1000 + start = 0 + + while start < total_count: + end = min(start + batch_size - 1, total_count - 1) + # print(f"Fetching data from {start} to {end}") + + response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at")\ + .eq("course_name", course_name)\ + .range(start, end)\ + .execute() + + # print(f"Fetched {len(response.data)} entries in this batch.") + + if not response.data: + print(f"No data returned for range {start} to {end}.") + break + + all_data.extend(response.data) + start += batch_size + + # print(f"Total entries retrieved: {len(all_data)}") + return all_data, total_count \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index a11e9c91..384b08fb 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -532,19 +532,6 @@ def get_conversation_stats(service: RetrievalService) -> Response: response.headers.add('Access-Control-Allow-Origin', '*') return response -@app.route('/getConversationHeatmapByHour', methods=['GET']) -def get_questions_heatmap_by_hour(service: RetrievalService) -> Response: - course_name = request.args.get('course_name', default='', type=str) - - if not course_name: - abort(400, description="Missing required parameter: 'course_name' must be provided.") - - heatmap_data = service.getConversationHeatmapByHour(course_name) - - response = jsonify(heatmap_data) - response.headers.add('Access-Control-Allow-Origin', '*') - return response - @app.route('/run_flow', methods=['POST']) def run_flow(service: WorkflowService) -> Response: """ diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 506d4af1..168cb9c9 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -18,7 +18,7 @@ from ai_ta_backend.database.sql import SQLDatabase from ai_ta_backend.database.vector import VectorDatabase from ai_ta_backend.executors.thread_pool_executor import ThreadPoolExecutorAdapter -from ai_ta_backend.service.nomic_service import NomicService +# from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost @@ -31,13 +31,13 @@ class RetrievalService: @inject def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService, thread_pool_executor: ThreadPoolExecutorAdapter): + sentry: SentryService, thread_pool_executor: ThreadPoolExecutorAdapter): # nomicService: NomicService, self.vdb = vdb self.sqlDb = sqlDb self.aws = aws self.sentry = sentry self.posthog = posthog - self.nomicService = nomicService + # self.nomicService = nomicService self.thread_pool_executor = thread_pool_executor openai.api_key = os.environ["VLADS_OPENAI_KEY"] @@ -548,17 +548,20 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: def getConversationStats(self, course_name: str): """ Fetches conversation data from the database and groups them by day, hour, and weekday. - - Args: - course_name (str) + Args: + course_name (str) + Returns: - dict: Aggregated conversation counts: - - 'per_day': By date (YYYY-MM-DD). - - 'per_hour': By hour (0-23). - - 'per_weekday': By weekday (Monday-Sunday). + dict: A dictionary containing: + - 'per_day': Counts of conversations by date (YYYY-MM-DD). + - 'per_hour': Counts of conversations by hour (0-23). + - 'per_weekday': Counts of conversations by weekday (Monday-Sunday). + - 'heatmap': A nested dictionary for heatmap data (days of the week as keys, hours as inner keys). """ - response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) + # print(f"Received course_name: {course_name}") + + conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name) central_tz = pytz.timezone('America/Chicago') @@ -566,23 +569,23 @@ def getConversationStats(self, course_name: str): 'per_day': defaultdict(int), 'per_hour': defaultdict(int), 'per_weekday': defaultdict(int), + 'heatmap': defaultdict(lambda: defaultdict(int)), } - if response and hasattr(response, 'data') and response.data: - for record in response.data: + if conversations: + for record in conversations: created_at = record['created_at'] + parsed_date = parser.parse(created_at).astimezone(central_tz) - parsed_date = parser.parse(created_at) - - central_time = parsed_date.astimezone(central_tz) - - day = central_time.date() - hour = central_time.hour - day_of_week = central_time.strftime('%A') + day = parsed_date.date() + hour = parsed_date.hour + day_of_week = parsed_date.strftime('%A') grouped_data['per_day'][str(day)] += 1 grouped_data['per_hour'][hour] += 1 grouped_data['per_weekday'][day_of_week] += 1 + grouped_data['heatmap'][day_of_week][hour] += 1 + else: print("No valid response data. Check if the query is correct or if the response is empty.") return {} @@ -591,36 +594,6 @@ def getConversationStats(self, course_name: str): 'per_day': dict(grouped_data['per_day']), 'per_hour': dict(grouped_data['per_hour']), 'per_weekday': dict(grouped_data['per_weekday']), + 'heatmap': {day: dict(hours) for day, hours in grouped_data['heatmap'].items()}, } - - def getConversationHeatmapByHour(self, course_name: str): - """ - Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time). - - Args: - course_name (str) - - Returns: - dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts. - """ - response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) - central_tz = pytz.timezone('America/Chicago') - - heatmap_data = defaultdict(lambda: defaultdict(int)) - - if response and hasattr(response, 'data') and response.data: - for record in response.data: - created_at = record['created_at'] - - parsed_date = parser.parse(created_at) - central_time = parsed_date.astimezone(central_tz) - - day_of_week = central_time.strftime('%A') - hour = central_time.hour - - heatmap_data[day_of_week][hour] += 1 - else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} - - return dict(heatmap_data) \ No newline at end of file + \ No newline at end of file From 46a255e5f90e3c3db7f0dc391ed9d35987fc3f72 Mon Sep 17 00:00:00 2001 From: Vira Kasprova Date: Mon, 11 Nov 2024 12:03:41 -0600 Subject: [PATCH 2/4] extracting course stats --- ai_ta_backend/database/sql.py | 33 ++++++++++++++++++---- ai_ta_backend/main.py | 14 +++++++++ ai_ta_backend/service/retrieval_service.py | 15 ++++++++++ 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 358ea848..5c350306 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -1,4 +1,5 @@ import os +from typing import Dict from injector import inject @@ -164,7 +165,6 @@ def getConversationsCreatedAtByCourse(self, course_name: str): .execute() total_count = count_response.count - # print(f"Total entries available: {total_count}") if total_count <= 0: return [], 0 @@ -175,7 +175,6 @@ def getConversationsCreatedAtByCourse(self, course_name: str): while start < total_count: end = min(start + batch_size - 1, total_count - 1) - # print(f"Fetching data from {start} to {end}") response = self.supabase_client.table("llm-convo-monitor")\ .select("created_at")\ @@ -183,8 +182,6 @@ def getConversationsCreatedAtByCourse(self, course_name: str): .range(start, end)\ .execute() - # print(f"Fetched {len(response.data)} entries in this batch.") - if not response.data: print(f"No data returned for range {start} to {end}.") break @@ -192,5 +189,29 @@ def getConversationsCreatedAtByCourse(self, course_name: str): all_data.extend(response.data) start += batch_size - # print(f"Total entries retrieved: {len(all_data)}") - return all_data, total_count \ No newline at end of file + return all_data, total_count + + def getCourseStats(self, course_name: str) -> Dict[str, int]: + conversations_response = self.supabase_client.table("llm-convo-monitor") \ + .select("id, user_email, convo_id", count="exact") \ + .eq("course_name", course_name) \ + .execute() + + total_conversations = conversations_response.count if conversations_response.count else 0 + + unique_users = set(record["user_email"] + for record in conversations_response.data + if record.get("user_email")) + total_users = len(unique_users) + + messages_response = self.supabase_client.rpc( + "get_message_count", + {"course": course_name}).execute() + total_messages = messages_response.data if messages_response.data else 0 + + return { + "total_conversations": total_conversations, + "total_users": total_users, + "total_messages": total_messages + } + \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 384b08fb..a4ccda35 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -590,6 +590,20 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) -> return response +@app.route('/getCourseStats', methods=['GET']) +def get_course_stats(service: RetrievalService) -> Response: + course_name = request.args.get('course_name', default='', type=str) + + if course_name == '': + abort(400, description="Missing required parameter: 'course_name' must be provided.") + + course_stats = service.getCourseStats(course_name) + + response = jsonify(course_stats) + response.headers.add('Access-Control-Allow-Origin', '*') + return response + + def configure(binder: Binder) -> None: binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter(max_workers=10), scope=SingletonScope) binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter(max_workers=10), scope=SingletonScope) diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 168cb9c9..5142631c 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -596,4 +596,19 @@ def getConversationStats(self, course_name: str): 'per_weekday': dict(grouped_data['per_weekday']), 'heatmap': {day: dict(hours) for day, hours in grouped_data['heatmap'].items()}, } + + def getCourseStats(self, course_name: str) -> Dict[str, int]: + """ + Get statistics about conversations for a course. + + Args: + course_name (str): Name of the course to get stats for + + Returns: + Dict[str, int]: Dictionary containing: + - total_conversations: Total number of conversations + - total_users: Number of unique users + - total_messages: Total number of messages + """ + return self.sqlDb.getCourseStats(course_name) \ No newline at end of file From 6114c4b5ff430589dc44a1fd1eeb0f2ecbdc92a6 Mon Sep 17 00:00:00 2001 From: Vira Kasprova Date: Thu, 14 Nov 2024 14:39:20 -0600 Subject: [PATCH 3/4] added getProjectStats --- ai_ta_backend/database/sql.py | 29 +++++----------------- ai_ta_backend/main.py | 14 +++++------ ai_ta_backend/service/retrieval_service.py | 8 +++--- 3 files changed, 17 insertions(+), 34 deletions(-) diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 5c350306..d0972cb0 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -191,27 +191,10 @@ def getConversationsCreatedAtByCourse(self, course_name: str): return all_data, total_count - def getCourseStats(self, course_name: str) -> Dict[str, int]: - conversations_response = self.supabase_client.table("llm-convo-monitor") \ - .select("id, user_email, convo_id", count="exact") \ - .eq("course_name", course_name) \ - .execute() - - total_conversations = conversations_response.count if conversations_response.count else 0 - - unique_users = set(record["user_email"] - for record in conversations_response.data - if record.get("user_email")) - total_users = len(unique_users) - - messages_response = self.supabase_client.rpc( - "get_message_count", - {"course": course_name}).execute() - total_messages = messages_response.data if messages_response.data else 0 - - return { - "total_conversations": total_conversations, - "total_users": total_users, - "total_messages": total_messages - } + def getProjectStats(self, project_name: str): + response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\ + .eq("project_name", project_name).execute() + + return response.data[0] if response.data else {"total_messages": 0, "total_conversations": 0, "unique_users": 0} + \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index ac40ca75..b20c7322 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -591,16 +591,16 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) -> return response -@app.route('/getCourseStats', methods=['GET']) -def get_course_stats(service: RetrievalService) -> Response: - course_name = request.args.get('course_name', default='', type=str) +@app.route('/getProjectStats', methods=['GET']) +def get_project_stats(service: RetrievalService) -> Response: + project_name = request.args.get('project_name', default='', type=str) - if course_name == '': - abort(400, description="Missing required parameter: 'course_name' must be provided.") + if project_name == '': + abort(400, description="Missing required parameter: 'project_name' must be provided.") - course_stats = service.getCourseStats(course_name) + project_stats = service.getProjectStats(project_name) - response = jsonify(course_stats) + response = jsonify(project_stats) response.headers.add('Access-Control-Allow-Origin', '*') return response diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 67f45905..28bc4561 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -575,12 +575,12 @@ def getConversationStats(self, course_name: str): 'heatmap': {day: dict(hours) for day, hours in grouped_data['heatmap'].items()}, } - def getCourseStats(self, course_name: str) -> Dict[str, int]: + def getProjectStats(self, project_name: str) -> Dict[str, int]: """ - Get statistics about conversations for a course. + Get statistics for a project. Args: - course_name (str): Name of the course to get stats for + project_name (str) Returns: Dict[str, int]: Dictionary containing: @@ -588,5 +588,5 @@ def getCourseStats(self, course_name: str) -> Dict[str, int]: - total_users: Number of unique users - total_messages: Total number of messages """ - return self.sqlDb.getCourseStats(course_name) + return self.sqlDb.getProjectStats(project_name) \ No newline at end of file From 8212dbbc21282c887015a5118c409f1bea25d054 Mon Sep 17 00:00:00 2001 From: Vira Kasprova Date: Thu, 14 Nov 2024 16:26:58 -0600 Subject: [PATCH 4/4] added error handling --- ai_ta_backend/database/sql.py | 86 ++++++++++------- ai_ta_backend/service/retrieval_service.py | 104 ++++++++++++--------- 2 files changed, 113 insertions(+), 77 deletions(-) diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index d0972cb0..9e154ab4 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -159,42 +159,64 @@ def getPreAssignedAPIKeys(self, email: str): return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute() def getConversationsCreatedAtByCourse(self, course_name: str): - count_response = self.supabase_client.table("llm-convo-monitor")\ - .select("created_at", count="exact")\ - .eq("course_name", course_name)\ - .execute() - - total_count = count_response.count - - if total_count <= 0: - return [], 0 - - all_data = [] - batch_size = 1000 - start = 0 - - while start < total_count: - end = min(start + batch_size - 1, total_count - 1) - - response = self.supabase_client.table("llm-convo-monitor")\ - .select("created_at")\ + try: + count_response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at", count="exact")\ .eq("course_name", course_name)\ - .range(start, end)\ .execute() - - if not response.data: - print(f"No data returned for range {start} to {end}.") - break - - all_data.extend(response.data) - start += batch_size - - return all_data, total_count + + total_count = count_response.count if hasattr(count_response, 'count') else 0 + + if total_count <= 0: + print(f"No conversations found for course: {course_name}") + return [], 0 + + all_data = [] + batch_size = 1000 + start = 0 + + while start < total_count: + end = min(start + batch_size - 1, total_count - 1) + + try: + response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at")\ + .eq("course_name", course_name)\ + .range(start, end)\ + .execute() + + if not response or not hasattr(response, 'data') or not response.data: + print(f"No data returned for range {start} to {end}.") + break + + all_data.extend(response.data) + start += batch_size + + except Exception as batch_error: + print(f"Error fetching batch {start}-{end}: {str(batch_error)}") + continue + + if not all_data: + print(f"No conversation data could be retrieved for course: {course_name}") + return [], 0 + + return all_data, len(all_data) + + except Exception as e: + print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}") + return [], 0 def getProjectStats(self, project_name: str): - response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\ - .eq("project_name", project_name).execute() + try: + response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\ + .eq("project_name", project_name).execute() + + if response and hasattr(response, 'data') and response.data: + return response.data[0] + except Exception as e: + print(f"Error fetching project stats for {project_name}: {str(e)}") - return response.data[0] if response.data else {"total_messages": 0, "total_conversations": 0, "unique_users": 0} + # Return default values if anything fails + return {"total_messages": 0, "total_conversations": 0, "unique_users": 0} \ No newline at end of file diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 28bc4561..68f91a1c 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -526,54 +526,68 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: def getConversationStats(self, course_name: str): """ Fetches conversation data from the database and groups them by day, hour, and weekday. - - Args: - course_name (str) - - Returns: - dict: A dictionary containing: - - 'per_day': Counts of conversations by date (YYYY-MM-DD). - - 'per_hour': Counts of conversations by hour (0-23). - - 'per_weekday': Counts of conversations by weekday (Monday-Sunday). - - 'heatmap': A nested dictionary for heatmap data (days of the week as keys, hours as inner keys). """ - # print(f"Received course_name: {course_name}") - - conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name) - - central_tz = pytz.timezone('America/Chicago') - - grouped_data = { - 'per_day': defaultdict(int), - 'per_hour': defaultdict(int), - 'per_weekday': defaultdict(int), - 'heatmap': defaultdict(lambda: defaultdict(int)), - } + try: + conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name) + + # Initialize with empty data (all zeros) + response_data = { + 'per_day': {}, + 'per_hour': {str(hour): 0 for hour in range(24)}, # Convert hour to string for consistency + 'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'heatmap': {day: {str(hour): 0 for hour in range(24)} for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'total_count': 0 + } + + if not conversations: + return response_data + + central_tz = pytz.timezone('America/Chicago') + grouped_data = { + 'per_day': defaultdict(int), + 'per_hour': defaultdict(int), + 'per_weekday': defaultdict(int), + 'heatmap': defaultdict(lambda: defaultdict(int)), + } - if conversations: for record in conversations: - created_at = record['created_at'] - parsed_date = parser.parse(created_at).astimezone(central_tz) - - day = parsed_date.date() - hour = parsed_date.hour - day_of_week = parsed_date.strftime('%A') - - grouped_data['per_day'][str(day)] += 1 - grouped_data['per_hour'][hour] += 1 - grouped_data['per_weekday'][day_of_week] += 1 - grouped_data['heatmap'][day_of_week][hour] += 1 - - else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} - - return { - 'per_day': dict(grouped_data['per_day']), - 'per_hour': dict(grouped_data['per_hour']), - 'per_weekday': dict(grouped_data['per_weekday']), - 'heatmap': {day: dict(hours) for day, hours in grouped_data['heatmap'].items()}, - } + try: + created_at = record['created_at'] + parsed_date = parser.parse(created_at).astimezone(central_tz) + + day = parsed_date.date() + hour = parsed_date.hour + day_of_week = parsed_date.strftime('%A') + + grouped_data['per_day'][str(day)] += 1 + grouped_data['per_hour'][str(hour)] += 1 # Convert hour to string + grouped_data['per_weekday'][day_of_week] += 1 + grouped_data['heatmap'][day_of_week][str(hour)] += 1 # Convert hour to string + except Exception as e: + print(f"Error processing record: {str(e)}") + continue + + return { + 'per_day': dict(grouped_data['per_day']), + 'per_hour': {str(k): v for k, v in grouped_data['per_hour'].items()}, + 'per_weekday': dict(grouped_data['per_weekday']), + 'heatmap': {day: {str(h): count for h, count in hours.items()} + for day, hours in grouped_data['heatmap'].items()}, + 'total_count': total_count + } + + except Exception as e: + print(f"Error in getConversationStats for course {course_name}: {str(e)}") + self.sentry.capture_exception(e) + # Return empty data structure on error + return { + 'per_day': {}, + 'per_hour': {str(hour): 0 for hour in range(24)}, + 'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'heatmap': {day: {str(hour): 0 for hour in range(24)} + for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'total_count': 0 + } def getProjectStats(self, project_name: str) -> Dict[str, int]: """