Skip to content

Commit 01724f3

Browse files
anthony-maioclaude
andcommitted
Fix TTT: use eval_model (int6 artifact) not base_model, honor EVAL_STRIDE
P1: TTT was running on the pre-quantization base_model instead of the int6 round-tripped eval_model. This overstated TTT gains since the artifact model has quantization noise. Now matches PR openai#473's approach. P2: TTT hardcoded stride=64 instead of using args.eval_stride. Now honors the configured stride so TTT results stay consistent with the sliding window eval path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c3442df commit 01724f3

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

  • records/track_10min_16mb/2026-03-23_Reproduce414_LegalTTT

records/track_10min_16mb/2026-03-23_Reproduce414_LegalTTT/train_gpt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,9 +1564,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
15641564
torch.cuda.synchronize()
15651565
t_ttt = time.perf_counter()
15661566
ttt_val_loss, ttt_val_bpb = eval_val_sliding_ttt(
1567-
args, base_model, rank, world_size, device,
1567+
args, eval_model, rank, world_size, device,
15681568
val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
1569-
stride=64, batch_seqs=args.ttt_batch_seqs, log0=log0,
1569+
stride=args.eval_stride if args.eval_stride > 0 else 64,
1570+
batch_seqs=args.ttt_batch_seqs, log0=log0,
15701571
)
15711572
torch.cuda.synchronize()
15721573
log0(f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} "

0 commit comments

Comments
 (0)