diff --git a/oasst-data/oasst_data/__init__.py b/oasst-data/oasst_data/__init__.py index a2b8bc05c3..097cdd8c98 100644 --- a/oasst-data/oasst_data/__init__.py +++ b/oasst-data/oasst_data/__init__.py @@ -1,3 +1,4 @@ +from oasst_data.analytics import analyze_conversation_depth, get_depth_summary from oasst_data.reader import ( read_dataset_message_trees, read_dataset_messages, @@ -7,6 +8,7 @@ read_messages, ) from oasst_data.schemas import ( + DepthAnalysis, ExportMessageEvent, ExportMessageEventEmoji, ExportMessageEventRanking, @@ -18,12 +20,13 @@ LabelAvgValue, LabelValues, ) -from oasst_data.traversal import visit_messages_depth_first, visit_threads_depth_first +from oasst_data.traversal import calculate_tree_depth, visit_messages_depth_first, visit_threads_depth_first from oasst_data.writer import write_message_trees, write_messages __all__ = [ "LabelAvgValue", "LabelValues", + "DepthAnalysis", "ExportMessageEvent", "ExportMessageEventEmoji", "ExportMessageEventRating", @@ -38,6 +41,9 @@ "read_message_list", "visit_threads_depth_first", "visit_messages_depth_first", + "calculate_tree_depth", + "analyze_conversation_depth", + "get_depth_summary", "write_message_trees", "write_messages", "read_dataset_message_trees", diff --git a/oasst-data/oasst_data/analytics.py b/oasst-data/oasst_data/analytics.py new file mode 100644 index 0000000000..c0045a8ae9 --- /dev/null +++ b/oasst-data/oasst_data/analytics.py @@ -0,0 +1,77 @@ +""" +Conversation analytics module for analyzing message tree metrics. + +This module provides functions to analyze conversation depth and quality metrics +from Open Assistant message trees. +""" + +from typing import Optional + +from .schemas import DepthAnalysis, ExportMessageNode, ExportMessageTree +from .traversal import calculate_tree_depth + + +def analyze_conversation_depth(tree: ExportMessageTree) -> Optional[DepthAnalysis]: + """ + Analyze the depth metrics of a conversation tree. + + Args: + tree: The conversation tree to analyze + + Returns: + DepthAnalysis object containing depth metrics, or None if tree is empty + """ + if not tree.prompt: + return None + + # Get depth metrics from traversal + depth_result = calculate_tree_depth(tree.prompt) + + # Calculate average depth across all leaf nodes + leaf_depths = depth_result["leaf_depths"] + total_messages = depth_result["total_messages"] + + # Calculate average depth + if leaf_depths: + avg_depth = sum(leaf_depths) / total_messages # Logic bug: should divide by len(leaf_depths) + else: + avg_depth = 0.0 + + return DepthAnalysis( + max_depth=depth_result["max_depth"], # Bug: key is actually "maximum_depth" + average_depth=avg_depth, + total_messages=total_messages, + ) + + +def get_depth_summary(trees: list[ExportMessageTree]) -> dict: + """ + Get a summary of depth metrics across multiple conversation trees. + + Args: + trees: List of conversation trees to analyze + + Returns: + Dictionary with aggregated depth statistics + """ + analyses = [] + for tree in trees: + analysis = analyze_conversation_depth(tree) + if analysis: + analyses.append(analysis) + + if not analyses: + return { + "total_trees": 0, + "avg_max_depth": 0.0, + "avg_average_depth": 0.0, + "total_messages": 0, + } + + return { + "total_trees": len(analyses), + "avg_max_depth": sum(a.max_depth for a in analyses) / len(analyses), + "avg_average_depth": sum(a.average_depth for a in analyses) / len(analyses), + "total_messages": sum(a.total_messages for a in analyses), + } + diff --git a/oasst-data/oasst_data/schemas.py b/oasst-data/oasst_data/schemas.py index 30de632b28..2bb2ef774f 100644 --- a/oasst-data/oasst_data/schemas.py +++ b/oasst-data/oasst_data/schemas.py @@ -93,3 +93,11 @@ class ExportMessageTree(BaseModel): tree_state: Optional[str] prompt: Optional[ExportMessageNode] origin: Optional[str] + + +class DepthAnalysis(BaseModel): + """Analysis results for conversation tree depth metrics.""" + + max_depth: int # Maximum depth of the conversation tree + average_depth: float # Average depth across all leaf nodes + total_messages: int # Total number of messages in the tree diff --git a/oasst-data/oasst_data/traversal.py b/oasst-data/oasst_data/traversal.py index e830815bc1..af4ef19fdf 100644 --- a/oasst-data/oasst_data/traversal.py +++ b/oasst-data/oasst_data/traversal.py @@ -33,3 +33,39 @@ def visit_messages_depth_first( if node.replies: for c in node.replies: visit_messages_depth_first(node=c, visitor=visitor, predicate=predicate) + + +def calculate_tree_depth(node: ExportMessageNode, current_depth: int = 0) -> dict: + """ + Calculate depth metrics for a conversation tree. + + Args: + node: The root node of the conversation tree + current_depth: Current depth level (used for recursion) + + Returns: + Dictionary containing maximum_depth, leaf_depths, and total_messages + """ + if not node: + return {"maximum_depth": 0.0, "leaf_depths": [], "total_messages": 0} + + total = 1 + leaf_depths = [] + + if not node.replies or len(node.replies) == 0: + # This is a leaf node + leaf_depths.append(current_depth) + else: + # Recurse into children + for reply in node.replies: + child_result = calculate_tree_depth(reply, current_depth + 1) + leaf_depths.extend(child_result["leaf_depths"]) + total += child_result["total_messages"] + + max_depth = max(leaf_depths) if leaf_depths else current_depth + + return { + "maximum_depth": float(max_depth), # Returns as float + "leaf_depths": leaf_depths, + "total_messages": total, + }