diff --git a/configs/training.yaml b/configs/training.yaml index aef05e5..3dbf801 100644 --- a/configs/training.yaml +++ b/configs/training.yaml @@ -69,6 +69,12 @@ cloud: learning_rate: 2.0e-4 warmup_ratio: 0.03 max_seq_length: 2048 + log_interval_steps: 10 + checkpoint_interval_steps: 250 + eval_interval_steps: 0 # 0 disables mid-epoch evals + eval_at_epoch_end: true + checkpoint_ttl_seconds: null + resume_from_checkpoint: true # === OUTPUT PATHS === output: diff --git a/docs/SETUP.md b/docs/SETUP.md index 997c922..5f065dd 100644 --- a/docs/SETUP.md +++ b/docs/SETUP.md @@ -122,50 +122,66 @@ TINKER_API_KEY=tk_... ### 2.2 Training with Tinker -```python -# src/training/train_tinker.py -import os -from tinker import ServiceClient -from pathlib import Path - -from src.training.train_tinker import TinkerTrainingConfig, run_training_loop, write_run_metadata - -# Initialize SDK clients -service_client = ServiceClient(api_key=os.environ["TINKER_API_KEY"]) -training_client = service_client.create_lora_training_client( - base_model="Qwen/Qwen3-8B", -) - -# Configure training -config = TinkerTrainingConfig( - base_model="Qwen/Qwen3-8B", - epochs=3, -) - -# Run training loop and persist metadata -metadata = run_training_loop(training_client, config) -metadata_path = write_run_metadata(metadata, output_dir=Path("models/adapters/tinker")) -print(f"Run ID: {metadata.run_id}") -print(f"Run metadata: {metadata_path}") +Use the CLI wrapper for end-to-end training: + +```bash +python scripts/train_tinker.py \ + --config configs/training.yaml \ + --output models/adapters/tinker ``` +ServiceClient mode writes artifacts directly in `--output`: + +- `tinker_run.json` (resume state + checkpoint history) +- `run.json` (MLflow-compatible run metadata) +- `train.log` (train/val lines parsed by `scripts/mlflow_logger.py`) +- `metrics.jsonl` (structured metric events) + +Auto-resume is enabled by default and uses `latest_checkpoint_path` from +`tinker_run.json`. + ### 2.3 Tinker CLI Workflow ```bash -# Start training (records run metadata under models/adapters/tinker/runs) +# Start training with config defaults python scripts/train_tinker.py \ --config configs/training.yaml \ --output models/adapters/tinker -# Check status +# Customize training telemetry/checkpoints python scripts/train_tinker.py \ - --status \ - --output models/adapters/tinker + --config configs/training.yaml \ + --output models/adapters/tinker \ + --log-interval-steps 10 \ + --checkpoint-interval-steps 250 \ + --eval-interval-steps 100 -# Inspect run metadata -cat models/adapters/tinker/runs/.json +# Disable auto-resume for a fresh run in same output dir +python scripts/train_tinker.py \ + --config configs/training.yaml \ + --output models/adapters/tinker \ + --no-resume + +# Check status (legacy Client API mode) +python scripts/train_tinker.py --status + +# Inspect service-mode run artifacts +cat models/adapters/tinker/tinker_run.json +cat models/adapters/tinker/run.json + +# Send artifacts to MLflow/DagsHub +python scripts/mlflow_logger.py \ + --run-dir models/adapters/tinker \ + --experiment-name "compression-v2" \ + --dagshub-owner Sudhendra \ + --dagshub-repo compression-layer ``` +Additional useful flags: + +- `--no-eval-at-epoch-end` +- `--checkpoint-ttl-seconds ` + ### 2.4 Cost Estimation | Model | Per 1M Tokens | 10K pairs (~5M tok) | 50K pairs (~25M tok) | @@ -258,8 +274,10 @@ python scripts/validate_batch.py --input data/seed/pairs.jsonl ### Tinker job failed - Check dataset format (JSONL with `text` or `messages` field) - Verify API key in `.env` or shell: `TINKER_API_KEY` -- Inspect run metadata: `models/adapters/tinker/runs/.json` -- Re-run status: `python scripts/train_tinker.py --status --output models/adapters/tinker` +- Inspect run metadata: `models/adapters/tinker/tinker_run.json` +- Inspect MLflow metadata/logs: `models/adapters/tinker/run.json`, `models/adapters/tinker/train.log` +- Re-run same command to resume from latest checkpoint (default) +- Add `--no-resume` to force a fresh run in an existing output directory ### Slow local inference - Ensure using 4-bit model: `*-4bit` diff --git a/scripts/train_tinker.py b/scripts/train_tinker.py index 466796f..f6eae8d 100644 --- a/scripts/train_tinker.py +++ b/scripts/train_tinker.py @@ -144,6 +144,40 @@ def parse_args() -> argparse.Namespace: default=128, help="LoRA alpha (default: 128)", ) + parser.add_argument( + "--log-interval-steps", + type=int, + default=10, + help="How often to log train metrics (default: 10)", + ) + parser.add_argument( + "--checkpoint-interval-steps", + type=int, + default=250, + help="How often to save resumable checkpoints (default: 250)", + ) + parser.add_argument( + "--eval-interval-steps", + type=int, + default=0, + help="How often to run validation during an epoch (0 disables, default: 0)", + ) + parser.add_argument( + "--no-eval-at-epoch-end", + action="store_true", + help="Disable validation pass at epoch end", + ) + parser.add_argument( + "--checkpoint-ttl-seconds", + type=int, + default=None, + help="Optional TTL for saved checkpoints in seconds", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Disable auto-resume from latest checkpoint in output directory", + ) # Job control parser.add_argument( @@ -207,6 +241,11 @@ def print_config(config: TinkerTrainingConfig) -> None: table.add_row("Learning Rate", f"{config.learning_rate:.0e}") table.add_row("LoRA Rank", str(config.lora.rank)) table.add_row("LoRA Alpha", str(config.lora.alpha)) + table.add_row("Log Interval", str(config.log_interval_steps)) + table.add_row("Checkpoint Every", str(config.checkpoint_interval_steps)) + table.add_row("Eval Every", str(config.eval_interval_steps)) + table.add_row("Eval Epoch End", str(config.eval_at_epoch_end)) + table.add_row("Auto Resume", str(config.resume_from_checkpoint)) console.print(table) @@ -240,6 +279,12 @@ def main() -> int: config.lora = TinkerLoRAConfig(rank=args.lora_rank, alpha=args.lora_alpha) config.wait_for_completion = not args.no_wait config.dataset_name = args.dataset_name + config.log_interval_steps = args.log_interval_steps + config.checkpoint_interval_steps = args.checkpoint_interval_steps + config.eval_interval_steps = args.eval_interval_steps + config.eval_at_epoch_end = not args.no_eval_at_epoch_end + config.checkpoint_ttl_seconds = args.checkpoint_ttl_seconds + config.resume_from_checkpoint = not args.no_resume # Handle status check if args.status: diff --git a/src/training/train_tinker.py b/src/training/train_tinker.py index b9762f3..c711a84 100644 --- a/src/training/train_tinker.py +++ b/src/training/train_tinker.py @@ -21,7 +21,9 @@ import math import os import time +from collections.abc import Iterator from dataclasses import dataclass, field +from datetime import datetime, timezone from pathlib import Path from types import ModuleType from typing import Any @@ -81,6 +83,14 @@ class TinkerTrainingConfig: wait_for_completion: bool = True poll_interval: int = 30 # seconds + # ServiceClient mode controls + log_interval_steps: int = 10 + checkpoint_interval_steps: int = 250 + eval_interval_steps: int = 0 # 0 = disabled + eval_at_epoch_end: bool = True + checkpoint_ttl_seconds: int | None = None + resume_from_checkpoint: bool = True + def to_dict(self) -> dict[str, Any]: """Convert to dictionary for API calls.""" return { @@ -90,6 +100,12 @@ def to_dict(self) -> dict[str, Any]: "learning_rate": self.learning_rate, "warmup_ratio": self.warmup_ratio, "max_seq_length": self.max_seq_length, + "log_interval_steps": self.log_interval_steps, + "checkpoint_interval_steps": self.checkpoint_interval_steps, + "eval_interval_steps": self.eval_interval_steps, + "eval_at_epoch_end": self.eval_at_epoch_end, + "checkpoint_ttl_seconds": self.checkpoint_ttl_seconds, + "resume_from_checkpoint": self.resume_from_checkpoint, "lora": { "r": self.lora.rank, "alpha": self.lora.alpha, @@ -448,28 +464,83 @@ def train_on_tinker( ) -def _to_sdk_datum(local_datum: Any, tinker_module: Any) -> Any: - """Convert local tinker_data datum to SDK Datum.""" +def _utc_now_iso() -> str: + """Return UTC timestamp in ISO format.""" + return datetime.now(tz=timezone.utc).isoformat() # noqa: UP017 + + +def _append_jsonl(file_path: Path, payload: dict[str, Any]) -> None: + """Append one JSON object to a JSONL file.""" + with open(file_path, "a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def _append_train_log_line(train_log_path: Path, line: str) -> None: + """Append one line to train.log.""" + with open(train_log_path, "a", encoding="utf-8") as f: + f.write(line + "\n") + + +def _write_service_run_state(state_path: Path, state: dict[str, Any]) -> None: + """Persist service-mode run state and a mlflow-compatible run.json.""" + state_path.parent.mkdir(parents=True, exist_ok=True) + with open(state_path, "w", encoding="utf-8") as f: + json.dump(state, f, indent=2, ensure_ascii=False) + + run_json_path = state_path.parent / "run.json" + run_json = { + "started_at": state.get("started_at", _utc_now_iso()), + "model": state.get("model"), + "git_sha": state.get("git_sha", os.environ.get("GITHUB_SHA", "unknown")), + "data_dir": state.get("data_dir"), + "lora_rank": state.get("lora_rank"), + "lora_alpha": state.get("lora_alpha"), + "batch_size": state.get("batch_size"), + "learning_rate": state.get("learning_rate"), + "iters": state.get("total_steps"), + } + with open(run_json_path, "w", encoding="utf-8") as f: + json.dump(run_json, f, indent=2, ensure_ascii=False) + + +def _load_service_run_state(state_path: Path) -> dict[str, Any]: + """Load prior run state if present.""" + if not state_path.exists(): + return {} + + try: + with open(state_path, encoding="utf-8") as f: + loaded = json.load(f) + return loaded if isinstance(loaded, dict) else {} + except json.JSONDecodeError: + logger.warning("Unable to parse run state at %s; starting fresh", state_path) + return {} + + +def _to_sdk_datum(local_datum: Any, tinker_module: ModuleType) -> tuple[Any, int]: + """Convert local tinker_data datum to SDK Datum and return token count.""" target_tokens = [int(token) for token in local_datum.loss_fn_inputs["target_tokens"]] weights = [float(weight) for weight in local_datum.loss_fn_inputs["weights"]] - return tinker_module.Datum( + datum = tinker_module.Datum( model_input=tinker_module.ModelInput.from_ints(local_datum.model_input.tokens), loss_fn_inputs={ "target_tokens": tinker_module.TensorData(data=target_tokens, dtype="int64"), "weights": tinker_module.TensorData(data=weights, dtype="float32"), }, ) + return datum, len(target_tokens) def _iter_training_batches( train_file: Path, tokenizer: Any, - tinker_module: Any, + tinker_module: ModuleType, batch_size: int, -) -> Any: - """Yield SDK-ready training batches from chat JSONL data.""" +) -> Iterator[tuple[list[Any], int]]: + """Yield SDK-ready training batches with token counts from chat JSONL data.""" batch: list[Any] = [] + batch_tokens = 0 with open(train_file, encoding="utf-8") as f: for line_number, line in enumerate(f, start=1): @@ -488,14 +559,17 @@ def _iter_training_batches( continue local_datum = render_chat_example(messages, tokenizer) - batch.append(_to_sdk_datum(local_datum, tinker_module)) + sdk_datum, token_count = _to_sdk_datum(local_datum, tinker_module) + batch.append(sdk_datum) + batch_tokens += token_count if len(batch) >= batch_size: - yield batch + yield batch, batch_tokens batch = [] + batch_tokens = 0 if batch: - yield batch + yield batch, batch_tokens def _extract_loss(metrics: dict[str, float]) -> float | None: @@ -511,21 +585,48 @@ def _extract_loss(metrics: dict[str, float]) -> float | None: return None +def _run_validation( + training_client: Any, + valid_file: Path, + tokenizer: Any, + tinker_module: ModuleType, + batch_size: int, +) -> tuple[float | None, int]: + """Run validation pass and return mean validation loss and batch count.""" + if not valid_file.exists() or not hasattr(training_client, "forward"): + return None, 0 + + losses: list[float] = [] + val_batches = 0 + for batch, _ in _iter_training_batches(valid_file, tokenizer, tinker_module, batch_size): + forward_result = training_client.forward(batch, "cross_entropy").result() + metrics = getattr(forward_result, "metrics", {}) or {} + val_loss = _extract_loss(metrics) + if val_loss is not None: + losses.append(val_loss) + val_batches += 1 + + if not losses: + return None, val_batches + return sum(losses) / len(losses), val_batches + + def _train_with_service_client_sdk( config: TinkerTrainingConfig, api_key: str, ) -> TinkerTrainingResult: """Train via the modern Tinker ServiceClient/TrainingClient SDK flow.""" - tinker = __import__("tinker") + if config.log_interval_steps <= 0: + return TinkerTrainingResult( + success=False, + error="log_interval_steps must be greater than 0", + ) - service_client = tinker.ServiceClient(api_key=api_key) - training_client = service_client.create_lora_training_client( - base_model=config.model, - rank=config.lora.rank, - ) - tokenizer = training_client.get_tokenizer() + tinker = __import__("tinker") train_file = config.dataset_path / "train.jsonl" + valid_file = config.dataset_path / "valid.jsonl" + with open(train_file, encoding="utf-8") as train_stream: total_examples = sum(1 for line in train_stream if line.strip()) if total_examples == 0: @@ -533,10 +634,97 @@ def _train_with_service_client_sdk( success=False, error=f"No training examples found in {train_file}" ) + config.output_dir.mkdir(parents=True, exist_ok=True) + train_log_path = config.output_dir / "train.log" + metrics_path = config.output_dir / "metrics.jsonl" + run_state_path = config.output_dir / "tinker_run.json" + existing_state = _load_service_run_state(run_state_path) + if existing_state.get("sdk_mode") not in {None, "service_client"}: + existing_state = {} + + service_client = tinker.ServiceClient(api_key=api_key) + + latest_checkpoint_path = "" + completed_steps = 0 + if config.resume_from_checkpoint: + latest_checkpoint_path = str(existing_state.get("latest_checkpoint_path", "")) + completed_steps = int(existing_state.get("completed_steps", 0) or 0) + + resumed = False + if latest_checkpoint_path and hasattr( + service_client, "create_training_client_from_state_with_optimizer" + ): + training_client = service_client.create_training_client_from_state_with_optimizer( + latest_checkpoint_path + ) + resumed = True + logger.info("Resuming ServiceClient training from checkpoint: %s", latest_checkpoint_path) + else: + if latest_checkpoint_path: + logger.warning( + "Resume checkpoint found but SDK cannot restore optimizer state; " + "starting fresh training for %s", + config.model, + ) + latest_checkpoint_path = "" + completed_steps = 0 + training_client = service_client.create_lora_training_client( + base_model=config.model, + rank=config.lora.rank, + ) + + tokenizer = training_client.get_tokenizer() + info = training_client.get_info() + training_run_id = str(getattr(info, "model_id", "unknown")) + steps_per_epoch = math.ceil(total_examples / config.batch_size) total_steps = steps_per_epoch * config.epochs - current_step = 0 + completed_final_loss_raw = existing_state.get("final_loss") + completed_final_loss: float | None = None + if isinstance(completed_final_loss_raw, (int, float)): + completed_final_loss = float(completed_final_loss_raw) + if completed_steps >= total_steps: + return TinkerTrainingResult( + success=True, + job_id=training_run_id, + adapter_path=config.output_dir, + final_loss=completed_final_loss, + total_epochs=config.epochs, + metrics={"message": "Training already complete for configured epochs"}, + ) + + checkpoints = existing_state.get("checkpoints", []) + if not isinstance(checkpoints, list): + checkpoints = [] + + state: dict[str, Any] = { + "sdk_mode": "service_client", + "training_run_id": training_run_id, + "model": config.model, + "data_dir": str(config.dataset_path), + "epochs": config.epochs, + "batch_size": config.batch_size, + "learning_rate": config.learning_rate, + "lora_rank": config.lora.rank, + "lora_alpha": config.lora.alpha, + "total_examples": total_examples, + "steps_per_epoch": steps_per_epoch, + "total_steps": total_steps, + "completed_steps": completed_steps, + "latest_checkpoint_path": latest_checkpoint_path, + "checkpoints": checkpoints, + "status": "running", + "started_at": str(existing_state.get("started_at", _utc_now_iso())), + "updated_at": _utc_now_iso(), + "resumed": resumed, + } + _write_service_run_state(run_state_path, state) + + current_step = completed_steps final_loss: float | None = None + existing_final_loss_raw = existing_state.get("final_loss") + if isinstance(existing_final_loss_raw, (int, float)): + final_loss = float(existing_final_loss_raw) logger.info( "Starting ServiceClient training: %s examples, %s epochs, %s total steps", @@ -545,17 +733,29 @@ def _train_with_service_client_sdk( total_steps, ) - for epoch in range(1, config.epochs + 1): - for batch in _iter_training_batches( + start_epoch = (completed_steps // steps_per_epoch) + 1 + skip_batches = completed_steps % steps_per_epoch + for epoch in range(start_epoch, config.epochs + 1): + epoch_batch_index = 0 + for batch, batch_tokens in _iter_training_batches( train_file, tokenizer, tinker, config.batch_size, ): - fwdbwd = training_client.forward_backward(batch, "cross_entropy").result() - training_client.optim_step( + epoch_batch_index += 1 + if epoch == start_epoch and epoch_batch_index <= skip_batches: + continue + + step_start = time.perf_counter() + fwdbwd_future = training_client.forward_backward(batch, "cross_entropy") + optim_future = training_client.optim_step( tinker.AdamParams(learning_rate=config.learning_rate), - ).result() + ) + fwdbwd = fwdbwd_future.result() + optim_future.result() + step_elapsed = max(time.perf_counter() - step_start, 1e-6) + tokens_per_sec = batch_tokens / step_elapsed if batch_tokens > 0 else 0.0 current_step += 1 metrics = getattr(fwdbwd, "metrics", {}) or {} @@ -563,7 +763,7 @@ def _train_with_service_client_sdk( if step_loss is not None: final_loss = step_loss - if current_step % 10 == 0 or current_step == total_steps: + if current_step % config.log_interval_steps == 0 or current_step == total_steps: logger.info( "Epoch %s/%s Step %s/%s Loss %s", epoch, @@ -573,37 +773,141 @@ def _train_with_service_client_sdk( f"{step_loss:.4f}" if step_loss is not None else "n/a", ) - if current_step == 0: + if step_loss is not None: + _append_train_log_line( + train_log_path, + ( + f"Iter {current_step}: Train loss {step_loss:.4f} " + f"| Tokens/sec {tokens_per_sec:.1f} | Peak mem 0.0 GB" + ), + ) + _append_jsonl( + metrics_path, + { + "timestamp": _utc_now_iso(), + "type": "train", + "step": current_step, + "epoch": epoch, + "train_loss": step_loss, + "tokens_per_sec": tokens_per_sec, + }, + ) + + state["completed_steps"] = current_step + state["last_train_loss"] = step_loss + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) + + if config.eval_interval_steps > 0 and current_step % config.eval_interval_steps == 0: + val_loss, val_batches = _run_validation( + training_client, + valid_file, + tokenizer, + tinker, + config.batch_size, + ) + if val_loss is not None: + _append_train_log_line( + train_log_path, f"Iter {current_step}: Val loss {val_loss:.4f}" + ) + _append_jsonl( + metrics_path, + { + "timestamp": _utc_now_iso(), + "type": "val", + "step": current_step, + "epoch": epoch, + "val_loss": val_loss, + "val_batches": val_batches, + }, + ) + state["last_val_loss"] = val_loss + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) + + if ( + config.checkpoint_interval_steps > 0 + and current_step % config.checkpoint_interval_steps == 0 + ): + checkpoint_name = f"step-{current_step:06d}" + checkpoint_response = training_client.save_state( + checkpoint_name, + ttl_seconds=config.checkpoint_ttl_seconds, + ).result() + checkpoint_path = str(getattr(checkpoint_response, "path", "")) + if checkpoint_path: + checkpoints.append( + { + "name": checkpoint_name, + "step": current_step, + "epoch": epoch, + "path": checkpoint_path, + "created_at": _utc_now_iso(), + } + ) + state["completed_steps"] = current_step + state["latest_checkpoint_path"] = checkpoint_path + state["checkpoints"] = checkpoints + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) + + if config.eval_at_epoch_end: + val_loss, val_batches = _run_validation( + training_client, + valid_file, + tokenizer, + tinker, + config.batch_size, + ) + if val_loss is not None: + _append_train_log_line( + train_log_path, f"Iter {current_step}: Val loss {val_loss:.4f}" + ) + _append_jsonl( + metrics_path, + { + "timestamp": _utc_now_iso(), + "type": "val", + "step": current_step, + "epoch": epoch, + "val_loss": val_loss, + "val_batches": val_batches, + }, + ) + state["last_val_loss"] = val_loss + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) + + if current_step == completed_steps: return TinkerTrainingResult( success=False, error="No valid training examples were found after parsing train.jsonl", ) - save_name = config.dataset_name or f"compression-{int(time.time())}" - save_response = training_client.save_state(save_name).result() + save_name = "final" + save_response = training_client.save_state( + save_name, + ttl_seconds=config.checkpoint_ttl_seconds, + ).result() checkpoint_path = str(getattr(save_response, "path", "")) - info = training_client.get_info() - training_run_id = str(getattr(info, "model_id", "unknown")) - - config.output_dir.mkdir(parents=True, exist_ok=True) - metadata_path = config.output_dir / "tinker_run.json" - metadata_path.write_text( - json.dumps( + if checkpoint_path: + checkpoints.append( { - "training_run_id": training_run_id, - "checkpoint_path": checkpoint_path, - "model": config.model, - "epochs": config.epochs, - "batch_size": config.batch_size, - "learning_rate": config.learning_rate, - "lora_rank": config.lora.rank, - "final_loss": final_loss, - "sdk_mode": "service_client", - }, - indent=2, - ), - encoding="utf-8", - ) + "name": save_name, + "step": current_step, + "epoch": config.epochs, + "path": checkpoint_path, + "created_at": _utc_now_iso(), + } + ) + + state["status"] = "completed" + state["completed_steps"] = current_step + state["latest_checkpoint_path"] = checkpoint_path or state.get("latest_checkpoint_path", "") + state["checkpoints"] = checkpoints + state["final_loss"] = final_loss + state["updated_at"] = _utc_now_iso() + _write_service_run_state(run_state_path, state) return TinkerTrainingResult( success=True, @@ -611,7 +915,12 @@ def _train_with_service_client_sdk( adapter_path=config.output_dir, final_loss=final_loss, total_epochs=config.epochs, - metrics={"checkpoint_path": checkpoint_path, "metadata_path": str(metadata_path)}, + metrics={ + "checkpoint_path": checkpoint_path, + "metadata_path": str(run_state_path), + "metrics_path": str(metrics_path), + "train_log_path": str(train_log_path), + }, ) @@ -645,6 +954,12 @@ def load_config_from_yaml(config_path: Path) -> TinkerTrainingConfig: epochs=training_config.get("epochs", 3), batch_size=training_config.get("batch_size", 4), learning_rate=training_config.get("learning_rate", 2e-4), + log_interval_steps=training_config.get("log_interval_steps", 10), + checkpoint_interval_steps=training_config.get("checkpoint_interval_steps", 250), + eval_interval_steps=training_config.get("eval_interval_steps", 0), + eval_at_epoch_end=training_config.get("eval_at_epoch_end", True), + checkpoint_ttl_seconds=training_config.get("checkpoint_ttl_seconds", None), + resume_from_checkpoint=training_config.get("resume_from_checkpoint", True), ) diff --git a/tests/test_train_tinker.py b/tests/test_train_tinker.py index 2d90913..8a966f6 100644 --- a/tests/test_train_tinker.py +++ b/tests/test_train_tinker.py @@ -1,3 +1,5 @@ +import json +import re import sys import types from pathlib import Path @@ -282,3 +284,491 @@ def __init__(self, **kwargs): # noqa: ARG002 assert result.error is not None assert "No valid training examples" in result.error assert _TrainingClient.save_called is False + + +def test_train_on_tinker_service_client_writes_metrics_validation_and_checkpoints( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + "\n".join( + [ + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u1"},{"role":"assistant","content":"a1"}]}', + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u2"},{"role":"assistant","content":"a2"}]}', + ] + ) + + "\n", + encoding="utf-8", + ) + (data_dir / "valid.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"vu"},{"role":"assistant","content":"va"}]}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + checkpoint_names: list[str] = [] + forward_calls = 0 + + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def forward(self, data, loss_fn): # noqa: ARG002 + _TrainingClient.forward_calls += 1 + return _Future(types.SimpleNamespace(metrics={"loss": 0.25}, loss_fn_outputs=[])) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + _TrainingClient.checkpoint_names.append(name) + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + def create_training_client_from_state_with_optimizer(self, path: str): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + output_dir = tmp_path / "out" + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + epochs=1, + batch_size=1, + log_interval_steps=1, + checkpoint_interval_steps=1, + eval_at_epoch_end=True, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is True + assert (output_dir / "tinker_run.json").exists() + assert (output_dir / "run.json").exists() + assert (output_dir / "metrics.jsonl").exists() + assert (output_dir / "train.log").exists() + assert _TrainingClient.forward_calls > 0 + assert "final" in _TrainingClient.checkpoint_names + + run_json = json.loads((output_dir / "run.json").read_text(encoding="utf-8")) + assert run_json["model"] == config.model + assert run_json["lora_rank"] == config.lora.rank + assert run_json["lora_alpha"] == config.lora.alpha + assert run_json["iters"] == 2 + + train_log_text = (output_dir / "train.log").read_text(encoding="utf-8") + assert re.search( + r"Iter (\d+): Train loss ([0-9.]+).*Tokens/sec ([0-9.]+).*Peak mem ([0-9.]+) GB", + train_log_text, + ) + assert re.search(r"Iter (\d+): Val loss ([0-9.]+)", train_log_text) + + metric_types = [ + json.loads(line)["type"] + for line in (output_dir / "metrics.jsonl").read_text(encoding="utf-8").splitlines() + ] + assert "train" in metric_types + assert "val" in metric_types + + +def test_train_on_tinker_service_client_resumes_from_checkpoint( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "tinker_run.json").write_text( + json.dumps( + { + "sdk_mode": "service_client", + "latest_checkpoint_path": "tinker://run-123/weights/step-000050", + "completed_steps": 0, + } + ), + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + resume_path = None + + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + def create_training_client_from_state_with_optimizer(self, path: str): + _ServiceClient.resume_path = path + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + epochs=1, + batch_size=1, + log_interval_steps=1, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is True + assert _ServiceClient.resume_path == "tinker://run-123/weights/step-000050" + + +def test_train_on_tinker_service_client_checkpoint_persists_completed_steps( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + if name == "final": + raise RuntimeError("final save failed") + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + output_dir = tmp_path / "out" + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + epochs=1, + batch_size=1, + log_interval_steps=10, + checkpoint_interval_steps=1, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is False + saved_state = json.loads((output_dir / "tinker_run.json").read_text(encoding="utf-8")) + assert saved_state["completed_steps"] == 1 + assert saved_state["latest_checkpoint_path"].endswith("step-000001") + + +def test_train_on_tinker_service_client_resets_progress_without_resume_support( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + output_dir = tmp_path / "out" + output_dir.mkdir() + (output_dir / "tinker_run.json").write_text( + json.dumps( + { + "sdk_mode": "service_client", + "latest_checkpoint_path": "tinker://run-123/weights/step-000001", + "completed_steps": 1, + } + ), + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + train_steps = 0 + + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + _TrainingClient.train_steps += 1 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + return _Future(types.SimpleNamespace(path=f"tinker://run-123/weights/{name}")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=output_dir, + wait_for_completion=True, + epochs=2, + batch_size=1, + log_interval_steps=1, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is True + assert _TrainingClient.train_steps == 2 + + +def test_train_on_tinker_service_client_rejects_zero_log_interval( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + class _ServiceModeClient: + def __init__(self, api_key: str | None = None): + self.api_key = api_key or "" + + @property + def is_available(self) -> bool: + return True + + def has_legacy_client_api(self) -> bool: + return False + + def has_service_client_api(self) -> bool: + return True + + monkeypatch.setattr(train_tinker, "TinkerClient", _ServiceModeClient) + + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=tmp_path / "out", + log_interval_steps=0, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is False + assert result.error == "log_interval_steps must be greater than 0"