Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ cloud:
learning_rate: 2.0e-4
warmup_ratio: 0.03
max_seq_length: 2048
log_interval_steps: 10
checkpoint_interval_steps: 250
eval_interval_steps: 0 # 0 disables mid-epoch evals
eval_at_epoch_end: true
checkpoint_ttl_seconds: null
resume_from_checkpoint: true

# === OUTPUT PATHS ===
output:
Expand Down
84 changes: 51 additions & 33 deletions docs/SETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,50 +122,66 @@ TINKER_API_KEY=tk_...

### 2.2 Training with Tinker

```python
# src/training/train_tinker.py
import os
from tinker import ServiceClient
from pathlib import Path

from src.training.train_tinker import TinkerTrainingConfig, run_training_loop, write_run_metadata

# Initialize SDK clients
service_client = ServiceClient(api_key=os.environ["TINKER_API_KEY"])
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen3-8B",
)

# Configure training
config = TinkerTrainingConfig(
base_model="Qwen/Qwen3-8B",
epochs=3,
)

# Run training loop and persist metadata
metadata = run_training_loop(training_client, config)
metadata_path = write_run_metadata(metadata, output_dir=Path("models/adapters/tinker"))
print(f"Run ID: {metadata.run_id}")
print(f"Run metadata: {metadata_path}")
Use the CLI wrapper for end-to-end training:

```bash
python scripts/train_tinker.py \
--config configs/training.yaml \
--output models/adapters/tinker
```

ServiceClient mode writes artifacts directly in `--output`:

- `tinker_run.json` (resume state + checkpoint history)
- `run.json` (MLflow-compatible run metadata)
- `train.log` (train/val lines parsed by `scripts/mlflow_logger.py`)
- `metrics.jsonl` (structured metric events)

Auto-resume is enabled by default and uses `latest_checkpoint_path` from
`tinker_run.json`.

### 2.3 Tinker CLI Workflow

```bash
# Start training (records run metadata under models/adapters/tinker/runs)
# Start training with config defaults
python scripts/train_tinker.py \
--config configs/training.yaml \
--output models/adapters/tinker

# Check status
# Customize training telemetry/checkpoints
python scripts/train_tinker.py \
--status <run-id> \
--output models/adapters/tinker
--config configs/training.yaml \
--output models/adapters/tinker \
--log-interval-steps 10 \
--checkpoint-interval-steps 250 \
--eval-interval-steps 100

# Inspect run metadata
cat models/adapters/tinker/runs/<run-id>.json
# Disable auto-resume for a fresh run in same output dir
python scripts/train_tinker.py \
--config configs/training.yaml \
--output models/adapters/tinker \
--no-resume

# Check status (legacy Client API mode)
python scripts/train_tinker.py --status <job-id>

# Inspect service-mode run artifacts
cat models/adapters/tinker/tinker_run.json
cat models/adapters/tinker/run.json

# Send artifacts to MLflow/DagsHub
python scripts/mlflow_logger.py \
--run-dir models/adapters/tinker \
--experiment-name "compression-v2" \
--dagshub-owner Sudhendra \
--dagshub-repo compression-layer
```

Additional useful flags:

- `--no-eval-at-epoch-end`
- `--checkpoint-ttl-seconds <seconds>`

### 2.4 Cost Estimation

| Model | Per 1M Tokens | 10K pairs (~5M tok) | 50K pairs (~25M tok) |
Expand Down Expand Up @@ -258,8 +274,10 @@ python scripts/validate_batch.py --input data/seed/pairs.jsonl
### Tinker job failed
- Check dataset format (JSONL with `text` or `messages` field)
- Verify API key in `.env` or shell: `TINKER_API_KEY`
- Inspect run metadata: `models/adapters/tinker/runs/<run-id>.json`
- Re-run status: `python scripts/train_tinker.py --status <run-id> --output models/adapters/tinker`
- Inspect run metadata: `models/adapters/tinker/tinker_run.json`
- Inspect MLflow metadata/logs: `models/adapters/tinker/run.json`, `models/adapters/tinker/train.log`
- Re-run same command to resume from latest checkpoint (default)
- Add `--no-resume` to force a fresh run in an existing output directory

### Slow local inference
- Ensure using 4-bit model: `*-4bit`
Expand Down
45 changes: 45 additions & 0 deletions scripts/train_tinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,40 @@ def parse_args() -> argparse.Namespace:
default=128,
help="LoRA alpha (default: 128)",
)
parser.add_argument(
"--log-interval-steps",
type=int,
default=10,
help="How often to log train metrics (default: 10)",
)
parser.add_argument(
"--checkpoint-interval-steps",
type=int,
default=250,
help="How often to save resumable checkpoints (default: 250)",
)
parser.add_argument(
"--eval-interval-steps",
type=int,
default=0,
help="How often to run validation during an epoch (0 disables, default: 0)",
)
parser.add_argument(
"--no-eval-at-epoch-end",
action="store_true",
help="Disable validation pass at epoch end",
)
parser.add_argument(
"--checkpoint-ttl-seconds",
type=int,
default=None,
help="Optional TTL for saved checkpoints in seconds",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Disable auto-resume from latest checkpoint in output directory",
)

# Job control
parser.add_argument(
Expand Down Expand Up @@ -207,6 +241,11 @@ def print_config(config: TinkerTrainingConfig) -> None:
table.add_row("Learning Rate", f"{config.learning_rate:.0e}")
table.add_row("LoRA Rank", str(config.lora.rank))
table.add_row("LoRA Alpha", str(config.lora.alpha))
table.add_row("Log Interval", str(config.log_interval_steps))
table.add_row("Checkpoint Every", str(config.checkpoint_interval_steps))
table.add_row("Eval Every", str(config.eval_interval_steps))
table.add_row("Eval Epoch End", str(config.eval_at_epoch_end))
table.add_row("Auto Resume", str(config.resume_from_checkpoint))

console.print(table)

Expand Down Expand Up @@ -240,6 +279,12 @@ def main() -> int:
config.lora = TinkerLoRAConfig(rank=args.lora_rank, alpha=args.lora_alpha)
config.wait_for_completion = not args.no_wait
config.dataset_name = args.dataset_name
config.log_interval_steps = args.log_interval_steps
config.checkpoint_interval_steps = args.checkpoint_interval_steps
config.eval_interval_steps = args.eval_interval_steps
config.eval_at_epoch_end = not args.no_eval_at_epoch_end
config.checkpoint_ttl_seconds = args.checkpoint_ttl_seconds
config.resume_from_checkpoint = not args.no_resume

# Handle status check
if args.status:
Expand Down
Loading