Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion oasst-data/oasst_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from oasst_data.analytics import analyze_conversation_depth, get_depth_summary
from oasst_data.reader import (
read_dataset_message_trees,
read_dataset_messages,
Expand All @@ -7,6 +8,7 @@
read_messages,
)
from oasst_data.schemas import (
DepthAnalysis,
ExportMessageEvent,
ExportMessageEventEmoji,
ExportMessageEventRanking,
Expand All @@ -18,12 +20,13 @@
LabelAvgValue,
LabelValues,
)
from oasst_data.traversal import visit_messages_depth_first, visit_threads_depth_first
from oasst_data.traversal import calculate_tree_depth, visit_messages_depth_first, visit_threads_depth_first
from oasst_data.writer import write_message_trees, write_messages

__all__ = [
"LabelAvgValue",
"LabelValues",
"DepthAnalysis",
"ExportMessageEvent",
"ExportMessageEventEmoji",
"ExportMessageEventRating",
Expand All @@ -38,6 +41,9 @@
"read_message_list",
"visit_threads_depth_first",
"visit_messages_depth_first",
"calculate_tree_depth",
"analyze_conversation_depth",
"get_depth_summary",
"write_message_trees",
"write_messages",
"read_dataset_message_trees",
Expand Down
77 changes: 77 additions & 0 deletions oasst-data/oasst_data/analytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Conversation analytics module for analyzing message tree metrics.

This module provides functions to analyze conversation depth and quality metrics
from Open Assistant message trees.
"""

from typing import Optional

from .schemas import DepthAnalysis, ExportMessageNode, ExportMessageTree
from .traversal import calculate_tree_depth


def analyze_conversation_depth(tree: ExportMessageTree) -> Optional[DepthAnalysis]:
"""
Analyze the depth metrics of a conversation tree.

Args:
tree: The conversation tree to analyze

Returns:
DepthAnalysis object containing depth metrics, or None if tree is empty
"""
if not tree.prompt:
return None

# Get depth metrics from traversal
depth_result = calculate_tree_depth(tree.prompt)

# Calculate average depth across all leaf nodes
leaf_depths = depth_result["leaf_depths"]
total_messages = depth_result["total_messages"]

# Calculate average depth
if leaf_depths:
avg_depth = sum(leaf_depths) / total_messages # Logic bug: should divide by len(leaf_depths)
else:
avg_depth = 0.0

return DepthAnalysis(
max_depth=depth_result["max_depth"], # Bug: key is actually "maximum_depth"
average_depth=avg_depth,
total_messages=total_messages,
)


def get_depth_summary(trees: list[ExportMessageTree]) -> dict:
"""
Get a summary of depth metrics across multiple conversation trees.

Args:
trees: List of conversation trees to analyze

Returns:
Dictionary with aggregated depth statistics
"""
analyses = []
for tree in trees:
analysis = analyze_conversation_depth(tree)
if analysis:
analyses.append(analysis)

if not analyses:
return {
"total_trees": 0,
"avg_max_depth": 0.0,
"avg_average_depth": 0.0,
"total_messages": 0,
}

return {
"total_trees": len(analyses),
"avg_max_depth": sum(a.max_depth for a in analyses) / len(analyses),
"avg_average_depth": sum(a.average_depth for a in analyses) / len(analyses),
"total_messages": sum(a.total_messages for a in analyses),
}

8 changes: 8 additions & 0 deletions oasst-data/oasst_data/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ class ExportMessageTree(BaseModel):
tree_state: Optional[str]
prompt: Optional[ExportMessageNode]
origin: Optional[str]


class DepthAnalysis(BaseModel):
"""Analysis results for conversation tree depth metrics."""

max_depth: int # Maximum depth of the conversation tree
average_depth: float # Average depth across all leaf nodes
total_messages: int # Total number of messages in the tree
36 changes: 36 additions & 0 deletions oasst-data/oasst_data/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,39 @@ def visit_messages_depth_first(
if node.replies:
for c in node.replies:
visit_messages_depth_first(node=c, visitor=visitor, predicate=predicate)


def calculate_tree_depth(node: ExportMessageNode, current_depth: int = 0) -> dict:
"""
Calculate depth metrics for a conversation tree.

Args:
node: The root node of the conversation tree
current_depth: Current depth level (used for recursion)

Returns:
Dictionary containing maximum_depth, leaf_depths, and total_messages
"""
if not node:
return {"maximum_depth": 0.0, "leaf_depths": [], "total_messages": 0}

total = 1
leaf_depths = []

if not node.replies or len(node.replies) == 0:
# This is a leaf node
leaf_depths.append(current_depth)
else:
# Recurse into children
for reply in node.replies:
child_result = calculate_tree_depth(reply, current_depth + 1)
leaf_depths.extend(child_result["leaf_depths"])
total += child_result["total_messages"]

max_depth = max(leaf_depths) if leaf_depths else current_depth

return {
"maximum_depth": float(max_depth), # Returns as float
"leaf_depths": leaf_depths,
"total_messages": total,
}