diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 576226d8..5224e3b8 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -155,4 +155,7 @@ def insertProject(self, project_info): return self.supabase_client.table("projects").insert(project_info).execute() def getPreAssignedAPIKeys(self, email: str): - return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute() \ No newline at end of file + return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute() + + def getConversationsCreatedAtByCourse(self, course_name: str): + return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute() \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 6282541c..a11e9c91 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -519,6 +519,31 @@ def switch_workflow(service: WorkflowService) -> Response: else: abort(400, description=f"Bad request: {e}") +@app.route('/getConversationStats', methods=['GET']) +def get_conversation_stats(service: RetrievalService) -> Response: + course_name = request.args.get('course_name', default='', type=str) + + if course_name == '': + abort(400, description="Missing required parameter: 'course_name' must be provided.") + + conversation_stats = service.getConversationStats(course_name) + + response = jsonify(conversation_stats) + response.headers.add('Access-Control-Allow-Origin', '*') + return response + +@app.route('/getConversationHeatmapByHour', methods=['GET']) +def get_questions_heatmap_by_hour(service: RetrievalService) -> Response: + course_name = request.args.get('course_name', default='', type=str) + + if not course_name: + abort(400, description="Missing required parameter: 'course_name' must be provided.") + + heatmap_data = service.getConversationHeatmapByHour(course_name) + + response = jsonify(heatmap_data) + response.headers.add('Access-Control-Allow-Origin', '*') + return response @app.route('/run_flow', methods=['POST']) def run_flow(service: WorkflowService) -> Response: diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index b098345b..506d4af1 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -3,7 +3,10 @@ import os import time import traceback +import pytz from typing import Dict, List, Union +from dateutil import parser +from collections import defaultdict import openai from injector import inject @@ -541,3 +544,83 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: ] return contexts + + def getConversationStats(self, course_name: str): + """ + Fetches conversation data from the database and groups them by day, hour, and weekday. + + Args: + course_name (str) + + Returns: + dict: Aggregated conversation counts: + - 'per_day': By date (YYYY-MM-DD). + - 'per_hour': By hour (0-23). + - 'per_weekday': By weekday (Monday-Sunday). + """ + response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) + + central_tz = pytz.timezone('America/Chicago') + + grouped_data = { + 'per_day': defaultdict(int), + 'per_hour': defaultdict(int), + 'per_weekday': defaultdict(int), + } + + if response and hasattr(response, 'data') and response.data: + for record in response.data: + created_at = record['created_at'] + + parsed_date = parser.parse(created_at) + + central_time = parsed_date.astimezone(central_tz) + + day = central_time.date() + hour = central_time.hour + day_of_week = central_time.strftime('%A') + + grouped_data['per_day'][str(day)] += 1 + grouped_data['per_hour'][hour] += 1 + grouped_data['per_weekday'][day_of_week] += 1 + else: + print("No valid response data. Check if the query is correct or if the response is empty.") + return {} + + return { + 'per_day': dict(grouped_data['per_day']), + 'per_hour': dict(grouped_data['per_hour']), + 'per_weekday': dict(grouped_data['per_weekday']), + } + + def getConversationHeatmapByHour(self, course_name: str): + """ + Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time). + + Args: + course_name (str) + + Returns: + dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts. + """ + response = self.sqlDb.getConversationsCreatedAtByCourse(course_name) + central_tz = pytz.timezone('America/Chicago') + + heatmap_data = defaultdict(lambda: defaultdict(int)) + + if response and hasattr(response, 'data') and response.data: + for record in response.data: + created_at = record['created_at'] + + parsed_date = parser.parse(created_at) + central_time = parsed_date.astimezone(central_tz) + + day_of_week = central_time.strftime('%A') + hour = central_time.hour + + heatmap_data[day_of_week][hour] += 1 + else: + print("No valid response data. Check if the query is correct or if the response is empty.") + return {} + + return dict(heatmap_data) \ No newline at end of file