Skip to content

Commit

Permalink
Merge pull request #322 from UIUC-Chatbot/statistics-for-analysis-page
Browse files Browse the repository at this point in the history
Implement Pagination
  • Loading branch information
kasprovav authored Nov 14, 2024
2 parents b177a30 + 8212dbb commit ae5ed49
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 87 deletions.
63 changes: 62 additions & 1 deletion ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

from injector import inject

Expand Down Expand Up @@ -158,4 +159,64 @@ def getPreAssignedAPIKeys(self, email: str):
return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute()

def getConversationsCreatedAtByCourse(self, course_name: str):
return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute()
try:
count_response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at", count="exact")\
.eq("course_name", course_name)\
.execute()

total_count = count_response.count if hasattr(count_response, 'count') else 0

if total_count <= 0:
print(f"No conversations found for course: {course_name}")
return [], 0

all_data = []
batch_size = 1000
start = 0

while start < total_count:
end = min(start + batch_size - 1, total_count - 1)

try:
response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at")\
.eq("course_name", course_name)\
.range(start, end)\
.execute()

if not response or not hasattr(response, 'data') or not response.data:
print(f"No data returned for range {start} to {end}.")
break

all_data.extend(response.data)
start += batch_size

except Exception as batch_error:
print(f"Error fetching batch {start}-{end}: {str(batch_error)}")
continue

if not all_data:
print(f"No conversation data could be retrieved for course: {course_name}")
return [], 0

return all_data, len(all_data)

except Exception as e:
print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}")
return [], 0

def getProjectStats(self, project_name: str):
try:
response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\
.eq("project_name", project_name).execute()

if response and hasattr(response, 'data') and response.data:
return response.data[0]
except Exception as e:
print(f"Error fetching project stats for {project_name}: {str(e)}")

# Return default values if anything fails
return {"total_messages": 0, "total_conversations": 0, "unique_users": 0}


29 changes: 14 additions & 15 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,21 +533,6 @@ def get_conversation_stats(service: RetrievalService) -> Response:
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/getConversationHeatmapByHour', methods=['GET'])
def get_questions_heatmap_by_hour(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)

if not course_name:
abort(400, description="Missing required parameter: 'course_name' must be provided.")

heatmap_data = service.getConversationHeatmapByHour(course_name)

response = jsonify(heatmap_data)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/run_flow', methods=['POST'])
def run_flow(service: WorkflowService) -> Response:
"""
Expand Down Expand Up @@ -606,6 +591,20 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) ->
return response


@app.route('/getProjectStats', methods=['GET'])
def get_project_stats(service: RetrievalService) -> Response:
project_name = request.args.get('project_name', default='', type=str)

if project_name == '':
abort(400, description="Missing required parameter: 'project_name' must be provided.")

project_stats = service.getProjectStats(project_name)

response = jsonify(project_stats)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


def configure(binder: Binder) -> None:
binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter(max_workers=10), scope=SingletonScope)
binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter(max_workers=10), scope=SingletonScope)
Expand Down
144 changes: 73 additions & 71 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ai_ta_backend.database.sql import SQLDatabase
from ai_ta_backend.database.vector import VectorDatabase
from ai_ta_backend.executors.thread_pool_executor import ThreadPoolExecutorAdapter
from ai_ta_backend.service.nomic_service import NomicService
# from ai_ta_backend.service.nomic_service import NomicService
from ai_ta_backend.service.posthog_service import PosthogService
from ai_ta_backend.service.sentry_service import SentryService

Expand All @@ -30,13 +30,13 @@ class RetrievalService:

@inject
def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService,
sentry: SentryService, nomicService: NomicService, thread_pool_executor: ThreadPoolExecutorAdapter):
sentry: SentryService, thread_pool_executor: ThreadPoolExecutorAdapter): # nomicService: NomicService,
self.vdb = vdb
self.sqlDb = sqlDb
self.aws = aws
self.sentry = sentry
self.posthog = posthog
self.nomicService = nomicService
# self.nomicService = nomicService
self.thread_pool_executor = thread_pool_executor
openai.api_key = os.environ["VLADS_OPENAI_KEY"]

Expand Down Expand Up @@ -526,79 +526,81 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]:
def getConversationStats(self, course_name: str):
"""
Fetches conversation data from the database and groups them by day, hour, and weekday.
Args:
course_name (str)
Returns:
dict: Aggregated conversation counts:
- 'per_day': By date (YYYY-MM-DD).
- 'per_hour': By hour (0-23).
- 'per_weekday': By weekday (Monday-Sunday).
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)

central_tz = pytz.timezone('America/Chicago')

grouped_data = {
'per_day': defaultdict(int),
'per_hour': defaultdict(int),
'per_weekday': defaultdict(int),
}

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)

central_time = parsed_date.astimezone(central_tz)

day = central_time.date()
hour = central_time.hour
day_of_week = central_time.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][hour] += 1
grouped_data['per_weekday'][day_of_week] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': dict(grouped_data['per_hour']),
'per_weekday': dict(grouped_data['per_weekday']),
}
try:
conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name)

# Initialize with empty data (all zeros)
response_data = {
'per_day': {},
'per_hour': {str(hour): 0 for hour in range(24)}, # Convert hour to string for consistency
'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'heatmap': {day: {str(hour): 0 for hour in range(24)} for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'total_count': 0
}

if not conversations:
return response_data

central_tz = pytz.timezone('America/Chicago')
grouped_data = {
'per_day': defaultdict(int),
'per_hour': defaultdict(int),
'per_weekday': defaultdict(int),
'heatmap': defaultdict(lambda: defaultdict(int)),
}

for record in conversations:
try:
created_at = record['created_at']
parsed_date = parser.parse(created_at).astimezone(central_tz)

day = parsed_date.date()
hour = parsed_date.hour
day_of_week = parsed_date.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][str(hour)] += 1 # Convert hour to string
grouped_data['per_weekday'][day_of_week] += 1
grouped_data['heatmap'][day_of_week][str(hour)] += 1 # Convert hour to string
except Exception as e:
print(f"Error processing record: {str(e)}")
continue

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': {str(k): v for k, v in grouped_data['per_hour'].items()},
'per_weekday': dict(grouped_data['per_weekday']),
'heatmap': {day: {str(h): count for h, count in hours.items()}
for day, hours in grouped_data['heatmap'].items()},
'total_count': total_count
}

def getConversationHeatmapByHour(self, course_name: str):
except Exception as e:
print(f"Error in getConversationStats for course {course_name}: {str(e)}")
self.sentry.capture_exception(e)
# Return empty data structure on error
return {
'per_day': {},
'per_hour': {str(hour): 0 for hour in range(24)},
'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'heatmap': {day: {str(hour): 0 for hour in range(24)}
for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'total_count': 0
}

def getProjectStats(self, project_name: str) -> Dict[str, int]:
"""
Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time).
Get statistics for a project.
Args:
course_name (str)
project_name (str)
Returns:
dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts.
Dict[str, int]: Dictionary containing:
- total_conversations: Total number of conversations
- total_users: Number of unique users
- total_messages: Total number of messages
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)
central_tz = pytz.timezone('America/Chicago')

heatmap_data = defaultdict(lambda: defaultdict(int))

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)
central_time = parsed_date.astimezone(central_tz)

day_of_week = central_time.strftime('%A')
hour = central_time.hour

heatmap_data[day_of_week][hour] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return dict(heatmap_data)
return self.sqlDb.getProjectStats(project_name)

0 comments on commit ae5ed49

Please sign in to comment.