Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 213 additions & 4 deletions src/training/train_tinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
result = train_on_tinker(config)
"""

import json
import logging
import math
import os
import time
from dataclasses import dataclass, field
Expand All @@ -25,6 +27,8 @@

import yaml

from src.training.tinker_data import render_chat_example

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -150,6 +154,20 @@ def _check_sdk(self) -> bool:
except ImportError:
return False

def _load_tinker_module(self):
"""Import and return the installed tinker module."""
return __import__("tinker")

def has_legacy_client_api(self) -> bool:
"""Whether SDK exposes legacy Client(api_key=...) API."""
tinker = self._load_tinker_module()
return callable(getattr(tinker, "Client", None))

def has_service_client_api(self) -> bool:
"""Whether SDK exposes ServiceClient API."""
tinker = self._load_tinker_module()
return callable(getattr(tinker, "ServiceClient", None))

@property
def is_available(self) -> bool:
"""Check if Tinker is properly configured."""
Expand All @@ -169,7 +187,14 @@ def upload_dataset(self, data_dir: Path, name: str) -> str:
if not self._sdk_available:
raise RuntimeError("Tinker SDK not available. Install with: pip install tinker")

tinker = __import__("tinker") # Dynamic import to avoid static analysis errors
tinker = self._load_tinker_module() # Dynamic import to avoid static analysis errors

if not self.has_legacy_client_api():
raise RuntimeError(
"Installed tinker SDK does not expose Client API required for job-style upload. "
"Use scripts/train_tinker.py in ServiceClient mode (default in this repo) or "
"install the legacy SDK variant."
)

client = tinker.Client(api_key=self.api_key)

Expand Down Expand Up @@ -197,7 +222,12 @@ def start_training(
if not self._sdk_available:
raise RuntimeError("Tinker SDK not available. Install with: pip install tinker")

tinker = __import__("tinker")
tinker = self._load_tinker_module()

if not self.has_legacy_client_api():
raise RuntimeError(
"Installed tinker SDK does not expose Client API required for job-style training."
)

client = tinker.Client(api_key=self.api_key)

Expand Down Expand Up @@ -228,7 +258,12 @@ def get_job_status(self, job_id: str) -> TinkerJobStatus:
if not self._sdk_available:
raise RuntimeError("Tinker SDK not available")

tinker = __import__("tinker")
tinker = self._load_tinker_module()

if not self.has_legacy_client_api():
raise RuntimeError(
"Installed tinker SDK does not expose Client API required for job status polling."
)

client = tinker.Client(api_key=self.api_key)
job = client.get_job(job_id)
Expand Down Expand Up @@ -290,7 +325,12 @@ def download_adapter(self, job_id: str, output_path: Path) -> Path:
if not self._sdk_available:
raise RuntimeError("Tinker SDK not available")

tinker = __import__("tinker")
tinker = self._load_tinker_module()

if not self.has_legacy_client_api():
raise RuntimeError(
"Installed tinker SDK does not expose Client API required for adapter download."
)

client = tinker.Client(api_key=self.api_key)
job = client.get_job(job_id)
Expand Down Expand Up @@ -345,6 +385,14 @@ def train_on_tinker(
error=f"Training file not found: {train_file}",
)

if client.has_service_client_api() and not client.has_legacy_client_api():
if not config.wait_for_completion:
return TinkerTrainingResult(
success=False,
error="--no-wait is not supported with the installed ServiceClient SDK mode.",
)
return _train_with_service_client_sdk(config, api_key=client.api_key)

try:
# Generate dataset name if not provided
dataset_name = config.dataset_name or f"compression-{int(time.time())}"
Expand Down Expand Up @@ -399,6 +447,167 @@ def train_on_tinker(
)


def _to_sdk_datum(local_datum: Any, tinker_module: Any) -> Any:
"""Convert local tinker_data datum to SDK Datum."""
target_tokens = [int(token) for token in local_datum.loss_fn_inputs["target_tokens"]]
weights = [float(weight) for weight in local_datum.loss_fn_inputs["weights"]]

return tinker_module.Datum(
model_input=tinker_module.ModelInput.from_ints(local_datum.model_input.tokens),
loss_fn_inputs={
"target_tokens": tinker_module.TensorData(data=target_tokens, dtype="int64"),
"weights": tinker_module.TensorData(data=weights, dtype="float32"),
},
)


def _iter_training_batches(
train_file: Path,
tokenizer: Any,
tinker_module: Any,
batch_size: int,
) -> Any:
"""Yield SDK-ready training batches from chat JSONL data."""
batch: list[Any] = []

with open(train_file, encoding="utf-8") as f:
for line_number, line in enumerate(f, start=1):
if not line.strip():
continue

try:
record = json.loads(line)
except json.JSONDecodeError:
logger.warning("Skipping invalid JSON at line %s", line_number)
continue

messages = record.get("messages") if isinstance(record, dict) else None
if not isinstance(messages, list) or not messages:
logger.warning("Skipping malformed record at line %s", line_number)
continue

local_datum = render_chat_example(messages, tokenizer)
batch.append(_to_sdk_datum(local_datum, tinker_module))

if len(batch) >= batch_size:
yield batch
batch = []

if batch:
yield batch


def _extract_loss(metrics: dict[str, float]) -> float | None:
"""Extract best-effort scalar loss value from metrics dict."""
for key, value in metrics.items():
if "loss" in key.lower():
return float(value)

if metrics:
first_value = next(iter(metrics.values()))
return float(first_value)

return None


def _train_with_service_client_sdk(
config: TinkerTrainingConfig,
api_key: str,
) -> TinkerTrainingResult:
"""Train via the modern Tinker ServiceClient/TrainingClient SDK flow."""
tinker = __import__("tinker")

service_client = tinker.ServiceClient(api_key=api_key)
training_client = service_client.create_lora_training_client(
base_model=config.model,
rank=config.lora.rank,
)
tokenizer = training_client.get_tokenizer()

train_file = config.dataset_path / "train.jsonl"
with open(train_file, encoding="utf-8") as train_stream:
total_examples = sum(1 for line in train_stream if line.strip())
if total_examples == 0:
return TinkerTrainingResult(
success=False, error=f"No training examples found in {train_file}"
)

steps_per_epoch = math.ceil(total_examples / config.batch_size)
total_steps = steps_per_epoch * config.epochs
current_step = 0
final_loss: float | None = None

logger.info(
"Starting ServiceClient training: %s examples, %s epochs, %s total steps",
total_examples,
config.epochs,
total_steps,
)

for epoch in range(1, config.epochs + 1):
for batch in _iter_training_batches(
train_file,
tokenizer,
tinker,
config.batch_size,
):
fwdbwd = training_client.forward_backward(batch, "cross_entropy").result()
training_client.optim_step(
tinker.AdamParams(learning_rate=config.learning_rate),
).result()

current_step += 1
metrics = getattr(fwdbwd, "metrics", {}) or {}
step_loss = _extract_loss(metrics)
if step_loss is not None:
final_loss = step_loss

if current_step % 10 == 0 or current_step == total_steps:
logger.info(
"Epoch %s/%s Step %s/%s Loss %s",
epoch,
config.epochs,
current_step,
total_steps,
f"{step_loss:.4f}" if step_loss is not None else "n/a",
)

save_name = config.dataset_name or f"compression-{int(time.time())}"
save_response = training_client.save_state(save_name).result()
checkpoint_path = str(getattr(save_response, "path", ""))
info = training_client.get_info()
training_run_id = str(getattr(info, "model_id", "unknown"))

config.output_dir.mkdir(parents=True, exist_ok=True)
metadata_path = config.output_dir / "tinker_run.json"
metadata_path.write_text(
json.dumps(
{
"training_run_id": training_run_id,
"checkpoint_path": checkpoint_path,
"model": config.model,
"epochs": config.epochs,
"batch_size": config.batch_size,
"learning_rate": config.learning_rate,
"lora_rank": config.lora.rank,
"final_loss": final_loss,
"sdk_mode": "service_client",
},
indent=2,
),
encoding="utf-8",
)

return TinkerTrainingResult(
success=True,
job_id=training_run_id,
adapter_path=config.output_dir,
final_loss=final_loss,
total_epochs=config.epochs,
metrics={"checkpoint_path": checkpoint_path, "metadata_path": str(metadata_path)},
)


def load_config_from_yaml(config_path: Path) -> TinkerTrainingConfig:
"""
Load Tinker training config from YAML file.
Expand Down
101 changes: 101 additions & 0 deletions tests/test_train_tinker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
import types
from pathlib import Path

