diff --git a/src/training/train_tinker.py b/src/training/train_tinker.py index 7fde9bd..b9762f3 100644 --- a/src/training/train_tinker.py +++ b/src/training/train_tinker.py @@ -16,15 +16,20 @@ result = train_on_tinker(config) """ +import json import logging +import math import os import time from dataclasses import dataclass, field from pathlib import Path +from types import ModuleType from typing import Any import yaml +from src.training.tinker_data import render_chat_example + logger = logging.getLogger(__name__) @@ -150,6 +155,20 @@ def _check_sdk(self) -> bool: except ImportError: return False + def _load_tinker_module(self) -> ModuleType: + """Import and return the installed tinker module.""" + return __import__("tinker") + + def has_legacy_client_api(self) -> bool: + """Whether SDK exposes legacy Client(api_key=...) API.""" + tinker = self._load_tinker_module() + return callable(getattr(tinker, "Client", None)) + + def has_service_client_api(self) -> bool: + """Whether SDK exposes ServiceClient API.""" + tinker = self._load_tinker_module() + return callable(getattr(tinker, "ServiceClient", None)) + @property def is_available(self) -> bool: """Check if Tinker is properly configured.""" @@ -169,7 +188,14 @@ def upload_dataset(self, data_dir: Path, name: str) -> str: if not self._sdk_available: raise RuntimeError("Tinker SDK not available. Install with: pip install tinker") - tinker = __import__("tinker") # Dynamic import to avoid static analysis errors + tinker = self._load_tinker_module() # Dynamic import to avoid static analysis errors + + if not self.has_legacy_client_api(): + raise RuntimeError( + "Installed tinker SDK does not expose Client API required for job-style upload. " + "Use scripts/train_tinker.py in ServiceClient mode (default in this repo) or " + "install the legacy SDK variant." + ) client = tinker.Client(api_key=self.api_key) @@ -197,7 +223,12 @@ def start_training( if not self._sdk_available: raise RuntimeError("Tinker SDK not available. Install with: pip install tinker") - tinker = __import__("tinker") + tinker = self._load_tinker_module() + + if not self.has_legacy_client_api(): + raise RuntimeError( + "Installed tinker SDK does not expose Client API required for job-style training." + ) client = tinker.Client(api_key=self.api_key) @@ -228,7 +259,12 @@ def get_job_status(self, job_id: str) -> TinkerJobStatus: if not self._sdk_available: raise RuntimeError("Tinker SDK not available") - tinker = __import__("tinker") + tinker = self._load_tinker_module() + + if not self.has_legacy_client_api(): + raise RuntimeError( + "Installed tinker SDK does not expose Client API required for job status polling." + ) client = tinker.Client(api_key=self.api_key) job = client.get_job(job_id) @@ -290,7 +326,12 @@ def download_adapter(self, job_id: str, output_path: Path) -> Path: if not self._sdk_available: raise RuntimeError("Tinker SDK not available") - tinker = __import__("tinker") + tinker = self._load_tinker_module() + + if not self.has_legacy_client_api(): + raise RuntimeError( + "Installed tinker SDK does not expose Client API required for adapter download." + ) client = tinker.Client(api_key=self.api_key) job = client.get_job(job_id) @@ -346,6 +387,14 @@ def train_on_tinker( ) try: + if client.has_service_client_api() and not client.has_legacy_client_api(): + if not config.wait_for_completion: + return TinkerTrainingResult( + success=False, + error="--no-wait is not supported with the installed ServiceClient SDK mode.", + ) + return _train_with_service_client_sdk(config, api_key=client.api_key) + # Generate dataset name if not provided dataset_name = config.dataset_name or f"compression-{int(time.time())}" @@ -399,6 +448,173 @@ def train_on_tinker( ) +def _to_sdk_datum(local_datum: Any, tinker_module: Any) -> Any: + """Convert local tinker_data datum to SDK Datum.""" + target_tokens = [int(token) for token in local_datum.loss_fn_inputs["target_tokens"]] + weights = [float(weight) for weight in local_datum.loss_fn_inputs["weights"]] + + return tinker_module.Datum( + model_input=tinker_module.ModelInput.from_ints(local_datum.model_input.tokens), + loss_fn_inputs={ + "target_tokens": tinker_module.TensorData(data=target_tokens, dtype="int64"), + "weights": tinker_module.TensorData(data=weights, dtype="float32"), + }, + ) + + +def _iter_training_batches( + train_file: Path, + tokenizer: Any, + tinker_module: Any, + batch_size: int, +) -> Any: + """Yield SDK-ready training batches from chat JSONL data.""" + batch: list[Any] = [] + + with open(train_file, encoding="utf-8") as f: + for line_number, line in enumerate(f, start=1): + if not line.strip(): + continue + + try: + record = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping invalid JSON at line %s", line_number) + continue + + messages = record.get("messages") if isinstance(record, dict) else None + if not isinstance(messages, list) or not messages: + logger.warning("Skipping malformed record at line %s", line_number) + continue + + local_datum = render_chat_example(messages, tokenizer) + batch.append(_to_sdk_datum(local_datum, tinker_module)) + + if len(batch) >= batch_size: + yield batch + batch = [] + + if batch: + yield batch + + +def _extract_loss(metrics: dict[str, float]) -> float | None: + """Extract best-effort scalar loss value from metrics dict.""" + for key, value in metrics.items(): + if "loss" in key.lower(): + return float(value) + + if metrics: + first_value = next(iter(metrics.values())) + return float(first_value) + + return None + + +def _train_with_service_client_sdk( + config: TinkerTrainingConfig, + api_key: str, +) -> TinkerTrainingResult: + """Train via the modern Tinker ServiceClient/TrainingClient SDK flow.""" + tinker = __import__("tinker") + + service_client = tinker.ServiceClient(api_key=api_key) + training_client = service_client.create_lora_training_client( + base_model=config.model, + rank=config.lora.rank, + ) + tokenizer = training_client.get_tokenizer() + + train_file = config.dataset_path / "train.jsonl" + with open(train_file, encoding="utf-8") as train_stream: + total_examples = sum(1 for line in train_stream if line.strip()) + if total_examples == 0: + return TinkerTrainingResult( + success=False, error=f"No training examples found in {train_file}" + ) + + steps_per_epoch = math.ceil(total_examples / config.batch_size) + total_steps = steps_per_epoch * config.epochs + current_step = 0 + final_loss: float | None = None + + logger.info( + "Starting ServiceClient training: %s examples, %s epochs, %s total steps", + total_examples, + config.epochs, + total_steps, + ) + + for epoch in range(1, config.epochs + 1): + for batch in _iter_training_batches( + train_file, + tokenizer, + tinker, + config.batch_size, + ): + fwdbwd = training_client.forward_backward(batch, "cross_entropy").result() + training_client.optim_step( + tinker.AdamParams(learning_rate=config.learning_rate), + ).result() + + current_step += 1 + metrics = getattr(fwdbwd, "metrics", {}) or {} + step_loss = _extract_loss(metrics) + if step_loss is not None: + final_loss = step_loss + + if current_step % 10 == 0 or current_step == total_steps: + logger.info( + "Epoch %s/%s Step %s/%s Loss %s", + epoch, + config.epochs, + current_step, + total_steps, + f"{step_loss:.4f}" if step_loss is not None else "n/a", + ) + + if current_step == 0: + return TinkerTrainingResult( + success=False, + error="No valid training examples were found after parsing train.jsonl", + ) + + save_name = config.dataset_name or f"compression-{int(time.time())}" + save_response = training_client.save_state(save_name).result() + checkpoint_path = str(getattr(save_response, "path", "")) + info = training_client.get_info() + training_run_id = str(getattr(info, "model_id", "unknown")) + + config.output_dir.mkdir(parents=True, exist_ok=True) + metadata_path = config.output_dir / "tinker_run.json" + metadata_path.write_text( + json.dumps( + { + "training_run_id": training_run_id, + "checkpoint_path": checkpoint_path, + "model": config.model, + "epochs": config.epochs, + "batch_size": config.batch_size, + "learning_rate": config.learning_rate, + "lora_rank": config.lora.rank, + "final_loss": final_loss, + "sdk_mode": "service_client", + }, + indent=2, + ), + encoding="utf-8", + ) + + return TinkerTrainingResult( + success=True, + job_id=training_run_id, + adapter_path=config.output_dir, + final_loss=final_loss, + total_epochs=config.epochs, + metrics={"checkpoint_path": checkpoint_path, "metadata_path": str(metadata_path)}, + ) + + def load_config_from_yaml(config_path: Path) -> TinkerTrainingConfig: """ Load Tinker training config from YAML file. diff --git a/tests/test_train_tinker.py b/tests/test_train_tinker.py index ece8cdc..2d90913 100644 --- a/tests/test_train_tinker.py +++ b/tests/test_train_tinker.py @@ -1,3 +1,5 @@ +import sys +import types from pathlib import Path import pytest @@ -17,6 +19,12 @@ def __init__(self, api_key: str | None = None): def is_available(self) -> bool: return True + def has_legacy_client_api(self) -> bool: + return True + + def has_service_client_api(self) -> bool: + return False + def upload_dataset(self, data_dir: Path, name: str) -> str: return "dataset-123" @@ -48,3 +56,229 @@ def test_train_on_tinker_uses_api_key(tmp_path: Path, monkeypatch: pytest.Monkey assert result.success is True assert FakeClient.last_api_key == "test-key" + + +def test_train_on_tinker_supports_service_client_sdk( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + return _Future(types.SimpleNamespace(path="tinker://run-123/weights/final")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + init_kwargs = None + + def __init__(self, **kwargs): # noqa: ARG002 + _ServiceClient.init_kwargs = kwargs + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=tmp_path / "out", + wait_for_completion=True, + epochs=1, + batch_size=1, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is True + assert result.job_id == "run-123" + assert _ServiceClient.init_kwargs is not None + assert _ServiceClient.init_kwargs.get("api_key") == "test-key" + + +def test_train_on_tinker_service_client_errors_are_captured( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + '{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n', + encoding="utf-8", + ) + + class _ServiceModeClient: + def __init__(self, api_key: str | None = None): + self.api_key = api_key or "" + + @property + def is_available(self) -> bool: + return True + + def has_legacy_client_api(self) -> bool: + return False + + def has_service_client_api(self) -> bool: + return True + + config = train_tinker.TinkerTrainingConfig(dataset_path=data_dir) + monkeypatch.setattr(train_tinker, "TinkerClient", _ServiceModeClient) + + def _raise_service_error(config, api_key): # noqa: ARG001 + raise RuntimeError("service path exploded") + + monkeypatch.setattr(train_tinker, "_train_with_service_client_sdk", _raise_service_error) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is False + assert result.error is not None + assert "service path exploded" in result.error + + +def test_train_on_tinker_service_client_rejects_no_valid_examples( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + data_dir = tmp_path / "training" + data_dir.mkdir() + (data_dir / "train.jsonl").write_text( + 'not-json\n{"foo":"bar"}\n', + encoding="utf-8", + ) + + class _Future: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + class _Tokenizer: + def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: + tokens = [1 for _ in text] + if add_special_tokens: + return [0] + tokens + return tokens + + class _TrainingClient: + save_called = False + + def get_tokenizer(self): + return _Tokenizer() + + def forward_backward(self, data, loss_fn): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[])) + + def optim_step(self, adam_params): # noqa: ARG002 + return _Future(types.SimpleNamespace(metrics={})) + + def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002 + _TrainingClient.save_called = True + return _Future(types.SimpleNamespace(path="tinker://run-123/weights/final")) + + def get_info(self): + return types.SimpleNamespace(model_id="run-123") + + class _ServiceClient: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002 + return _TrainingClient() + + class _Datum: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _ModelInput: + @classmethod + def from_ints(cls, tokens): # noqa: ARG003 + return cls() + + class _TensorData: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + class _AdamParams: + def __init__(self, **kwargs): # noqa: ARG002 + pass + + fake_tinker = types.ModuleType("tinker") + fake_tinker.Client = None + fake_tinker.ServiceClient = _ServiceClient + fake_tinker.Datum = _Datum + fake_tinker.ModelInput = _ModelInput + fake_tinker.TensorData = _TensorData + fake_tinker.AdamParams = _AdamParams + + monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True) + monkeypatch.setitem(sys.modules, "tinker", fake_tinker) + + config = train_tinker.TinkerTrainingConfig( + dataset_path=data_dir, + output_dir=tmp_path / "out", + wait_for_completion=True, + epochs=1, + batch_size=1, + ) + + result = train_tinker.train_on_tinker(config, api_key="test-key") + + assert result.success is False + assert result.error is not None + assert "No valid training examples" in result.error + assert _TrainingClient.save_called is False