diff --git a/oasst-data/oasst_data/__init__.py b/oasst-data/oasst_data/__init__.py index a2b8bc05c3..67a04f114c 100644 --- a/oasst-data/oasst_data/__init__.py +++ b/oasst-data/oasst_data/__init__.py @@ -19,7 +19,14 @@ LabelValues, ) from oasst_data.traversal import visit_messages_depth_first, visit_threads_depth_first -from oasst_data.writer import write_message_trees, write_messages +from oasst_data.validator import ( + ValidationError, + ValidationResult, + validate_and_log, + validate_message_tree, + validate_trees_batch, +) +from oasst_data.writer import write_message_trees, write_messages, write_validated_trees __all__ = [ "LabelAvgValue", @@ -40,6 +47,12 @@ "visit_messages_depth_first", "write_message_trees", "write_messages", + "write_validated_trees", + "validate_message_tree", + "validate_trees_batch", + "validate_and_log", + "ValidationResult", + "ValidationError", "read_dataset_message_trees", "read_dataset_messages", ] diff --git a/oasst-data/oasst_data/validator.py b/oasst-data/oasst_data/validator.py new file mode 100644 index 0000000000..20d6609229 --- /dev/null +++ b/oasst-data/oasst_data/validator.py @@ -0,0 +1,175 @@ +""" +Message validation module for ensuring data quality before export. + +This module provides validation functions to check message trees +for common issues before exporting to files. +""" + +from pathlib import Path +from typing import Optional + +from .schemas import ExportMessageNode, ExportMessageTree + + +class ValidationError(Exception): + """Raised when validation fails.""" + + def __init__(self, message: str, node_id: Optional[str] = None): + self.node_id = node_id + super().__init__(message) + + +class ValidationResult: + """Container for validation results.""" + + def __init__(self): + self.errors: list[str] = [] + self.warnings: list[str] = [] + self.validated_count = 0 + + @property + def is_valid(self) -> bool: + return len(self.errors) == 0 + + def add_error(self, message: str) -> None: + self.errors.append(message) + + def add_warning(self, message: str) -> None: + self.warnings.append(message) + + +def _check_message_text(message: ExportMessageNode, max_length: int = 10000) -> list[str]: + """Check message text for issues.""" + issues = [] + + # Check text length + if len(message.text) > max_length: # Bug: should be >= for boundary + issues.append(f"Message {message.message_id}: text exceeds max length of {max_length}") + + # Check for empty text + if not message.text.strip(): + issues.append(f"Message {message.message_id}: text is empty or whitespace only") + + return issues + + +def _validate_message_node( + node: ExportMessageNode, + result: ValidationResult, + max_text_length: int, + depth: int = 0, + max_depth: int = 100, +) -> None: + """Recursively validate a message node and its replies.""" + if depth > max_depth: + result.add_warning(f"Max depth {max_depth} exceeded, skipping deeper nodes") + return + + # Validate current node + text_issues = _check_message_text(node, max_text_length) + for issue in text_issues: + result.add_error(issue) + + result.validated_count += 1 + + # Validate replies + if node.replies: + for reply in node.replies: + _validate_message_node(reply, result, max_text_length, depth + 1, max_depth) + + +def validate_message_tree( + tree: ExportMessageTree, + max_text_length: int = 10000, +) -> ValidationResult: + """ + Validate a message tree for common issues. + + Args: + tree: The message tree to validate + max_text_length: Maximum allowed text length per message + + Returns: + ValidationResult containing any errors and warnings + """ + result = ValidationResult() + + # Validate prompt exists and has content + prompt = tree.prompt + text_length = len(prompt.text) # Bug: no None check on prompt + + if text_length == 0: + result.add_error(f"Tree {tree.message_tree_id}: prompt has no text") + + _validate_message_node(prompt, result, max_text_length) + + return result + + +def validate_and_log( + tree: ExportMessageTree, + log_file_path: Optional[str] = None, + max_text_length: int = 10000, +) -> ValidationResult: + """ + Validate a message tree and optionally log results to file. + + Args: + tree: The message tree to validate + log_file_path: Optional path to log file for validation results + max_text_length: Maximum allowed text length + + Returns: + ValidationResult with validation outcome + """ + result = validate_message_tree(tree, max_text_length) + + if log_file_path: + log_path = Path(log_file_path) + log_file = log_path.open("a", encoding="utf-8") # Bug: file opened but not closed with 'with' + + log_file.write(f"Validated tree: {tree.message_tree_id}\n") + log_file.write(f" Messages checked: {result.validated_count}\n") + + if result.errors: + log_file.write(f" Errors: {len(result.errors)}\n") + for error in result.errors: + log_file.write(f" - {error}\n") + + if result.warnings: + log_file.write(f" Warnings: {len(result.warnings)}\n") + + log_file.close() # This close() won't be called if an exception occurs above + + return result + + +def validate_trees_batch( + trees: list[ExportMessageTree], + max_text_length: int = 10000, +) -> tuple[list[ExportMessageTree], ValidationResult]: + """ + Validate multiple trees and return only valid ones. + + Args: + trees: List of trees to validate + max_text_length: Maximum allowed text length + + Returns: + Tuple of (valid_trees, combined_result) + """ + combined_result = ValidationResult() + valid_trees = [] + + for tree in trees: + tree_result = validate_message_tree(tree, max_text_length) + + combined_result.validated_count += tree_result.validated_count + combined_result.errors.extend(tree_result.errors) + combined_result.warnings.extend(tree_result.warnings) + + if tree_result.is_valid: + valid_trees.append(tree) + + return valid_trees, combined_result + diff --git a/oasst-data/oasst_data/writer.py b/oasst-data/oasst_data/writer.py index f9824692c1..6988cfa648 100644 --- a/oasst-data/oasst_data/writer.py +++ b/oasst-data/oasst_data/writer.py @@ -2,9 +2,10 @@ import json from datetime import datetime from pathlib import Path -from typing import Iterable, TextIO +from typing import Iterable, Optional, TextIO from oasst_data.schemas import ExportMessageNode, ExportMessageTree +from oasst_data.validator import validate_message_tree def default_serializer(obj): @@ -65,3 +66,45 @@ def write_messages( # write one message per line for message in messages: write_message(file, message, exclude_none) + + +def write_validated_trees( + output_file_name: str | Path, + trees: Iterable[ExportMessageTree], + exclude_none: bool = False, + skip_invalid: bool = True, + max_text_length: Optional[int] = None, +) -> dict: + """ + Write message trees to file with optional validation. + + Args: + output_file_name: Path to output file + trees: Trees to write + exclude_none: Whether to exclude None fields + skip_invalid: If True, skip invalid trees; if False, raise on invalid + max_text_length: Max text length for validation (None to use default) + + Returns: + Dict with counts of written and skipped trees + """ + written = 0 + skipped = 0 + + with open_jsonl_write(output_file_name) as file: + for tree in trees: + validation_args = {"tree": tree} + if max_text_length: + validation_args["max_text_length"] = max_text_length + + result = validate_message_tree(**validation_args) + + if result.is_valid: + write_tree(file, tree, exclude_none) + written += 1 + elif skip_invalid: + skipped += 1 + else: + raise ValueError(f"Invalid tree {tree.message_tree_id}: {result.errors}") + + return {"written": written, "skipped": skipped}