diff --git a/pyproject.toml b/pyproject.toml index 7b0dceb..1fa1020 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "datasets>=3.0.0", "tinker", "matplotlib>=3.8.0", + "dagshub", + "mlflow", ] [project.optional-dependencies] diff --git a/scripts/analysis_data.py b/scripts/analysis_data.py new file mode 100644 index 0000000..da642cc --- /dev/null +++ b/scripts/analysis_data.py @@ -0,0 +1,664 @@ +#!/usr/bin/env python3 +""" +Logical Compression Analysis +Analyzes symbol usage, logical coherence, and pattern quality in compression training data. +Based on propositional logic principles: scope, precedence, ambiguity detection. +""" + +import json +import re +import statistics +from collections import Counter, defaultdict +from datetime import datetime +from pathlib import Path + +# ============================================================================ +# SYMBOL DEFINITIONS +# ============================================================================ + +SYMBOLS = { + "→": "implication", # if-then, leads to, results in (\u2192) + "|": "separator", # separates facts/alternatives + "@": "location", # at, located at + "∵": "causation", # because, due to (\u2235) + ":": "assignment", # label, definition +} + +# Logical operators from propositional logic +LOGICAL_CONNECTIVES = { + "&": "conjunction", # and + "|": "disjunction", # or (if used logically) + "→": "implication", # if-then + "⇔": "biconditional", # if and only if + "¬": "negation", # not + "~": "negation_alt", # not (alternative) +} + + +# ============================================================================ +# PATTERN ANALYSIS FUNCTIONS +# ============================================================================ + + +def extract_verbose_compressed(sample: dict) -> tuple[str, str]: + """Extract input (verbose) and output (compressed) from message structure.""" + messages = sample.get("messages", []) + + verbose = "" + compressed = "" + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + + if role == "user" and "Compress:" in content: + # Extract text after "Compress:" + verbose = content.split("Compress:", 1)[1].strip() + elif role == "assistant": + compressed = content.strip() + + return verbose, compressed + + +def tokenize_compression(compressed: str) -> list[str]: + """Break compression into tokens: symbols and text chunks.""" + tokens = [] + current = [] + + for char in compressed: + if char in SYMBOLS or char in LOGICAL_CONNECTIVES: + if current: + tokens.append("".join(current).strip()) + current = [] + tokens.append(char) + else: + current.append(char) + + if current: + tokens.append("".join(current).strip()) + + return [t for t in tokens if t] + + +def extract_symbol_sequence(compressed: str) -> str: + """Extract just the symbol sequence (X for content, symbols as-is).""" + result = [] + for char in compressed: + if char in SYMBOLS or char in LOGICAL_CONNECTIVES: + result.append(char) + elif result and result[-1] != "X": + result.append("X") + + return "".join(result) + + +def detect_scope_ambiguity(compressed: str) -> list[str]: + """ + Detect scope ambiguities based on propositional logic. + E.g., A → B | C could mean (A → B) | C or A → (B | C) + """ + issues = [] + + # Pattern: multiple implications with separator, no parentheses + # A → B | C → D is ambiguous + # Check if there are parentheses to disambiguate + if ( + compressed.count("→") > 1 + and ("|" in compressed or "&" in compressed) + and "(" not in compressed + ): + issues.append("SCOPE_AMBIGUITY: Multiple → with | or & but no grouping") + + # Pattern: A | B → C (which binds first?) + if "|" in compressed and "→" in compressed: + parts = re.split(r"[()]", compressed) + for part in parts: + if "|" in part and "→" in part: + # Check order + pipe_idx = part.find("|") + arrow_idx = part.find("→") + if pipe_idx < arrow_idx: + issues.append("PRECEDENCE_UNCLEAR: | before → without grouping") + + return issues + + +def detect_orphaned_symbols(compressed: str) -> list[str]: + """Detect symbols at start/end or consecutive symbols.""" + issues = [] + + # Symbol at start + if compressed and compressed[0] in SYMBOLS: + issues.append(f"ORPHAN_START: {compressed[0]} at beginning") + + # Symbol at end + if compressed and compressed[-1] in SYMBOLS: + issues.append(f"ORPHAN_END: {compressed[-1]} at end") + + # Consecutive symbols + for i in range(len(compressed) - 1): + if compressed[i] in SYMBOLS and compressed[i + 1] in SYMBOLS: + issues.append(f"CONSECUTIVE: {compressed[i]}{compressed[i + 1]}") + break + + return issues + + +def analyze_symbol_context(verbose: str, compressed: str, symbol: str) -> dict: + """ + Analyze when a symbol is used vs when it could/should be used. + """ + result = { + "symbol": symbol, + "used": symbol in compressed, + "context_present": False, + "context_keywords": [], + } + + verbose_lower = verbose.lower() + + if symbol == "→": + # Implication keywords + keywords = ["if", "then", "therefore", "thus", "hence", "leads to", "results in", "causes"] + for kw in keywords: + if kw in verbose_lower: + result["context_present"] = True + result["context_keywords"].append(kw) + + elif symbol == "∵": + # Causation keywords + keywords = ["because", "since", "due to", "caused by", "reason", "as a result"] + for kw in keywords: + if kw in verbose_lower: + result["context_present"] = True + result["context_keywords"].append(kw) + + elif symbol == "@": + # Location keywords + keywords = ["at", "in", "located", "place", "where", "location", "based"] + for kw in keywords: + if kw in verbose_lower: + result["context_present"] = True + result["context_keywords"].append(kw) + + elif symbol == "|": + # Separator - lists, alternatives + keywords = ["and", "or", ",", ";"] + for kw in keywords: + if kw in verbose_lower: + result["context_present"] = True + result["context_keywords"].append(kw) + + elif symbol == ":": + # Assignment/definition + keywords = ["is", "are", "means", "defined as", "represents"] + for kw in keywords: + if kw in verbose_lower: + result["context_present"] = True + result["context_keywords"].append(kw) + + return result + + +def extract_logical_pattern(compressed: str) -> str: + """ + Extract logical structure pattern. + Similar to propositional logic form but for compression. + """ + # Replace content with P, Q, R... but keep symbols + pattern = compressed + + # Remove parentheses content but keep structure + pattern = re.sub(r"\([^)]+\)", "(P)", pattern) + + # Replace text chunks with P + pattern = re.sub(r"[^→|@∵:()]+", "P", pattern) + + # Collapse consecutive P's + pattern = re.sub(r"P+", "P", pattern) + + return pattern.strip() + + +def check_negation_preservation(verbose: str, compressed: str) -> dict: + """Check if negations are preserved.""" + # Words safe for word-boundary matching + negation_words = ["not", "no", "never", "neither", "nor", "without", "none"] + + verbose_lower = verbose.lower() + compressed_lower = compressed.lower() + + # Whole-word negations + verbose_has_word_neg = any( + re.search(r"\b" + word + r"\b", verbose_lower) for word in negation_words + ) + compressed_has_word_neg = any( + re.search(r"\b" + word + r"\b", compressed_lower) for word in negation_words + ) + + # Contracted negation (handles don't / can’t / isn’t) + verbose_has_contraction = bool(re.search(r"n[’']t\b", verbose_lower)) + compressed_has_contraction = bool(re.search(r"n[’']t\b", compressed_lower)) + + # Negation symbols in compressed form + has_neg_symbol = "¬" in compressed or "~" in compressed or "!" in compressed + + verbose_has_negation = verbose_has_word_neg or verbose_has_contraction + compressed_has_negation = ( + compressed_has_word_neg or compressed_has_contraction or has_neg_symbol + ) + + return { + "verbose_has_negation": verbose_has_negation, + "compressed_has_negation": compressed_has_negation, + "negation_lost": verbose_has_negation and not compressed_has_negation, + } + + +def analyze_symbol_combinations(compressed: str) -> list[tuple[str, str]]: + """Extract symbol pairs (bigrams) to find common combinations.""" + symbols_only = [char for char in compressed if char in SYMBOLS] + bigrams = [] + for i in range(len(symbols_only) - 1): + bigrams.append((symbols_only[i], symbols_only[i + 1])) + return bigrams + + +# ============================================================================ +# MAIN ANALYSIS +# ============================================================================ + + +def analyze_dataset(data: list[dict]) -> dict: + """Run comprehensive logical analysis on compression data.""" + + results = { + "total_samples": len(data), + "symbol_usage": defaultdict(int), + "symbol_context_analysis": defaultdict( + lambda: { + "used_when_context_present": 0, + "not_used_when_context_present": 0, + "used_without_context": 0, + } + ), + "scope_ambiguities": [], + "orphaned_symbols": [], + "logical_patterns": Counter(), + "symbol_combinations": Counter(), + "negation_analysis": { + "total_with_negation": 0, + "negation_preserved": 0, + "negation_lost": 0, + }, + "compression_ratios": [], + "problematic_samples": [], + "good_samples": [], + } + + for idx, sample in enumerate(data): + verbose, compressed = extract_verbose_compressed(sample) + + if not verbose or not compressed: + continue + + # Compression ratio + v_tokens = len(verbose.split()) + c_tokens = len(compressed.split()) + ratio = v_tokens / c_tokens if c_tokens > 0 else 0 + results["compression_ratios"].append(ratio) + + # Symbol usage + for symbol in SYMBOLS: + if symbol in compressed: + results["symbol_usage"][symbol] += 1 + + # Symbol context analysis + for symbol in SYMBOLS: + ctx = analyze_symbol_context(verbose, compressed, symbol) + + if ctx["context_present"] and ctx["used"]: + results["symbol_context_analysis"][symbol]["used_when_context_present"] += 1 + elif ctx["context_present"] and not ctx["used"]: + results["symbol_context_analysis"][symbol]["not_used_when_context_present"] += 1 + elif not ctx["context_present"] and ctx["used"]: + results["symbol_context_analysis"][symbol]["used_without_context"] += 1 + + # Scope ambiguity detection + scope_issues = detect_scope_ambiguity(compressed) + if scope_issues: + results["scope_ambiguities"].append( + {"id": idx, "issues": scope_issues, "compressed": compressed[:150]} + ) + + # Orphaned symbols + orphan_issues = detect_orphaned_symbols(compressed) + if orphan_issues: + results["orphaned_symbols"].append( + {"id": idx, "issues": orphan_issues, "compressed": compressed[:150]} + ) + + # Logical patterns + pattern = extract_logical_pattern(compressed) + results["logical_patterns"][pattern] += 1 + + # Symbol combinations + bigrams = analyze_symbol_combinations(compressed) + for bg in bigrams: + results["symbol_combinations"][bg] += 1 + + # Negation analysis + neg_check = check_negation_preservation(verbose, compressed) + if neg_check["verbose_has_negation"]: + results["negation_analysis"]["total_with_negation"] += 1 + if neg_check["compressed_has_negation"]: + results["negation_analysis"]["negation_preserved"] += 1 + else: + results["negation_analysis"]["negation_lost"] += 1 + + # Flag problematic samples + if ratio < 1.0 or scope_issues or orphan_issues: + results["problematic_samples"].append( + { + "id": idx, + "ratio": ratio, + "scope_issues": scope_issues, + "orphan_issues": orphan_issues, + "verbose": verbose[:100], + "compressed": compressed[:100], + } + ) + + # Flag good samples + if ratio > 3.0 and not scope_issues and not orphan_issues: + results["good_samples"].append( + {"id": idx, "ratio": ratio, "compressed": compressed[:150]} + ) + + return results + + +def generate_report(results: dict) -> str: + """Generate analysis report.""" + + lines = [] + lines.append("=" * 80) + lines.append("LOGICAL COMPRESSION ANALYSIS REPORT") + lines.append("=" * 80) + lines.append("") + + # Dataset overview + lines.append("DATASET OVERVIEW") + lines.append("-" * 80) + lines.append(f"Total samples analyzed: {results['total_samples']}") + lines.append("") + + # Compression quality + lines.append("COMPRESSION QUALITY") + lines.append("-" * 80) + ratios = results["compression_ratios"] + if ratios: + lines.append(f"Mean ratio: {statistics.mean(ratios):.2f}x") + lines.append(f"Median ratio: {statistics.median(ratios):.2f}x") + lines.append(f"Min: {min(ratios):.2f}x | Max: {max(ratios):.2f}x") + lines.append( + f"Samples with ratio < 1.0: {sum(1 for r in ratios if r < 1.0)} (WORSE than input)" + ) + lines.append( + f"Samples with ratio > 3.0: {sum(1 for r in ratios if r > 3.0)} (GOOD compression)" + ) + lines.append("") + + # Symbol usage + lines.append("SYMBOL USAGE") + lines.append("-" * 80) + total = results["total_samples"] + for symbol, name in SYMBOLS.items(): + count = results["symbol_usage"][symbol] + pct = (count / total * 100) if total > 0 else 0 + lines.append(f"{symbol} ({name:12s}): {count:4d} / {total} ({pct:5.1f}%)") + lines.append("") + + # Symbol context analysis (KEY INSIGHT) + lines.append("SYMBOL CONTEXT ANALYSIS (When should symbol be used?)") + lines.append("-" * 80) + for symbol in SYMBOLS: + ctx = results["symbol_context_analysis"][symbol] + used_when_should = ctx["used_when_context_present"] + missed_when_should = ctx["not_used_when_context_present"] + used_wrongly = ctx["used_without_context"] + + total_opportunities = used_when_should + missed_when_should + if total_opportunities > 0: + accuracy = used_when_should / total_opportunities * 100 + lines.append(f"\n{symbol} ({SYMBOLS[symbol]}):") + lines.append( + f" Correctly used when context present: {used_when_should} / {total_opportunities} ({accuracy:.1f}%)" + ) + lines.append(f" Missed opportunities: {missed_when_should}") + lines.append(f" Used without clear context: {used_wrongly}") + lines.append("") + + # Symbol combinations (PATTERN DISCOVERY) + lines.append("SYMBOL COMBINATIONS (Bigrams)") + lines.append("-" * 80) + lines.append("Most common symbol sequences:") + for (s1, s2), count in results["symbol_combinations"].most_common(15): + lines.append(f" {s1}{s2} : {count:3d} times") + lines.append("") + + # Logical patterns + lines.append("TOP LOGICAL PATTERNS") + lines.append("-" * 80) + for pattern, count in results["logical_patterns"].most_common(20): + pct = (count / total * 100) if total > 0 else 0 + lines.append(f"{pattern:50s} : {count:3d} ({pct:4.1f}%)") + lines.append("") + + # Scope ambiguities (CRITICAL ISSUE) + lines.append("SCOPE AMBIGUITIES (Propositional Logic Issues)") + lines.append("-" * 80) + if results["scope_ambiguities"]: + lines.append(f"Total samples with scope ambiguity: {len(results['scope_ambiguities'])}") + lines.append("\nFirst 5 examples:") + for item in results["scope_ambiguities"][:5]: + lines.append(f"\n ID {item['id']}:") + for issue in item["issues"]: + lines.append(f" - {issue}") + lines.append(f" Sample: {item['compressed'][:100]}...") + else: + lines.append("No scope ambiguities detected.") + lines.append("") + + # Orphaned symbols + lines.append("ORPHANED SYMBOLS (Syntax Errors)") + lines.append("-" * 80) + if results["orphaned_symbols"]: + lines.append(f"Total samples with orphaned symbols: {len(results['orphaned_symbols'])}") + lines.append("\nFirst 5 examples:") + for item in results["orphaned_symbols"][:5]: + lines.append(f"\n ID {item['id']}:") + for issue in item["issues"]: + lines.append(f" - {issue}") + lines.append(f" Sample: {item['compressed'][:100]}...") + else: + lines.append("No orphaned symbols detected.") + lines.append("") + + # Negation analysis + lines.append("NEGATION PRESERVATION") + lines.append("-" * 80) + neg = results["negation_analysis"] + if neg["total_with_negation"] > 0: + preservation_rate = neg["negation_preserved"] / neg["total_with_negation"] * 100 + lines.append(f"Samples with negation in input: {neg['total_with_negation']}") + lines.append(f"Negation preserved: {neg['negation_preserved']} ({preservation_rate:.1f}%)") + lines.append(f"Negation LOST: {neg['negation_lost']} ({100 - preservation_rate:.1f}%)") + else: + lines.append("No negations detected in inputs.") + lines.append("") + + # Problematic samples + lines.append("PROBLEMATIC SAMPLES (First 10)") + lines.append("-" * 80) + for item in results["problematic_samples"][:10]: + lines.append(f"\nID {item['id']}: Ratio {item['ratio']:.2f}x") + if item["scope_issues"]: + lines.append(f" Scope issues: {', '.join(item['scope_issues'])}") + if item["orphan_issues"]: + lines.append(f" Orphan issues: {', '.join(item['orphan_issues'])}") + lines.append(f" Input: {item['verbose']}...") + lines.append(f" Output: {item['compressed']}...") + lines.append("") + + # Good samples + lines.append("GOOD SAMPLES (High quality compressions, first 5)") + lines.append("-" * 80) + for item in sorted(results["good_samples"], key=lambda x: x["ratio"], reverse=True)[:5]: + lines.append(f"\nID {item['id']}: Ratio {item['ratio']:.2f}x") + lines.append(f" {item['compressed']}") + lines.append("") + + # RECOMMENDATIONS + lines.append("RECOMMENDATIONS") + lines.append("=" * 80) + + rec_num = 1 + + # Symbol usage recommendations + for symbol in ["@", "∵"]: + ctx = results["symbol_context_analysis"][symbol] + total_opportunities = ( + ctx["used_when_context_present"] + ctx["not_used_when_context_present"] + ) + if total_opportunities > 0: + accuracy = ctx["used_when_context_present"] / total_opportunities * 100 + if accuracy < 50: + lines.append( + f"{rec_num}. UNDERUSED SYMBOL {symbol}: Only used {accuracy:.1f}% when context present." + ) + lines.append( + f" Action: Add training examples explicitly using {symbol} or filter samples missing it." + ) + rec_num += 1 + + # Scope ambiguity + if len(results["scope_ambiguities"]) > total * 0.1: + lines.append( + f"{rec_num}. SCOPE AMBIGUITY: {len(results['scope_ambiguities'])} samples have unclear precedence." + ) + lines.append( + " Action: Define operator precedence (e.g., | binds tighter than →) or require parentheses." + ) + rec_num += 1 + + # Negation loss + if neg["total_with_negation"] > 0: + loss_rate = neg["negation_lost"] / neg["total_with_negation"] * 100 + if loss_rate > 30: + lines.append(f"{rec_num}. NEGATION LOST: {loss_rate:.1f}% of negations are dropped.") + lines.append(" Action: Add negation symbol (¬ or ~) to preserve meaning.") + rec_num += 1 + + # Bad compressions + bad_count = sum(1 for r in ratios if r < 1.0) + if bad_count > 0: + lines.append( + f"{rec_num}. FILTER BAD SAMPLES: {bad_count} samples are LONGER than input (ratio < 1.0)." + ) + lines.append(" Action: Remove these from training data.") + rec_num += 1 + + # Symbol combinations insight + top_combo = results["symbol_combinations"].most_common(1) + if top_combo: + (s1, s2), count = top_combo[0] + lines.append(f"{rec_num}. DOMINANT PATTERN: {s1}{s2} appears {count} times.") + lines.append( + " Insight: This is your model's most common structure. Ensure it's logically sound." + ) + rec_num += 1 + + lines.append("") + lines.append("=" * 80) + lines.append("END OF REPORT") + lines.append("=" * 80) + + return "\n".join(lines) + + +# ============================================================================ +# MAIN +# ============================================================================ + + +def main(): + # EDIT THESE PATHS + INPUT_FILE = "data/training/train.jsonl" # <-- PUT YOUR FILE PATH HERE + OUTPUT_DIR = Path("reports/logical_analysis") + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + OUTPUT_REPORT = OUTPUT_DIR / f"logical_analysis_report_{timestamp}.txt" + OUTPUT_JSON = OUTPUT_DIR / f"logical_analysis_data_{timestamp}.json" + + print(f"Loading data from {INPUT_FILE}...") + + try: + with open(INPUT_FILE, encoding="utf-8") as f: + data = [json.loads(line) for line in f if line.strip()] + except FileNotFoundError: + print(f"ERROR: File '{INPUT_FILE}' not found.") + print("Please edit the INPUT_FILE path in the script.") + return + + print(f"Loaded {len(data)} samples.") + print("Running logical analysis...") + + results = analyze_dataset(data) + + print("Generating report...") + report = generate_report(results) + + # Print to console + print("\n" + report) + + # Save report + with open(OUTPUT_REPORT, "w", encoding="utf-8") as f: + f.write(report) + print(f"\nReport saved to {OUTPUT_REPORT}") + + # Save raw data + serializable = { + "total_samples": results["total_samples"], + "symbol_usage": dict(results["symbol_usage"]), + "symbol_context_analysis": { + k: dict(v) for k, v in results["symbol_context_analysis"].items() + }, + "compression_ratios": { + "mean": statistics.mean(results["compression_ratios"]) + if results["compression_ratios"] + else 0, + "median": statistics.median(results["compression_ratios"]) + if results["compression_ratios"] + else 0, + "min": min(results["compression_ratios"]) if results["compression_ratios"] else 0, + "max": max(results["compression_ratios"]) if results["compression_ratios"] else 0, + }, + "logical_patterns": dict(results["logical_patterns"].most_common(50)), + "symbol_combinations": { + f"{s1}{s2}": count for (s1, s2), count in results["symbol_combinations"].most_common(30) + }, + "negation_analysis": results["negation_analysis"], + "scope_ambiguities_count": len(results["scope_ambiguities"]), + "orphaned_symbols_count": len(results["orphaned_symbols"]), + "problematic_count": len(results["problematic_samples"]), + "good_samples_count": len(results["good_samples"]), + } + + with open(OUTPUT_JSON, "w", encoding="utf-8") as f: + json.dump(serializable, f, indent=2) + print(f"Raw data saved to {OUTPUT_JSON}") + + +if __name__ == "__main__": + main() diff --git a/scripts/data_sanitization.py b/scripts/data_sanitization.py new file mode 100644 index 0000000..bfad34c --- /dev/null +++ b/scripts/data_sanitization.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +""" +Data Sanitization for Compression Training + +SCOPE: This script is designed specifically for data/training/train.jsonl + with chat-message format (system/user/assistant roles). + +Usage: + # Default paths + python data_sanitization.py + + # Custom paths + python data_sanitization.py \ + --input data/training/train.jsonl \ + --sanitized data/training/sanitized_train.jsonl \ + --unsanitized data/training/unsanitized_train.jsonl + +CHANGES FROM V3: +- Rule B now code-aware: allows leading @ for code samples (decorators) +- Token-based compression ratio aligned with src/utils/tokenizers.py +- Explicit format validation with parse error logging +- Guards for unexpected message structures +- CLI arguments for flexible path configuration + +Extracts BOTH sanitized and unsanitized samples in one pass. +Unsanitized samples are saved for recovery analysis. +""" + +import argparse +import json +import re +import sys +from pathlib import Path + +# Import tokenizer for token-based ratios +sys.path.insert(0, str(Path(__file__).parent.parent)) +from src.utils.tokenizers import compression_ratio + +# ============================================================================ +# SYMBOL DEFINITIONS +# ============================================================================ + +SYMBOLS = {"→", "|", "@", "∵", ":"} + +# Strict keywords for natural language only +LOCATION_KEYWORDS_NL = [ + "located in", + "located at", + "based in", + "situated in", + "found in", + "positioned at", + "positioned in", + "place is", + "place was", + "city of", + "town of", + "on the shores of", + "near the", + "by the", +] + +CAUSATION_KEYWORDS_NL = [ + "because of", + "due to", + "caused by", + "as a result of", + "leads to", + "results in", + "led to", + "resulted in", + "owing to", + "on account of", + "thanks to", + "consequently", + "therefore", + "thus", +] + +# Code detection indicators +CODE_INDICATORS = [ + "def ", + "class ", + "import ", + "return ", + "yield ", + "async ", + "await ", + "self.", + "__init__", + "__", + "lambda ", + "isinstance(", + "raise ", + "@classmethod", + "@staticmethod", + "@property", + "function ", + "const ", + "let ", + "var ", + "=>", + "async function", + "fn:", + "->", + "fn(", + "void ", + "int ", + "string ", + "bool ", + "```", + "```python", + "```javascript", + "```java", + " def ", + " class ", +] + + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + + +def is_code_sample(verbose: str) -> bool: + """ + Detect if sample is code-related. + + Uses multiple signals to avoid misclassifying brace-heavy text (JSON, etc.) as code. + Requires at least 2 strong signals OR 1 very strong signal. + """ + verbose_lower = verbose.lower() + + # Very strong signals (definitive code indicators) + very_strong_indicators = [ + "def ", + "class ", + "function ", + "import ", + "return ", + "async ", + "await ", + "yield ", + "@property", + "@staticmethod", + "@classmethod", + "fn:", + "lambda ", + "isinstance(", + "raise ", + ] + + for indicator in very_strong_indicators: + if indicator.lower() in verbose_lower: + return True # Single very strong signal is enough + + # Strong signals (likely code) + strong_signals = 0 + + # Check for code-specific keywords + code_keywords = [ + "self.", + "__init__", + "const ", + "let ", + "var ", + "void ", + "int ", + "string ", + "bool ", + ] + for keyword in code_keywords: + if keyword.lower() in verbose_lower: + strong_signals += 1 + break # Count once + + # Check for indentation pattern (multiple indented lines) + lines = verbose.split("\n") + indented_lines = sum(1 for line in lines if line.startswith(" ") or line.startswith("\t")) + if indented_lines >= 3: # Increased threshold from 2 to 3 + strong_signals += 1 + + # Check for code block markers + if ( + "```python" in verbose_lower + or "```javascript" in verbose_lower + or "```java" in verbose_lower + ): + strong_signals += 1 + + # Tightened code patterns - require more context + code_patterns = [ + r"\bdef\s+\w+\s*\(", + r"\bclass\s+\w+\s*[\(:]", + r"\bfunction\s+\w+\s*\(", + r"\w+\s*=\s*function\s*\(", + # Tightened type annotation pattern - require multiple + r"(\w+\s*:\s*\w+.*){2,}", # At least 2 type annotations + ] + + for pattern in code_patterns: + if re.search(pattern, verbose): + strong_signals += 1 + break # Count once + + # Tightened {...} check - only count if it looks like actual code block + # Must have semicolons, return statements, or assignment inside braces + brace_code_pattern = r"\{[^}]*(;|return\s|=\s)[^}]*\}" + if re.search(brace_code_pattern, verbose): + strong_signals += 1 + + # Require at least 2 strong signals to classify as code + return strong_signals >= 2 + + +def extract_verbose_compressed( + sample: dict, sample_id: int +) -> tuple[str | None, str | None, str | None]: + """ + Extract input (verbose) and output (compressed) from chat message structure. + + EXPECTED FORMAT for data/training/train.jsonl: + { + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "Compress:\n"}, + {"role": "assistant", "content": ""} + ] + } + + Returns: + (verbose, compressed, error_message) + error_message is None if parsing succeeded + """ + # Validate top-level structure + if not isinstance(sample, dict): + return None, None, f"Sample {sample_id}: Not a dict" + + if "messages" not in sample: + return None, None, f"Sample {sample_id}: Missing 'messages' key" + + messages = sample.get("messages", []) + + if not isinstance(messages, list): + return None, None, f"Sample {sample_id}: 'messages' is not a list" + + if len(messages) < 2: + return ( + None, + None, + f"Sample {sample_id}: Expected at least 2 messages (user + assistant), got {len(messages)}", + ) + + verbose = "" + compressed = "" + found_user = False + found_assistant = False + + for msg in messages: + if not isinstance(msg, dict): + continue + + role = msg.get("role", "") + content = msg.get("content", "") + + if role == "user": + found_user = True + if "Compress:" in content: + verbose = content.split("Compress:", 1)[1].strip() + else: + return ( + None, + None, + f"Sample {sample_id}: User message doesn't contain 'Compress:' marker", + ) + + elif role == "assistant": + found_assistant = True + compressed = content.strip() + + # Validate we found expected roles + if not found_user: + return None, None, f"Sample {sample_id}: No user message found" + + if not found_assistant: + return None, None, f"Sample {sample_id}: No assistant message found" + + if not verbose: + return None, None, f"Sample {sample_id}: Empty verbose text after 'Compress:'" + + if not compressed: + return None, None, f"Sample {sample_id}: Empty compressed text" + + return verbose, compressed, None + + +def compute_compression_ratio_tokens(verbose: str, compressed: str) -> float: + """ + Compute token-based compression ratio aligned with src/utils/tokenizers.py. + + Returns compression_ratio where ratio > 1.0 means expansion (bad). + For Rule A, we want ratio <= 1.0 (compressed <= verbose in tokens). + """ + return compression_ratio(verbose, compressed) + + +# ============================================================================ +# VALIDATION RULES +# ============================================================================ + + +def rule_a_ratio_check(verbose: str, compressed: str) -> tuple[bool, str]: + """ + Rule A: Remove samples with compression ratio > 1.0 (expansion). + + Uses token-based ratio from src/utils/tokenizers.py. + Ratio > 1.0 means compressed text is longer than input (bad compression). + """ + ratio = compute_compression_ratio_tokens(verbose, compressed) + + # compression_ratio returns compressed/verbose + # We want compressed <= verbose, so ratio <= 1.0 + if ratio > 1.0: + return False, f"Ratio {ratio:.2f} > 1.0 (expansion)" + return True, "" + + +def rule_b_orphaned_symbols(compressed: str, is_code: bool) -> tuple[bool, str]: + """ + Rule B: Remove samples with orphaned symbols. + + CODE-AWARE: Allows leading @ for code samples (Python decorators). + """ + if not compressed: + return False, "Empty compression" + + # Check leading symbol - allow @ for code samples (decorators) + if compressed[0] in SYMBOLS: + if is_code and compressed[0] == "@": + # Valid decorator pattern + pass + else: + return False, f"Orphaned symbol at start: '{compressed[0]}'" + + # Check trailing symbol (: is allowed at end) + if compressed[-1] in SYMBOLS and compressed[-1] != ":": + return False, f"Orphaned symbol at end: '{compressed[-1]}'" + + # Check consecutive symbols (except ::) + for i in range(len(compressed) - 1): + if compressed[i] in SYMBOLS and compressed[i + 1] in SYMBOLS: + if compressed[i] == ":" and compressed[i + 1] == ":": + continue # :: is allowed (namespace separator) + return False, f"Consecutive symbols: '{compressed[i]}{compressed[i + 1]}'" + + return True, "" + + +""" +RULE C FIX: Drop-in replacement for rule_c_negation_preservation + +Replace the existing rule_c_negation_preservation function in your +sanitization script with these functions. + +FIXES: +1. Unicode apostrophes (can't vs can't) - CRITICAL +2. Flexible whitespace in multi-word phrases +3. Reduced false positives from symbols (~, !) +4. Proper contraction handling as suffix, not standalone word +""" + +import re +from typing import Tuple + + +def _normalize_apostrophes(text: str) -> str: + """Normalize Unicode apostrophes to ASCII.""" + unicode_apostrophes = ['\u2019', '\u2018', '\u02BC', '\u0060'] + normalized = text + for unicode_apos in unicode_apostrophes: + normalized = normalized.replace(unicode_apos, "'") + return normalized + + +def _normalize_whitespace(text: str) -> str: + """Normalize whitespace for consistent matching.""" + normalized = text.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ') + normalized = re.sub(r'\s+', ' ', normalized) + return normalized.strip() + + +def _has_contractions(text: str) -> bool: + """Detect negation contractions (don't, can't, won't, etc.).""" + normalized = _normalize_apostrophes(text.lower()) + return bool(re.search(r"\w+n't\b", normalized)) + + +def _has_standalone_negation(text: str) -> bool: + """Detect standalone negation words.""" + normalized = _normalize_whitespace(text.lower()) + words = ["not", "no", "never", "neither", "nor", + "without", "none", "nothing", "nobody", "nowhere"] + + for word in words: + if re.search(r'\b' + re.escape(word) + r'\b', normalized): + return True + return False + + +def _has_multiword_negation(text: str) -> bool: + """Detect multi-word negation phrases with flexible whitespace.""" + normalized = _normalize_whitespace(text.lower()) + patterns = [r'\bno\s+longer\b', r'\bno\s+more\b', r'\bnot\s+anymore\b'] + + return any(re.search(pattern, normalized) for pattern in patterns) + + +def _has_negation_symbols(text: str, strict: bool = True) -> bool: + """ + Detect negation symbols with reduced false positives. + + Args: + strict: If True, only check ¬ (unambiguous) + If False, also check ~ and ! (more permissive) + """ + if '¬' in text: + return True + + if not strict: + # Check for ! but avoid != and !! + if re.search(r'![^=!]', text): + return True + # Check for ~ but avoid ~/ and ~digits + if re.search(r'~(?![/\d])', text): + return True + + return False + + +def rule_c_negation_preservation(verbose: str, compressed: str) -> Tuple[bool, str]: + """ + Rule C: Remove samples that lost negation (NL only). + + FIXED VERSION addressing: + - Unicode apostrophes in contractions (can't vs can't) + - Flexible whitespace in multi-word phrases + - Reduced false positives from symbols + - Proper contraction detection + + Returns: + (passed, reason) - passed=False if negation was lost + """ + # Check if input has any form of negation + has_input_negation = ( + _has_contractions(verbose) or + _has_standalone_negation(verbose) or + _has_multiword_negation(verbose) + ) + + if not has_input_negation: + return True, "" # No negation to preserve + + # Input has negation - check if output preserved it + has_output_negation = ( + _has_contractions(compressed) or + _has_standalone_negation(compressed) or + _has_multiword_negation(compressed) or + _has_negation_symbols(compressed, strict=False) # Allow symbols in compressed + ) + + if not has_output_negation: + return False, "Negation lost" + + return True, "" + + +def rule_d_semantic_symbol_usage_nl(verbose: str, compressed: str) -> tuple[bool, str]: + """Rule D: Remove samples that should use @ or ∵ but don't (NL only)""" + verbose_lower = verbose.lower() + + has_location_context = any(kw in verbose_lower for kw in LOCATION_KEYWORDS_NL) + if has_location_context and "@" not in compressed: + return False, "Location context but no '@'" + + has_causation_context = any(kw in verbose_lower for kw in CAUSATION_KEYWORDS_NL) + if has_causation_context and "∵" not in compressed: + return False, "Causation context but no '∵'" + + return True, "" + + +# ============================================================================ +# MAIN SANITIZATION + EXTRACTION +# ============================================================================ + + +def sanitize_and_extract(input_path: Path, sanitized_path: Path, unsanitized_path: Path) -> dict: + """ + Single-pass processing: sanitize AND extract unsanitized samples. + Both outputs maintain original JSON structure. + + SCOPE: Designed for data/training/train.jsonl with chat-message format. + """ + # Validate input file exists + if not input_path.exists(): + raise FileNotFoundError(f"Input file not found: {input_path}") + + # Validate expected location + if input_path.name != "train.jsonl": + print(f"⚠ WARNING: Expected train.jsonl, got {input_path.name}") + print(" This script is designed for data/training/train.jsonl format") + + print(f"Loading data from {input_path}...") + + data = [] + + with open(input_path, encoding="utf-8") as f: + for idx, line in enumerate(f): + if not line.strip(): + continue + + try: + sample = json.loads(line) + data.append(sample) + except json.JSONDecodeError as e: + stats["parse_errors"] += 1 + stats["failed_all"] += 1 + + error_msg = f"Sample {idx}: JSON decode error ({e.msg})" + unsanitized_data.append({"raw_line": line.strip()}) + + stats["parse_error_samples"].append( + { + "id": idx, + "error": error_msg, + "raw_line": line.strip(), + } + ) + + print(f"⚠ {error_msg}") + continue + + + print(f"✓ Loaded {len(data)} samples\n") + + stats = { + "total_input": len(data), + "code_samples": 0, + "nl_samples": 0, + "code_passed": 0, + "nl_passed": 0, + "rule_a_failed": 0, + "rule_b_failed": 0, + "rule_c_failed": 0, + "rule_d_failed": 0, + "parse_errors": 0, + "passed_all": 0, + "failed_all": 0, + "failed_samples": [], + "passed_samples": [], + "parse_error_samples": [], + } + + sanitized_data = [] + unsanitized_data = [] + + print("Processing samples...\n") + + for idx, sample in enumerate(data): + # Extract with format validation + verbose, compressed, parse_error = extract_verbose_compressed(sample, idx) + + # Handle parse errors - tracked separately from rule failures + # Parse errors are NOT counted as Rule A failures + if parse_error: + stats["parse_errors"] += 1 + stats["failed_all"] += 1 + unsanitized_data.append(sample) + stats["parse_error_samples"].append({"id": idx, "error": parse_error, "sample": sample}) + print(f"⚠ {parse_error}") + continue # Skip rule validation for malformed samples + + is_code = is_code_sample(verbose) + content_type = "code" if is_code else "nl" + + if is_code: + stats["code_samples"] += 1 + else: + stats["nl_samples"] += 1 + + # Apply all rules + passed = True + failure_reason = "" + failed_rules = [] + + # Rule A (universal) - token-based ratio + rule_a_pass, reason = rule_a_ratio_check(verbose, compressed) + if not rule_a_pass: + stats["rule_a_failed"] += 1 + passed = False + failure_reason = f"Rule A: {reason}" + failed_rules.append("A") + + # Rule B (universal, code-aware) + if passed: + rule_b_pass, reason = rule_b_orphaned_symbols(compressed, is_code) + if not rule_b_pass: + stats["rule_b_failed"] += 1 + passed = False + failure_reason = f"Rule B: {reason}" + failed_rules.append("B") + + # Rule C (NL only) + if passed and not is_code: + rule_c_pass, reason = rule_c_negation_preservation(verbose, compressed) + if not rule_c_pass: + stats["rule_c_failed"] += 1 + passed = False + failure_reason = f"Rule C: {reason}" + failed_rules.append("C") + + # Rule D (NL only) + if passed and not is_code: + rule_d_pass, reason = rule_d_semantic_symbol_usage_nl(verbose, compressed) + if not rule_d_pass: + stats["rule_d_failed"] += 1 + passed = False + failure_reason = f"Rule D: {reason}" + failed_rules.append("D") + + # Sort into sanitized or unsanitized + if passed: + stats["passed_all"] += 1 + if is_code: + stats["code_passed"] += 1 + else: + stats["nl_passed"] += 1 + + sanitized_data.append(sample) + stats["passed_samples"].append( + { + "id": idx, + "type": content_type, + "ratio": compute_compression_ratio_tokens(verbose, compressed), + } + ) + else: + stats["failed_all"] += 1 + unsanitized_data.append(sample) + stats["failed_samples"].append( + { + "id": idx, + "type": content_type, + "reason": failure_reason, + "failed_rules": failed_rules, + "sample": sample, + } + ) + + # Save sanitized + print(f"\nSaving sanitized data to {sanitized_path}...") + sanitized_path.parent.mkdir(parents=True, exist_ok=True) + with open(sanitized_path, "w", encoding="utf-8") as f: + for sample in sanitized_data: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + print(f"✓ Saved {len(sanitized_data)} sanitized samples") + + # Save unsanitized + print(f"Saving unsanitized data to {unsanitized_path}...") + unsanitized_path.parent.mkdir(parents=True, exist_ok=True) + with open(unsanitized_path, "w", encoding="utf-8") as f: + for sample in unsanitized_data: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + print(f"✓ Saved {len(unsanitized_data)} unsanitized samples\n") + + return stats + + +def print_statistics(stats: dict): + """Print statistics.""" + total = stats["total_input"] + code = stats["code_samples"] + nl = stats["nl_samples"] + + def pct(n: int, d: int) -> float: + return (n / d * 100.0) if d > 0 else 0.0 + + print("=" * 80) + print("PROCESSING STATISTICS") + print("=" * 80) + print() + + print(f"Total input: {total:5d}") + print(f" Code samples: {code:5d} ({pct(code, total):5.1f}%)") + print(f" NL samples: {nl:5d} ({pct(nl, total):5.1f}%)") + print() + + print(f"✓ SANITIZED (passed): {stats['passed_all']:5d} ({pct(stats['passed_all'], total):5.1f}%)") + print(f" Code: {stats['code_passed']:5d} ({pct(stats['code_passed'], code):5.1f}%)") + print(f" NL: {stats['nl_passed']:5d} ({pct(stats['nl_passed'], nl):5.1f}%)") + print() + + print(f"✗ UNSANITIZED (failed): {stats['failed_all']:5d} ({pct(stats['failed_all'], total):5.1f}%)") + print() + + print("Failed by rule:") + print(f" Rule A (ratio > 1.0): {stats['rule_a_failed']:5d}") + print(f" Rule B (orphaned symbols):{stats['rule_b_failed']:5d}") + print(f" Rule C (lost negation): {stats['rule_c_failed']:5d}") + print(f" Rule D (missing @ or ∵): {stats['rule_d_failed']:5d}") + print(f" Parse errors: {stats['parse_errors']:5d}") + print() + + if stats["parse_errors"] > 0: + print("=" * 80) + print("PARSE ERRORS (First 5)") + print("=" * 80) + print() + for item in stats["parse_error_samples"][:5]: + print(f"Sample {item['id']}:") + print(f" Error: {item['error']}") + print() + + print("=" * 80) + print("UNSANITIZED SAMPLES (First 5)") + print("=" * 80) + print() + + for item in stats["failed_samples"][:5]: + print(f"Sample {item['id']} ({item.get('type', 'unknown').upper()}):") + print(f" Reason: {item['reason']}") + print(f" Failed rules: {', '.join(item.get('failed_rules', []))}") + print() + + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Sanitize compression training data with validation rules", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use default paths + python data_sanitization.py + + # Custom paths + python data_sanitization.py \\ + --input data/training/train.jsonl \\ + --sanitized data/training/sanitized_train.jsonl \\ + --unsanitized data/training/unsanitized_train.jsonl + + # Different dataset location + python data_sanitization.py \\ + --input data/experiments/custom_train.jsonl \\ + --sanitized data/experiments/custom_sanitized.jsonl \\ + --unsanitized data/experiments/custom_unsanitized.jsonl + +Validation Rules: + Rule A: Compression ratio > 1.0 (expansion) + Rule B: Orphaned symbols (code-aware for @ decorators) + Rule C: Lost negation (NL only) + Rule D: Missing semantic symbols @ or ∵ (NL only) + """, + ) + + parser.add_argument( + "--input", + type=str, + default="data/training/train.jsonl", + help="Path to input training file (default: data/training/train.jsonl)", + ) + parser.add_argument( + "--sanitized", + type=str, + default="data/training/sanitized_train.jsonl", + help="Path to output sanitized file (default: data/training/sanitized_train.jsonl)", + ) + parser.add_argument( + "--unsanitized", + type=str, + default="data/training/unsanitized_train.jsonl", + help="Path to output unsanitized file (default: data/training/unsanitized_train.jsonl)", + ) + + args = parser.parse_args() + + # Convert to Path objects + input_path = Path(args.input) + sanitized_path = Path(args.sanitized) + unsanitized_path = Path(args.unsanitized) + + print("\n" + "=" * 80) + print("DATA SANITIZATION + EXTRACTION (Single Pass)") + print("SCOPE: data/training/train.jsonl with chat-message format") + print("=" * 80) + print() + print(f"Input: {input_path}") + print(f"Sanitized: {sanitized_path}") + print(f"Unsanitized: {unsanitized_path}") + print("=" * 80) + print() + + # Process + stats = sanitize_and_extract(input_path, sanitized_path, unsanitized_path) + + # Print stats + print_statistics(stats) + + print("=" * 80) + print("OUTPUT FILES") + print("=" * 80) + print() + print(f"1. {sanitized_path}") + print(f" → {stats['passed_all']} samples (clean, ready for training)") + print() + print(f"2. {unsanitized_path}") + print(f" → {stats['failed_all']} samples (for recovery analysis)") + print() + print("=" * 80) + print("NEXT STEPS") + print("=" * 80) + print() + print("1. Train on sanitized data:") + print(f" Use {sanitized_path} ({stats['passed_all']} samples)") + print() + print("2. Analyze unsanitized samples for recovery:") + print(f" Use {unsanitized_path} ({stats['failed_all']} samples)") + print() + + +if __name__ == "__main__": + main() diff --git a/scripts/dataset_manager.py b/scripts/dataset_manager.py new file mode 100644 index 0000000..616fb7e --- /dev/null +++ b/scripts/dataset_manager.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +""" +Dataset Manager +Manages swapping between sanitized and original training datasets. + +Usage: + python dataset_manager.py --update # Switch to sanitized dataset + python dataset_manager.py --revert # Switch back to original dataset + python dataset_manager.py --status # Check current state + python dataset_manager.py --log # View change history + + # With custom paths + python dataset_manager.py --update \ + --train-file data/training/train.jsonl \ + --sanitized data/training/sanitized_train.jsonl +""" + +import argparse +import json +import shutil +from datetime import datetime +from pathlib import Path + +# ============================================================================ +# CONFIGURATION +# ============================================================================ + + +def get_config(args) -> dict: + """Build configuration from CLI arguments with sensible defaults.""" + + # Use CLI args if provided, otherwise use defaults + train_file = Path(args.train_file) if args.train_file else Path("data/training/train.jsonl") + sanitized_file = ( + Path(args.sanitized) if args.sanitized else Path("data/training/sanitized_train.jsonl") + ) + + # Derive backup path from train file + backup_file = train_file.parent / f"{train_file.stem}.original{train_file.suffix}" + + # State files in same directory as train file + state_dir = train_file.parent + + return { + # Main training file (the one used by training scripts) + "active_train": train_file, + # Backup of original data + "original_backup": backup_file, + # Sanitized data + "sanitized_data": sanitized_file, + # State and log files + "state_file": state_dir / ".dataset_state.json", + "log_file": state_dir / ".dataset_changes.log", + } + + +# ============================================================================ +# STATE MANAGEMENT +# ============================================================================ + + +def load_state(config: dict) -> dict: + """Load current dataset state.""" + if config["state_file"].exists(): + with open(config["state_file"]) as f: + return json.load(f) + + return { + "current": "original", # 'original' or 'sanitized' + "last_action": None, # 'update' or 'revert' + "last_change": None, + "change_count": 0, + } + + +def save_state(config: dict, state: dict): + """Save dataset state.""" + config["state_file"].parent.mkdir(parents=True, exist_ok=True) + with open(config["state_file"], "w") as f: + json.dump(state, f, indent=2) + + +def log_change(config: dict, action: str, from_state: str, to_state: str, details: str = ""): + """Log a dataset change.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + log_entry = f"[{timestamp}] {action}: {from_state} → {to_state}" + if details: + log_entry += f" | {details}" + log_entry += "\n" + + config["log_file"].parent.mkdir(parents=True, exist_ok=True) + with open(config["log_file"], "a") as f: + f.write(log_entry) + + print(f"✓ Logged: {log_entry.strip()}") + + +# ============================================================================ +# FILE OPERATIONS +# ============================================================================ + + +def count_samples(path): + """ + Count number of JSON objects in a JSONL file. + """ + path = Path(path) + + if not path.exists(): + return 0 + + count = 0 + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + json.loads(line) + count += 1 + except json.JSONDecodeError: + continue # skip malformed lines safely + + return count + + +def verify_files_exist(config: dict) -> bool: + """Verify required files exist.""" + errors = [] + + if not config["active_train"].exists(): + errors.append(f"❌ Active training file not found: {config['active_train']}") + + if not config["sanitized_data"].exists(): + errors.append(f"❌ Sanitized data not found: {config['sanitized_data']}") + + if errors: + print("\n".join(errors)) + print("\nPlease ensure sanitized data exists before updating.") + return False + + return True + + +def backup_original(config: dict): + """Create backup of original data if it doesn't exist.""" + if not config["original_backup"].exists() and config["active_train"].exists(): + print("Creating backup of original data...") + shutil.copy2(config["active_train"], config["original_backup"]) + print(f"✓ Backup created: {config['original_backup']}") + return True + return False + + +# ============================================================================ +# MAIN OPERATIONS +# ============================================================================ + + +def update_to_sanitized(config: dict) -> bool: + """Switch to sanitized dataset.""" + state = load_state(config) + + # Safety check: if last action was update, only allow revert + if state.get("last_action") == "update": + print("❌ Last action was already UPDATE.") + print(" You can only REVERT after an update.") + print() + print("To revert to original dataset, run:") + print(f" python dataset_manager.py --revert --train-file {config['active_train']}") + return False + + if state["current"] == "sanitized": + print("⚠ Already using sanitized dataset. No changes made.") + return False + + if not verify_files_exist(config): + return False + + # Create backup of original if needed + backup_original(config) + + # Get sample counts + original_count = count_samples(config["active_train"]) + sanitized_count = count_samples(config["sanitized_data"]) + + print("\nSwitching to sanitized dataset...") + print(f" Original samples: {original_count}") + print(f" Sanitized samples: {sanitized_count}") + print(f" Removed samples: {original_count - sanitized_count}") + + # Perform swap + try: + shutil.copy2(config["sanitized_data"], config["active_train"]) + + # Update state + state["current"] = "sanitized" + state["last_action"] = "update" + state["last_change"] = datetime.now().isoformat() + state["change_count"] += 1 + save_state(config, state) + + # Log change + log_change( + config, + action="UPDATE", + from_state="original", + to_state="sanitized", + details=f"{original_count} → {sanitized_count} samples", + ) + + print("\n✓ Successfully switched to sanitized dataset") + print(f"✓ {config['active_train']} now contains {sanitized_count} samples") + print() + print("⚠ To undo this change, run:") + print(f" python dataset_manager.py --revert --train-file {config['active_train']}") + return True + + except Exception as e: + print(f"\n❌ Error during update: {e}") + return False + + +def revert_to_original(config: dict) -> bool: + """Revert to original dataset.""" + state = load_state(config) + + # Safety check: if last action was revert, only allow update + if state.get("last_action") == "revert": + print("❌ Last action was already REVERT.") + print(" You can only UPDATE after a revert.") + print() + print("To switch to sanitized dataset, run:") + print(f" python dataset_manager.py --update --train-file {config['active_train']}") + return False + + if state["current"] == "original": + print("⚠ Already using original dataset. No changes made.") + return False + + if not config["original_backup"].exists(): + print(f"❌ Original backup not found: {config['original_backup']}") + print("Cannot revert without backup.") + return False + + # Get sample counts + sanitized_count = count_samples(config["active_train"]) + original_count = count_samples(config["original_backup"]) + + print("\nReverting to original dataset...") + print(f" Sanitized samples: {sanitized_count}") + print(f" Original samples: {original_count}") + + # Perform swap + try: + shutil.copy2(config["original_backup"], config["active_train"]) + + # Update state + state["current"] = "original" + state["last_action"] = "revert" + state["last_change"] = datetime.now().isoformat() + state["change_count"] += 1 + save_state(config, state) + + # Log change + log_change( + config, + action="REVERT", + from_state="sanitized", + to_state="original", + details=f"{sanitized_count} → {original_count} samples", + ) + + print("\n✓ Successfully reverted to original dataset") + print(f"✓ {config['active_train']} now contains {original_count} samples") + print() + print("⚠ To switch back to sanitized, run:") + print(f" python dataset_manager.py --update --train-file {config['active_train']}") + return True + + except Exception as e: + print(f"\n❌ Error during revert: {e}") + return False + + +def show_status(config: dict): + """Show current dataset status.""" + state = load_state(config) + + print("\n" + "=" * 80) + print("DATASET STATUS") + print("=" * 80) + print() + + # Configuration + print("Configuration:") + print(f" Train file: {config['active_train']}") + print(f" Sanitized file: {config['sanitized_data']}") + print(f" Backup file: {config['original_backup']}") + print() + + # Current state + current = state["current"].upper() + last_action = state.get("last_action", "None") + print(f"Current dataset: {current}") + print(f"Last action: {last_action.upper() if last_action else 'None'}") + print(f"Total changes: {state['change_count']}") + + if state["last_change"]: + last_change = datetime.fromisoformat(state["last_change"]) + print(f"Last change: {last_change.strftime('%Y-%m-%d %H:%M:%S')}") + else: + print("Last change: Never") + + print() + + # File information + print("Files:") + + if config["active_train"].exists(): + active_count = count_samples(config["active_train"]) + print(f" ✓ {config['active_train'].name:25s} {active_count:4d} samples (ACTIVE)") + else: + print(f" ❌ {config['active_train'].name:25s} Not found") + + if config["original_backup"].exists(): + original_count = count_samples(config["original_backup"]) + print(f" ✓ {config['original_backup'].name:25s} {original_count:4d} samples (backup)") + else: + print( + f" ⚠ {config['original_backup'].name:25s} Not found (will be created on first update)" + ) + + if config["sanitized_data"].exists(): + sanitized_count = count_samples(config["sanitized_data"]) + print(f" ✓ {config['sanitized_data'].name:25s} {sanitized_count:4d} samples") + else: + print(f" ❌ {config['sanitized_data'].name:25s} Not found") + + print() + + # Recommendations + print("=" * 80) + print("AVAILABLE ACTIONS") + print("=" * 80) + print() + + # Show only the allowed action based on last action + if last_action == "update": + print("✓ You can REVERT (last action was UPDATE):") + print(f" python dataset_manager.py --revert --train-file {config['active_train']}") + print() + print("✗ You cannot UPDATE again (already updated)") + elif last_action == "revert": + print("✓ You can UPDATE (last action was REVERT):") + print(f" python dataset_manager.py --update --train-file {config['active_train']}") + print() + print("✗ You cannot REVERT again (already reverted)") + else: + # No previous action - allow either + if state["current"] == "original": + print("✓ You can UPDATE to sanitized dataset:") + print(f" python dataset_manager.py --update --train-file {config['active_train']}") + else: + print("✓ You can REVERT to original dataset:") + print(f" python dataset_manager.py --revert --train-file {config['active_train']}") + + print() + print("To view change history:") + print(" python dataset_manager.py --log") + print() + + +def show_log(config: dict, lines: int | None = None): + """Show change log.""" + if not config["log_file"].exists(): + print("No changes logged yet.") + return + + print("\n" + "=" * 80) + print("DATASET CHANGE LOG") + print("=" * 80) + print() + + with open(config["log_file"]) as f: + log_lines = f.readlines() + + # Show last N lines if specified + if lines: + log_lines = log_lines[-lines:] + + if not log_lines: + print("No changes logged yet.") + return + + for line in log_lines: + print(line.rstrip()) + + print() + print(f"Total entries: {len(log_lines)}") + print() + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Manage dataset switching between sanitized and original training data", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Switch to sanitized dataset (default paths) + python dataset_manager.py --update + + # Switch with custom paths + python dataset_manager.py --update \\ + --train-file data/training/train.jsonl \\ + --sanitized data/training/sanitized_train.jsonl + + # Revert to original dataset + python dataset_manager.py --revert --train-file data/training/train.jsonl + + # Check current status + python dataset_manager.py --status + + # View full change log + python dataset_manager.py --log + + # View last 10 changes + python dataset_manager.py --log --lines 10 + +Safety: + - After UPDATE, you can only REVERT + - After REVERT, you can only UPDATE + - This prevents accidental double-operations + """, + ) + + # Actions + parser.add_argument("--update", action="store_true", help="Switch to sanitized dataset") + parser.add_argument("--revert", action="store_true", help="Revert to original dataset") + parser.add_argument("--status", action="store_true", help="Show current dataset status") + parser.add_argument("--log", action="store_true", help="Show change log") + parser.add_argument( + "--lines", type=int, metavar="N", help="Show last N log entries (use with --log)" + ) + + # Path configuration + parser.add_argument( + "--train-file", + type=str, + default=None, + help="Path to training file to manage (default: data/training/train.jsonl)", + ) + parser.add_argument( + "--sanitized", + type=str, + default=None, + help="Path to sanitized data file (default: data/training/sanitized_train.jsonl)", + ) + + args = parser.parse_args() + + # Build configuration from args + config = get_config(args) + + # Execute requested action + if args.update: + update_to_sanitized(config) + elif args.revert: + revert_to_original(config) + elif args.log: + show_log(config, args.lines) + elif args.status: + show_status(config) + else: + # Default: show status + show_status(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/mlflow_logger.py b/scripts/mlflow_logger.py new file mode 100644 index 0000000..10580be --- /dev/null +++ b/scripts/mlflow_logger.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +""" +Post-training MLflow logger for existing runs. + +Behavior: +- If --run-dir is provided → use it +- If --run-dir is omitted → auto-detect latest run in models/runs/mlx/ +- DagsHub/MLflow destination can be configured via CLI or env vars. + +Env vars: +- DAGSHUB_OWNER +- DAGSHUB_REPO +- MLFLOW_TRACKING_URI +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +from pathlib import Path + +import dagshub +import matplotlib.pyplot as plt +import mlflow + + +# =============================== +# ARGUMENTS +# =============================== +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + + p.add_argument( + "--run-dir", + type=Path, + default=None, + help="Path to a specific run directory. If omitted, latest run is used.", + ) + p.add_argument( + "--runs-root", + type=Path, + default=Path("models/runs/mlx"), + help="Root directory containing run folders (default: models/runs/mlx)", + ) + p.add_argument("--experiment-name", required=True, type=str) + + # Remote / tracking config + p.add_argument( + "--dagshub-owner", + type=str, + default=os.getenv("DAGSHUB_OWNER", "Gautam-Galada"), + help="DagsHub repo owner (env: DAGSHUB_OWNER). Default: Gautam-Galada", + ) + p.add_argument( + "--dagshub-repo", + type=str, + default=os.getenv("DAGSHUB_REPO", "compression-layer"), + help="DagsHub repo name (env: DAGSHUB_REPO). Default: compression-layer", + ) + p.add_argument( + "--mlflow-tracking-uri", + type=str, + default=os.getenv("MLFLOW_TRACKING_URI", ""), + help=( + "MLflow tracking URI (env: MLFLOW_TRACKING_URI). " + "If omitted, derived from DagsHub owner/repo." + ), + ) + p.add_argument( + "--no-dagshub-init", + action="store_true", + help="Skip dagshub.init() (useful if tracking is configured elsewhere).", + ) + + return p.parse_args() + + +# =============================== +# RUN DISCOVERY +# =============================== +def find_latest_run(runs_root: Path) -> Path: + if not runs_root.exists(): + raise FileNotFoundError(f"Runs root not found: {runs_root}") + + candidates = [p for p in runs_root.iterdir() if p.is_dir() and p.name != "latest"] + if not candidates: + raise RuntimeError(f"No runs found in {runs_root}") + + # Pick newest by mtime (most reliable in practice) + return max(candidates, key=lambda p: p.stat().st_mtime) + + +# =============================== +# CORE LOGIC +# =============================== +def _derive_dagshub_tracking_uri(owner: str, repo: str) -> str: + # Matches your existing hardcoded pattern + return f"https://dagshub.com/{owner}/{repo}.mlflow" + + +def log_run_dir_to_mlflow( + run_dir: Path, + experiment_name: str, + dagshub_owner: str, + dagshub_repo: str, + mlflow_tracking_uri: str, + no_dagshub_init: bool, +) -> None: + if not run_dir.exists(): + raise FileNotFoundError(f"Run dir not found: {run_dir}") + + # =============================== + # INIT DAGSHUB + MLFLOW + # =============================== + tracking_uri = mlflow_tracking_uri.strip() or _derive_dagshub_tracking_uri( + dagshub_owner, dagshub_repo + ) + + if not no_dagshub_init: + dagshub.init( + repo_owner=dagshub_owner, + repo_name=dagshub_repo, + mlflow=True, + ) + + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment(experiment_name) + + # =============================== + # LOAD FILES + # =============================== + run_json = run_dir / "run.json" + train_log = run_dir / "train.log" + + if not run_json.exists(): + raise FileNotFoundError(f"Missing run.json in {run_dir}") + if not train_log.exists(): + raise FileNotFoundError(f"Missing train.log in {run_dir}") + + with run_json.open("r", encoding="utf-8") as f: + run_cfg = json.load(f) + + log_text = train_log.read_text(encoding="utf-8", errors="replace") + + # =============================== + # REGEX PARSERS + # =============================== + train_pat = re.compile( + r"Iter (\d+): Train loss ([0-9.]+).*Tokens/sec ([0-9.]+).*Peak mem ([0-9.]+) GB" + ) + val_pat = re.compile(r"Iter (\d+): Val loss ([0-9.]+)") + + train_steps, train_loss, tokens_sec, peak_mem = [], [], [], [] + for m in train_pat.finditer(log_text): + train_steps.append(int(m.group(1))) + train_loss.append(float(m.group(2))) + tokens_sec.append(float(m.group(3))) + peak_mem.append(float(m.group(4))) + + val_steps, val_loss = [], [] + for m in val_pat.finditer(log_text): + val_steps.append(int(m.group(1))) + val_loss.append(float(m.group(2))) + + # =============================== + # MLFLOW RUN + # =============================== + run_name = run_cfg.get("started_at", run_dir.name) + + with mlflow.start_run(run_name=run_name): + # --------------------------- + # PARAMS + # --------------------------- + mlflow.log_params( + { + "model": run_cfg.get("model"), + "git_sha": run_cfg.get("git_sha"), + "data_dir": run_cfg.get("data_dir"), + "lora_rank": run_cfg.get("lora_rank"), + "lora_alpha": run_cfg.get("lora_alpha"), + "batch_size": run_cfg.get("batch_size"), + "learning_rate": run_cfg.get("learning_rate"), + "iters": run_cfg.get("iters"), + } + ) + + # --------------------------- + # METRICS + # --------------------------- + for i, step in enumerate(train_steps): + mlflow.log_metric("train_loss", train_loss[i], step=step) + mlflow.log_metric("tokens_per_sec", tokens_sec[i], step=step) + mlflow.log_metric("peak_mem_gb", peak_mem[i], step=step) + + for i, step in enumerate(val_steps): + mlflow.log_metric("val_loss", val_loss[i], step=step) + + # --------------------------- + # ARTIFACTS + # --------------------------- + mlflow.log_artifact(run_json) + mlflow.log_artifact(train_log) + + adapter_dir = run_dir / "adapter" + if adapter_dir.exists() and adapter_dir.is_dir(): + mlflow.log_artifacts(adapter_dir, artifact_path="weights") + + # --------------------------- + # PLOTS + # --------------------------- + if train_steps and val_steps: + plt.figure() + plt.plot(train_steps, train_loss, label="Train Loss") + plt.plot(val_steps, val_loss, label="Val Loss") + plt.xlabel("Iteration") + plt.ylabel("Loss") + plt.legend() + plt.title("Training vs Validation Loss") + + plot_path = run_dir / "loss_curve.png" + plt.savefig(plot_path) + plt.close() + + mlflow.log_artifact(plot_path) + + +# =============================== +# CLI ENTRYPOINT +# =============================== +def main() -> None: + args = parse_args() + + run_dir: Path | None = args.run_dir + if run_dir is None: + run_dir = find_latest_run(args.runs_root) + print(f"📌 Auto-selected latest run: {run_dir}") + + log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name=args.experiment_name, + dagshub_owner=args.dagshub_owner, + dagshub_repo=args.dagshub_repo, + mlflow_tracking_uri=args.mlflow_tracking_uri, + no_dagshub_init=args.no_dagshub_init, + ) + print("✅ Existing training run successfully logged to MLflow") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/utils/tokenizers.py b/src/utils/tokenizers.py index b463cf3..cb7d84b 100644 --- a/src/utils/tokenizers.py +++ b/src/utils/tokenizers.py @@ -33,7 +33,7 @@ def count_tokens(text: str, tokenizer_type: TokenizerType = TokenizerType.OPENAI Number of tokens """ enc = get_tokenizer(tokenizer_type) - return len(enc.encode(text)) + return len(enc.encode(text, disallowed_special=())) def count_tokens_multi(text: str) -> dict[str, int]: diff --git a/tests/test_analysis_data.py b/tests/test_analysis_data.py new file mode 100644 index 0000000..6678c27 --- /dev/null +++ b/tests/test_analysis_data.py @@ -0,0 +1,191 @@ +# tests/test_analysis_data.py +from __future__ import annotations + +import re + +import pytest + +# Adjust this import to match where you place the script/module. +# If the file is, e.g., scripts/analysis_data.py: +# import scripts.analysis_data as m +# +# Run with: PYTHONPATH=. python -m pytest -q +import scripts.analysis_data as m + + +def _sample(verbose: str, compressed: str) -> dict: + return { + "messages": [ + {"role": "user", "content": f"Compress: {verbose}"}, + {"role": "assistant", "content": compressed}, + ] + } + + +def test_extract_verbose_compressed_parses_message_structure(): + s = _sample("If A then B.", "A → B") + verbose, compressed = m.extract_verbose_compressed(s) + assert verbose == "If A then B." + assert compressed == "A → B" + + +def test_tokenize_compression_splits_symbols_and_text(): + tokens = m.tokenize_compression("A → B | C") + # Expect symbols as standalone tokens + assert "→" in tokens + assert "|" in tokens + # And content chunks + assert "A" in tokens + assert "B" in tokens + assert "C" in tokens + + +def test_extract_symbol_sequence_basic(): + # X for content, symbols preserved; adjacent content should collapse to single X + seq = m.extract_symbol_sequence("A → B | C") + # Leading "A " content is omitted, then "→", " B " -> X, "|", " C" -> X + assert seq == "→X|X" + + +def test_detect_scope_ambiguity_multiple_implications_without_grouping(): + issues = m.detect_scope_ambiguity("A → B | C → D") + assert any(i.startswith("SCOPE_AMBIGUITY") for i in issues) + + +def test_detect_scope_ambiguity_precedence_unclear_pipe_before_arrow(): + issues = m.detect_scope_ambiguity("A | B → C") + assert any(i.startswith("PRECEDENCE_UNCLEAR") for i in issues) + + +def test_detect_orphaned_symbols_start_end_and_consecutive(): + assert any("ORPHAN_START" in x for x in m.detect_orphaned_symbols("| A")) + assert any("ORPHAN_END" in x for x in m.detect_orphaned_symbols("A |")) + assert any("CONSECUTIVE" in x for x in m.detect_orphaned_symbols("A||B")) + + +def test_analyze_symbol_context_implication_detects_keywords_and_usage(): + verbose = "If the alarm triggers then evacuate. Therefore leave." + compressed = "alarm → evacuate" + ctx = m.analyze_symbol_context(verbose, compressed, "→") + assert ctx["context_present"] + assert ctx["used"] + assert any(k in ctx["context_keywords"] for k in ["if", "then", "therefore"]) + + +def test_check_negation_preservation_detects_lost_negation(): + verbose = "Do not open the door." + compressed = "open door" # negation lost + out = m.check_negation_preservation(verbose, compressed) + assert out["verbose_has_negation"] + assert out["negation_lost"] + assert not out["compressed_has_negation"] + + +def test_check_negation_preservation_accepts_negation_symbol(): + verbose = "Do not open the door." + compressed = "¬ open door" + out = m.check_negation_preservation(verbose, compressed) + assert out["verbose_has_negation"] + assert out["compressed_has_negation"] + assert not out["negation_lost"] + + +def test_extract_logical_pattern_normalizes_structure(): + # Parentheses content becomes (P), text becomes P, symbols preserved + pat = m.extract_logical_pattern("(foo bar) → baz | qux") + assert "(P)" in pat + assert "→" in pat + assert "|" in pat + # Should be only P and symbols + parentheses + assert re.fullmatch(r"[P→|@∵:() ]+", pat) is not None + + +def test_analyze_symbol_combinations_bigrams_only_from_SYMBOLS_not_LOGICAL_CONNECTIVES(): + # Note: analyze_symbol_combinations uses SYMBOLS only (not LOGICAL_CONNECTIVES) + bigrams = m.analyze_symbol_combinations("A → B | C : D") + # SYMBOLS include →, |, : so we should see bigrams among these (order preserved) + assert ("→", "|") in bigrams + assert ("|", ":") in bigrams + + +def test_analyze_dataset_aggregates_core_fields_and_flags_problematic_and_good(): + data = [ + # Good: high ratio, no orphan/scope issues + _sample( + "This is a long verbose description with many words because it explains details clearly.", + "desc: details", + ), + # Problematic: ratio < 1 (compressed longer than verbose) + _sample("short", "this is longer than short"), + # Scope ambiguity + _sample("If A then B or if C then D", "A → B | C → D"), + # Orphaned symbol + _sample("List items A and B", "| A | B"), + # Negation lost + _sample("Do not proceed", "proceed"), + ] + + results = m.analyze_dataset(data) + + assert results["total_samples"] == len(data) + assert isinstance(results["compression_ratios"], list) + assert results["symbol_usage"]["|"] >= 1 # used in at least one sample + assert "logical_patterns" in results and len(results["logical_patterns"]) > 0 + assert "symbol_combinations" in results + + # We created at least one problematic (ratio < 1), plus scope and orphan cases + assert len(results["problematic_samples"]) >= 3 + + # Negation analysis: one verbose with negation, and it is lost in compressed + neg = results["negation_analysis"] + assert neg["total_with_negation"] >= 1 + assert neg["negation_lost"] >= 1 + + # Good samples: first sample likely yields ratio > 3.0 depending on tokenization + # Make the assertion robust by checking at least one good sample exists OR ratio > 3 appears. + if results["good_samples"]: + assert all(x["ratio"] > 3.0 for x in results["good_samples"]) + + +def test_generate_report_contains_key_sections(): + # Minimal results skeleton for report generation + results = { + "total_samples": 2, + "symbol_usage": {"→": 1, "|": 1, "@": 0, "∵": 0, ":": 1}, + "symbol_context_analysis": { + s: { + "used_when_context_present": 0, + "not_used_when_context_present": 0, + "used_without_context": 0, + } + for s in m.SYMBOLS + }, + "scope_ambiguities": [], + "orphaned_symbols": [], + "logical_patterns": m.Counter({"P→P": 1}), + "symbol_combinations": m.Counter({("→", "|"): 2}), + "negation_analysis": { + "total_with_negation": 0, + "negation_preserved": 0, + "negation_lost": 0, + }, + "compression_ratios": [2.0, 3.5], + "problematic_samples": [], + "good_samples": [{"id": 1, "ratio": 3.5, "compressed": "x"}], + } + + report = m.generate_report(results) + assert "LOGICAL COMPRESSION ANALYSIS REPORT" in report + assert "DATASET OVERVIEW" in report + assert "COMPRESSION QUALITY" in report + assert "SYMBOL USAGE" in report + assert "SCOPE AMBIGUITIES" in report + assert "ORPHANED SYMBOLS" in report + assert "NEGATION PRESERVATION" in report + assert "RECOMMENDATIONS" in report + + +if __name__ == "__main__": + import pytest + + raise SystemExit(pytest.main([__file__, "-v"])) diff --git a/tests/test_mlflow_logger.py b/tests/test_mlflow_logger.py new file mode 100644 index 0000000..26e58da --- /dev/null +++ b/tests/test_mlflow_logger.py @@ -0,0 +1,473 @@ +# tests/test_mlflow_logger.py +""" +Comprehensive tests for mlflow_logger.py + +Coverage: +- Params/metrics/artifacts logging +- Adapter directory handling +- Log parsing (matching/non-matching) +- Latest run discovery +- CLI argument parsing +- Error handling +- DagsHub URI derivation +""" + +from __future__ import annotations + +import json +from contextlib import contextmanager +from pathlib import Path + +import pytest + + +class MlflowRecorder: + """Mock MLflow client that records all operations.""" + + def __init__(self): + self.tracking_uri = None + self.experiment = None + self.run_name = None + self.params = {} + self.metrics = [] # (key, value, step) + self.artifacts = [] # (path, artifact_path) + self.log_artifacts_calls = [] # (dir_path, artifact_path) + + def set_tracking_uri(self, uri: str): + self.tracking_uri = uri + + def set_experiment(self, name: str): + self.experiment = name + + @contextmanager + def start_run(self, run_name: str = None, **_kwargs): + self.run_name = run_name + yield + + def log_params(self, d: dict): + self.params.update(d) + + def log_metric(self, key: str, value: float, step: int | None = None): + self.metrics.append((key, value, step)) + + def log_artifact(self, path: Path, artifact_path: str | None = None): + self.artifacts.append((Path(path), artifact_path)) + + def log_artifacts(self, dir_path: Path, artifact_path: str | None = None): + self.log_artifacts_calls.append((Path(dir_path), artifact_path)) + + +def _write_run_dir( + base: Path, + name: str, + *, + with_adapter: bool = True, + with_matching_logs: bool = True, + with_run_json: bool = True, + with_train_log: bool = True, +) -> Path: + """Helper to create a run directory with configurable content.""" + run_dir = base / name + run_dir.mkdir(parents=True, exist_ok=True) + + if with_run_json: + (run_dir / "run.json").write_text( + json.dumps( + { + "started_at": name, + "model": "mistral", + "git_sha": "abc123", + "data_dir": "data/x", + "lora_rank": 8, + "lora_alpha": 16, + "batch_size": 4, + "learning_rate": 1e-4, + "iters": 100, + } + ), + encoding="utf-8", + ) + + if with_train_log: + if with_matching_logs: + # Must match regex patterns in mlflow_logger.py + (run_dir / "train.log").write_text( + "\n".join( + [ + "Iter 10: Train loss 1.23 | Tokens/sec 456.7 | Peak mem 12.3 GB", + "Iter 20: Train loss 1.10 | Tokens/sec 470.0 | Peak mem 12.4 GB", + "Iter 10: Val loss 1.50", + "Iter 20: Val loss 1.40", + ] + ) + + "\n", + encoding="utf-8", + ) + else: + (run_dir / "train.log").write_text("no matching lines\n", encoding="utf-8") + + if with_adapter: + adapter = run_dir / "adapter" + adapter.mkdir() + (adapter / "adapter.safetensors").write_bytes(b"fake") + (adapter / "config.json").write_text("{}", encoding="utf-8") + + return run_dir + + +@pytest.fixture +def mlflow_recorder(monkeypatch) -> MlflowRecorder: + """ + Patch mlflow_logger.py's `mlflow` and `dagshub.init` to avoid real tracking. + """ + rec = MlflowRecorder() + + import scripts.mlflow_logger as m + + # Patch mlflow and dagshub + monkeypatch.setattr(m, "mlflow", rec, raising=True) + monkeypatch.setattr(m.dagshub, "init", lambda **kwargs: None, raising=True) + + # Non-interactive matplotlib + monkeypatch.setenv("MPLBACKEND", "Agg") + + return rec + + +# ============================================================================ +# CORE FUNCTIONALITY TESTS +# ============================================================================ + + +def test_logs_params_metrics_and_artifacts(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that all params, metrics, and artifacts are logged correctly.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir(runs_root, "2026-02-05_21-59-37", with_adapter=True) + + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + # Check experiment and tracking URI set + assert mlflow_recorder.experiment == "exp_test" + assert mlflow_recorder.tracking_uri is not None + + # Check params + assert mlflow_recorder.params["model"] == "mistral" + assert mlflow_recorder.params["lora_rank"] == 8 + assert mlflow_recorder.params["iters"] == 100 + assert mlflow_recorder.params["git_sha"] == "abc123" + + # Check metrics (2 train steps * 3 metrics + 2 val steps = 8) + assert len(mlflow_recorder.metrics) == 8 + assert ("train_loss", 1.23, 10) in mlflow_recorder.metrics + assert ("tokens_per_sec", 470.0, 20) in mlflow_recorder.metrics + assert ("val_loss", 1.40, 20) in mlflow_recorder.metrics + assert ("peak_mem_gb", 12.3, 10) in mlflow_recorder.metrics + + # Check artifacts + artifact_names = {p.name for (p, _) in mlflow_recorder.artifacts} + assert "run.json" in artifact_names + assert "train.log" in artifact_names + assert "loss_curve.png" in artifact_names + + # Check adapter logged + assert mlflow_recorder.log_artifacts_calls == [(run_dir / "adapter", "weights")] + + +def test_no_adapter_dir_does_not_log_weights(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that missing adapter directory doesn't cause errors.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir(runs_root, "2026-02-05_21-59-37", with_adapter=False) + + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + # No adapter artifacts should be logged + assert mlflow_recorder.log_artifacts_calls == [] + + +def test_nonmatching_log_produces_no_loss_plot(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that non-matching logs don't generate loss plot.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir( + runs_root, + "2026-02-05_21-59-37", + with_adapter=False, + with_matching_logs=False, + ) + + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + artifact_names = {p.name for (p, _) in mlflow_recorder.artifacts} + assert "run.json" in artifact_names + assert "train.log" in artifact_names + assert "loss_curve.png" not in artifact_names # No plot without metrics + + +# ============================================================================ +# RUN DISCOVERY TESTS +# ============================================================================ + + +def test_find_latest_run_ignores_latest_and_uses_mtime(tmp_path: Path): + """Test that find_latest_run() ignores 'latest' symlink and uses mtime.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + runs_root.mkdir(parents=True, exist_ok=True) + + # Create two runs + older = _write_run_dir(runs_root, "2026-02-04_00-39-58", with_adapter=False) + newer = _write_run_dir(runs_root, "2026-02-06_11-52-46", with_adapter=False) + + # Create 'latest' directory that should be ignored + (runs_root / "latest").mkdir() + + # Set mtimes explicitly + import os + + older_ts = 1_000_000_000 + newer_ts = 1_000_000_100 + os.utime(older, (older_ts, older_ts)) + os.utime(newer, (newer_ts, newer_ts)) + + latest = m.find_latest_run(runs_root) + assert latest.name == newer.name + + +def test_find_latest_run_raises_if_no_runs(tmp_path: Path): + """Test that find_latest_run raises error when no runs found.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + runs_root.mkdir(parents=True, exist_ok=True) + + # Only create 'latest' directory (should be ignored) + (runs_root / "latest").mkdir() + + with pytest.raises(RuntimeError, match="No runs found"): + m.find_latest_run(runs_root) + + +def test_find_latest_run_raises_if_root_missing(tmp_path: Path): + """Test that find_latest_run raises error when root doesn't exist.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "nonexistent" / "runs" + + with pytest.raises(FileNotFoundError, match="Runs root not found"): + m.find_latest_run(runs_root) + + +# ============================================================================ +# ERROR HANDLING TESTS +# ============================================================================ + + +def test_missing_run_json_raises_error(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that missing run.json raises FileNotFoundError.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir( + runs_root, + "2026-02-05_21-59-37", + with_run_json=False, # Don't create run.json + with_adapter=False, + ) + + with pytest.raises(FileNotFoundError, match="Missing run.json"): + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + +def test_missing_train_log_raises_error(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that missing train.log raises FileNotFoundError.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir( + runs_root, + "2026-02-05_21-59-37", + with_train_log=False, # Don't create train.log + with_adapter=False, + ) + + with pytest.raises(FileNotFoundError, match="Missing train.log"): + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + +def test_nonexistent_run_dir_raises_error(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that nonexistent run directory raises FileNotFoundError.""" + import scripts.mlflow_logger as m + + run_dir = tmp_path / "nonexistent_run" + + with pytest.raises(FileNotFoundError, match="Run dir not found"): + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri="", + no_dagshub_init=False, + ) + + +# ============================================================================ +# DAGSHUB URI DERIVATION TESTS +# ============================================================================ + + +def test_derive_dagshub_tracking_uri(): + """Test DagsHub tracking URI derivation.""" + import scripts.mlflow_logger as m + + uri = m._derive_dagshub_tracking_uri("my-owner", "my-repo") + assert uri == "https://dagshub.com/my-owner/my-repo.mlflow" + + +def test_custom_tracking_uri_used_when_provided(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that custom tracking URI is used when provided.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir(runs_root, "2026-02-05_21-59-37", with_adapter=False) + + custom_uri = "https://custom.mlflow.server/tracking" + + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="test-owner", + dagshub_repo="test-repo", + mlflow_tracking_uri=custom_uri, + no_dagshub_init=False, + ) + + assert mlflow_recorder.tracking_uri == custom_uri + + +def test_derived_uri_used_when_not_provided(tmp_path: Path, mlflow_recorder: MlflowRecorder): + """Test that URI is derived from DagsHub owner/repo when not provided.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir(runs_root, "2026-02-05_21-59-37", with_adapter=False) + + m.log_run_dir_to_mlflow( + run_dir=run_dir, + experiment_name="exp_test", + dagshub_owner="my-owner", + dagshub_repo="my-repo", + mlflow_tracking_uri="", # Empty string triggers derivation + no_dagshub_init=False, + ) + + assert mlflow_recorder.tracking_uri == "https://dagshub.com/my-owner/my-repo.mlflow" + + +# ============================================================================ +# CLI INTEGRATION TESTS +# ============================================================================ + + +def test_main_with_explicit_run_dir(tmp_path: Path, mlflow_recorder: MlflowRecorder, monkeypatch): + """Test main() with explicit --run-dir argument.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + run_dir = _write_run_dir(runs_root, "2026-02-05_21-59-37", with_adapter=False) + + # Mock sys.argv + monkeypatch.setattr( + "sys.argv", + [ + "mlflow_logger.py", + "--run-dir", + str(run_dir), + "--experiment-name", + "test-exp", + ], + ) + + m.main() + + assert mlflow_recorder.experiment == "test-exp" + assert mlflow_recorder.params["model"] == "mistral" + + +def test_main_auto_detects_latest_run(tmp_path: Path, mlflow_recorder: MlflowRecorder, monkeypatch): + """Test main() auto-detects latest run when --run-dir omitted.""" + import scripts.mlflow_logger as m + + runs_root = tmp_path / "runs" / "mlx" + + # Create two runs + _write_run_dir(runs_root, "2026-02-04_older", with_adapter=False) + newer = _write_run_dir(runs_root, "2026-02-06_newer", with_adapter=False) + + # Force mtime + import os + + os.utime(runs_root / "2026-02-04_older", (1_000_000_000, 1_000_000_000)) + os.utime(newer, (1_000_000_100, 1_000_000_100)) + + # Mock sys.argv without --run-dir + monkeypatch.setattr( + "sys.argv", + [ + "mlflow_logger.py", + "--runs-root", + str(runs_root), + "--experiment-name", + "test-exp", + ], + ) + + m.main() + + # Should use newer run + assert mlflow_recorder.run_name == "2026-02-06_newer" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_sanitization_manager.py b/tests/test_sanitization_manager.py new file mode 100644 index 0000000..74c80e0 --- /dev/null +++ b/tests/test_sanitization_manager.py @@ -0,0 +1,503 @@ +""" +Unit Tests for Sanitization and Dataset Manager +Run with: python -m pytest tests/test_sanitization_manager.py -v +""" + +import json +import shutil +import tempfile +from pathlib import Path + +import pytest + +# ============================================================================ +# IMPORTS - Fixed to match actual function signatures +# ============================================================================ + +try: + from scripts.data_sanitization import ( + compute_compression_ratio_tokens, # ✓ Fixed: was compute_compression_ratio + extract_verbose_compressed, + is_code_sample, + rule_a_ratio_check, + rule_b_orphaned_symbols, + rule_c_negation_preservation, + rule_d_semantic_symbol_usage_nl, + sanitize_and_extract, # ✓ Added: for integration tests + ) + + HAS_SANITIZE = True +except ImportError: + HAS_SANITIZE = False + +try: + from scripts.dataset_manager import ( + count_samples, + # get_config, # ✓ Added: to create config objects + load_state, + # log_change, # ✓ Added: for log tests + revert_to_original, # ✓ Added: for integration tests + save_state, + update_to_sanitized, # ✓ Added: for integration tests + ) + + HAS_DATASET_MANAGER = True +except ImportError: + HAS_DATASET_MANAGER = False + + +# ============================================================================ +# FIXTURES +# ============================================================================ + + +@pytest.fixture +def temp_dir(): + """Create temp directory, cleanup after test.""" + tmp = tempfile.mkdtemp() + yield Path(tmp) + shutil.rmtree(tmp) + + +@pytest.fixture +def mock_config(temp_dir): + """Create a mock config for testing.""" + return { + "active_train": temp_dir / "train.jsonl", + "original_backup": temp_dir / "train.original.jsonl", + "sanitized_data": temp_dir / "sanitized_train.jsonl", + "state_file": temp_dir / ".dataset_state.json", + "log_file": temp_dir / ".dataset_changes.log", + } + + +# ============================================================================ +# TESTS: SANITIZATION - Helper Functions +# ============================================================================ + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestCodeDetection: + """Test code vs natural language detection.""" + + def test_detects_python_code(self): + assert is_code_sample("def hello():\n return 'hi'") + + def test_detects_natural_language(self): + assert not is_code_sample("The cat sat on the mat.") + + def test_detects_javascript(self): + # Multi-line with multiple code signals + js_code = """ + const x = () => { + return true; + } + function test() { + let y = 5; + } + """ + assert is_code_sample(js_code) + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestExtraction: + """Test extracting verbose and compressed text.""" + + def test_extracts_correctly(self): + sample = { + "messages": [ + {"role": "user", "content": "Compress: Hello world"}, + {"role": "assistant", "content": "hi world"}, + ] + } + verbose, compressed, error = extract_verbose_compressed( + sample, 0 + ) # ✓ Fixed: added sample_id + assert verbose == "Hello world" + assert compressed == "hi world" + assert error is None + + def test_handles_missing_messages(self): + sample = {"messages": []} + verbose, compressed, error = extract_verbose_compressed( + sample, 0 + ) # ✓ Fixed: added sample_id + assert verbose is None + assert compressed is None + assert error is not None + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestCompressionRatio: + """Test compression ratio calculation.""" + + def test_basic_ratio(self): + # Note: compute_compression_ratio_tokens uses actual tokenizer + # This test assumes ratio > 1.0 means expansion + ratio = compute_compression_ratio_tokens("one two three four", "1 2") + assert ratio < 1.0 # Good compression + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestRuleA: + """Test Rule A: compression ratio validation.""" + + def test_passes_good_ratio(self): + passed, _ = rule_a_ratio_check("one two three", "1 2") + assert passed + + def test_fails_bad_ratio(self): + passed, _ = rule_a_ratio_check("hi", "hello there friend") + assert not passed + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestRuleB: + """Test Rule B: orphaned symbols.""" + + def test_passes_clean_text(self): + passed, _ = rule_b_orphaned_symbols( + "Paris @ France", is_code=False + ) # ✓ Fixed: added is_code + assert passed + + def test_fails_symbol_at_start(self): + passed, _ = rule_b_orphaned_symbols("→ bad start", is_code=False) # ✓ Fixed: added is_code + assert not passed + + def test_allows_colon_at_end(self): + passed, _ = rule_b_orphaned_symbols("function:", is_code=False) # ✓ Fixed: added is_code + assert passed + + def test_allows_decorator_for_code(self): + passed, _ = rule_b_orphaned_symbols( + "@classmethod", is_code=True + ) # ✓ New: test code-aware behavior + assert passed + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestRuleC: + """Test Rule C: negation preservation.""" + + def test_passes_no_negation(self): + passed, _ = rule_c_negation_preservation("I like it", "like") + assert passed + + def test_passes_preserved_negation(self): + passed, _ = rule_c_negation_preservation("I do not like it", "not like") + assert passed + + def test_fails_lost_negation(self): + passed, _ = rule_c_negation_preservation("I never eat meat", "eat meat") + assert not passed + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestRuleD: + """Test Rule D: semantic symbol usage.""" + + def test_passes_location_with_at(self): + passed, _ = rule_d_semantic_symbol_usage_nl("Paris is located in France", "Paris @ France") + assert passed + + def test_fails_location_without_at(self): + passed, _ = rule_d_semantic_symbol_usage_nl("Tokyo is located in Japan", "Tokyo Japan") + assert not passed + + +# ============================================================================ +# TESTS: SANITIZATION - Integration (Main Flow) +# ============================================================================ + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestSanitizeDataset: + """Test main sanitization flow.""" + + def test_sanitize_dataset_splits_correctly(self, temp_dir): + """Test that sanitize_and_extract splits good/bad samples correctly.""" + + # Create input file with 2 good + 2 bad samples + input_file = temp_dir / "train.jsonl" + sanitized_file = temp_dir / "sanitized.jsonl" + unsanitized_file = temp_dir / "unsanitized.jsonl" + + good_sample_1 = { + "messages": [ + {"role": "user", "content": "Compress: one two three four"}, + {"role": "assistant", "content": "1 2"}, + ] + } + + good_sample_2 = { + "messages": [ + {"role": "user", "content": "Compress: hello world"}, + {"role": "assistant", "content": "hi"}, + ] + } + + # Bad sample: expansion (ratio > 1.0) + bad_sample_1 = { + "messages": [ + {"role": "user", "content": "Compress: hi"}, + {"role": "assistant", "content": "hello there my friend"}, + ] + } + + # Bad sample: orphaned symbol + bad_sample_2 = { + "messages": [ + {"role": "user", "content": "Compress: test"}, + {"role": "assistant", "content": "→ bad"}, + ] + } + + with open(input_file, "w", encoding="utf-8") as f: + for sample in [good_sample_1, good_sample_2, bad_sample_1, bad_sample_2]: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + + # Run sanitization + stats = sanitize_and_extract(input_file, sanitized_file, unsanitized_file) + + # Assert + assert stats["total_input"] == 4 + assert stats["passed_all"] == 2 + assert stats["failed_all"] == 2 + assert sanitized_file.exists() + assert unsanitized_file.exists() + + def test_unicode_symbols_preserved(self, temp_dir): + """Test that → ∵ @ symbols are NOT escaped in sanitized output.""" + + input_file = temp_dir / "train.jsonl" + sanitized_file = temp_dir / "sanitized.jsonl" + unsanitized_file = temp_dir / "unsanitized.jsonl" + + # Sample that PASSES all validation rules + # (avoid causation/location keywords to prevent Rule D failures) + sample = { + "messages": [ + {"role": "user", "content": "Compress: This implies that result"}, + {"role": "assistant", "content": "this → that"}, + ] + } + + with open(input_file, "w", encoding="utf-8") as f: + f.write(json.dumps(sample, ensure_ascii=False) + "\n") + + # Run sanitization + stats = sanitize_and_extract(input_file, sanitized_file, unsanitized_file) + + # Verify it passed validation + assert stats["passed_all"] == 1, f"Sample failed: {stats['failed_samples']}" + + # Read back from sanitized file + with open(sanitized_file, encoding="utf-8") as f: + result_text = f.read() + + # Assert NOT escaped + assert "→" in result_text, "Arrow symbol should be preserved, not escaped" + assert "\\u2192" not in result_text, "Arrow should not be escaped as \\u2192" + + +# ============================================================================ +# TESTS: DATASET MANAGER - File Operations +# ============================================================================ + + +@pytest.mark.skipif(not HAS_DATASET_MANAGER, reason="dataset_manager not available") +class TestFileOperations: + """Test file utilities.""" + + def test_count_samples(self, temp_dir): + test_file = temp_dir / "test.jsonl" + with open(test_file, "w", encoding="utf-8") as f: + f.write('{"test": 1}\n') + f.write('{"test": 2}\n') + f.write('{"test": 3}\n') + + assert count_samples(test_file) == 3 + + def test_count_nonexistent_file(self, temp_dir): + count = count_samples(temp_dir / "missing.jsonl") + assert count == 0 + + def test_count_empty_file(self, temp_dir): + test_file = temp_dir / "empty.jsonl" + test_file.touch() + assert count_samples(test_file) == 0 + + +# ============================================================================ +# TESTS: DATASET MANAGER - State Management +# ============================================================================ + + +@pytest.mark.skipif(not HAS_DATASET_MANAGER, reason="dataset_manager not available") +class TestState: + """Test state management.""" + + def test_load_default_state(self, mock_config): + state = load_state(mock_config) # ✓ Fixed: pass config + assert state["current"] == "original" + assert state["change_count"] == 0 + assert state["last_action"] is None + + def test_save_and_load(self, mock_config): + test_state = { + "current": "sanitized", + "last_action": "update", + "last_change": "2024-01-01T00:00:00", + "change_count": 1, + } + save_state(mock_config, test_state) # ✓ Fixed: pass config + loaded = load_state(mock_config) # ✓ Fixed: pass config + assert loaded == test_state + + +# ============================================================================ +# TESTS: DATASET MANAGER - Integration (Main Flows) +# ============================================================================ + + +@pytest.mark.skipif(not HAS_DATASET_MANAGER, reason="dataset_manager not available") +class TestUpdateToSanitized: + """Test update_to_sanitized flow.""" + + def test_update_creates_backup_first_time(self, mock_config): + """Test that first update creates backup.""" + + # Create original and sanitized files + with open(mock_config["active_train"], "w") as f: + f.write('{"original": 1}\n') + f.write('{"original": 2}\n') + + with open(mock_config["sanitized_data"], "w") as f: + f.write('{"sanitized": 1}\n') + + # Execute update + result = update_to_sanitized(mock_config) + + # Assert + assert result + assert mock_config["original_backup"].exists() # Backup created + + state = load_state(mock_config) + assert state["current"] == "sanitized" + assert state["last_action"] == "update" + + def test_update_blocks_consecutive_updates(self, mock_config): + """Test that you can't update twice in a row.""" + + # Setup files + with open(mock_config["active_train"], "w") as f: + f.write('{"original": 1}\n') + + with open(mock_config["sanitized_data"], "w") as f: + f.write('{"sanitized": 1}\n') + + # First update + update_to_sanitized(mock_config) + + # Second update (should be blocked) + result = update_to_sanitized(mock_config) + + # Assert + assert not result + + +@pytest.mark.skipif(not HAS_DATASET_MANAGER, reason="dataset_manager not available") +class TestRevertToOriginal: + """Test revert_to_original flow.""" + + def test_revert_restores_from_backup(self, mock_config): + """Test that revert copies backup to active.""" + + # Setup: Update first (creates backup) + with open(mock_config["active_train"], "w") as f: + f.write('{"original": 1}\n') + + with open(mock_config["sanitized_data"], "w") as f: + f.write('{"sanitized": 1}\n') + + update_to_sanitized(mock_config) + + # Execute: Revert + result = revert_to_original(mock_config) + + # Assert + assert result + + state = load_state(mock_config) + assert state["current"] == "original" + assert state["last_action"] == "revert" + + def test_revert_blocks_consecutive_reverts(self, mock_config): + """Test that you can't revert twice in a row.""" + + # Setup: Update then revert + with open(mock_config["active_train"], "w") as f: + f.write('{"original": 1}\n') + + with open(mock_config["sanitized_data"], "w") as f: + f.write('{"sanitized": 1}\n') + + update_to_sanitized(mock_config) + revert_to_original(mock_config) # First revert + + # Execute: Try again + result = revert_to_original(mock_config) # Second revert + + # Assert + assert not result + + +@pytest.mark.skipif(not HAS_DATASET_MANAGER, reason="dataset_manager not available") +class TestLogUpdates: + """Test log file updates.""" + + def test_operations_append_to_log(self, mock_config): + """Test that update/revert write to log file.""" + + # Setup files + with open(mock_config["active_train"], "w") as f: + f.write('{"test": 1}\n') + + with open(mock_config["sanitized_data"], "w") as f: + f.write('{"test": 1}\n') + + # Execute operations + update_to_sanitized(mock_config) + revert_to_original(mock_config) + + # Assert + assert mock_config["log_file"].exists() + log_lines = mock_config["log_file"].read_text().strip().split("\n") + assert len(log_lines) == 2 + assert "UPDATE" in log_lines[0] + assert "REVERT" in log_lines[1] + + +# ============================================================================ +# EDGE CASES +# ============================================================================ + + +@pytest.mark.skipif(not HAS_SANITIZE, reason="data_sanitization not available") +class TestEdgeCases: + """Test edge cases and safety.""" + + def test_empty_symbol_check(self): + passed, _ = rule_b_orphaned_symbols("", is_code=False) # ✓ Fixed: added is_code + assert not passed + + def test_unicode_handling(self): + passed, _ = rule_b_orphaned_symbols( + "test → result", is_code=False + ) # ✓ Fixed: added is_code + assert passed + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..eb0a4a4 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,323 @@ +""" +Unit tests for tokenizer utilities with special token handling. + +Tests cover: +- Basic token counting +- Special token handling (<|endoftext|>, <|im_start|>, etc.) +- Compression ratio calculations +- Edge cases (empty strings, zero tokens, etc.) +- Regression test for tiktoken version changes +""" + +import sys +from pathlib import Path + +import pytest + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.utils.tokenizers import compression_ratio, count_tokens + + +class TestTokenCounting: + """Test basic token counting functionality.""" + + def test_simple_text(self): + """Test counting tokens in simple text.""" + text = "Hello, world!" + count = count_tokens(text) + assert count > 0, "Should count tokens in simple text" + assert isinstance(count, int), "Token count should be integer" + + def test_empty_string(self): + """Test empty string returns 0 tokens.""" + assert count_tokens("") == 0 + + def test_whitespace_only(self): + """Test whitespace-only string.""" + count = count_tokens(" ") + assert count >= 0, "Whitespace should have non-negative token count" + + def test_longer_text(self): + """Test longer text has more tokens.""" + short = "Hi" + long = "This is a much longer piece of text with many more words." + + short_count = count_tokens(short) + long_count = count_tokens(long) + + assert long_count > short_count, "Longer text should have more tokens" + + +class TestSpecialTokens: + """Test handling of special tokens in text.""" + + def test_endoftext_token(self): + """ + REGRESSION TEST: Verify <|endoftext|> is handled as normal text. + + This is the token that caused the original error. We now encode it + as normal text rather than treating it as a special control token. + """ + text_with_special = "This is text with <|endoftext|> in it." + text_without_special = "This is text with PLACEHOLDER in it." + + # Should not raise ValueError + count_with = count_tokens(text_with_special) + count_without = count_tokens(text_without_special) + + assert count_with > 0, "Should count tokens with <|endoftext|>" + assert isinstance(count_with, int), "Token count should be integer" + + # The counts may differ, but both should work + assert count_with >= count_without - 5, "Token counts should be similar" + + def test_multiple_special_tokens(self): + """Test text with multiple special tokens.""" + text = "Start <|endoftext|> middle <|im_start|> end <|im_end|>" + + # Should not raise ValueError + count = count_tokens(text) + assert count > 0, "Should handle multiple special tokens" + + def test_only_special_token(self): + """Test text that is only a special token.""" + text = "<|endoftext|>" + + # Should not raise ValueError + count = count_tokens(text) + assert count > 0, "Should count special token as text" + + def test_special_token_variations(self): + """Test various special token formats.""" + special_tokens = [ + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|im_sep|>", + ] + + for token in special_tokens: + # Should not raise ValueError for any special token + count = count_tokens(token) + assert count > 0, f"Should handle {token}" + + +class TestCompressionRatio: + """Test compression ratio calculations.""" + + def test_perfect_compression(self): + """Test ratio when compressed is much shorter.""" + original = "This is a very long piece of text that should compress well." + compressed = "Short" + + ratio = compression_ratio(original, compressed) + assert ratio < 1.0, "Good compression should have ratio < 1.0" + + def test_no_compression(self): + """Test ratio when text is unchanged.""" + text = "Same text" + ratio = compression_ratio(text, text) + + assert ratio == 1.0, "Identical text should have ratio of 1.0" + + def test_expansion(self): + """Test ratio when compressed is longer (bad compression).""" + original = "Short" + compressed = "This is actually much longer than the original text." + + ratio = compression_ratio(original, compressed) + assert ratio > 1.0, "Expansion should have ratio > 1.0" + + def test_empty_original(self): + """Test ratio with empty original string.""" + ratio = compression_ratio("", "something") + assert ratio == 1.0, "Empty original should return 1.0" + + def test_both_empty(self): + """Test ratio when both strings are empty.""" + ratio = compression_ratio("", "") + assert ratio == 1.0, "Both empty should return 1.0" + + def test_ratio_with_special_tokens(self): + """ + REGRESSION TEST: Verify compression ratio works with special tokens. + + This ensures the special token handling doesn't break ratio calculations. + """ + original = "This text has <|endoftext|> special tokens in it." + compressed = "Text with <|endoftext|> tokens." + + # Should not raise ValueError + ratio = compression_ratio(original, compressed) + + assert 0 < ratio <= 1.0, "Compression should have ratio between 0 and 1" + assert isinstance(ratio, float), "Ratio should be float" + + +class TestTokenizerConsistency: + """Test consistency across different inputs and edge cases.""" + + def test_unicode_characters(self): + """Test counting tokens in text with unicode characters.""" + text = "Hello 世界 🌍 café" + count = count_tokens(text) + assert count > 0, "Should handle unicode characters" + + def test_code_text(self): + """Test counting tokens in code snippets.""" + code = """ + def hello(): + return "world" + """ + count = count_tokens(code) + assert count > 0, "Should count tokens in code" + + def test_newlines_and_whitespace(self): + """Test text with various whitespace.""" + text = "Line 1\nLine 2\n\nLine 3\t\tTabbed" + count = count_tokens(text) + assert count > 0, "Should handle newlines and tabs" + + def test_repeated_counting(self): + """Test that counting is consistent across calls.""" + text = "Consistency test" + count1 = count_tokens(text) + count2 = count_tokens(text) + count3 = count_tokens(text) + + assert count1 == count2 == count3, "Token counting should be consistent" + + +class TestTiktokenVersionRegression: + """ + Regression tests to catch tiktoken version changes. + + These tests will fail if tiktoken's behavior changes significantly, + alerting us to review and update our special token handling. + """ + + def test_known_token_counts(self): + """Test known strings have expected token counts (±1 for version tolerance).""" + test_cases = { + "Hello": (1, 2), # Expected range: 1-2 tokens + "Hello, world!": (3, 5), # Expected range: 3-5 tokens + "The quick brown fox": (4, 6), # Expected range: 4-6 tokens + } + + for text, (min_expected, max_expected) in test_cases.items(): + count = count_tokens(text) + assert min_expected <= count <= max_expected, ( + f"Token count for '{text}' outside expected range: {count} not in [{min_expected}, {max_expected}]" + ) + + def test_special_token_encoding_behavior(self): + """ + CRITICAL REGRESSION TEST: Verify special tokens are encoded as text. + + If this fails after a tiktoken update, it means: + 1. Our disallowed_special=() parameter may have changed behavior + 2. We need to review our special token handling strategy + """ + text_with_special = "Text <|endoftext|> here" + + # This should NOT raise ValueError + try: + count = count_tokens(text_with_special) + assert count > 0, "Should successfully count with special tokens" + except ValueError as e: + if "special token" in str(e).lower(): + pytest.fail( + "Special token handling broken! " + "disallowed_special=() may no longer work. " + "Review tokenizer implementation and tiktoken version." + ) + else: + raise + + def test_tokenizer_caching(self): + """Test that tokenizer caching works (via consistent token counts).""" + # Instead of testing internal caching mechanism, test that + # repeated calls give consistent results (which implies caching works) + text = "Test tokenizer caching" + + counts = [count_tokens(text) for _ in range(5)] + + # All counts should be identical (proves caching works) + assert len(set(counts)) == 1, "Token counts should be consistent across calls" + assert counts[0] > 0, "Should count tokens" + + +class TestDataSanitizationIntegration: + """Integration tests for data sanitization use cases.""" + + def test_compression_ratio_threshold(self): + """Test that ratio threshold logic works correctly.""" + # Good compression (should pass) + original = "This is a long verbose sentence with many unnecessary words." + compressed = "Long verbose sentence." + ratio = compression_ratio(original, compressed) + assert ratio <= 1.0, "Good compression should pass <= 1.0 threshold" + + # Bad compression / expansion (should fail) + original = "Short" + compressed = "This became very long and verbose." + ratio = compression_ratio(original, compressed) + assert ratio > 1.0, "Expansion should fail > 1.0 threshold" + + def test_real_world_training_sample(self): + """Test with realistic training data format.""" + verbose = """ + def validate(function: AnyCallableT) -> AnyCallableT: + _check_function_type(function) + validate_call_wrapper = _validate_call.ValidateCallWrapper( + cast(_generate_schema.ValidateCallSupportedTypes, function), + config, validate_return, parent_namespace + ) + return _validate_call.update_wrapper_attributes(function, validate_call_wrapper.__call__) + """ + + compressed = """fn:validate(function:AnyCallableT)->AnyCallableT = + _check_function_type(function) |> + _validate_call.ValidateCallWrapper(cast(...), config, validate_return, parent_namespace) |> + λwrapper: _validate_call.update_wrapper_attributes(function, wrapper.__call__)""" + + ratio = compression_ratio(verbose, compressed) + + # Should be valid compression + assert 0 < ratio < 1.0, f"Real training sample should compress well, got ratio: {ratio}" + + def test_sample_with_special_tokens_in_training_data(self): + """ + CRITICAL: Test that training data with special tokens works. + + This is the actual bug that was reported - training data contained + <|endoftext|> and caused ValueError during sanitization. + """ + verbose = "This is training data with <|endoftext|> token." + compressed = "Training data with <|endoftext|>" + + # Should not raise ValueError + ratio = compression_ratio(verbose, compressed) + + assert 0 < ratio <= 1.0, "Should handle special tokens in training data" + + +# Pytest fixtures +@pytest.fixture +def sample_texts(): + """Fixture providing various text samples for testing.""" + return { + "simple": "Hello, world!", + "empty": "", + "with_special": "Text <|endoftext|> here", + "code": "def foo(): return 42", + "unicode": "café 世界 🌍", + "long": " ".join(["word"] * 100), + } + + +# Run tests with: pytest test_tokenizers.py -v +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/visualization/corpus.ipynb b/visualization/corpus.ipynb new file mode 100644 index 0000000..1f20b92 --- /dev/null +++ b/visualization/corpus.ipynb @@ -0,0 +1,1213 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "7d2d49d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Analyzing: Natural Language\n", + "Dataset: ../nl_v2.jsonl\n", + "================================================================================\n", + "\n", + "✓ Loaded 8256 samples\n", + " Sample keys: ['verbose', 'compressed', 'domain', 'language', 'metadata']\n", + "\n", + "Example sample:\n", + " Verbose (first 100 chars): Artificial intelligence (AI) is revolutionizing quality control in rocket manufacturing by enhancing...\n", + " Compressed (first 100 chars): AI revolutionizing rocket manufacturing QC: +precision | +automated inspections | +failure predictio...\n", + "\n", + "✓ Computed metrics for 8256 samples\n", + "\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "NATURAL LANGUAGE COMPRESSION STATISTICS\n", + "================================================================================\n", + "\n", + "Dataset Size: 8256 samples\n", + "\n", + "COMPRESSION RATIO (compressed/verbose):\n", + " Mean: 1.399\n", + " Median: 0.650\n", + " Std Dev: 2.084\n", + " Min: 0.016\n", + " Max: 41.250\n", + " 25th %ile: 0.506\n", + " 75th %ile: 1.419\n", + "\n", + "TOKEN REDUCTION:\n", + " Mean reduction: -39.9%\n", + " Median reduction: 35.0%\n", + " Best reduction: 98.4%\n", + " Worst reduction: -4025.0%\n", + "\n", + "AVERAGE TOKEN COUNTS:\n", + " Verbose: 130 tokens\n", + " Compressed: 130 tokens\n", + " Saved: 0 tokens per sample\n", + "\n", + "TOTAL TOKEN SAVINGS:\n", + " Total verbose tokens: 1,076,104\n", + " Total compressed tokens: 1,072,110\n", + " Total tokens saved: 3,994\n", + " Overall reduction: 0.4%\n", + "\n", + "================================================================================\n", + "QUALITY CHECKS\n", + "================================================================================\n", + "\n", + "🚨 EXPANSION DETECTED: 2541 samples (30.8%)\n", + " These samples got LONGER after compression (BAD)\n", + " Worst expansion: 41.250 ratio\n", + " → ACTION: Investigate these samples manually\n", + "\n", + "⚠️ WEAK COMPRESSION: 2761 samples (33.4%)\n", + " Ratio > 0.9 means <10% token reduction\n", + " → ACTION: These samples may not be worth compressing\n", + "\n", + "COMPRESSION QUALITY BREAKDOWN:\n", + " Strong (>50% reduction): 1935 samples ( 23.4%)\n", + " Moderate (30-50% reduction): 2637 samples ( 31.9%)\n", + " Weak (<30% reduction): 1108 samples ( 13.4%)\n", + "\n", + "SAMPLE LENGTH DISTRIBUTION:\n", + " Verbose tokens:\n", + " Short (<100): 4113 samples\n", + " Medium (100-500): 4137 samples\n", + " Long (500-1000): 6 samples\n", + " Very long (>1000): 0 samples\n", + "\n", + "================================================================================\n", + "KEY INSIGHTS & RECOMMENDATIONS\n", + "================================================================================\n", + "\n", + "📊 COMPRESSION EFFECTIVENESS:\n", + " 🚨 WEAK: -39.9% average reduction\n", + " Compression needs significant improvement\n", + "\n", + "💰 TOKEN SAVINGS (Cost Impact):\n", + " Total tokens saved: 3,994\n", + " At $3/1M tokens (Claude Sonnet input):\n", + " Cost with verbose: $3.23\n", + " Cost with compressed: $3.22\n", + " 💵 Savings: $0.01\n", + "\n", + "🎯 REALISTIC TARGET:\n", + " Based on this dataset, expect:\n", + " • Average ratio: 1.40x\n", + " • Average reduction: -39.9%\n", + " • Tokens saved per sample: 0\n", + "\n", + "================================================================================\n", + "NEXT STEPS\n", + "================================================================================\n", + "1. ✅ NL analysis complete\n", + "2. ⏭️ Change DATASET_PATH to 'code_dataset.jsonl' and re-run\n", + "3. ⏭️ Compare NL vs Code compression effectiveness\n", + "4. ⏭️ Move to Analysis #2: Symbol Usage Patterns\n", + "================================================================================\n" + ] + } + ], + "source": [ + "# Analysis #1: Compression Ratio Distribution\n", + "# Single domain analysis (NL)\n", + "\n", + "import json\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import tiktoken\n", + "\n", + "# Set style\n", + "sns.set_style(\"whitegrid\")\n", + "plt.rcParams[\"figure.figsize\"] = (12, 6)\n", + "\n", + "# ============================================================================\n", + "# CONFIGURATION - CHANGE THIS FOR EACH RUN\n", + "# ============================================================================\n", + "\n", + "# Choose which dataset to analyze\n", + "DATASET_PATH = \"../nl_v2.jsonl\" # OR \"data/raw/code_dataset.jsonl\"\n", + "DOMAIN_NAME = \"Natural Language\" # OR \"Code\"\n", + "DOMAIN_TYPE = \"nl\" # OR \"code\"\n", + "\n", + "print(f\"Analyzing: {DOMAIN_NAME}\")\n", + "print(f\"Dataset: {DATASET_PATH}\")\n", + "print(\"=\" * 80 + \"\\n\")\n", + "\n", + "# ============================================================================\n", + "# 1. LOAD DATA\n", + "# ============================================================================\n", + "\n", + "\n", + "def load_dataset(filepath: str) -> list[dict]:\n", + " \"\"\"Load dataset with verbose/compressed pairs\"\"\"\n", + " data = []\n", + " with open(filepath, encoding=\"utf-8\") as f:\n", + " for line in f:\n", + " if line.strip():\n", + " data.append(json.loads(line))\n", + " return data\n", + "\n", + "\n", + "data = load_dataset(DATASET_PATH)\n", + "\n", + "print(f\"✓ Loaded {len(data)} samples\")\n", + "print(f\" Sample keys: {list(data[0].keys())}\")\n", + "\n", + "# Show one example\n", + "print(\"\\nExample sample:\")\n", + "print(f\" Verbose (first 100 chars): {data[0]['verbose'][:100]}...\")\n", + "print(f\" Compressed (first 100 chars): {data[0]['compressed'][:100]}...\")\n", + "\n", + "# ============================================================================\n", + "# 2. CALCULATE COMPRESSION RATIOS\n", + "# ============================================================================\n", + "\n", + "\n", + "def count_tokens(text: str) -> int:\n", + " \"\"\"Count tokens using tiktoken (same as your tokenizer.py)\"\"\"\n", + " enc = tiktoken.get_encoding(\"cl100k_base\")\n", + " return len(enc.encode(text, disallowed_special=()))\n", + "\n", + "\n", + "def compute_metrics(sample: dict) -> dict:\n", + " \"\"\"Compute compression metrics for a sample\"\"\"\n", + " verbose = sample[\"verbose\"]\n", + " compressed = sample[\"compressed\"]\n", + "\n", + " verbose_tokens = count_tokens(verbose)\n", + " compressed_tokens = count_tokens(compressed)\n", + "\n", + " # Compression ratio (compressed/verbose)\n", + " # Ratio < 1.0 = compression (good)\n", + " # Ratio > 1.0 = expansion (bad)\n", + " ratio = compressed_tokens / verbose_tokens if verbose_tokens > 0 else 1.0\n", + "\n", + " # Token reduction percentage\n", + " reduction_pct = (1 - ratio) * 100\n", + "\n", + " # Token savings\n", + " tokens_saved = verbose_tokens - compressed_tokens\n", + "\n", + " return {\n", + " \"verbose_tokens\": verbose_tokens,\n", + " \"compressed_tokens\": compressed_tokens,\n", + " \"compression_ratio\": ratio,\n", + " \"reduction_pct\": reduction_pct,\n", + " \"tokens_saved\": tokens_saved,\n", + " \"verbose_chars\": len(verbose),\n", + " \"compressed_chars\": len(compressed),\n", + " }\n", + "\n", + "\n", + "# Compute metrics for all samples\n", + "metrics = [compute_metrics(sample) for sample in data]\n", + "df = pd.DataFrame(metrics)\n", + "\n", + "print(f\"\\n✓ Computed metrics for {len(df)} samples\\n\")\n", + "\n", + "# ============================================================================\n", + "# 3. VISUALIZATION: COMPRESSION RATIO DISTRIBUTION\n", + "# ============================================================================\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", + "\n", + "# Color scheme\n", + "color = \"#3498db\" if DOMAIN_TYPE == \"nl\" else \"#e74c3c\"\n", + "\n", + "# Plot 1: Histogram of compression ratios\n", + "axes[0].hist(df[\"compression_ratio\"], bins=30, alpha=0.7, color=color, edgecolor=\"black\")\n", + "axes[0].axvline(\n", + " x=df[\"compression_ratio\"].mean(),\n", + " color=\"darkgreen\",\n", + " linestyle=\"--\",\n", + " linewidth=2,\n", + " label=f\"Mean: {df['compression_ratio'].mean():.3f}\",\n", + ")\n", + "axes[0].axvline(\n", + " x=df[\"compression_ratio\"].median(),\n", + " color=\"orange\",\n", + " linestyle=\"--\",\n", + " linewidth=2,\n", + " label=f\"Median: {df['compression_ratio'].median():.3f}\",\n", + ")\n", + "axes[0].axvline(x=1.0, color=\"red\", linestyle=\":\", alpha=0.5, label=\"No compression (1.0)\")\n", + "axes[0].set_xlabel(\"Compression Ratio (compressed/verbose)\")\n", + "axes[0].set_ylabel(\"Frequency\")\n", + "axes[0].set_title(f\"{DOMAIN_NAME}: Compression Ratio Distribution\")\n", + "axes[0].legend()\n", + "axes[0].grid(True, alpha=0.3)\n", + "\n", + "# Plot 2: Box plot\n", + "bp = axes[1].boxplot(df[\"compression_ratio\"], vert=True, patch_artist=True, widths=0.5)\n", + "bp[\"boxes\"][0].set_facecolor(color)\n", + "bp[\"boxes\"][0].set_alpha(0.7)\n", + "axes[1].axhline(y=1.0, color=\"red\", linestyle=\":\", alpha=0.5, label=\"No compression\")\n", + "axes[1].set_ylabel(\"Compression Ratio\")\n", + "axes[1].set_title(f\"{DOMAIN_NAME}: Ratio Box Plot\")\n", + "axes[1].set_xticklabels([DOMAIN_NAME])\n", + "axes[1].grid(True, alpha=0.3)\n", + "\n", + "# Add outlier annotations\n", + "q1 = df[\"compression_ratio\"].quantile(0.25)\n", + "q3 = df[\"compression_ratio\"].quantile(0.75)\n", + "iqr = q3 - q1\n", + "outliers = df[\n", + " (df[\"compression_ratio\"] < q1 - 1.5 * iqr) | (df[\"compression_ratio\"] > q3 + 1.5 * iqr)\n", + "]\n", + "if len(outliers) > 0:\n", + " axes[1].text(\n", + " 1,\n", + " q3 + 1.5 * iqr + 0.1,\n", + " f\"{len(outliers)} outliers\",\n", + " ha=\"center\",\n", + " fontsize=9,\n", + " bbox=dict(boxstyle=\"round\", facecolor=\"yellow\", alpha=0.3),\n", + " )\n", + "\n", + "# Plot 3: Scatter plot (verbose vs compressed tokens)\n", + "axes[2].scatter(\n", + " df[\"verbose_tokens\"],\n", + " df[\"compressed_tokens\"],\n", + " alpha=0.5,\n", + " c=df[\"compression_ratio\"],\n", + " cmap=\"RdYlGn_r\", # Red = bad (high ratio), Green = good (low ratio)\n", + " s=50,\n", + " edgecolor=\"black\",\n", + " linewidth=0.5,\n", + ")\n", + "\n", + "# Add diagonal line (ratio = 1.0)\n", + "max_tokens = max(df[\"verbose_tokens\"].max(), df[\"compressed_tokens\"].max())\n", + "axes[2].plot([0, max_tokens], [0, max_tokens], \"r--\", alpha=0.5, label=\"No compression (1:1)\")\n", + "\n", + "# Add reference lines for common ratios\n", + "for ratio, label in [(0.5, \"50%\"), (0.3, \"70%\"), (0.7, \"30%\")]:\n", + " axes[2].plot([0, max_tokens], [0, max_tokens * ratio], linestyle=\":\", alpha=0.3, color=\"gray\")\n", + " axes[2].text(\n", + " max_tokens * 0.9, max_tokens * ratio * 0.9, f\"{label} reduction\", fontsize=8, alpha=0.5\n", + " )\n", + "\n", + "axes[2].set_xlabel(\"Verbose Tokens\")\n", + "axes[2].set_ylabel(\"Compressed Tokens\")\n", + "axes[2].set_title(f\"{DOMAIN_NAME}: Token Count Relationship\")\n", + "axes[2].legend()\n", + "axes[2].grid(True, alpha=0.3)\n", + "\n", + "# Add colorbar for ratio\n", + "cbar = plt.colorbar(axes[2].collections[0], ax=axes[2])\n", + "cbar.set_label(\"Compression Ratio\", rotation=270, labelpad=15)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"analysis_01_{DOMAIN_TYPE}_compression_ratio.png\", dpi=300, bbox_inches=\"tight\")\n", + "plt.show()\n", + "\n", + "# ============================================================================\n", + "# 4. SUMMARY STATISTICS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(f\"{DOMAIN_NAME.upper()} COMPRESSION STATISTICS\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(f\"\\nDataset Size: {len(df)} samples\")\n", + "\n", + "print(\"\\nCOMPRESSION RATIO (compressed/verbose):\")\n", + "print(f\" Mean: {df['compression_ratio'].mean():.3f}\")\n", + "print(f\" Median: {df['compression_ratio'].median():.3f}\")\n", + "print(f\" Std Dev: {df['compression_ratio'].std():.3f}\")\n", + "print(f\" Min: {df['compression_ratio'].min():.3f}\")\n", + "print(f\" Max: {df['compression_ratio'].max():.3f}\")\n", + "print(f\" 25th %ile: {df['compression_ratio'].quantile(0.25):.3f}\")\n", + "print(f\" 75th %ile: {df['compression_ratio'].quantile(0.75):.3f}\")\n", + "\n", + "print(\"\\nTOKEN REDUCTION:\")\n", + "print(f\" Mean reduction: {df['reduction_pct'].mean():.1f}%\")\n", + "print(f\" Median reduction: {df['reduction_pct'].median():.1f}%\")\n", + "print(f\" Best reduction: {df['reduction_pct'].max():.1f}%\")\n", + "print(f\" Worst reduction: {df['reduction_pct'].min():.1f}%\")\n", + "\n", + "print(\"\\nAVERAGE TOKEN COUNTS:\")\n", + "print(f\" Verbose: {df['verbose_tokens'].mean():.0f} tokens\")\n", + "print(f\" Compressed: {df['compressed_tokens'].mean():.0f} tokens\")\n", + "print(f\" Saved: {df['tokens_saved'].mean():.0f} tokens per sample\")\n", + "\n", + "print(\"\\nTOTAL TOKEN SAVINGS:\")\n", + "print(f\" Total verbose tokens: {df['verbose_tokens'].sum():,}\")\n", + "print(f\" Total compressed tokens: {df['compressed_tokens'].sum():,}\")\n", + "print(f\" Total tokens saved: {df['tokens_saved'].sum():,}\")\n", + "print(\n", + " f\" Overall reduction: {(1 - df['compressed_tokens'].sum() / df['verbose_tokens'].sum()) * 100:.1f}%\"\n", + ")\n", + "\n", + "# ============================================================================\n", + "# 5. QUALITY CHECKS & INSIGHTS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"QUALITY CHECKS\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Check for expansion (ratio > 1.0) - RED FLAG\n", + "expansions = df[df[\"compression_ratio\"] > 1.0]\n", + "if len(expansions) > 0:\n", + " print(\n", + " f\"\\n🚨 EXPANSION DETECTED: {len(expansions)} samples ({len(expansions) / len(df) * 100:.1f}%)\"\n", + " )\n", + " print(\" These samples got LONGER after compression (BAD)\")\n", + " print(f\" Worst expansion: {expansions['compression_ratio'].max():.3f} ratio\")\n", + " print(\" → ACTION: Investigate these samples manually\")\n", + "else:\n", + " print(\"\\n✅ NO EXPANSION: All samples compressed successfully\")\n", + "\n", + "# Check for very weak compression (ratio > 0.9) - WARNING\n", + "weak = df[df[\"compression_ratio\"] > 0.9]\n", + "if len(weak) > 0:\n", + " print(f\"\\n⚠️ WEAK COMPRESSION: {len(weak)} samples ({len(weak) / len(df) * 100:.1f}%)\")\n", + " print(\" Ratio > 0.9 means <10% token reduction\")\n", + " print(\" → ACTION: These samples may not be worth compressing\")\n", + "\n", + "# Distribution of compression quality\n", + "strong = df[df[\"compression_ratio\"] < 0.5] # >50% reduction\n", + "moderate = df[(df[\"compression_ratio\"] >= 0.5) & (df[\"compression_ratio\"] < 0.7)] # 30-50%\n", + "weak = df[(df[\"compression_ratio\"] >= 0.7) & (df[\"compression_ratio\"] < 1.0)] # <30%\n", + "\n", + "print(\"\\nCOMPRESSION QUALITY BREAKDOWN:\")\n", + "print(\n", + " f\" Strong (>50% reduction): {len(strong):4d} samples ({len(strong) / len(df) * 100:5.1f}%)\"\n", + ")\n", + "print(\n", + " f\" Moderate (30-50% reduction): {len(moderate):4d} samples ({len(moderate) / len(df) * 100:5.1f}%)\"\n", + ")\n", + "print(f\" Weak (<30% reduction): {len(weak):4d} samples ({len(weak) / len(df) * 100:5.1f}%)\")\n", + "\n", + "# Sample length distribution\n", + "print(\"\\nSAMPLE LENGTH DISTRIBUTION:\")\n", + "print(\" Verbose tokens:\")\n", + "print(f\" Short (<100): {len(df[df['verbose_tokens'] < 100])} samples\")\n", + "print(\n", + " f\" Medium (100-500): {len(df[(df['verbose_tokens'] >= 100) & (df['verbose_tokens'] < 500)])} samples\"\n", + ")\n", + "print(\n", + " f\" Long (500-1000): {len(df[(df['verbose_tokens'] >= 500) & (df['verbose_tokens'] < 1000)])} samples\"\n", + ")\n", + "print(f\" Very long (>1000): {len(df[df['verbose_tokens'] >= 1000])} samples\")\n", + "\n", + "# ============================================================================\n", + "# 6. KEY INSIGHTS & RECOMMENDATIONS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"KEY INSIGHTS & RECOMMENDATIONS\")\n", + "print(\"=\" * 80)\n", + "\n", + "mean_ratio = df[\"compression_ratio\"].mean()\n", + "mean_reduction = df[\"reduction_pct\"].mean()\n", + "\n", + "print(\"\\n📊 COMPRESSION EFFECTIVENESS:\")\n", + "if mean_reduction > 60:\n", + " print(f\" ✅ EXCELLENT: {mean_reduction:.1f}% average reduction\")\n", + " print(\" Your compression is highly effective\")\n", + "elif mean_reduction > 40:\n", + " print(f\" ✅ GOOD: {mean_reduction:.1f}% average reduction\")\n", + " print(\" Solid compression performance\")\n", + "elif mean_reduction > 20:\n", + " print(f\" ⚠️ MODERATE: {mean_reduction:.1f}% average reduction\")\n", + " print(\" Room for improvement in compression strategy\")\n", + "else:\n", + " print(f\" 🚨 WEAK: {mean_reduction:.1f}% average reduction\")\n", + " print(\" Compression needs significant improvement\")\n", + "\n", + "print(\"\\n💰 TOKEN SAVINGS (Cost Impact):\")\n", + "tokens_saved = df[\"tokens_saved\"].sum()\n", + "print(f\" Total tokens saved: {tokens_saved:,}\")\n", + "print(\" At $3/1M tokens (Claude Sonnet input):\")\n", + "print(f\" Cost with verbose: ${df['verbose_tokens'].sum() / 1_000_000 * 3:.2f}\")\n", + "print(f\" Cost with compressed: ${df['compressed_tokens'].sum() / 1_000_000 * 3:.2f}\")\n", + "print(f\" 💵 Savings: ${tokens_saved / 1_000_000 * 3:.2f}\")\n", + "\n", + "print(\"\\n🎯 REALISTIC TARGET:\")\n", + "print(\" Based on this dataset, expect:\")\n", + "print(f\" • Average ratio: {mean_ratio:.2f}x\")\n", + "print(f\" • Average reduction: {mean_reduction:.1f}%\")\n", + "print(f\" • Tokens saved per sample: {df['tokens_saved'].mean():.0f}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"NEXT STEPS\")\n", + "print(\"=\" * 80)\n", + "if DOMAIN_TYPE == \"nl\":\n", + " print(\"1. ✅ NL analysis complete\")\n", + " print(\"2. ⏭️ Change DATASET_PATH to 'code_dataset.jsonl' and re-run\")\n", + " print(\"3. ⏭️ Compare NL vs Code compression effectiveness\")\n", + " print(\"4. ⏭️ Move to Analysis #2: Symbol Usage Patterns\")\n", + "else:\n", + " print(\"1. ✅ Code analysis complete\")\n", + " print(\"2. 📊 Compare with NL results\")\n", + " print(\"3. ⏭️ Move to Analysis #2: Symbol Usage Patterns\")\n", + "print(\"=\" * 80)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0d54ac1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Analyzing: Natural Language\n", + "Dataset: ../nl_v2.jsonl\n", + "================================================================================\n", + "\n", + "Tracking symbols:\n", + " → - Arrow (implies/leads to)\n", + " | - Pipe (separator/or)\n", + " @ - At (location/decorator)\n", + " ∵ - Because (causation)\n", + " : - Colon (assignment/type)\n", + "\n", + "✓ Loaded 8256 samples\n", + "\n", + "✓ Analyzed symbol usage in 8256 samples\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gautamgalada/compression-layer/.venv/lib/python3.14/site-packages/seaborn/utils.py:61: UserWarning: Glyph 8757 (\\N{BECAUSE}) missing from font(s) Arial.\n", + " fig.canvas.draw()\n", + "/var/folders/kz/qrlwnxhd4pg9q70k5m2_zw700000gp/T/ipykernel_48241/1961506372.py:237: UserWarning: Glyph 8757 (\\N{BECAUSE}) missing from font(s) Arial.\n", + " plt.savefig(f'analysis_02_{DOMAIN_TYPE}_symbol_usage.png', dpi=300, bbox_inches='tight')\n", + "/Users/gautamgalada/compression-layer/.venv/lib/python3.14/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 8757 (\\N{BECAUSE}) missing from font(s) Arial.\n", + " fig.canvas.print_figure(bytes_io, **kw)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "NATURAL LANGUAGE SYMBOL USAGE STATISTICS\n", + "================================================================================\n", + "\n", + "Dataset Size: 8256 samples\n", + "\n", + "SYMBOL PRESENCE RATES:\n", + "Symbol Description Samples Rate \n", + "------------------------------------------------------------\n", + "→ Arrow 3223 39.0%\n", + "| Pipe 7532 91.2%\n", + "@ At 2048 24.8%\n", + "∵ Because 1052 12.7%\n", + ": Colon 7513 91.0%\n", + "\n", + "AVERAGE COUNTS (ALL SAMPLES):\n", + "Symbol Avg Count Total Uses \n", + "---------------------------------------------\n", + "→ 1.35 11150 \n", + "| 8.38 69168 \n", + "@ 0.44 3629 \n", + "∵ 0.17 1429 \n", + ": 5.97 49313 \n", + "\n", + "AVERAGE COUNTS (WHEN PRESENT):\n", + "Symbol Avg When Used \n", + "------------------------------\n", + "→ 3.46 \n", + "| 9.18 \n", + "@ 1.77 \n", + "∵ 1.36 \n", + ": 6.56 \n", + "\n", + "SYMBOL DIVERSITY:\n", + " Average unique symbols per sample: 2.59\n", + " Median unique symbols per sample: 2\n", + " Max unique symbols in one sample: 5\n", + " Samples using all 5 symbols: 261 (3.2%)\n", + " Samples using 0 symbols: 75 (0.9%)\n", + "\n", + "================================================================================\n", + "INSIGHTS & VALIDATION\n", + "================================================================================\n", + "\n", + "📋 EXPECTED PATTERNS FOR NATURAL LANGUAGE:\n", + " • @ symbol: Moderate usage (15-30%) for location contexts\n", + " • ∵ symbol: Moderate-High usage (20-40%) for causation\n", + " • → symbol: Moderate usage (20-40%) for implications\n", + " • | symbol: Low-Moderate usage for separators\n", + " • : symbol: High usage (60-90%) for definitions/assignments\n", + "\n", + "📊 ACTUAL RESULTS:\n", + " ✅ @ symbol usage normal (24.8%)\n", + " ⚠️ ∵ symbol underutilized (12.7%) - expected 20-40%\n", + " → ACTION: Review if causation contexts are being compressed\n", + " ✅ : symbol usage strong (91.0%)\n", + "\n", + "⚠️ SYMBOL COVERAGE ISSUE:\n", + " 75 samples (0.9%) use NO symbols\n", + " → ACTION: These samples may be poorly compressed\n", + " → Investigate: Are they very short? Failed compression?\n", + "\n", + "📊 HIGH SYMBOL USAGE:\n", + " Top 5% of samples use 55+ symbols\n", + " Max symbols in one sample: 220\n", + "\n", + "================================================================================\n", + "SYMBOL CO-OCCURRENCE INSIGHTS\n", + "================================================================================\n", + "\n", + "Most common symbol pairs (co-occur in same sample):\n", + " | + :: 83.6% of samples\n", + " → + |: 37.3% of samples\n", + " → + :: 37.2% of samples\n", + " | + @: 24.1% of samples\n", + " @ + :: 23.7% of samples\n", + "\n", + "================================================================================\n", + "NEXT STEPS\n", + "================================================================================\n", + "1. ✅ NL symbol analysis complete\n", + "2. ⏭️ Change to 'code_dataset.jsonl' and re-run\n", + "3. ⏭️ Compare NL vs Code symbol usage patterns\n", + "4. ⏭️ Move to Analysis #3: Token Length Distribution\n", + "================================================================================\n" + ] + } + ], + "source": [ + "# Analysis #2: Symbol Usage Patterns\n", + "# Analyze how compression symbols are used (NL)\n", + "\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "# Set style\n", + "sns.set_style(\"whitegrid\")\n", + "plt.rcParams[\"figure.figsize\"] = (14, 10)\n", + "\n", + "# ============================================================================\n", + "# CONFIGURATION - CHANGE THIS FOR EACH RUN\n", + "# ============================================================================\n", + "\n", + "DATASET_PATH = \"../nl_v2.jsonl\" # OR \"data/raw/code_dataset.jsonl\"\n", + "DOMAIN_NAME = \"Natural Language\" # OR \"Code\"\n", + "DOMAIN_TYPE = \"nl\" # OR \"code\"\n", + "\n", + "print(f\"Analyzing: {DOMAIN_NAME}\")\n", + "print(f\"Dataset: {DATASET_PATH}\")\n", + "print(\"=\" * 80 + \"\\n\")\n", + "\n", + "# ============================================================================\n", + "# SYMBOL DEFINITIONS\n", + "# ============================================================================\n", + "\n", + "# Your compression symbols\n", + "SYMBOLS = {\n", + " \"→\": \"Arrow (implies/leads to)\",\n", + " \"|\": \"Pipe (separator/or)\",\n", + " \"@\": \"At (location/decorator)\",\n", + " \"∵\": \"Because (causation)\",\n", + " \":\": \"Colon (assignment/type)\",\n", + "}\n", + "\n", + "SYMBOL_CHARS = list(SYMBOLS.keys())\n", + "\n", + "print(\"Tracking symbols:\")\n", + "for sym, desc in SYMBOLS.items():\n", + " print(f\" {sym} - {desc}\")\n", + "print()\n", + "\n", + "# ============================================================================\n", + "# 1. LOAD DATA\n", + "# ============================================================================\n", + "\n", + "\n", + "def load_dataset(filepath: str) -> list[dict]:\n", + " \"\"\"Load dataset with verbose/compressed pairs\"\"\"\n", + " data = []\n", + " with open(filepath, encoding=\"utf-8\") as f:\n", + " for line in f:\n", + " if line.strip():\n", + " data.append(json.loads(line))\n", + " return data\n", + "\n", + "\n", + "data = load_dataset(DATASET_PATH)\n", + "print(f\"✓ Loaded {len(data)} samples\\n\")\n", + "\n", + "# ============================================================================\n", + "# 2. SYMBOL ANALYSIS FUNCTIONS\n", + "# ============================================================================\n", + "\n", + "\n", + "def count_symbols(text: str) -> dict:\n", + " \"\"\"Count occurrences of each symbol in text\"\"\"\n", + " counts = {sym: text.count(sym) for sym in SYMBOL_CHARS}\n", + " return counts\n", + "\n", + "\n", + "def analyze_sample(sample: dict) -> dict:\n", + " \"\"\"Analyze symbol usage in a single sample\"\"\"\n", + " compressed = sample[\"compressed\"]\n", + "\n", + " symbol_counts = count_symbols(compressed)\n", + "\n", + " return {\n", + " \"has_arrow\": symbol_counts[\"→\"] > 0,\n", + " \"has_pipe\": symbol_counts[\"|\"] > 0,\n", + " \"has_at\": symbol_counts[\"@\"] > 0,\n", + " \"has_because\": symbol_counts[\"∵\"] > 0,\n", + " \"has_colon\": symbol_counts[\":\"] > 0,\n", + " \"count_arrow\": symbol_counts[\"→\"],\n", + " \"count_pipe\": symbol_counts[\"|\"],\n", + " \"count_at\": symbol_counts[\"@\"],\n", + " \"count_because\": symbol_counts[\"∵\"],\n", + " \"count_colon\": symbol_counts[\":\"],\n", + " \"total_symbols\": sum(symbol_counts.values()),\n", + " \"unique_symbols\": sum(1 for v in symbol_counts.values() if v > 0),\n", + " \"compressed_length\": len(compressed),\n", + " }\n", + "\n", + "\n", + "# Analyze all samples\n", + "results = [analyze_sample(sample) for sample in data]\n", + "df = pd.DataFrame(results)\n", + "\n", + "print(f\"✓ Analyzed symbol usage in {len(df)} samples\\n\")\n", + "\n", + "# ============================================================================\n", + "# 3. VISUALIZATION: SYMBOL FREQUENCY\n", + "# ============================================================================\n", + "\n", + "fig = plt.figure(figsize=(16, 10))\n", + "gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)\n", + "\n", + "# Calculate presence rates (% of samples containing each symbol)\n", + "presence_rates = {\n", + " \"→\": df[\"has_arrow\"].sum() / len(df) * 100,\n", + " \"|\": df[\"has_pipe\"].sum() / len(df) * 100,\n", + " \"@\": df[\"has_at\"].sum() / len(df) * 100,\n", + " \"∵\": df[\"has_because\"].sum() / len(df) * 100,\n", + " \":\": df[\"has_colon\"].sum() / len(df) * 100,\n", + "}\n", + "\n", + "# Average counts per sample (including samples without the symbol)\n", + "avg_counts = {\n", + " \"→\": df[\"count_arrow\"].mean(),\n", + " \"|\": df[\"count_pipe\"].mean(),\n", + " \"@\": df[\"count_at\"].mean(),\n", + " \"∵\": df[\"count_because\"].mean(),\n", + " \":\": df[\"count_colon\"].mean(),\n", + "}\n", + "\n", + "# Average counts per sample that HAS the symbol (conditional mean)\n", + "avg_counts_when_present = {\n", + " \"→\": df[df[\"has_arrow\"]][\"count_arrow\"].mean() if df[\"has_arrow\"].sum() > 0 else 0,\n", + " \"|\": df[df[\"has_pipe\"]][\"count_pipe\"].mean() if df[\"has_pipe\"].sum() > 0 else 0,\n", + " \"@\": df[df[\"has_at\"]][\"count_at\"].mean() if df[\"has_at\"].sum() > 0 else 0,\n", + " \"∵\": df[df[\"has_because\"]][\"count_because\"].mean() if df[\"has_because\"].sum() > 0 else 0,\n", + " \":\": df[df[\"has_colon\"]][\"count_colon\"].mean() if df[\"has_colon\"].sum() > 0 else 0,\n", + "}\n", + "\n", + "# Plot 1: Symbol Presence Rate (% of samples)\n", + "ax1 = fig.add_subplot(gs[0, 0])\n", + "symbols = list(presence_rates.keys())\n", + "rates = list(presence_rates.values())\n", + "colors = plt.cm.viridis(np.linspace(0, 1, len(symbols)))\n", + "\n", + "bars = ax1.barh(symbols, rates, color=colors, alpha=0.7, edgecolor=\"black\")\n", + "ax1.set_xlabel(\"% of Samples Containing Symbol\")\n", + "ax1.set_title(f\"{DOMAIN_NAME}: Symbol Presence Rate\")\n", + "ax1.set_xlim(0, 100)\n", + "\n", + "# Add value labels on bars\n", + "for i, (_bar, rate) in enumerate(zip(bars, rates, strict=True)):\n", + " ax1.text(rate + 2, i, f\"{rate:.1f}%\", va=\"center\", fontsize=10)\n", + "\n", + "\n", + "ax1.grid(True, alpha=0.3, axis=\"x\")\n", + "\n", + "# Plot 2: Average Count Per Sample\n", + "ax2 = fig.add_subplot(gs[0, 1])\n", + "counts = list(avg_counts.values())\n", + "\n", + "bars = ax2.barh(symbols, counts, color=colors, alpha=0.7, edgecolor=\"black\")\n", + "ax2.set_xlabel(\"Average Count Per Sample (All Samples)\")\n", + "ax2.set_title(f\"{DOMAIN_NAME}: Average Symbol Usage\")\n", + "\n", + "# Add value labels\n", + "for i, (_bar, count) in enumerate(zip(bars, counts, strict=True)):\n", + " ax2.text(count + 0.1, i, f\"{count:.2f}\", va=\"center\", fontsize=10)\n", + "\n", + "ax2.grid(True, alpha=0.3, axis=\"x\")\n", + "\n", + "# Plot 3: Average Count When Present\n", + "ax3 = fig.add_subplot(gs[1, 0])\n", + "counts_present = list(avg_counts_when_present.values())\n", + "\n", + "bars = ax3.barh(symbols, counts_present, color=colors, alpha=0.7, edgecolor=\"black\")\n", + "ax3.set_xlabel(\"Average Count (When Symbol Is Present)\")\n", + "ax3.set_title(f\"{DOMAIN_NAME}: Intensity When Used\")\n", + "\n", + "# Add value labels\n", + "for i, (_bar, count) in enumerate(zip(bars, counts_present, strict=True)):\n", + " ax3.text(count + 0.1, i, f\"{count:.2f}\", va=\"center\", fontsize=10)\n", + "\n", + "ax3.grid(True, alpha=0.3, axis=\"x\")\n", + "\n", + "# Plot 4: Distribution of unique symbols per sample\n", + "ax4 = fig.add_subplot(gs[1, 1])\n", + "unique_counts = df[\"unique_symbols\"].value_counts().sort_index()\n", + "\n", + "ax4.bar(unique_counts.index, unique_counts.values, color=\"steelblue\", alpha=0.7, edgecolor=\"black\")\n", + "ax4.set_xlabel(\"Number of Unique Symbols in Sample\")\n", + "ax4.set_ylabel(\"Number of Samples\")\n", + "ax4.set_title(f\"{DOMAIN_NAME}: Symbol Diversity Per Sample\")\n", + "ax4.grid(True, alpha=0.3, axis=\"y\")\n", + "\n", + "# Add mean line\n", + "mean_unique = df[\"unique_symbols\"].mean()\n", + "ax4.axvline(mean_unique, color=\"red\", linestyle=\"--\", linewidth=2, label=f\"Mean: {mean_unique:.1f}\")\n", + "ax4.legend()\n", + "\n", + "# Plot 5: Symbol co-occurrence heatmap\n", + "ax5 = fig.add_subplot(gs[2, :])\n", + "\n", + "# Build co-occurrence matrix\n", + "cooccurrence = np.zeros((5, 5))\n", + "symbol_names = [\"→\", \"|\", \"@\", \"∵\", \":\"]\n", + "\n", + "for _, row in df.iterrows():\n", + " present = [\n", + " row[\"has_arrow\"],\n", + " row[\"has_pipe\"],\n", + " row[\"has_at\"],\n", + " row[\"has_because\"],\n", + " row[\"has_colon\"],\n", + " ]\n", + "\n", + " for i in range(5):\n", + " for j in range(5):\n", + " if present[i] and present[j]:\n", + " cooccurrence[i, j] += 1\n", + "\n", + "# Normalize by total samples\n", + "cooccurrence = cooccurrence / len(df) * 100\n", + "\n", + "# Create heatmap\n", + "sns.heatmap(\n", + " cooccurrence,\n", + " annot=True,\n", + " fmt=\".1f\",\n", + " cmap=\"YlOrRd\",\n", + " xticklabels=symbol_names,\n", + " yticklabels=symbol_names,\n", + " cbar_kws={\"label\": \"% of Samples\"},\n", + " ax=ax5,\n", + " linewidths=0.5,\n", + " linecolor=\"gray\",\n", + ")\n", + "ax5.set_title(f\"{DOMAIN_NAME}: Symbol Co-occurrence Matrix\")\n", + "ax5.set_xlabel(\"Symbol\")\n", + "ax5.set_ylabel(\"Symbol\")\n", + "\n", + "plt.suptitle(f\"{DOMAIN_NAME}: Symbol Usage Analysis\", fontsize=16, fontweight=\"bold\", y=0.995)\n", + "plt.savefig(f\"analysis_02_{DOMAIN_TYPE}_symbol_usage.png\", dpi=300, bbox_inches=\"tight\")\n", + "plt.show()\n", + "\n", + "# ============================================================================\n", + "# 4. DETAILED STATISTICS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(f\"{DOMAIN_NAME.upper()} SYMBOL USAGE STATISTICS\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(f\"\\nDataset Size: {len(df)} samples\")\n", + "\n", + "print(\"\\nSYMBOL PRESENCE RATES:\")\n", + "print(f\"{'Symbol':<10} {'Description':<25} {'Samples':<10} {'Rate':<10}\")\n", + "print(\"-\" * 60)\n", + "for sym in SYMBOL_CHARS:\n", + " desc = SYMBOLS[sym].split(\"(\")[0].strip()\n", + " count = df[f\"has_{['arrow', 'pipe', 'at', 'because', 'colon'][SYMBOL_CHARS.index(sym)]}\"].sum()\n", + " rate = count / len(df) * 100\n", + " print(f\"{sym:<10} {desc:<25} {count:<10} {rate:>6.1f}%\")\n", + "\n", + "print(\"\\nAVERAGE COUNTS (ALL SAMPLES):\")\n", + "print(f\"{'Symbol':<10} {'Avg Count':<15} {'Total Uses':<15}\")\n", + "print(\"-\" * 45)\n", + "for sym in SYMBOL_CHARS:\n", + " col = f\"count_{['arrow', 'pipe', 'at', 'because', 'colon'][SYMBOL_CHARS.index(sym)]}\"\n", + " avg = df[col].mean()\n", + " total = df[col].sum()\n", + " print(f\"{sym:<10} {avg:<15.2f} {total:<15.0f}\")\n", + "\n", + "print(\"\\nAVERAGE COUNTS (WHEN PRESENT):\")\n", + "print(f\"{'Symbol':<10} {'Avg When Used':<15}\")\n", + "print(\"-\" * 30)\n", + "for sym, avg in avg_counts_when_present.items():\n", + " print(f\"{sym:<10} {avg:<15.2f}\")\n", + "\n", + "print(\"\\nSYMBOL DIVERSITY:\")\n", + "print(f\" Average unique symbols per sample: {df['unique_symbols'].mean():.2f}\")\n", + "print(f\" Median unique symbols per sample: {df['unique_symbols'].median():.0f}\")\n", + "print(f\" Max unique symbols in one sample: {df['unique_symbols'].max():.0f}\")\n", + "print(\n", + " f\" Samples using all 5 symbols: {(df['unique_symbols'] == 5).sum()} ({(df['unique_symbols'] == 5).sum() / len(df) * 100:.1f}%)\"\n", + ")\n", + "print(\n", + " f\" Samples using 0 symbols: {(df['unique_symbols'] == 0).sum()} ({(df['unique_symbols'] == 0).sum() / len(df) * 100:.1f}%)\"\n", + ")\n", + "\n", + "# ============================================================================\n", + "# 5. INSIGHTS & VALIDATION\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"INSIGHTS & VALIDATION\")\n", + "print(\"=\" * 80)\n", + "\n", + "# Check for expected patterns based on domain\n", + "if DOMAIN_TYPE == \"nl\":\n", + " print(\"\\n📋 EXPECTED PATTERNS FOR NATURAL LANGUAGE:\")\n", + " print(\" • @ symbol: Moderate usage (15-30%) for location contexts\")\n", + " print(\" • ∵ symbol: Moderate-High usage (20-40%) for causation\")\n", + " print(\" • → symbol: Moderate usage (20-40%) for implications\")\n", + " print(\" • | symbol: Low-Moderate usage for separators\")\n", + " print(\" • : symbol: High usage (60-90%) for definitions/assignments\")\n", + "\n", + " print(\"\\n📊 ACTUAL RESULTS:\")\n", + " at_rate = presence_rates[\"@\"]\n", + " because_rate = presence_rates[\"∵\"]\n", + " arrow_rate = presence_rates[\"→\"]\n", + " colon_rate = presence_rates[\":\"]\n", + "\n", + " # Validate expectations\n", + " if at_rate < 10:\n", + " print(f\" ⚠️ @ symbol underutilized ({at_rate:.1f}%) - expected 15-30%\")\n", + " print(\" → ACTION: Review if location contexts are being compressed\")\n", + " elif at_rate > 30:\n", + " print(f\" ⚠️ @ symbol overutilized ({at_rate:.1f}%) - may be used incorrectly\")\n", + " else:\n", + " print(f\" ✅ @ symbol usage normal ({at_rate:.1f}%)\")\n", + "\n", + " if because_rate < 15:\n", + " print(f\" ⚠️ ∵ symbol underutilized ({because_rate:.1f}%) - expected 20-40%\")\n", + " print(\" → ACTION: Review if causation contexts are being compressed\")\n", + " else:\n", + " print(f\" ✅ ∵ symbol usage appropriate ({because_rate:.1f}%)\")\n", + "\n", + " if colon_rate < 50:\n", + " print(f\" ⚠️ : symbol underutilized ({colon_rate:.1f}%) - expected 60-90%\")\n", + " else:\n", + " print(f\" ✅ : symbol usage strong ({colon_rate:.1f}%)\")\n", + "\n", + "else: # code\n", + " print(\"\\n📋 EXPECTED PATTERNS FOR CODE:\")\n", + " print(\" • @ symbol: High usage (70-95%) for decorators\")\n", + " print(\" • : symbol: Very high usage (85-100%) for type hints/assignments\")\n", + " print(\" • → symbol: Moderate-High usage (40-70%) for return indicators\")\n", + " print(\" • | symbol: High usage (60-90%) for union types, pipes\")\n", + " print(\" • ∵ symbol: Low usage (<10%) - rarely used in code\")\n", + "\n", + " print(\"\\n📊 ACTUAL RESULTS:\")\n", + " at_rate = presence_rates[\"@\"]\n", + " colon_rate = presence_rates[\":\"]\n", + " arrow_rate = presence_rates[\"→\"]\n", + " pipe_rate = presence_rates[\"|\"]\n", + " because_rate = presence_rates[\"∵\"]\n", + "\n", + " if at_rate < 60:\n", + " print(f\" ⚠️ @ symbol underutilized ({at_rate:.1f}%) - expected 70-95%\")\n", + " print(\" → ACTION: Decorators may not be preserved properly\")\n", + " else:\n", + " print(f\" ✅ @ symbol usage strong ({at_rate:.1f}%)\")\n", + "\n", + " if colon_rate < 80:\n", + " print(f\" ⚠️ : symbol underutilized ({colon_rate:.1f}%) - expected 85-100%\")\n", + " else:\n", + " print(f\" ✅ : symbol usage excellent ({colon_rate:.1f}%)\")\n", + "\n", + " if because_rate > 10:\n", + " print(f\" ⚠️ ∵ symbol unexpectedly high ({because_rate:.1f}%) - should be <10% in code\")\n", + " print(\" → ACTION: May be misused, check samples\")\n", + " else:\n", + " print(f\" ✅ ∵ symbol usage appropriate ({because_rate:.1f}%)\")\n", + "\n", + "# Symbol diversity check\n", + "zero_symbols = (df[\"unique_symbols\"] == 0).sum()\n", + "if zero_symbols > 0:\n", + " print(\"\\n⚠️ SYMBOL COVERAGE ISSUE:\")\n", + " print(f\" {zero_symbols} samples ({zero_symbols / len(df) * 100:.1f}%) use NO symbols\")\n", + " print(\" → ACTION: These samples may be poorly compressed\")\n", + " print(\" → Investigate: Are they very short? Failed compression?\")\n", + "\n", + "# High symbol usage\n", + "high_symbol_samples = df[df[\"total_symbols\"] > df[\"total_symbols\"].quantile(0.95)]\n", + "if len(high_symbol_samples) > 0:\n", + " print(\"\\n📊 HIGH SYMBOL USAGE:\")\n", + " print(f\" Top 5% of samples use {high_symbol_samples['total_symbols'].min():.0f}+ symbols\")\n", + " print(f\" Max symbols in one sample: {df['total_symbols'].max():.0f}\")\n", + "\n", + "# ============================================================================\n", + "# 6. CO-OCCURRENCE INSIGHTS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"SYMBOL CO-OCCURRENCE INSIGHTS\")\n", + "print(\"=\" * 80)\n", + "\n", + "print(\"\\nMost common symbol pairs (co-occur in same sample):\")\n", + "cooccur_pairs = []\n", + "for i in range(len(symbol_names)):\n", + " for j in range(i + 1, len(symbol_names)):\n", + " if cooccurrence[i, j] > 0:\n", + " cooccur_pairs.append((symbol_names[i], symbol_names[j], cooccurrence[i, j]))\n", + "\n", + "cooccur_pairs.sort(key=lambda x: x[2], reverse=True)\n", + "for sym1, sym2, rate in cooccur_pairs[:5]:\n", + " print(f\" {sym1} + {sym2}: {rate:.1f}% of samples\")\n", + "\n", + "# ============================================================================\n", + "# 7. NEXT STEPS\n", + "# ============================================================================\n", + "\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"NEXT STEPS\")\n", + "print(\"=\" * 80)\n", + "\n", + "if DOMAIN_TYPE == \"nl\":\n", + " print(\"1. ✅ NL symbol analysis complete\")\n", + " print(\"2. ⏭️ Change to 'code_dataset.jsonl' and re-run\")\n", + " print(\"3. ⏭️ Compare NL vs Code symbol usage patterns\")\n", + " print(\"4. ⏭️ Move to Analysis #3: Token Length Distribution\")\n", + "else:\n", + " print(\"1. ✅ Code symbol analysis complete\")\n", + " print(\"2. 📊 Compare with NL results:\")\n", + " print(\" • Is @ high in code (decorators) and low in NL?\")\n", + " print(\" • Is ∵ high in NL (causation) and low in code?\")\n", + " print(\" • Are patterns as expected?\")\n", + " print(\"3. ⏭️ Move to Analysis #3: Token Length Distribution\")\n", + "\n", + "print(\"=\" * 80)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "bd0fbff4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Split complete\n", + "Input file: ../nl_v2.jsonl\n", + "Output dir: ../splits\n", + "Total samples: 8256\n", + "GOOD (>= 20% reduction): 5326\n", + "BAD (< 20% or elongated): 2930\n", + "Skipped (invalid): 0\n" + ] + } + ], + "source": [ + "from typing import Any\n", + "\n", + "# ===== PARAMETERS =====\n", + "INPUT_PATH = Path(\"../nl_v2.jsonl\")\n", + "OUT_DIR = Path(\"../splits\")\n", + "\n", + "THRESHOLD = 0.20 # 20% reduction\n", + "MIN_LEN = 1 # skip empty samples\n", + "\n", + "\n", + "def safe_strip_len(x: Any) -> int:\n", + " if not isinstance(x, str):\n", + " return 0\n", + " return len(x.strip())\n", + "\n", + "\n", + "def split_by_compression(\n", + " input_path: Path,\n", + " out_dir: Path,\n", + " threshold: float,\n", + " min_len: int,\n", + ") -> None:\n", + " if not input_path.exists():\n", + " raise FileNotFoundError(f\"Input file not found: {input_path}\")\n", + "\n", + " out_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " bad_path = out_dir / \"bad_or_undercompressed.jsonl\"\n", + " good_path = out_dir / \"good_20pct_or_more.jsonl\"\n", + " skipped_path = out_dir / \"skipped_invalid.jsonl\"\n", + "\n", + " n_total = n_bad = n_good = n_skipped = 0\n", + "\n", + " with (\n", + " input_path.open(\"r\", encoding=\"utf-8\") as f_in,\n", + " bad_path.open(\"w\", encoding=\"utf-8\") as f_bad,\n", + " good_path.open(\"w\", encoding=\"utf-8\") as f_good,\n", + " skipped_path.open(\"w\", encoding=\"utf-8\") as f_skip,\n", + " ):\n", + " for line_no, line in enumerate(f_in, start=1):\n", + " line = line.strip()\n", + " if not line:\n", + " continue\n", + "\n", + " n_total += 1\n", + "\n", + " try:\n", + " obj: dict[str, Any] = json.loads(line)\n", + " except json.JSONDecodeError:\n", + " f_skip.write(\n", + " json.dumps({\"line_no\": line_no, \"reason\": \"json_decode_error\"}, ensure_ascii=False)\n", + " + \"\\n\"\n", + " )\n", + " n_skipped += 1\n", + " continue\n", + "\n", + " verbose = obj.get(\"verbose\", \"\")\n", + " compressed = obj.get(\"compressed\", \"\")\n", + "\n", + " if not isinstance(verbose, str) or not isinstance(compressed, str):\n", + " f_skip.write(\n", + " json.dumps(\n", + " {\"line_no\": line_no, \"reason\": \"non_string_fields\", \"keys\": list(obj.keys())},\n", + " ensure_ascii=False,\n", + " )\n", + " + \"\\n\"\n", + " )\n", + " n_skipped += 1\n", + " continue\n", + "\n", + " v_len = safe_strip_len(verbose)\n", + " c_len = safe_strip_len(compressed)\n", + "\n", + " if v_len < min_len or c_len < min_len:\n", + " f_skip.write(\n", + " json.dumps(\n", + " {\"line_no\": line_no, \"reason\": \"too_short\", \"v_len\": v_len, \"c_len\": c_len},\n", + " ensure_ascii=False,\n", + " )\n", + " + \"\\n\"\n", + " )\n", + " n_skipped += 1\n", + " continue\n", + "\n", + " ratio = c_len / v_len\n", + " reduction = 1.0 - ratio\n", + "\n", + " obj[\"_metrics\"] = {\n", + " \"char_verbose\": v_len,\n", + " \"char_compressed\": c_len,\n", + " \"ratio_c_over_v\": round(ratio, 4),\n", + " \"reduction\": round(reduction, 4),\n", + " \"threshold\": threshold,\n", + " }\n", + "\n", + " if ratio > 1.0 or reduction < threshold:\n", + " f_bad.write(json.dumps(obj, ensure_ascii=False) + \"\\n\")\n", + " n_bad += 1\n", + " else:\n", + " f_good.write(json.dumps(obj, ensure_ascii=False) + \"\\n\")\n", + " n_good += 1\n", + "\n", + " print(\"✅ Split complete\")\n", + " print(f\"Input file: {input_path}\")\n", + " print(f\"Output dir: {out_dir}\")\n", + " print(f\"Total samples: {n_total}\")\n", + " print(f\"GOOD (>= {threshold:.0%} reduction): {n_good}\")\n", + " print(f\"BAD (< {threshold:.0%} or elongated): {n_bad}\")\n", + " print(f\"Skipped (invalid): {n_skipped}\")\n", + "\n", + "\n", + "# Run\n", + "split_by_compression(INPUT_PATH, OUT_DIR, THRESHOLD, MIN_LEN)\n" + ] + }, + { + "cell_type": "markdown", + "id": "f8deb7f1", + "metadata": {}, + "source": [ + "# Summary :\n", + "Verbose samples are sometimes expanded, and the distribution suggests that almost half of our samples are expanded or less than 20% compression. Hence, a very slight chance that the model has learnt to use symbols instead." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "965144b1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}