Skip to content

Commit a0e0841

Browse files
Octavianclaude
andcommitted
Add AdamW TTT option — PR #462 shows 5x better TTT gain vs SGD
PR #462 achieves 1.0672 BPB. Their key finding: switching TTT optimizer from SGD to AdamW gives 5x more improvement (0.053 vs 0.011 BPB). AdamW's per-parameter adaptive LR handles the heterogeneous update needs of attention/MLP/control params naturally — exactly what we were trying to do manually. New defaults (matching PR #462 recipe): TTT_OPTIMIZER=adamw (was implicit SGD) TTT_LR=0.0005 (was 0.002) TTT_EPOCHS=10 (was 3) TTT_FREEZE_BLOCKS=0 (was 2) Fallback to SGD: TTT_OPTIMIZER=sgd TTT_LR=0.002 TTT_EPOCHS=3 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8b22b10 commit a0e0841

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

train_gpt_v7.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ class Hyperparameters:
8686
ve_layers = os.environ.get("VE_LAYERS", "9,10")
8787
# Legal score-first TTT eval (PR #461 recipe)
8888
ttt_eval_enabled = bool(int(os.environ.get("TTT_EVAL_ENABLED", "1")))
89-
ttt_lr = float(os.environ.get("TTT_LR", 0.002))
90-
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3))
89+
ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "sgd" or "adamw" (PR #462: AdamW 5x better)
90+
ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) # 0.0005 for AdamW (PR #462), 0.002 for SGD
91+
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 10)) # 10 for AdamW (PR #462), 3 for SGD
9192
ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768))
92-
ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2))
93+
ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # PR #462 freezes 0
9394
ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9))
9495
ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32))
9596
ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0))
@@ -928,7 +929,12 @@ def eval_val_sliding_ttt(
928929
else:
929930
p.requires_grad_(True); ttt_params.append(p)
930931
log0(f"ttt_sliding:unfrozen={sum(p.numel() for p in ttt_params)} freeze_embed={args.ttt_freeze_embed}")
931-
optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum)
932+
if args.ttt_optimizer == "adamw":
933+
optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0)
934+
log0(f"ttt_sliding:optimizer=AdamW lr={args.ttt_lr}")
935+
else:
936+
optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum)
937+
log0(f"ttt_sliding:optimizer=SGD lr={args.ttt_lr} momentum={args.ttt_momentum}")
932938
# TTT-EMA: maintain smoothed weights for scoring
933939
ema_decay = args.ttt_ema_decay
934940
ema_state = None

0 commit comments

Comments
 (0)