import pytest
Expand All @@ -17,6 +19,12 @@ def __init__(self, api_key: str | None = None):
def is_available(self) -> bool:
return True

def has_legacy_client_api(self) -> bool:
return True

def has_service_client_api(self) -> bool:
return False

def upload_dataset(self, data_dir: Path, name: str) -> str:
return "dataset-123"

Expand Down Expand Up @@ -48,3 +56,96 @@ def test_train_on_tinker_uses_api_key(tmp_path: Path, monkeypatch: pytest.Monkey

assert result.success is True
assert FakeClient.last_api_key == "test-key"


def test_train_on_tinker_supports_service_client_sdk(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
data_dir = tmp_path / "training"
data_dir.mkdir()
(data_dir / "train.jsonl").write_text(
'{"messages":[{"role":"system","content":"s"},{"role":"user","content":"u"},{"role":"assistant","content":"a"}]}\n',
encoding="utf-8",
)

class _Future:
def __init__(self, value):
self._value = value

def result(self):
return self._value

class _Tokenizer:
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
tokens = [1 for _ in text]
if add_special_tokens:
return [0] + tokens
return tokens

class _TrainingClient:
def get_tokenizer(self):
return _Tokenizer()

def forward_backward(self, data, loss_fn): # noqa: ARG002
return _Future(types.SimpleNamespace(metrics={"loss": 0.5}, loss_fn_outputs=[]))

def optim_step(self, adam_params): # noqa: ARG002
return _Future(types.SimpleNamespace(metrics={}))

def save_state(self, name: str, ttl_seconds=None): # noqa: ARG002
return _Future(types.SimpleNamespace(path="tinker://run-123/weights/final"))

def get_info(self):
return types.SimpleNamespace(model_id="run-123")

class _ServiceClient:
init_kwargs = None

def __init__(self, **kwargs): # noqa: ARG002
_ServiceClient.init_kwargs = kwargs

def create_lora_training_client(self, base_model: str, rank: int): # noqa: ARG002
return _TrainingClient()

class _Datum:
def __init__(self, **kwargs): # noqa: ARG002
pass

class _ModelInput:
@classmethod
def from_ints(cls, tokens): # noqa: ARG003
return cls()

class _TensorData:
def __init__(self, **kwargs): # noqa: ARG002
pass

class _AdamParams:
def __init__(self, **kwargs): # noqa: ARG002
pass

fake_tinker = types.ModuleType("tinker")
fake_tinker.Client = None
fake_tinker.ServiceClient = _ServiceClient
fake_tinker.Datum = _Datum
fake_tinker.ModelInput = _ModelInput
fake_tinker.TensorData = _TensorData
fake_tinker.AdamParams = _AdamParams

monkeypatch.setattr(train_tinker.TinkerClient, "_check_sdk", lambda self: True)
monkeypatch.setitem(sys.modules, "tinker", fake_tinker)

config = train_tinker.TinkerTrainingConfig(
dataset_path=data_dir,
output_dir=tmp_path / "out",
wait_for_completion=True,
epochs=1,
batch_size=1,
)

result = train_tinker.train_on_tinker(config, api_key="test-key")

assert result.success is True
assert result.job_id == "run-123"
assert _ServiceClient.init_kwargs is not None
assert _ServiceClient.init_kwargs.get("api_key") == "test-key"
Loading