Skip to content

Commit

Permalink
Merge pull request #319 from UIUC-Chatbot/statistics-for-analysis-page
Browse files Browse the repository at this point in the history
APIs for conversation analysis page
  • Loading branch information
kasprovav authored Oct 29, 2024
2 parents b0f1050 + d1e6642 commit 60b1e7b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,7 @@ def insertProject(self, project_info):
return self.supabase_client.table("projects").insert(project_info).execute()

def getPreAssignedAPIKeys(self, email: str):
return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute()
return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute()

def getConversationsCreatedAtByCourse(self, course_name: str):
return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute()
25 changes: 25 additions & 0 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,31 @@ def switch_workflow(service: WorkflowService) -> Response:
else:
abort(400, description=f"Bad request: {e}")

@app.route('/getConversationStats', methods=['GET'])
def get_conversation_stats(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)

if course_name == '':
abort(400, description="Missing required parameter: 'course_name' must be provided.")

conversation_stats = service.getConversationStats(course_name)

response = jsonify(conversation_stats)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

@app.route('/getConversationHeatmapByHour', methods=['GET'])
def get_questions_heatmap_by_hour(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)

if not course_name:
abort(400, description="Missing required parameter: 'course_name' must be provided.")

heatmap_data = service.getConversationHeatmapByHour(course_name)

response = jsonify(heatmap_data)
response.headers.add('Access-Control-Allow-Origin', '*')
return response

@app.route('/run_flow', methods=['POST'])
def run_flow(service: WorkflowService) -> Response:
Expand Down
83 changes: 83 additions & 0 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import os
import time
import traceback
import pytz
from typing import Dict, List, Union
from dateutil import parser
from collections import defaultdict

import openai
from injector import inject
Expand Down Expand Up @@ -541,3 +544,83 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]:
]

return contexts

def getConversationStats(self, course_name: str):
"""
Fetches conversation data from the database and groups them by day, hour, and weekday.
Args:
course_name (str)
Returns:
dict: Aggregated conversation counts:
- 'per_day': By date (YYYY-MM-DD).
- 'per_hour': By hour (0-23).
- 'per_weekday': By weekday (Monday-Sunday).
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)

central_tz = pytz.timezone('America/Chicago')

grouped_data = {
'per_day': defaultdict(int),
'per_hour': defaultdict(int),
'per_weekday': defaultdict(int),
}

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)

central_time = parsed_date.astimezone(central_tz)

day = central_time.date()
hour = central_time.hour
day_of_week = central_time.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][hour] += 1
grouped_data['per_weekday'][day_of_week] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': dict(grouped_data['per_hour']),
'per_weekday': dict(grouped_data['per_weekday']),
}

def getConversationHeatmapByHour(self, course_name: str):
"""
Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time).
Args:
course_name (str)
Returns:
dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts.
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)
central_tz = pytz.timezone('America/Chicago')

heatmap_data = defaultdict(lambda: defaultdict(int))

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)
central_time = parsed_date.astimezone(central_tz)

day_of_week = central_time.strftime('%A')
hour = central_time.hour

heatmap_data[day_of_week][hour] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return dict(heatmap_data)

0 comments on commit 60b1e7b

Please sign in to comment.