Skip to content

Commit 81e9896

Browse files
committed
Clean base: QK_GAIN=5.0, WARMDOWN=4000, disable multi-res, revert BIGRAM
- QK_GAIN_INIT: 1.5 -> 5.0 (matches openai#1296 proven config) - WARMDOWN_ITERS: already 4000 (matches openai#1290 run command) - MULTIRES_ENABLED: 1 -> 0 (multi-res failed: only 1.13x speedup) - BIGRAM: revert to 2048x128 (3072x112 exceeded 16MB artifact limit)
1 parent 75d76a8 commit 81e9896

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train_gpt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Hyperparameters:
4242
train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048))
4343
eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048))
4444
max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
45-
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5))
45+
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.0))
4646
vocab_size = int(os.environ.get("VOCAB_SIZE", 1024))
4747
num_layers = int(os.environ.get("NUM_LAYERS", 11))
4848
num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4))
@@ -78,8 +78,8 @@ class Hyperparameters:
7878
muon_wd = float(os.environ.get("MUON_WD", 0.04))
7979
adam_wd = float(os.environ.get("ADAM_WD", 0.04))
8080
qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0")))
81-
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 3072))
82-
bigram_dim = int(os.environ.get("BIGRAM_DIM", 112))
81+
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048))
82+
bigram_dim = int(os.environ.get("BIGRAM_DIM", 128))
8383
trigram_enabled = bool(int(os.environ.get("TRIGRAM", "0"))) # TrigramHash (off by default, risky)
8484
xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) # XSA on ALL layers (our novel contribution)
8585
rope_dims = int(os.environ.get("ROPE_DIMS", 16))
@@ -99,7 +99,7 @@ class Hyperparameters:
9999
recur_start_step = int(os.environ.get("RECUR_START_STEP", 3000))
100100
parallel_start_layer = int(os.environ.get("PARALLEL_START_LAYER", "7"))
101101
# Multi-resolution sequence length
102-
multires_enabled = bool(int(os.environ.get("MULTIRES_ENABLED", "1")))
102+
multires_enabled = bool(int(os.environ.get("MULTIRES_ENABLED", "0")))
103103
multires_short_seq = int(os.environ.get("MULTIRES_SHORT_SEQ", 512))
104104
multires_switch_frac = float(os.environ.get("MULTIRES_SWITCH_FRAC", 0.70)) # fraction of wallclock at short seq
105105
# TTT (test-time training)

0 commit comments

Comments
 (0)