Skip to content

Commit c1fef72

Browse files
committed
Star-ReLU MLP (learnable scale+bias) + run_swiglu.sh for openai#462 architecture + cosine TTT
1 parent a2ed631 commit c1fef72

2 files changed

Lines changed: 63 additions & 3 deletions

File tree

records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT/train_gpt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,17 +880,20 @@ def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor
880880

881881

882882
class MLP(nn.Module):
883-
# relu^2 MLP from the original modded-nanogpt setup
883+
# Star-ReLU: relu(x)^2 with learnable per-channel scale and bias (PR #462)
884884
def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0):
885885
super().__init__()
886886
hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim
887887
self.fc = CastedLinear(dim, hidden, bias=False)
888888
self.proj = CastedLinear(hidden, dim, bias=False)
889889
self.proj._zero_init = True
890+
self.scale = nn.Parameter(torch.ones(hidden, dtype=torch.float32))
891+
self.bias = nn.Parameter(torch.zeros(hidden, dtype=torch.float32))
890892

891893
def forward(self, x: Tensor) -> Tensor:
892-
x = torch.relu(self.fc(x))
893-
return self.proj(x.square())
894+
activated = torch.relu(self.fc(x)).square()
895+
activated = activated * self.scale.to(dtype=activated.dtype) + self.bias.to(dtype=activated.dtype)
896+
return self.proj(activated)
894897

895898

896899
class Block(nn.Module):

run_swiglu.sh

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/bin/bash
2+
# RUN_SWIGLU: Star-ReLU architecture (#462) + our cosine per-layer TTT
3+
#
4+
# Architecture from PR #462 (JoeProAI):
5+
# Star-ReLU MLP with learnable scale+bias
6+
# 8 KV heads (full MHA, not GQA)
7+
# MLP hidden=1792 (wider)
8+
# BigramHash 8192 buckets
9+
# XSA4 enabled, EMA decay=0.9985, warmdown=6000
10+
#
11+
# TTT schedule (ours):
12+
# Cosine lr decay, per-layer lr, 50 epochs
13+
#
14+
# Target: sub-1.05
15+
16+
set -e
17+
cd /workspace/parameter-golf
18+
git fetch origin && git checkout swiglu-cosine-ttt && git reset --hard origin/swiglu-cosine-ttt
19+
20+
# Architecture (from #462)
21+
export TRAIN_SEQ_LEN=2048 EVAL_SEQ_LEN=2048 UNET_SKIPS=1
22+
export ROPE_DIMS=16 LN_SCALE=1 ROPE_BASE=10000
23+
export EVAL_STRIDE=64 DOC_ISOLATED_EVAL=0
24+
export NUM_KV_HEADS=8
25+
export MLP_HIDDEN=1792
26+
export BIGRAM_HASH_BUCKETS=8192
27+
export XSA_LAST_N=4
28+
export EMA_DECAY=0.9985
29+
export WARMDOWN_ITERS=6000
30+
export QAT=0
31+
export LATE_K_FP16=0 FP16_EMBED_EXPORT=0
32+
33+
# Cosine per-layer TTT (ours)
34+
export TTT_OPTIMIZER=adamw
35+
export TTT_LR=0.0005
36+
export TTT_EPOCHS=50
37+
export TTT_COSINE=1
38+
export TTT_PERLAYER=1
39+
export TTT_FREEZE_BLOCKS=0
40+
export TTT_BATCH_SEQS=64
41+
export TTT_MAX_STEPS=9999
42+
43+
# Seed from argument or default 1337
44+
export SEED=${1:-1337}
45+
46+
unset MLP_HIDDEN_OLD QUANT_BITS RUN_ID TIER2_MODE MLP_MULT \
47+
BACKOUT LAYER_DROP HEAD_DROP EVAL_TEMPERATURE \
48+
MLP_QUANT_BITS USE_FA3 TRAIN_BATCH_TOKENS SWA PRUNE_PCT \
49+
REPTILE_TTT VE_ENABLED TTT_TWO_PHASE
50+
51+
echo "=== SWIGLU + COSINE TTT ==="
52+
echo "SEED=$SEED KV=$NUM_KV_HEADS MLP=$MLP_HIDDEN BIGRAM=$BIGRAM_HASH_BUCKETS XSA=$XSA_LAST_N"
53+
echo "TTT: AdamW ${TTT_EPOCHS}ep cosine perlayer EMA=$EMA_DECAY WD=$WARMDOWN_ITERS"
54+
echo "==========================="
55+
56+
torchrun --standalone --nproc_per_node=8 \
57+
records/track_10min_16mb/2026-03-21_11L_XSA_EMA_TTT/train_gpt.py

0 commit comments

Comments
 (0)