Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ dependencies = [
"datasets>=3.0.0",
"tinker",
"matplotlib>=3.8.0",
"dagshub",
"mlflow",
]

[project.optional-dependencies]
mlflow = [
"dagshub",
"mlflow",
]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
Expand Down
148 changes: 73 additions & 75 deletions scripts/data_sanitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,26 +356,9 @@ def rule_b_orphaned_symbols(compressed: str, is_code: bool) -> tuple[bool, str]:
return True, ""


"""
RULE C FIX: Drop-in replacement for rule_c_negation_preservation

Replace the existing rule_c_negation_preservation function in your
sanitization script with these functions.

FIXES:
1. Unicode apostrophes (can't vs can't) - CRITICAL
2. Flexible whitespace in multi-word phrases
3. Reduced false positives from symbols (~, !)
4. Proper contraction handling as suffix, not standalone word
"""

import re
from typing import Tuple


def _normalize_apostrophes(text: str) -> str:
"""Normalize Unicode apostrophes to ASCII."""
unicode_apostrophes = ['\u2019', '\u2018', '\u02BC', '\u0060']
unicode_apostrophes = ["\u2019", "\u2018", "\u02bc", "\u0060"]
normalized = text
for unicode_apos in unicode_apostrophes:
normalized = normalized.replace(unicode_apos, "'")
Expand All @@ -384,8 +367,8 @@ def _normalize_apostrophes(text: str) -> str:

def _normalize_whitespace(text: str) -> str:
"""Normalize whitespace for consistent matching."""
normalized = text.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')
normalized = re.sub(r'\s+', ' ', normalized)
normalized = text.replace("\n", " ").replace("\t", " ").replace("\r", " ")
normalized = re.sub(r"\s+", " ", normalized)
return normalized.strip()


Expand All @@ -398,79 +381,86 @@ def _has_contractions(text: str) -> bool:
def _has_standalone_negation(text: str) -> bool:
"""Detect standalone negation words."""
normalized = _normalize_whitespace(text.lower())
words = ["not", "no", "never", "neither", "nor",
"without", "none", "nothing", "nobody", "nowhere"]

for word in words:
if re.search(r'\b' + re.escape(word) + r'\b', normalized):
return True
return False
words = [
"not",
"no",
"never",
"neither",
"nor",
"without",
"none",
"nothing",
"nobody",
"nowhere",
]

return any(re.search(r"\b" + re.escape(word) + r"\b", normalized) for word in words)


def _has_multiword_negation(text: str) -> bool:
"""Detect multi-word negation phrases with flexible whitespace."""
normalized = _normalize_whitespace(text.lower())
patterns = [r'\bno\s+longer\b', r'\bno\s+more\b', r'\bnot\s+anymore\b']
patterns = [r"\bno\s+longer\b", r"\bno\s+more\b", r"\bnot\s+anymore\b"]

return any(re.search(pattern, normalized) for pattern in patterns)


def _has_negation_symbols(text: str, strict: bool = True) -> bool:
"""
Detect negation symbols with reduced false positives.

Args:
strict: If True, only check ¬ (unambiguous)
If False, also check ~ and ! (more permissive)
"""
if '¬' in text:
if "¬" in text:
return True

if not strict:
# Check for ! but avoid != and !!
if re.search(r'![^=!]', text):
if re.search(r"![^=!]", text):
return True
# Check for ~ but avoid ~/ and ~digits
if re.search(r'~(?![/\d])', text):
if re.search(r"~(?![/\d])", text):
return True

return False


def rule_c_negation_preservation(verbose: str, compressed: str) -> Tuple[bool, str]:
def rule_c_negation_preservation(verbose: str, compressed: str) -> tuple[bool, str]:
"""
Rule C: Remove samples that lost negation (NL only).

FIXED VERSION addressing:
- Unicode apostrophes in contractions (can't vs can't)
- Flexible whitespace in multi-word phrases
- Reduced false positives from symbols
- Proper contraction detection

Returns:
(passed, reason) - passed=False if negation was lost
"""
# Check if input has any form of negation
has_input_negation = (
_has_contractions(verbose) or
_has_standalone_negation(verbose) or
_has_multiword_negation(verbose)
_has_contractions(verbose)
or _has_standalone_negation(verbose)
or _has_multiword_negation(verbose)
)

if not has_input_negation:
return True, "" # No negation to preserve

# Input has negation - check if output preserved it
has_output_negation = (
_has_contractions(compressed) or
_has_standalone_negation(compressed) or
_has_multiword_negation(compressed) or
_has_negation_symbols(compressed, strict=False) # Allow symbols in compressed
_has_contractions(compressed)
or _has_standalone_negation(compressed)
or _has_multiword_negation(compressed)
or _has_negation_symbols(compressed, strict=False) # Allow symbols in compressed
)

if not has_output_negation:
return False, "Negation lost"

