Skip to content

Commit b5835f1

Browse files
Chidera Ibeclaude
authored andcommitted
Revert DyT to RMSNorm + SGD momentum SLOT (novel eval improvement)
DyT hurt (1.1307 vs 1.1263 sliding). Back to RMSNorm. Novel: replace AdamW with SGD+momentum(0.9) for SLOT optimization. PR openai#995 showed SGD+momentum beats AdamW for TTT by 0.036 BPB. Nobody has tried SGD SLOT specifically. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ee34ab6 commit b5835f1

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

train_gpt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ def __init__(
540540
dtg: bool = False,
541541
):
542542
super().__init__()
543-
self.attn_norm = DyT(dim) # DyT replaces RMSNorm (arXiv:2503.10622)
544-
self.mlp_norm = DyT(dim)
543+
self.attn_norm = RMSNorm()
544+
self.mlp_norm = RMSNorm()
545545
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
546546
self.mlp = MLP(dim, mlp_mult)
547547
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
@@ -874,7 +874,8 @@ def eval_val_sliding_slot(
874874
valid_count = mask.sum()
875875
delta = torch.zeros(bsz, 1, hidden_f.size(-1), device=device, dtype=torch.float32, requires_grad=True)
876876
logit_bias = torch.zeros(bsz, 1, proj_w.size(0), device=device, dtype=torch.float32, requires_grad=True)
877-
slot_opt = torch.optim.AdamW([delta, logit_bias], lr=slot_lr)
877+
# SGD+momentum for SLOT (inspired by PR #995: SGD beats AdamW for TTT)
878+
slot_opt = torch.optim.SGD([delta, logit_bias], lr=slot_lr, momentum=0.9)
878879
targets_flat = y_batch.reshape(-1)
879880
for _step in range(slot_steps):
880881
_lr = slot_lr_min + 0.5 * (slot_lr - slot_lr_min) * (1 + math.cos(math.pi * _step / slot_steps))

0 commit comments

Comments
 (0)