diff --git a/configs/training.yaml b/configs/training.yaml index 3dbf801..ff4aa92 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -50,9 +50,9 @@ cloud: # - Qwen/Qwen3-30B-A3B (MoE, good quality/cost ratio) lora: - rank: 64 # Higher rank for better quality - alpha: 128 # Usually 2x rank - dropout: 0.0 + rank: 16 # Down from 64; r=16 is sweet spot for 8B models + alpha: 32 # 2x rank convention (was 128) + dropout: 0.05 # Mild regularization (was 0.0) target_modules: - q_proj - k_proj @@ -63,18 +63,20 @@ cloud: - down_proj training: - epochs: 3 - # Tinker training duration is controlled by epochs in scripts/train_tinker.py + epochs: 2 # Down from 3; expect early stopping at ~1-1.5 epochs batch_size: 4 learning_rate: 2.0e-4 warmup_ratio: 0.03 max_seq_length: 2048 log_interval_steps: 10 checkpoint_interval_steps: 250 - eval_interval_steps: 0 # 0 disables mid-epoch evals + eval_interval_steps: 250 # Was 0 (disabled!); now eval every 250 steps eval_at_epoch_end: true checkpoint_ttl_seconds: null resume_from_checkpoint: true + # Early stopping + early_stopping_patience: 5 # NEW: stop after 5 evals with no improvement + early_stopping_threshold: 0.01 # NEW: min improvement to count as "better" (absolute) # === OUTPUT PATHS === output: @@ -83,7 +85,7 @@ output: fused_model: "models/fused" # === COST ESTIMATES (Tinker) === -# Based on ~2,200 training examples with 80/10/10 split +# Based on ~19,845 training examples # Average ~500 tokens per example (system + user + assistant) # # Model pricing (per 1M tokens): @@ -91,8 +93,9 @@ output: # - Qwen3-4B: $0.20 # - Qwen3-30B-A3B: $0.45 # -# Estimates for 1,759 training examples × 3 epochs: -# - Qwen3-8B: ~$1.05 -# - With validation: ~$1.50-2.00 total +# With epochs=2 and early stopping (expect ~1.5 effective epochs): +# - Full 2 epochs: 19,845 × 500 × 2 × $0.40/1M = ~$7.94 +# - With early stop (~1.5 epochs): ~$5.95 +# - Add ~10% for validation passes: ~$6.50-8.75 total # -# Budget for 5 training runs: ~$10-15 +# Budget for 3 training runs: ~$20-25 diff --git a/scripts/visualize_tinker_training.py b/scripts/visualize_tinker_training.py new file mode 100644 index 0000000..b572fc6 --- /dev/null +++ b/scripts/visualize_tinker_training.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +"""Tinker Cloud Training Log Parser and Visualizer. + +Parses metrics.jsonl from Tinker training runs and generates +training curve visualizations including loss curves, throughput, +and key training events (epoch boundaries, best val loss, early stopping). +""" + +import argparse +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +@dataclass +class TinkerTrainingMetrics: + """Container for parsed Tinker training metrics.""" + + # Training metrics + train_steps: list[int] = field(default_factory=list) + train_losses: list[float] = field(default_factory=list) + train_losses_total: list[float] = field(default_factory=list) # raw sum-reduced + tokens_per_sec: list[float] = field(default_factory=list) + train_epochs: list[int] = field(default_factory=list) # epoch for each train step + + # Validation metrics + val_steps: list[int] = field(default_factory=list) + val_losses: list[float] = field(default_factory=list) + val_epochs: list[int] = field(default_factory=list) + + # Events + early_stop_step: int | None = None + best_val_loss: float | None = None + + # Derived + epoch_boundaries: list[int] = field(default_factory=list) # steps where epochs change + + +class TinkerMetricsParser: + """Parse metrics.jsonl from Tinker training runs.""" + + def parse(self, metrics_path: Path) -> TinkerTrainingMetrics: + """Parse a metrics.jsonl file into TinkerTrainingMetrics. + + Args: + metrics_path: Path to the metrics.jsonl file. + + Returns: + TinkerTrainingMetrics with all parsed data. + """ + metrics = TinkerTrainingMetrics() + prev_epoch = None + + with open(metrics_path) as f: + for line in f: + line = line.strip() + if not line: + continue + record = json.loads(line) + record_type = record.get("type") + + if record_type == "train": + step = record["step"] + epoch = record.get("epoch", 1) + metrics.train_steps.append(step) + metrics.train_losses.append(record["train_loss"]) + if "train_loss_total" in record: + metrics.train_losses_total.append(record["train_loss_total"]) + metrics.tokens_per_sec.append(record.get("tokens_per_sec", 0.0)) + metrics.train_epochs.append(epoch) + + # Detect epoch boundary + if prev_epoch is not None and epoch != prev_epoch: + metrics.epoch_boundaries.append(step) + prev_epoch = epoch + + elif record_type == "val": + metrics.val_steps.append(record["step"]) + metrics.val_losses.append(record["val_loss"]) + metrics.val_epochs.append(record.get("epoch", 1)) + + elif record_type == "early_stop": + metrics.early_stop_step = record.get("step") + metrics.best_val_loss = record.get("best_val_loss") + + # Find best val loss if not set by early stopping + if metrics.val_losses and metrics.best_val_loss is None: + best_idx = int(np.argmin(metrics.val_losses)) + metrics.best_val_loss = metrics.val_losses[best_idx] + + return metrics + + +class TinkerVisualizer: + """Generate training curve visualizations from Tinker metrics.""" + + def __init__(self, dpi: int = 150, ema_alpha: float = 0.1): + self.dpi = dpi + self.ema_alpha = ema_alpha + + def _compute_ema(self, values: list[float], alpha: float) -> list[float]: + """Compute exponential moving average. + + Args: + values: Raw values to smooth. + alpha: Smoothing factor (0-1). Higher = more responsive to recent values. + + Returns: + List of EMA-smoothed values, same length as input. + """ + ema = [] + current = values[0] if values else 0.0 + for v in values: + current = alpha * v + (1 - alpha) * current + ema.append(current) + return ema + + def plot(self, metrics: TinkerTrainingMetrics, output_path: Path) -> None: + """Generate and save the training curves plot. + + Creates a two-subplot figure: + - Top: Loss curves (train scatter + EMA, val line, epoch boundaries, best val, early stop) + - Bottom: Throughput (tokens/sec raw + EMA) + + Args: + metrics: Parsed training metrics. + output_path: Path to save the output PNG. + """ + fig, (ax_loss, ax_throughput) = plt.subplots( + 2, + 1, + figsize=(14, 10), + height_ratios=[3, 1], + sharex=True, + ) + fig.suptitle("Tinker Training Curves", fontsize=14, fontweight="bold") + + # === Top: Loss curves === + # Raw train loss (faint scatter) + if metrics.train_steps and metrics.train_losses: + ax_loss.scatter( + metrics.train_steps, + metrics.train_losses, + s=4, + alpha=0.15, + color="tab:blue", + label="_nolegend_", + ) + # EMA smoothed train loss + ema_losses = self._compute_ema(metrics.train_losses, self.ema_alpha) + ax_loss.plot( + metrics.train_steps, + ema_losses, + color="tab:blue", + linewidth=1.5, + label="Train Loss (EMA)", + ) + + # Val loss + if metrics.val_steps and metrics.val_losses: + ax_loss.plot( + metrics.val_steps, + metrics.val_losses, + color="tab:orange", + linewidth=1.5, + marker="o", + markersize=3, + label="Val Loss", + ) + + # Best val loss star + if metrics.best_val_loss is not None: + # Find the closest val loss to best_val_loss (handles rounding differences) + try: + best_idx = metrics.val_losses.index(metrics.best_val_loss) + except ValueError: + # best_val_loss not exact match; find closest + best_idx = int( + np.argmin([abs(v - metrics.best_val_loss) for v in metrics.val_losses]) + ) + ax_loss.plot( + metrics.val_steps[best_idx], + metrics.best_val_loss, + marker="*", + color="gold", + markersize=15, + markeredgecolor="black", + markeredgewidth=0.5, + zorder=5, + label=f"Best Val Loss: {metrics.best_val_loss:.2f}", + ) + + # Epoch boundaries + for boundary_step in metrics.epoch_boundaries: + ax_loss.axvline(x=boundary_step, color="gray", linestyle="--", alpha=0.5, linewidth=1) + ax_loss.text( + boundary_step, + ax_loss.get_ylim()[1] * 0.95, + " Epoch", + fontsize=8, + color="gray", + va="top", + ) + + # Early stopping marker + if metrics.early_stop_step is not None: + ax_loss.axvline( + x=metrics.early_stop_step, + color="red", + linestyle="-.", + alpha=0.7, + linewidth=1.5, + ) + ax_loss.text( + metrics.early_stop_step, + ax_loss.get_ylim()[1] * 0.85, + " Early Stop", + fontsize=8, + color="red", + va="top", + ) + + ax_loss.set_ylabel("Loss") + ax_loss.legend(loc="upper right") + ax_loss.grid(True, alpha=0.3) + ax_loss.set_title("Training & Validation Loss") + + # === Bottom: Throughput === + if metrics.train_steps and metrics.tokens_per_sec: + ax_throughput.plot( + metrics.train_steps, + metrics.tokens_per_sec, + color="tab:green", + linewidth=0.8, + alpha=0.5, + ) + # EMA smoothed + ema_throughput = self._compute_ema(metrics.tokens_per_sec, self.ema_alpha) + ax_throughput.plot( + metrics.train_steps, + ema_throughput, + color="tab:green", + linewidth=1.5, + label="Tokens/sec (EMA)", + ) + ax_throughput.legend(loc="upper right") + + for boundary_step in metrics.epoch_boundaries: + ax_throughput.axvline( + x=boundary_step, color="gray", linestyle="--", alpha=0.5, linewidth=1 + ) + + ax_throughput.set_xlabel("Training Step") + ax_throughput.set_ylabel("Tokens/sec") + ax_throughput.grid(True, alpha=0.3) + ax_throughput.set_title("Throughput") + + plt.tight_layout() + fig.savefig(output_path, dpi=self.dpi, bbox_inches="tight") + plt.close(fig) + print(f"Saved training curves to {output_path}") + + +def main() -> None: + """CLI entry point for Tinker training visualization.""" + parser = argparse.ArgumentParser( + description="Visualize Tinker training metrics from metrics.jsonl" + ) + parser.add_argument( + "--metrics", + type=Path, + default=Path("models/adapters/tinker/metrics.jsonl"), + help="Path to metrics.jsonl (default: models/adapters/tinker/metrics.jsonl)", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Output PNG path (default: /training_curves.png)", + ) + parser.add_argument( + "--dpi", + type=int, + default=150, + help="DPI for output (default: 150)", + ) + parser.add_argument( + "--ema-alpha", + type=float, + default=0.1, + help="EMA smoothing factor (0-1, higher = more responsive) (default: 0.1)", + ) + args = parser.parse_args() + + if not args.metrics.exists(): + print(f"Error: metrics file not found: {args.metrics}", file=sys.stderr) + sys.exit(1) + + output = args.output or args.metrics.parent / "training_curves.png" + + parser_obj = TinkerMetricsParser() + metrics = parser_obj.parse(args.metrics) + + print(f"Parsed {len(metrics.train_steps)} train entries, {len(metrics.val_steps)} val entries") + if metrics.best_val_loss is not None: + print(f"Best val loss: {metrics.best_val_loss:.4f}") + if metrics.epoch_boundaries: + print(f"Epoch boundaries at steps: {metrics.epoch_boundaries}") + if metrics.early_stop_step: + print(f"Early stopping at step: {metrics.early_stop_step}") + + visualizer = TinkerVisualizer(dpi=args.dpi, ema_alpha=args.ema_alpha) + visualizer.plot(metrics, output) + + +if __name__ == "__main__": + main() diff --git a/src/training/train_tinker.py b/src/training/train_tinker.py index c711a84..6d9b806 100644 --- a/src/training/train_tinker.py +++ b/src/training/train_tinker.py @@ -39,9 +39,9 @@ class TinkerLoRAConfig: """LoRA configuration for Tinker training.""" - rank: int = 64 - alpha: int = 128 - dropout: float = 0.0 + rank: int = 16 + alpha: int = 32 + dropout: float = 0.05 target_modules: list[str] = field( default_factory=lambda: [ "q_proj", @@ -73,7 +73,7 @@ class TinkerTrainingConfig: lora: TinkerLoRAConfig = field(default_factory=TinkerLoRAConfig) # Training parameters - epochs: int = 3 + epochs: int = 2 batch_size: int = 4 learning_rate: float = 2e-4 warmup_ratio: float = 0.03 @@ -91,6 +91,10 @@ class TinkerTrainingConfig: checkpoint_ttl_seconds: int | None = None resume_from_checkpoint: bool = True + # Early stopping + early_stopping_patience: int = 0 # 0 = disabled + early_stopping_threshold: float = 0.01 + def to_dict(self) -> dict[str, Any]: """Convert to dictionary for API calls.""" return { @@ -537,10 +541,15 @@ def _iter_training_batches( tokenizer: Any, tinker_module: ModuleType, batch_size: int, -) -> Iterator[tuple[list[Any], int]]: - """Yield SDK-ready training batches with token counts from chat JSONL data.""" +) -> Iterator[tuple[list[Any], int, int]]: + """Yield SDK-ready training batches with token counts from chat JSONL data. + + Returns (batch, total_tokens, completion_tokens) tuples. + completion_tokens = number of tokens with weight=1 (assistant response tokens). + """ batch: list[Any] = [] batch_tokens = 0 + batch_completion_tokens = 0 with open(train_file, encoding="utf-8") as f: for line_number, line in enumerate(f, start=1): @@ -562,14 +571,16 @@ def _iter_training_batches( sdk_datum, token_count = _to_sdk_datum(local_datum, tinker_module) batch.append(sdk_datum) batch_tokens += token_count + batch_completion_tokens += sum(local_datum.loss_fn_inputs["weights"]) if len(batch) >= batch_size: - yield batch, batch_tokens + yield batch, batch_tokens, batch_completion_tokens batch = [] batch_tokens = 0 + batch_completion_tokens = 0 if batch: - yield batch, batch_tokens + yield batch, batch_tokens, batch_completion_tokens def _extract_loss(metrics: dict[str, float]) -> float | None: @@ -592,23 +603,60 @@ def _run_validation( tinker_module: ModuleType, batch_size: int, ) -> tuple[float | None, int]: - """Run validation pass and return mean validation loss and batch count.""" + """Run validation pass and return mean per-token validation loss and batch count.""" if not valid_file.exists() or not hasattr(training_client, "forward"): return None, 0 - losses: list[float] = [] + total_loss = 0.0 + total_completion_tokens = 0 val_batches = 0 - for batch, _ in _iter_training_batches(valid_file, tokenizer, tinker_module, batch_size): + for batch, _, completion_tokens in _iter_training_batches( + valid_file, tokenizer, tinker_module, batch_size + ): forward_result = training_client.forward(batch, "cross_entropy").result() metrics = getattr(forward_result, "metrics", {}) or {} val_loss = _extract_loss(metrics) if val_loss is not None: - losses.append(val_loss) + total_loss += val_loss + total_completion_tokens += completion_tokens val_batches += 1 - if not losses: + if total_completion_tokens == 0: return None, val_batches - return sum(losses) / len(losses), val_batches + return total_loss / total_completion_tokens, val_batches + + +def _check_early_stopping( + val_loss: float | None, + best_val_loss: float | None, + evals_without_improvement: int, + config: TinkerTrainingConfig, + current_step: int, +) -> tuple[float | None, int, bool]: + """Check if early stopping should trigger. + + Returns: + Tuple of (best_val_loss, evals_without_improvement, should_stop). + """ + if val_loss is None or config.early_stopping_patience <= 0: + return best_val_loss, evals_without_improvement, False + + if best_val_loss is None or val_loss < best_val_loss - config.early_stopping_threshold: + return val_loss, 0, False + + evals_without_improvement += 1 + logger.info( + "No val improvement for %d/%d evals (best=%.4f, current=%.4f)", + evals_without_improvement, + config.early_stopping_patience, + best_val_loss, + val_loss, + ) + if evals_without_improvement >= config.early_stopping_patience: + logger.info("Early stopping triggered at step %d", current_step) + return best_val_loss, evals_without_improvement, True + + return best_val_loss, evals_without_improvement, False def _train_with_service_client_sdk( @@ -735,9 +783,23 @@ def _train_with_service_client_sdk( start_epoch = (completed_steps // steps_per_epoch) + 1 skip_batches = completed_steps % steps_per_epoch + + # Early stopping state - restore from checkpoint if resuming + best_val_loss_raw = existing_state.get("best_val_loss") + best_val_loss: float | None = ( + float(best_val_loss_raw) if isinstance(best_val_loss_raw, (int, float)) else None + ) + evals_without_improvement_raw = existing_state.get("evals_without_improvement", 0) + evals_without_improvement = ( + int(evals_without_improvement_raw) + if isinstance(evals_without_improvement_raw, (int, float)) + else 0 + ) + stopped_early = False + for epoch in range(start_epoch, config.epochs + 1): epoch_batch_index = 0 - for batch, batch_tokens in _iter_training_batches( + for batch, batch_tokens, completion_tokens in _iter_training_batches( train_file, tokenizer, tinker, @@ -760,8 +822,13 @@ def _train_with_service_client_sdk( current_step += 1 metrics = getattr(fwdbwd, "metrics", {}) or {} step_loss = _extract_loss(metrics) - if step_loss is not None: - final_loss = step_loss + step_loss_per_token: float | None = None + if step_loss is not None and completion_tokens > 0: + step_loss_per_token = step_loss / completion_tokens + elif step_loss is not None: + step_loss_per_token = step_loss + if step_loss_per_token is not None: + final_loss = step_loss_per_token if current_step % config.log_interval_steps == 0 or current_step == total_steps: logger.info( @@ -770,14 +837,14 @@ def _train_with_service_client_sdk( config.epochs, current_step, total_steps, - f"{step_loss:.4f}" if step_loss is not None else "n/a", + f"{step_loss_per_token:.4f}" if step_loss_per_token is not None else "n/a", ) - if step_loss is not None: + if step_loss_per_token is not None: _append_train_log_line( train_log_path, ( - f"Iter {current_step}: Train loss {step_loss:.4f} " + f"Iter {current_step}: Train loss {step_loss_per_token:.4f} " f"| Tokens/sec {tokens_per_sec:.1f} | Peak mem 0.0 GB" ), ) @@ -788,14 +855,18 @@ def _train_with_service_client_sdk( "type": "train", "step": current_step, "epoch": epoch, - "train_loss": step_loss, + "train_loss": step_loss_per_token, + "train_loss_total": step_loss, + "completion_tokens": completion_tokens, "tokens_per_sec": tokens_per_sec, }, ) state["completed_steps"] = current_step - state["last_train_loss"] = step_loss + state["last_train_loss"] = step_loss_per_token state["updated_at"] = _utc_now_iso() + state["best_val_loss"] = best_val_loss + state["evals_without_improvement"] = evals_without_improvement _write_service_run_state(run_state_path, state) if config.eval_interval_steps > 0 and current_step % config.eval_interval_steps == 0: @@ -823,8 +894,21 @@ def _train_with_service_client_sdk( ) state["last_val_loss"] = val_loss state["updated_at"] = _utc_now_iso() + state["best_val_loss"] = best_val_loss + state["evals_without_improvement"] = evals_without_improvement _write_service_run_state(run_state_path, state) + best_val_loss, evals_without_improvement, should_stop = _check_early_stopping( + val_loss, + best_val_loss, + evals_without_improvement, + config, + current_step, + ) + if should_stop: + stopped_early = True + break + if ( config.checkpoint_interval_steps > 0 and current_step % config.checkpoint_interval_steps == 0 @@ -851,6 +935,9 @@ def _train_with_service_client_sdk( state["updated_at"] = _utc_now_iso() _write_service_run_state(run_state_path, state) + if stopped_early: + break + if config.eval_at_epoch_end: val_loss, val_batches = _run_validation( training_client, @@ -878,6 +965,42 @@ def _train_with_service_client_sdk( state["updated_at"] = _utc_now_iso() _write_service_run_state(run_state_path, state) + best_val_loss, evals_without_improvement, should_stop = _check_early_stopping( + val_loss, + best_val_loss, + evals_without_improvement, + config, + current_step, + ) + if should_stop: + stopped_early = True + break + + if stopped_early: + _append_train_log_line( + train_log_path, + f"Early stopping at step {current_step} (best val loss: {best_val_loss:.4f})" + if best_val_loss is not None + else f"Early stopping at step {current_step}", + ) + _append_jsonl( + metrics_path, + { + "timestamp": _utc_now_iso(), + "type": "early_stop", + "step": current_step, + "best_val_loss": best_val_loss, + "evals_without_improvement": evals_without_improvement, + }, + ) + state["status"] = "completed" + state["early_stopped"] = True + state["early_stopped_step"] = current_step + state["completed_steps"] = current_step + state["final_loss"] = final_loss + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) + if current_step == completed_steps: return TinkerTrainingResult( success=False, @@ -944,8 +1067,9 @@ def load_config_from_yaml(config_path: Path) -> TinkerTrainingConfig: return TinkerTrainingConfig( model=cloud_config.get("model", "Qwen/Qwen3-8B"), lora=TinkerLoRAConfig( - rank=lora_config.get("rank", 64), - alpha=lora_config.get("alpha", 128), + rank=lora_config.get("rank", 16), + alpha=lora_config.get("alpha", 32), + dropout=lora_config.get("dropout", 0.05), target_modules=lora_config.get( "target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], @@ -960,6 +1084,8 @@ def load_config_from_yaml(config_path: Path) -> TinkerTrainingConfig: eval_at_epoch_end=training_config.get("eval_at_epoch_end", True), checkpoint_ttl_seconds=training_config.get("checkpoint_ttl_seconds", None), resume_from_checkpoint=training_config.get("resume_from_checkpoint", True), + early_stopping_patience=training_config.get("early_stopping_patience", 0), + early_stopping_threshold=training_config.get("early_stopping_threshold", 0.01), ) diff --git a/tests/test_train_tinker.py b/tests/test_train_tinker.py index 8a966f6..2ce892d 100644 --- a/tests/test_train_tinker.py +++ b/tests/test_train_tinker.py @@ -7,6 +7,7 @@ import pytest import src.training.train_tinker as train_tinker +from src.training.train_tinker import _check_early_stopping class FakeClient: @@ -772,3 +773,330 @@ def has_service_client_api(self) -> bool: assert result.success is False assert result.error == "log_interval_steps must be greater than 0" + + +# --------------------------------------------------------------------------- +# _check_early_stopping unit tests +# --------------------------------------------------------------------------- + + +def _make_config(patience: int = 0, threshold: float = 0.01) -> train_tinker.TinkerTrainingConfig: + """Helper to build a TinkerTrainingConfig with early stopping knobs.""" + return train_tinker.TinkerTrainingConfig( + early_stopping_patience=patience, + early_stopping_threshold=threshold, + ) + + +def test_check_early_stopping_disabled_when_patience_zero() -> None: + config = _make_config(patience=0) + best, count, stop = _check_early_stopping(3.0, 5.0, 0, config, 100) + assert best == 5.0 + assert count == 0 + assert stop is False + + +def test_check_early_stopping_noop_when_val_loss_none() -> None: + config = _make_config(patience=5) + best, count, stop = _check_early_stopping(None, 5.0, 2, config, 100) + assert best == 5.0 + assert count == 2 + assert stop is False + + +def test_check_early_stopping_sets_baseline_on_first_eval() -> None: + config = _make_config(patience=5) + best, count, stop = _check_early_stopping(3.0, None, 0, config, 100) + assert best == 3.0 + assert count == 0 + assert stop is False + + +def test_check_early_stopping_improvement_resets_counter() -> None: + config = _make_config(patience=5) + # 2.9 < 3.0 - 0.01 = 2.99 → improvement + best, count, stop = _check_early_stopping(2.9, 3.0, 3, config, 100) + assert best == 2.9 + assert count == 0 + assert stop is False + + +def test_check_early_stopping_no_improvement_increments_counter() -> None: + config = _make_config(patience=5) + # 3.0 is NOT < 3.0 - 0.01 = 2.99 → no improvement + best, count, stop = _check_early_stopping(3.0, 3.0, 2, config, 100) + assert best == 3.0 + assert count == 3 + assert stop is False + + +def test_check_early_stopping_triggers_at_patience() -> None: + config = _make_config(patience=5) + # evals_without_improvement=4, incremented to 5 = patience → stop + best, count, stop = _check_early_stopping(3.0, 2.5, 4, config, 100) + assert best == 2.5 + assert count == 5 + assert stop is True + + +def test_check_early_stopping_threshold_boundary() -> None: + config = _make_config(patience=5, threshold=0.01) + # 2.995 is NOT < 3.0 - 0.01 = 2.99 → NOT an improvement + best, count, stop = _check_early_stopping(2.995, 3.0, 0, config, 100) + assert best == 3.0 + assert count == 1 + assert stop is False + + +# --------------------------------------------------------------------------- +# Early stopping integration test +# --------------------------------------------------------------------------- + + +def test_train_on_tinker_service_client_stops_early( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + # 4 training examples so there are 4 steps per epoch (batch_size=1) + train_lines = "\n".join( + [ + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u1"},{"role":"assistant","content":"a1"}]}', + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u2"},{"role":"assistant","content":"a2"}]}', + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u3"},{"role":"assistant","content":"a3"}]}', + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u4"},{"role":"assistant","content":"a4"}]}', + ] + ) + (data_dir / "train.jsonl").write_text(train_lines + "\n", encoding="utf-8") + (data_dir / "valid.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"vu"},{"role":"assistant","content":"va"}]}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + forward_call_count = 0 + + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def forward(self, data, loss_fn): # noqa: ARG002 + _TrainingClient.forward_call_count += 1 + # Return increasing loss to trigger early stopping + val_loss = 1.0 + _TrainingClient.forward_call_count * 0.5 + return _Future(types.SimpleNamespace(metrics={"loss": val_loss})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + # Reset class-level counter before test + _TrainingClient.forward_call_count = 0 + + output_dir = tmp_path / "out" + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + # 4 examples × 3 epochs = 12 steps total with batch_size=1 + epochs=3, + batch_size=1, + log_interval_steps=1, + eval_interval_steps=1, + eval_at_epoch_end=False, + early_stopping_patience=2, + early_stopping_threshold=0.0, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is True + + # Should have stopped before all 12 steps + tinker_run = json.loads((output_dir / "tinker_run.json").read_text(encoding="utf-8")) + assert tinker_run["early_stopped"] is True + assert tinker_run["completed_steps"] < 12 + + # metrics.jsonl should contain an early_stop entry + metrics_lines = (output_dir / "metrics.jsonl").read_text(encoding="utf-8").splitlines() + metric_entries = [json.loads(line) for line in metrics_lines] + early_stop_entries = [e for e in metric_entries if e["type"] == "early_stop"] + assert len(early_stop_entries) == 1 + + # train.log should mention early stopping + train_log = (output_dir / "train.log").read_text(encoding="utf-8") + assert "Early stopping" in train_log + + +# --------------------------------------------------------------------------- +# Loss normalization test +# --------------------------------------------------------------------------- + + +def test_train_on_tinker_service_client_loss_normalization( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that train_loss in metrics.jsonl is per-token (raw_loss / completion_tokens).""" + data_dir = tmp_path / "training" + data_dir.mkdir() + # Use longer assistant content "assistant" (9 chars) for more completion tokens + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"assistant"}]}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + raw_loss = 0.9 + + class _TrainingClient: + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": raw_loss}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + output_dir = tmp_path / "out" + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + epochs=1, + batch_size=1, + log_interval_steps=1, + eval_at_epoch_end=False, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + assert result.success is True + + metrics_lines = (output_dir / "metrics.jsonl").read_text(encoding="utf-8").splitlines() + all_entries = [json.loads(line) for line in metrics_lines] + train_entries = [e for e in all_entries if e["type"] == "train"] + assert len(train_entries) >= 1 + + entry = train_entries[0] + + # With _Tokenizer on messages [system "s", user "u", assistant "assistant"]: + # prompt = "s\nu" → encode(add_special=True) = [0,1,1,1] (4 tokens) + # completion = "assistant" → encode(add_special=False) = [1]*9 (9 tokens) + # weights[1:] has sum = 9 → completion_tokens = 9 + expected_completion_tokens = 9 + expected_per_token_loss = raw_loss / expected_completion_tokens + + assert entry["completion_tokens"] == expected_completion_tokens + assert entry["train_loss_total"] == pytest.approx(raw_loss) + assert entry["train_loss"] == pytest.approx(expected_per_token_loss) + assert entry["train_loss"] < entry["train_loss_total"] diff --git a/tests/test_visualize_tinker_training.py b/tests/test_visualize_tinker_training.py new file mode 100644 index 0000000..f7c650b --- /dev/null +++ b/tests/test_visualize_tinker_training.py @@ -0,0 +1,267 @@ +"""Tests for Tinker training visualizer parser.""" + +import json +import tempfile +from pathlib import Path + +from scripts.visualize_tinker_training import TinkerMetricsParser + + +class TestTinkerMetricsParser: + def _write_metrics(self, records: list[dict]) -> Path: + """Write records to a temp JSONL file and return the path.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp: + for record in records: + tmp.write(json.dumps(record) + "\n") + return Path(tmp.name) + + def test_parse_train_entries(self): + path = self._write_metrics( + [ + { + "type": "train", + "step": 10, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 20, + "epoch": 1, + "train_loss": 4.0, + "tokens_per_sec": 200.0, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.train_steps == [10, 20] + assert metrics.train_losses == [5.0, 4.0] + assert metrics.tokens_per_sec == [100.0, 200.0] + assert metrics.train_epochs == [1, 1] + + def test_parse_val_entries(self): + path = self._write_metrics( + [ + {"type": "val", "step": 100, "epoch": 1, "val_loss": 3.5, "val_batches": 10}, + {"type": "val", "step": 200, "epoch": 1, "val_loss": 3.0, "val_batches": 10}, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.val_steps == [100, 200] + assert metrics.val_losses == [3.5, 3.0] + assert metrics.best_val_loss == 3.0 + + def test_epoch_boundary_detection(self): + path = self._write_metrics( + [ + { + "type": "train", + "step": 490, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 500, + "epoch": 2, + "train_loss": 4.0, + "tokens_per_sec": 100.0, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.epoch_boundaries == [500] + + def test_early_stop_parsing(self): + path = self._write_metrics( + [ + { + "type": "train", + "step": 10, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + }, + { + "type": "early_stop", + "step": 1500, + "best_val_loss": 2.5, + "evals_without_improvement": 5, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.early_stop_step == 1500 + assert metrics.best_val_loss == 2.5 + + def test_empty_file(self): + path = self._write_metrics([]) + metrics = TinkerMetricsParser().parse(path) + assert metrics.train_steps == [] + assert metrics.val_steps == [] + assert metrics.best_val_loss is None + + def test_mixed_entries(self): + path = self._write_metrics( + [ + { + "type": "train", + "step": 10, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + }, + {"type": "val", "step": 100, "epoch": 1, "val_loss": 4.0, "val_batches": 10}, + { + "type": "train", + "step": 110, + "epoch": 1, + "train_loss": 3.0, + "tokens_per_sec": 150.0, + }, + {"type": "val", "step": 200, "epoch": 1, "val_loss": 3.5, "val_batches": 10}, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert len(metrics.train_steps) == 2 + assert len(metrics.val_steps) == 2 + # min(4.0, 3.5) = 3.5 + assert metrics.best_val_loss == 3.5 + + def test_train_loss_total_field(self): + """New format entries with train_loss_total are captured.""" + path = self._write_metrics( + [ + { + "type": "train", + "step": 10, + "epoch": 1, + "train_loss": 5.0, + "train_loss_total": 50.0, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 20, + "epoch": 1, + "train_loss": 4.0, + "tokens_per_sec": 200.0, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + # Only one entry had train_loss_total + assert metrics.train_losses_total == [50.0] + # Both entries still have train_loss + assert metrics.train_losses == [5.0, 4.0] + + def test_missing_tokens_per_sec_defaults_to_zero(self): + """Entries without tokens_per_sec default to 0.0.""" + path = self._write_metrics( + [ + {"type": "train", "step": 10, "epoch": 1, "train_loss": 5.0}, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.tokens_per_sec == [0.0] + + def test_blank_lines_skipped(self): + """Blank lines in the JSONL file are skipped gracefully.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp: + tmp.write( + json.dumps( + { + "type": "train", + "step": 10, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + } + ) + + "\n" + ) + tmp.write("\n") + tmp.write( + json.dumps( + { + "type": "train", + "step": 20, + "epoch": 1, + "train_loss": 4.0, + "tokens_per_sec": 200.0, + } + ) + + "\n" + ) + tmp_path = Path(tmp.name) + metrics = TinkerMetricsParser().parse(tmp_path) + assert metrics.train_steps == [10, 20] + + def test_multiple_epoch_boundaries(self): + """Multiple epoch transitions are all captured.""" + path = self._write_metrics( + [ + { + "type": "train", + "step": 100, + "epoch": 1, + "train_loss": 5.0, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 200, + "epoch": 2, + "train_loss": 4.0, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 300, + "epoch": 2, + "train_loss": 3.5, + "tokens_per_sec": 100.0, + }, + { + "type": "train", + "step": 400, + "epoch": 3, + "train_loss": 3.0, + "tokens_per_sec": 100.0, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.epoch_boundaries == [200, 400] + + def test_early_stop_overrides_best_val_loss(self): + """Early stop's best_val_loss takes precedence over computed min.""" + path = self._write_metrics( + [ + {"type": "val", "step": 100, "epoch": 1, "val_loss": 4.0, "val_batches": 10}, + {"type": "val", "step": 200, "epoch": 1, "val_loss": 3.0, "val_batches": 10}, + { + "type": "early_stop", + "step": 300, + "best_val_loss": 2.8, + "evals_without_improvement": 3, + }, + ] + ) + metrics = TinkerMetricsParser().parse(path) + # early_stop sets best_val_loss directly; parser skips argmin computation + assert metrics.best_val_loss == 2.8 + + def test_default_epoch_when_missing(self): + """Entries without epoch field default to epoch 1.""" + path = self._write_metrics( + [ + {"type": "train", "step": 10, "train_loss": 5.0, "tokens_per_sec": 100.0}, + {"type": "val", "step": 100, "val_loss": 3.0, "val_batches": 10}, + ] + ) + metrics = TinkerMetricsParser().parse(path) + assert metrics.train_epochs == [1] + assert metrics.val_epochs == [1]