diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 5224e3b8..9e154ab4 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -1,4 +1,5 @@ import os +from typing import Dict from injector import inject @@ -158,4 +159,64 @@ def getPreAssignedAPIKeys(self, email: str): return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute() def getConversationsCreatedAtByCourse(self, course_name: str): - return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute() \ No newline at end of file + try: + count_response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at", count="exact")\ + .eq("course_name", course_name)\ + .execute() + + total_count = count_response.count if hasattr(count_response, 'count') else 0 + + if total_count <= 0: + print(f"No conversations found for course: {course_name}") + return [], 0 + + all_data = [] + batch_size = 1000 + start = 0 + + while start < total_count: + end = min(start + batch_size - 1, total_count - 1) + + try: + response = self.supabase_client.table("llm-convo-monitor")\ + .select("created_at")\ + .eq("course_name", course_name)\ + .range(start, end)\ + .execute() + + if not response or not hasattr(response, 'data') or not response.data: + print(f"No data returned for range {start} to {end}.") + break + + all_data.extend(response.data) + start += batch_size + + except Exception as batch_error: + print(f"Error fetching batch {start}-{end}: {str(batch_error)}") + continue + + if not all_data: + print(f"No conversation data could be retrieved for course: {course_name}") + return [], 0 + + return all_data, len(all_data) + + except Exception as e: + print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}") + return [], 0 + + def getProjectStats(self, project_name: str): + try: + response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\ + .eq("project_name", project_name).execute() + + if response and hasattr(response, 'data') and response.data: + return response.data[0] + except Exception as e: + print(f"Error fetching project stats for {project_name}: {str(e)}") + + # Return default values if anything fails + return {"total_messages": 0, "total_conversations": 0, "unique_users": 0} + + \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index d9f0765d..b20c7322 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -533,21 +533,6 @@ def get_conversation_stats(service: RetrievalService) -> Response: response.headers.add('Access-Control-Allow-Origin', '*') return response - -@app.route('/getConversationHeatmapByHour', methods=['GET']) -def get_questions_heatmap_by_hour(service: RetrievalService) -> Response: - course_name = request.args.get('course_name', default='', type=str) - - if not course_name: - abort(400, description="Missing required parameter: 'course_name' must be provided.") - - heatmap_data = service.getConversationHeatmapByHour(course_name) - - response = jsonify(heatmap_data) - response.headers.add('Access-Control-Allow-Origin', '*') - return response - - @app.route('/run_flow', methods=['POST']) def run_flow(service: WorkflowService) -> Response: """ @@ -606,6 +591,20 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) -> return response +@app.route('/getProjectStats', methods=['GET']) +def get_project_stats(service: RetrievalService) -> Response: + project_name = request.args.get('project_name', default='', type=str) + + if project_name == '': + abort(400, description="Missing required parameter: 'project_name' must be provided.") + + project_stats = service.getProjectStats(project_name) + + response = jsonify(project_stats) + response.headers.add('Access-Control-Allow-Origin', '*') + return response + + def configure(binder: Binder) -> None: binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter(max_workers=10), scope=SingletonScope) binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter(max_workers=10), scope=SingletonScope) diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 87032c48..68f91a1c 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -18,7 +18,7 @@ from ai_ta_backend.database.sql import SQLDatabase from ai_ta_backend.database.vector import VectorDatabase from ai_ta_backend.executors.thread_pool_executor import ThreadPoolExecutorAdapter -from ai_ta_backend.service.nomic_service import NomicService +# from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService @@ -30,13 +30,13 @@ class RetrievalService: @inject def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService, thread_pool_executor: ThreadPoolExecutorAdapter): + sentry: SentryService, thread_pool_executor: ThreadPoolExecutorAdapter): # nomicService: NomicService, self.vdb = vdb self.sqlDb = sqlDb self.aws = aws self.sentry = sentry self.posthog = posthog - self.nomicService = nomicService + # self.nomicService = nomicService self.thread_pool_executor = thread_pool_executor openai.api_key = os.environ["VLADS_OPENAI_KEY"] @@ -526,79 +526,81 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: def getConversationStats(self, course_name: str): """ Fetches conversation data from the database and groups them by day, hour, and weekday. - - Args: - course_name (str) - - Returns: - dict: Aggregated conversation counts: - - 'per_day': By date (YYYY-MM-DD). - - 'per_hour': By hour (0-23). - - 'per_weekday': By weekday (Monday-Sunday). """ - response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) - - central_tz = pytz.timezone('America/Chicago') - - grouped_data = { - 'per_day': defaultdict(int), - 'per_hour': defaultdict(int), - 'per_weekday': defaultdict(int), - } - - if response and hasattr(response, 'data') and response.data: - for record in response.data: - created_at = record['created_at'] - - parsed_date = parser.parse(created_at) - - central_time = parsed_date.astimezone(central_tz) - - day = central_time.date() - hour = central_time.hour - day_of_week = central_time.strftime('%A') - - grouped_data['per_day'][str(day)] += 1 - grouped_data['per_hour'][hour] += 1 - grouped_data['per_weekday'][day_of_week] += 1 - else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} - - return { - 'per_day': dict(grouped_data['per_day']), - 'per_hour': dict(grouped_data['per_hour']), - 'per_weekday': dict(grouped_data['per_weekday']), - } + try: + conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name) + + # Initialize with empty data (all zeros) + response_data = { + 'per_day': {}, + 'per_hour': {str(hour): 0 for hour in range(24)}, # Convert hour to string for consistency + 'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'heatmap': {day: {str(hour): 0 for hour in range(24)} for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'total_count': 0 + } + + if not conversations: + return response_data + + central_tz = pytz.timezone('America/Chicago') + grouped_data = { + 'per_day': defaultdict(int), + 'per_hour': defaultdict(int), + 'per_weekday': defaultdict(int), + 'heatmap': defaultdict(lambda: defaultdict(int)), + } + + for record in conversations: + try: + created_at = record['created_at'] + parsed_date = parser.parse(created_at).astimezone(central_tz) + + day = parsed_date.date() + hour = parsed_date.hour + day_of_week = parsed_date.strftime('%A') + + grouped_data['per_day'][str(day)] += 1 + grouped_data['per_hour'][str(hour)] += 1 # Convert hour to string + grouped_data['per_weekday'][day_of_week] += 1 + grouped_data['heatmap'][day_of_week][str(hour)] += 1 # Convert hour to string + except Exception as e: + print(f"Error processing record: {str(e)}") + continue + + return { + 'per_day': dict(grouped_data['per_day']), + 'per_hour': {str(k): v for k, v in grouped_data['per_hour'].items()}, + 'per_weekday': dict(grouped_data['per_weekday']), + 'heatmap': {day: {str(h): count for h, count in hours.items()} + for day, hours in grouped_data['heatmap'].items()}, + 'total_count': total_count + } - def getConversationHeatmapByHour(self, course_name: str): + except Exception as e: + print(f"Error in getConversationStats for course {course_name}: {str(e)}") + self.sentry.capture_exception(e) + # Return empty data structure on error + return { + 'per_day': {}, + 'per_hour': {str(hour): 0 for hour in range(24)}, + 'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'heatmap': {day: {str(hour): 0 for hour in range(24)} + for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']}, + 'total_count': 0 + } + + def getProjectStats(self, project_name: str) -> Dict[str, int]: """ - Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time). + Get statistics for a project. Args: - course_name (str) + project_name (str) Returns: - dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts. + Dict[str, int]: Dictionary containing: + - total_conversations: Total number of conversations + - total_users: Number of unique users + - total_messages: Total number of messages """ - response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) - central_tz = pytz.timezone('America/Chicago') - - heatmap_data = defaultdict(lambda: defaultdict(int)) - - if response and hasattr(response, 'data') and response.data: - for record in response.data: - created_at = record['created_at'] - - parsed_date = parser.parse(created_at) - central_time = parsed_date.astimezone(central_tz) - - day_of_week = central_time.strftime('%A') - hour = central_time.hour - - heatmap_data[day_of_week][hour] += 1 - else: - print("No valid response data. Check if the query is correct or if the response is empty.") - return {} - - return dict(heatmap_data) + return self.sqlDb.getProjectStats(project_name) + \ No newline at end of file