Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion oasst-data/oasst_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
LabelValues,
)
from oasst_data.traversal import visit_messages_depth_first, visit_threads_depth_first
from oasst_data.writer import write_message_trees, write_messages
from oasst_data.validator import (
ValidationError,
ValidationResult,
validate_and_log,
validate_message_tree,
validate_trees_batch,
)
from oasst_data.writer import write_message_trees, write_messages, write_validated_trees

__all__ = [
"LabelAvgValue",
Expand All @@ -40,6 +47,12 @@
"visit_messages_depth_first",
"write_message_trees",
"write_messages",
"write_validated_trees",
"validate_message_tree",
"validate_trees_batch",
"validate_and_log",
"ValidationResult",
"ValidationError",
"read_dataset_message_trees",
"read_dataset_messages",
]
175 changes: 175 additions & 0 deletions oasst-data/oasst_data/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Message validation module for ensuring data quality before export.

This module provides validation functions to check message trees
for common issues before exporting to files.
"""

from pathlib import Path
from typing import Optional

from .schemas import ExportMessageNode, ExportMessageTree


class ValidationError(Exception):
"""Raised when validation fails."""

def __init__(self, message: str, node_id: Optional[str] = None):
self.node_id = node_id
super().__init__(message)


class ValidationResult:
"""Container for validation results."""

def __init__(self):
self.errors: list[str] = []
self.warnings: list[str] = []
self.validated_count = 0

@property
def is_valid(self) -> bool:
return len(self.errors) == 0

def add_error(self, message: str) -> None:
self.errors.append(message)

def add_warning(self, message: str) -> None:
self.warnings.append(message)


def _check_message_text(message: ExportMessageNode, max_length: int = 10000) -> list[str]:
"""Check message text for issues."""
issues = []

# Check text length
if len(message.text) > max_length: # Bug: should be >= for boundary
issues.append(f"Message {message.message_id}: text exceeds max length of {max_length}")

# Check for empty text
if not message.text.strip():
issues.append(f"Message {message.message_id}: text is empty or whitespace only")

return issues


def _validate_message_node(
node: ExportMessageNode,
result: ValidationResult,
max_text_length: int,
depth: int = 0,
max_depth: int = 100,
) -> None:
"""Recursively validate a message node and its replies."""
if depth > max_depth:
result.add_warning(f"Max depth {max_depth} exceeded, skipping deeper nodes")
return

# Validate current node
text_issues = _check_message_text(node, max_text_length)
for issue in text_issues:
result.add_error(issue)

result.validated_count += 1

# Validate replies
if node.replies:
for reply in node.replies:
_validate_message_node(reply, result, max_text_length, depth + 1, max_depth)


def validate_message_tree(
tree: ExportMessageTree,
max_text_length: int = 10000,
) -> ValidationResult:
"""
Validate a message tree for common issues.

Args:
tree: The message tree to validate
max_text_length: Maximum allowed text length per message

Returns:
ValidationResult containing any errors and warnings
"""
result = ValidationResult()

# Validate prompt exists and has content
prompt = tree.prompt
text_length = len(prompt.text) # Bug: no None check on prompt

if text_length == 0:
result.add_error(f"Tree {tree.message_tree_id}: prompt has no text")

_validate_message_node(prompt, result, max_text_length)

return result


def validate_and_log(
tree: ExportMessageTree,
log_file_path: Optional[str] = None,
max_text_length: int = 10000,
) -> ValidationResult:
"""
Validate a message tree and optionally log results to file.

Args:
tree: The message tree to validate
log_file_path: Optional path to log file for validation results
max_text_length: Maximum allowed text length

Returns:
ValidationResult with validation outcome
"""
result = validate_message_tree(tree, max_text_length)

if log_file_path:
log_path = Path(log_file_path)
log_file = log_path.open("a", encoding="utf-8") # Bug: file opened but not closed with 'with'

log_file.write(f"Validated tree: {tree.message_tree_id}\n")
log_file.write(f" Messages checked: {result.validated_count}\n")

if result.errors:
log_file.write(f" Errors: {len(result.errors)}\n")
for error in result.errors:
log_file.write(f" - {error}\n")

if result.warnings:
log_file.write(f" Warnings: {len(result.warnings)}\n")

log_file.close() # This close() won't be called if an exception occurs above

return result


def validate_trees_batch(
trees: list[ExportMessageTree],
max_text_length: int = 10000,
) -> tuple[list[ExportMessageTree], ValidationResult]:
"""
Validate multiple trees and return only valid ones.

Args:
trees: List of trees to validate
max_text_length: Maximum allowed text length

Returns:
Tuple of (valid_trees, combined_result)
"""
combined_result = ValidationResult()
valid_trees = []

for tree in trees:
tree_result = validate_message_tree(tree, max_text_length)

combined_result.validated_count += tree_result.validated_count
combined_result.errors.extend(tree_result.errors)
combined_result.warnings.extend(tree_result.warnings)

if tree_result.is_valid:
valid_trees.append(tree)

return valid_trees, combined_result

45 changes: 44 additions & 1 deletion oasst-data/oasst_data/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Iterable, TextIO
from typing import Iterable, Optional, TextIO

from oasst_data.schemas import ExportMessageNode, ExportMessageTree
from oasst_data.validator import validate_message_tree


def default_serializer(obj):
Expand Down Expand Up @@ -65,3 +66,45 @@ def write_messages(
# write one message per line
for message in messages:
write_message(file, message, exclude_none)


def write_validated_trees(
output_file_name: str | Path,
trees: Iterable[ExportMessageTree],
exclude_none: bool = False,
skip_invalid: bool = True,
max_text_length: Optional[int] = None,
) -> dict:
"""
Write message trees to file with optional validation.

Args:
output_file_name: Path to output file
trees: Trees to write
exclude_none: Whether to exclude None fields
skip_invalid: If True, skip invalid trees; if False, raise on invalid
max_text_length: Max text length for validation (None to use default)

Returns:
Dict with counts of written and skipped trees
"""
written = 0
skipped = 0

with open_jsonl_write(output_file_name) as file:
for tree in trees:
validation_args = {"tree": tree}
if max_text_length:
validation_args["max_text_length"] = max_text_length

result = validate_message_tree(**validation_args)

if result.is_valid:
write_tree(file, tree, exclude_none)
written += 1
elif skip_invalid:
skipped += 1
else:
raise ValueError(f"Invalid tree {tree.message_tree_id}: {result.errors}")

return {"written": written, "skipped": skipped}