-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Record] 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 — val_bpb 1.0889 #1445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,89 @@ | ||||||||
| ## Record: 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 (val_bpb: 1.0889) | ||||||||
|
|
||||||||
| **val_bpb: 1.0889** (sliding window stride=64, 3-seed mean, std 0.0005) | **~15.89 MB** | 8xH100 SXM, 590s | ||||||||
|
|
||||||||
| ### 3-Seed Results (8×H100 80GB SXM) | ||||||||
|
|
||||||||
| | Seed | Pre-quant BPB | Sliding BPB (s64) | Artifact | | ||||||||
| |------|---------------|-------------------|----------| | ||||||||
| | 42 | 1.0950 | **1.0885** | 15,890,417 B | | ||||||||
| | 1337 | 1.0959 | **1.0894** | — | | ||||||||
| | 2024 | 1.0954 | **1.0888** | 15,895,711 B | | ||||||||
|
|
||||||||
| **Mean: 1.0889 | Std: 0.0005** | All artifacts under 16,000,000 bytes | ||||||||
|
|
||||||||
| Current merged SOTA: **1.1147** (PR #1019). Delta: **−0.0258 BPB**. | ||||||||
|
|
||||||||
| ### Key Changes | ||||||||
|
|
||||||||
| Four refinements stacked on top of PR #1334's depth recurrence architecture: | ||||||||
|
|
||||||||
| | Parameter | PR #1334 | This | Source | | ||||||||
| |-----------|----------|------|--------| | ||||||||
| | **Recurrence layers** | 4,5 (2-layer) | **3,4,5 (3-layer)** | PR #1331 | | ||||||||
| | **Weight decay** | 0.090 | **0.095** | PR #1331 | | ||||||||
| | **Matrix LR** | 0.020 | **0.022** | PR #1331 | | ||||||||
| | **EMA decay** | 0.997 | **0.9965** | PR #1421 (this author) | | ||||||||
| | **Recurrence start** | step 3000 | **step 2000** | This work | | ||||||||
| | **Warmdown fraction** | 0.667 | **0.72** | This work | | ||||||||
|
|
||||||||
| ### Why This Combination Works | ||||||||
|
|
||||||||
| 1. **3-layer recurrence (layers 3,4,5)**: Repeats 3 layers instead of 2, producing 14 virtual layers from 11 physical layers. More compute per forward pass without additional parameters. | ||||||||
|
|
||||||||
| 2. **WD=0.095 + MLR=0.022**: Higher weight decay compresses weights more aggressively, improving GPTQ quantization. Higher matrix LR compensates for the regularization. Only 134K-186K values pruned (vs 290K+ at WD=0.090). | ||||||||
|
|
||||||||
| 3. **EMA decay=0.9965**: Assigns slightly more weight to recent training steps, producing a final checkpoint that quantizes more cleanly under GPTQ int6. | ||||||||
|
|
||||||||
| 4. **Early recurrence (step 2000)**: Activating depth recurrence 1000 steps earlier gives the model more training time with 14 virtual layers, improving final quality. | ||||||||
|
|
||||||||
| 5. **Extended warmdown (72%)**: Longer learning rate decay allows weights to fully settle before GPTQ quantization, reducing the quant gap. | ||||||||
|
|
||||||||
| ### Architecture (from PR #1334) | ||||||||
|
|
||||||||
| - 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA) | ||||||||
| - **Depth recurrence**: layers 3,4,5 repeat (virtual 14 layers), activated at step 2000 | ||||||||
| - Skip gates (learnable residual gating) | ||||||||
| - Parallel residuals from layer 7 | ||||||||
| - QK-Gain 5.0 | ||||||||
| - Shared Value Embedding (dim=128, layers 9,10) | ||||||||
| - Tied embeddings, logit softcap=30.0 | ||||||||
| - SP4096 tokenizer (SentencePiece BPE) | ||||||||
|
|
||||||||
| ### Training | ||||||||
|
|
||||||||
| - FlashAttention 3 (Hopper-optimized) | ||||||||
| - Muon optimizer (matrices): lr=0.022, momentum=0.99, WD=0.095, backend_steps=5 | ||||||||
| - Adam (head): lr=0.008, fused=True | ||||||||
| - AdamW (embeddings): lr=0.6, WD=0.095, fused=True | ||||||||
| - AdamW (scalars): lr=0.02, WD=0.02, fused=True | ||||||||
| - Gradient clip: 0.3, Batch: 786,432 tokens/step, seq_len=2048 | ||||||||
| - Warmdown: 72%, **EMA decay=0.9965** | ||||||||
| - Wallclock: 590s effective (10s reserved for GPTQ) | ||||||||
|
|
||||||||
| ### Quantization | ||||||||
|
|
||||||||
| - GPTQ int6 with percdamp=0.05, 64 calibration batches | ||||||||
| - Selective pruning (~134K-186K lowest-error ±1 values) | ||||||||
| - Brotli compression | ||||||||
|
|
||||||||
| ### Run Command | ||||||||
|
|
||||||||
| ```bash | ||||||||
| SEED=42 RECUR_START_STEP=2000 WARMDOWN_FRAC=0.72 \ | ||||||||
| DATA_PATH=./data/datasets/fineweb10B_sp4096/ \ | ||||||||
| TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model \ | ||||||||
|
Comment on lines
+74
to
+75
|
||||||||
| DATA_PATH=./data/datasets/fineweb10B_sp4096/ \ | |
| TOKENIZER_PATH=./data/tokenizers/fineweb_4096_bpe.model \ | |
| DATA_DIR=./data \ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| { | ||
| "author": "Abhishek Leji", | ||
| "github_id": "X-Abhishek-X", | ||
| "name": "Record: 3-Layer Depth Recurrence + EMA 0.9965 + WD 0.095 + Early Recurrence", | ||
| "blurb": "3-layer depth recurrence (layers 3,4,5) with EMA decay 0.9965, WD=0.095, MLR=0.022, early recurrence activation (step 2000), and extended warmdown (72%). Built on PR #1334 architecture with innovations from PR #1331.", | ||
| "date": "2026-04-07T00:00:00Z", | ||
| "val_loss": 2.50552189, | ||
| "val_bpb": 1.08887087, | ||
| "bytes_total": 15895711 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| W0407 16:22:23.785000 48806 torch/distributed/run.py:803] | ||
| W0407 16:22:23.785000 48806 torch/distributed/run.py:803] ***************************************** | ||
| W0407 16:22:23.785000 48806 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | ||
| W0407 16:22:23.785000 48806 torch/distributed/run.py:803] ***************************************** | ||
| Hyperparameters: | ||
| adam_eps: 1e-08 | ||
| adam_wd: 0.02 | ||
| beta1: 0.9 | ||
| beta2: 0.95 | ||
| compressor: brotli | ||
| data_dir: ./data/ | ||
| datasets_dir: ./data/datasets/fineweb10B_sp4096 | ||
| distributed: True | ||
| ema_decay: 0.9965 | ||
| embed_lr: 0.6 | ||
| embed_wd: 0.095 | ||
| embedding_dim: 512 | ||
| eval_seq_len: 2048 | ||
| eval_stride: 64 | ||
| gptq_calibration_batches: 64 | ||
| gptq_enabled: True | ||
| gptq_reserve_seconds: 10.0 | ||
| grad_accum_steps: 1 | ||
| grad_clip_norm: 0.3 | ||
| head_lr: 0.008 | ||
| is_main_process: True | ||
| iterations: 20000 | ||
| ln_scale: True | ||
| local_rank: 0 | ||
| logfile: logs/bbc00e44-7393-4d92-a67c-239184601d85.txt | ||
| logit_softcap: 30.0 | ||
| matrix_lr: 0.022 | ||
| max_wallclock_seconds: 600.0 | ||
| min_lr: 0.0 | ||
| mlp_mult: 4.0 | ||
| model_dim: 512 | ||
| model_path: final_model.pt | ||
| muon_backend_steps: 5 | ||
| muon_beta2: 0.95 | ||
| muon_momentum: 0.99 | ||
| muon_momentum_warmup_start: 0.92 | ||
| muon_momentum_warmup_steps: 1500 | ||
| muon_wd: 0.095 | ||
| num_heads: 8 | ||
| num_kv_heads: 4 | ||
| num_layers: 11 | ||
| parallel_start_layer: 7 | ||
| qk_gain_init: 5.0 | ||
| quantized_model_path: final_model.int6.ptz | ||
| rank: 0 | ||
| recur_layers: 3,4,5 | ||
| recur_start_step: 2000 | ||
| rope_base: 10000.0 | ||
| rope_dims: 16 | ||
| rope_train_seq_len: 2048 | ||
| run_id: bbc00e44-7393-4d92-a67c-239184601d85 | ||
| scalar_lr: 0.02 | ||
| seed: 42 | ||
| skip_gates_enabled: True | ||
| sliding_window_enabled: True | ||
| tie_embeddings: True | ||
| tied_embed_init_std: 0.005 | ||
| tied_embed_lr: 0.03 | ||
| tokenizer_path: ./data/tokenizers/fineweb_4096_bpe.model | ||
| train_batch_tokens: 786432 | ||
| train_files: ./data/datasets/fineweb10B_sp4096/fineweb_train_*.bin | ||
| train_log_every: 500 | ||
| train_seq_len: 2048 | ||
| ttt_batch_seqs: 32 | ||
| ttt_chunk_tokens: 32768 | ||
| ttt_enabled: False | ||
| ttt_epochs: 3 | ||
| ttt_freeze_blocks: 0 | ||
| ttt_grad_clip: 1.0 | ||
| ttt_lr: 0.002 | ||
| ttt_momentum: 0.9 | ||
| val_batch_tokens: 524288 | ||
| val_files: ./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin | ||
| val_loss_every: 4000 | ||
| ve_dim: 128 | ||
| ve_enabled: True | ||
| ve_layers: 9,10 | ||
| vocab_size: 4096 | ||
| warmdown_frac: 0.72 | ||
| warmup_steps: 20 | ||
| world_size: 8 | ||
| xsa_last_n: 11 | ||
| train_shards: 143 | ||
| val_tokens: 45508608 | ||
| model_params:34401372 | ||
| gptq:reserving 10s, effective=590000ms | ||
| warmup_step: 1/20 | ||
| warmup_step: 2/20 | ||
| warmup_step: 3/20 | ||
| warmup_step: 4/20 | ||
| warmup_step: 5/20 | ||
| warmup_step: 6/20 | ||
| warmup_step: 10/20 | ||
| warmup_step: 20/20 | ||
| 0/20000 val_loss: 8.3187 val_bpb: 3.6152 | ||
| 1/20000 train_loss: 8.3178 train_time: 0.0m tok/s: 8488752 | ||
| 2/20000 train_loss: 12.0820 train_time: 0.0m tok/s: 8383316 | ||
| 3/20000 train_loss: 10.6643 train_time: 0.0m tok/s: 8277075 | ||
| 4/20000 train_loss: 8.9470 train_time: 0.0m tok/s: 8230819 | ||
| 5/20000 train_loss: 7.7086 train_time: 0.0m tok/s: 8197168 | ||
| 500/20000 train_loss: 2.9983 train_time: 0.8m tok/s: 7974261 | ||
| 1000/20000 train_loss: 2.9965 train_time: 1.6m tok/s: 7956910 | ||
| 1500/20000 train_loss: 2.9090 train_time: 2.5m tok/s: 7950269 | ||
| 2000/20000 train_loss: 2.7506 train_time: 3.3m tok/s: 7947268 | ||
| recurrence:activated at step 2000, virtual_layers=[0, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 8, 9, 10] | ||
| 2500/20000 train_loss: 2.7363 train_time: 4.5m tok/s: 7250427 | ||
| 3000/20000 train_loss: 2.6969 train_time: 5.5m tok/s: 7096412 | ||
| 3500/20000 train_loss: 2.6178 train_time: 6.6m tok/s: 6990908 | ||
| 4000/20000 train_loss: 2.6167 train_time: 7.6m tok/s: 6912506 | ||
| 4000/20000 val_loss: 2.6187 val_bpb: 1.1381 | ||
| 4500/20000 train_loss: 2.5537 train_time: 8.6m tok/s: 6854112 | ||
| 5000/20000 train_loss: 2.5098 train_time: 9.6m tok/s: 6808010 | ||
| 5102/20000 val_loss: 2.5227 val_bpb: 1.0963 | ||
| stopping_early: wallclock_cap train_time: 590087ms step: 5102/20000 | ||
| peak memory allocated: 32292 MiB reserved: 32332 MiB | ||
| ema:applying EMA weights | ||
| pre-quantization post-ema val_loss:2.51973533 val_bpb:1.09504795 eval_time:2150ms | ||
| Serialized model: 132406149 bytes | ||
| Code size: 83569 bytes | ||
| GPTQ:collecting Hessians from calibration data... | ||
| GPTQ:collected 66 Hessians in 10.4s | ||
| GPTQ quantization: 66 layers with full GPTQ, 0 fallback to clip-search | ||
| selective_prune: unpruned=15.89MB target=16.0MB | ||
| selective_prune: already fits, no pruning needed | ||
| Serialized model int6+brotli: 15806848 bytes | ||
| Total submission size int6+brotli: 15890417 bytes | ||
| final_int6_roundtrip val_loss:2.54747195 val_bpb:1.10710196 eval_time:8613ms | ||
| final_int6_sliding_window val_loss:2.50459726 val_bpb:1.08846912 eval_time:81043ms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README claims selective pruning of "~134K-186K" ±1 values, but the included logs show
selective_prune: already fits, no pruning neededfor all three seeds (42/1337/2024). Please update the pruning claims (lines 34 and 67) to match what actually happened in these runs, or point to the specific seed/config where pruning occurred.