Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GetTopContexts for Vyriad #334

Closed
wants to merge 8 commits into from
1 change: 1 addition & 0 deletions ai_ta_backend/database/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self):
# S3
self.s3_client = boto3.client(
's3',
endpoint_url=os.environ.get('MINIO_API_URL'), # for Self hosted MinIO bucket
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
)
Expand Down
63 changes: 62 additions & 1 deletion ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

from injector import inject

Expand Down Expand Up @@ -158,4 +159,64 @@ def getPreAssignedAPIKeys(self, email: str):
return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute()

def getConversationsCreatedAtByCourse(self, course_name: str):
return self.supabase_client.table("llm-convo-monitor").select("created_at").eq("course_name", course_name).execute()
try:
count_response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at", count="exact")\
.eq("course_name", course_name)\
.execute()

total_count = count_response.count if hasattr(count_response, 'count') else 0

if total_count <= 0:
print(f"No conversations found for course: {course_name}")
return [], 0

all_data = []
batch_size = 1000
start = 0

while start < total_count:
end = min(start + batch_size - 1, total_count - 1)

try:
response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at")\
.eq("course_name", course_name)\
.range(start, end)\
.execute()

if not response or not hasattr(response, 'data') or not response.data:
print(f"No data returned for range {start} to {end}.")
break

all_data.extend(response.data)
start += batch_size

except Exception as batch_error:
print(f"Error fetching batch {start}-{end}: {str(batch_error)}")
continue

if not all_data:
print(f"No conversation data could be retrieved for course: {course_name}")
return [], 0

return all_data, len(all_data)

except Exception as e:
print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}")
return [], 0

def getProjectStats(self, project_name: str):
try:
response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\
.eq("project_name", project_name).execute()

if response and hasattr(response, 'data') and response.data:
return response.data[0]
except Exception as e:
print(f"Error fetching project stats for {project_name}: {str(e)}")

# Return default values if anything fails
return {"total_messages": 0, "total_conversations": 0, "unique_users": 0}


1 change: 1 addition & 0 deletions ai_ta_backend/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
self.qdrant_client = QdrantClient(
url=os.environ['QDRANT_URL'],
api_key=os.environ['QDRANT_API_KEY'],
port=os.getenv('QDRANT_PORT') if os.getenv('QDRANT_PORT') else None,
timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups.
)

Expand Down
29 changes: 14 additions & 15 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,21 +533,6 @@ def get_conversation_stats(service: RetrievalService) -> Response:
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/getConversationHeatmapByHour', methods=['GET'])
def get_questions_heatmap_by_hour(service: RetrievalService) -> Response:
course_name = request.args.get('course_name', default='', type=str)

if not course_name:
abort(400, description="Missing required parameter: 'course_name' must be provided.")

heatmap_data = service.getConversationHeatmapByHour(course_name)

response = jsonify(heatmap_data)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/run_flow', methods=['POST'])
def run_flow(service: WorkflowService) -> Response:
"""
Expand Down Expand Up @@ -606,6 +591,20 @@ def createProject(service: ProjectService, flaskExecutor: ExecutorInterface) ->
return response


@app.route('/getProjectStats', methods=['GET'])
def get_project_stats(service: RetrievalService) -> Response:
project_name = request.args.get('project_name', default='', type=str)

if project_name == '':
abort(400, description="Missing required parameter: 'project_name' must be provided.")

project_stats = service.getProjectStats(project_name)

response = jsonify(project_stats)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


def configure(binder: Binder) -> None:
binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter(max_workers=10), scope=SingletonScope)
binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter(max_workers=10), scope=SingletonScope)
Expand Down
144 changes: 73 additions & 71 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ai_ta_backend.database.sql import SQLDatabase
from ai_ta_backend.database.vector import VectorDatabase
from ai_ta_backend.executors.thread_pool_executor import ThreadPoolExecutorAdapter
from ai_ta_backend.service.nomic_service import NomicService
# from ai_ta_backend.service.nomic_service import NomicService
from ai_ta_backend.service.posthog_service import PosthogService
from ai_ta_backend.service.sentry_service import SentryService

Expand All @@ -30,13 +30,13 @@ class RetrievalService:

@inject
def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService,
sentry: SentryService, nomicService: NomicService, thread_pool_executor: ThreadPoolExecutorAdapter):
sentry: SentryService, thread_pool_executor: ThreadPoolExecutorAdapter): # nomicService: NomicService,
self.vdb = vdb
self.sqlDb = sqlDb
self.aws = aws
self.sentry = sentry
self.posthog = posthog
self.nomicService = nomicService
# self.nomicService = nomicService
self.thread_pool_executor = thread_pool_executor
openai.api_key = os.environ["VLADS_OPENAI_KEY"]

Expand Down Expand Up @@ -526,79 +526,81 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]:
def getConversationStats(self, course_name: str):
"""
Fetches conversation data from the database and groups them by day, hour, and weekday.