return True, ""


Expand Down Expand Up @@ -514,6 +504,27 @@ def sanitize_and_extract(input_path: Path, sanitized_path: Path, unsanitized_pat

data = []

stats = {
"total_input": 0,
"code_samples": 0,
"nl_samples": 0,
"code_passed": 0,
"nl_passed": 0,
"rule_a_failed": 0,
"rule_b_failed": 0,
"rule_c_failed": 0,
"rule_d_failed": 0,
"parse_errors": 0,
"passed_all": 0,
"failed_all": 0,
"failed_samples": [],
"passed_samples": [],
"parse_error_samples": [],
}

sanitized_data = []
unsanitized_data = []

with open(input_path, encoding="utf-8") as f:
for idx, line in enumerate(f):
if not line.strip():
Expand All @@ -540,29 +551,9 @@ def sanitize_and_extract(input_path: Path, sanitized_path: Path, unsanitized_pat
print(f"⚠ {error_msg}")
continue


print(f"✓ Loaded {len(data)} samples\n")

stats = {
"total_input": len(data),
"code_samples": 0,
"nl_samples": 0,
"code_passed": 0,
"nl_passed": 0,
"rule_a_failed": 0,
"rule_b_failed": 0,
"rule_c_failed": 0,
"rule_d_failed": 0,
"parse_errors": 0,
"passed_all": 0,
"failed_all": 0,
"failed_samples": [],
"passed_samples": [],
"parse_error_samples": [],
}

sanitized_data = []
unsanitized_data = []
stats["total_input"] = len(data)

print("Processing samples...\n")

Expand Down Expand Up @@ -695,12 +686,20 @@ def pct(n: int, d: int) -> float:
print(f" NL samples: {nl:5d} ({pct(nl, total):5.1f}%)")
print()

print(f"✓ SANITIZED (passed): {stats['passed_all']:5d} ({pct(stats['passed_all'], total):5.1f}%)")
print(f" Code: {stats['code_passed']:5d} ({pct(stats['code_passed'], code):5.1f}%)")
print(f" NL: {stats['nl_passed']:5d} ({pct(stats['nl_passed'], nl):5.1f}%)")
print(
f"✓ SANITIZED (passed): {stats['passed_all']:5d} ({pct(stats['passed_all'], total):5.1f}%)"
)
print(
f" Code: {stats['code_passed']:5d} ({pct(stats['code_passed'], code):5.1f}%)"
)
print(
f" NL: {stats['nl_passed']:5d} ({pct(stats['nl_passed'], nl):5.1f}%)"
)
print()

print(f"✗ UNSANITIZED (failed): {stats['failed_all']:5d} ({pct(stats['failed_all'], total):5.1f}%)")
print(
f"✗ UNSANITIZED (failed): {stats['failed_all']:5d} ({pct(stats['failed_all'], total):5.1f}%)"
)
print()

print("Failed by rule:")
Expand Down Expand Up @@ -733,7 +732,6 @@ def pct(n: int, d: int) -> float:
print()



# ============================================================================
# CLI
# ============================================================================
Expand Down
4 changes: 4 additions & 0 deletions tests/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

import pytest

# Skip entire module if dagshub/mlflow not installed (optional dependencies)
pytest.importorskip("dagshub", reason="dagshub not installed (optional [mlflow] extra)")
pytest.importorskip("mlflow", reason="mlflow not installed (optional [mlflow] extra)")


class MlflowRecorder:
"""Mock MLflow client that records all operations."""
Expand Down
26 changes: 26 additions & 0 deletions tests/test_sanitization_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,32 @@ def test_sanitize_dataset_splits_correctly(self, temp_dir):
assert sanitized_file.exists()
assert unsanitized_file.exists()

def test_malformed_json_line_does_not_crash(self, temp_dir):
"""Test that malformed JSON lines are tracked as parse errors, not NameError."""

input_file = temp_dir / "train.jsonl"
sanitized_file = temp_dir / "sanitized.jsonl"
unsanitized_file = temp_dir / "unsanitized.jsonl"

good_sample = {
"messages": [
{"role": "user", "content": "Compress: one two three four"},
{"role": "assistant", "content": "1 2"},
]
}

with open(input_file, "w", encoding="utf-8") as f:
f.write(json.dumps(good_sample, ensure_ascii=False) + "\n")
f.write("{this is not valid json}\n") # malformed line
f.write(json.dumps(good_sample, ensure_ascii=False) + "\n")

# Should NOT raise NameError
stats = sanitize_and_extract(input_file, sanitized_file, unsanitized_file)

assert stats["parse_errors"] == 1
assert len(stats["parse_error_samples"]) == 1
assert stats["total_input"] == 2 # only 2 valid samples loaded

def test_unicode_symbols_preserved(self, temp_dir):
"""Test that → ∵ @ symbols are NOT escaped in sanitized output."""

Expand Down