Args:
course_name (str)

Returns:
dict: Aggregated conversation counts:
- 'per_day': By date (YYYY-MM-DD).
- 'per_hour': By hour (0-23).
- 'per_weekday': By weekday (Monday-Sunday).
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)

central_tz = pytz.timezone('America/Chicago')

grouped_data = {
'per_day': defaultdict(int),
'per_hour': defaultdict(int),
'per_weekday': defaultdict(int),
}

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)

central_time = parsed_date.astimezone(central_tz)

day = central_time.date()
hour = central_time.hour
day_of_week = central_time.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][hour] += 1
grouped_data['per_weekday'][day_of_week] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': dict(grouped_data['per_hour']),
'per_weekday': dict(grouped_data['per_weekday']),
}
try:
conversations, total_count = self.sqlDb.getConversationsCreatedAtByCourse(course_name)

# Initialize with empty data (all zeros)
response_data = {
'per_day': {},
'per_hour': {str(hour): 0 for hour in range(24)}, # Convert hour to string for consistency
'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'heatmap': {day: {str(hour): 0 for hour in range(24)} for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'total_count': 0
}

if not conversations:
return response_data

central_tz = pytz.timezone('America/Chicago')
grouped_data = {
'per_day': defaultdict(int),
'per_hour': defaultdict(int),
'per_weekday': defaultdict(int),
'heatmap': defaultdict(lambda: defaultdict(int)),
}

for record in conversations:
try:
created_at = record['created_at']
parsed_date = parser.parse(created_at).astimezone(central_tz)

day = parsed_date.date()
hour = parsed_date.hour
day_of_week = parsed_date.strftime('%A')

grouped_data['per_day'][str(day)] += 1
grouped_data['per_hour'][str(hour)] += 1 # Convert hour to string
grouped_data['per_weekday'][day_of_week] += 1
grouped_data['heatmap'][day_of_week][str(hour)] += 1 # Convert hour to string
except Exception as e:
print(f"Error processing record: {str(e)}")
continue

return {
'per_day': dict(grouped_data['per_day']),
'per_hour': {str(k): v for k, v in grouped_data['per_hour'].items()},
'per_weekday': dict(grouped_data['per_weekday']),
'heatmap': {day: {str(h): count for h, count in hours.items()}
for day, hours in grouped_data['heatmap'].items()},
'total_count': total_count
}

def getConversationHeatmapByHour(self, course_name: str):
except Exception as e:
print(f"Error in getConversationStats for course {course_name}: {str(e)}")
self.sentry.capture_exception(e)
# Return empty data structure on error
return {
'per_day': {},
'per_hour': {str(hour): 0 for hour in range(24)},
'per_weekday': {day: 0 for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'heatmap': {day: {str(hour): 0 for hour in range(24)}
for day in ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']},
'total_count': 0
}

def getProjectStats(self, project_name: str) -> Dict[str, int]:
"""
Fetches conversation data and groups them into a heatmap by day of the week and hour (Central Time).
Get statistics for a project.

Args:
course_name (str)
project_name (str)

Returns:
dict: A nested dictionary with days of the week as outer keys and hours (0-23) as inner keys, where values are conversation counts.
Dict[str, int]: Dictionary containing:
- total_conversations: Total number of conversations
- total_users: Number of unique users
- total_messages: Total number of messages
"""
response = self.sqlDb.getConversationsCreatedAtByCourse(course_name)
central_tz = pytz.timezone('America/Chicago')

heatmap_data = defaultdict(lambda: defaultdict(int))

if response and hasattr(response, 'data') and response.data:
for record in response.data:
created_at = record['created_at']

parsed_date = parser.parse(created_at)
central_time = parsed_date.astimezone(central_tz)

day_of_week = central_time.strftime('%A')
hour = central_time.hour

heatmap_data[day_of_week][hour] += 1
else:
print("No valid response data. Check if the query is correct or if the response is empty.")
return {}

return dict(heatmap_data)
return self.sqlDb.getProjectStats(project_name)

Loading