Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Non-record: Parallel Residuals + Hessian-Aware SDClip (3-seed mean 1.08354 BPB)

**val bpb: 1.08354** (3-seed mean, std=0.00050)

Not a record. This is a small 3-seed experiment over [PR #1394](https://github.com/openai/parameter-golf/pull/1394) on my runs, but not enough evidence for a statistical claim — the seed count is too small for confidence. Posting because the changes are zero-cost, reproducible, and may be useful to others trying out different techniques.

| Seed | Steps | Pre-quant BPB | Post-quant BPB | **Sliding BPB** | Artifact |
|-|-|-|-|-|-|
| 1337 | 5178 | 1.08765 | 1.09959 | **1.08301** | 15,976,275 |
| 42 | 5180 | 1.08816 | 1.10013 | **1.08363** | 15,978,439 |
| 3141 | 5182 | 1.08872 | 1.10044 | **1.08399** | 15,979,649 |
| **Mean** | | 1.08818 | 1.10005 | **1.08354** | 15,978,121 |

## Changes

Three zero-cost modifications on top of [PR #1394](https://github.com/openai/parameter-golf/pull/1394), adding zero extra parameters or bytes:

### 1. Parallel Residuals (Layers 7+)

GPT-J style parallel attention+MLP ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)) for the last 4 layers. Both attention and MLP read from the same input and their outputs are added in parallel:

```
# Parallel (layers 7-10):
x_out = x + attn_scale * Attn(norm(x)) + mlp_scale * MLP(norm(x))

# Sequential (layers 0-6, unchanged):
h = x + attn_scale * Attn(norm(x))
x_out = h + mlp_scale * MLP(norm(h))
```

I expected parallel residuals to reduce interference between attention and MLP during GPTQ calibration. Pre-quant BPB barely moved, but the quantization gap tightened across all 3 seeds, which made this the most useful change in practice.

### 2. Hessian-Aware SDClip

I used GPTQ's existing Hessian diagonal as a cheap importance signal to slightly modulate SDClip thresholds by row:

$$c_i = k \cdot \sigma_i \cdot [1 + \lambda(r_i - 1)], \quad \lambda = 0.175$$

where $\sigma_i$ is the standard deviation of row $i$ and $r_i$ is the row importance derived from Hessian-weighted magnitude. The effect is small but directionally useful at $\lambda = 0.175$; higher $\lambda$ hurt compression. I initially used $\lambda = 0.30$ but found $\lambda = 0.175$ is consistently better across seeds — both lower BPB and smaller artifact. Higher $\lambda$ reduces rounding error but increases entropy, which makes Brotli compression less effective.

### 3. Progressive Recurrence

Depth recurrence split into two phases: first loop enabled at 50% of training, second at 65%. The split points were not optimized — 50% matches the original and 65% was a single manual choice. Enabling both loops at once causes a sharper loss spike; splitting gives the model time to adapt to each additional pass before adding the next.

## Hessian Analysis (Cross-Seed)

Hessian diagnostics from 3 seeds, 67 matrices each:

- **Group-level traces** (early/loop/mid/late blocks): $r=0.997$ across seeds
- **Per-matrix traces**: $r=0.994$
- **Per-row importance**: $r=0.12$ (noise)

Importance hierarchy: early blocks (30x trace of late blocks) >> loop >> mid >> late. Per-row importance is too noisy to be a reliable signal, but group-level traces are very stable across seeds. This suggests per-group clip allocation could be a useful direction.

## Future Directions

Several ideas I'd like to explore with more compute time:

- **Per-group clip allocation**: Non-uniform $k$ across layer groups, using the stable group-level trace hierarchy as a guide.
- **Output-Hessian weighting**: Using backward-pass gradients for output-side row importance rather than input-side alone.
- **More seeds**: 3 seeds is not enough for strong statistical claims. I'd want 5+ to be confident about the gap vs PR #1394.
- **YAQA**: I like the idea of the paper ([arXiv:2505.22988](https://arxiv.org/abs/2505.22988)), but I couldn't get a working backward pass for it. I think maybe it could be adapted for the parameter golf problem in an interesting way. I also like the math in Mousse ([arXiv:2603.09697](https://arxiv.org/abs/2603.09697)), but exploiting curvature in small LMs seems tough.

## Run Command

```bash
HESSIAN_CLIP_LAMBDA=0.175 LOOP_PHASE2_AT=0.65 PARALLEL_RESIDUAL_START=7 SEED=1337 \
torchrun --standalone --nproc_per_node=8 train_gpt_sweep.py
```

## Requirements

Flash Attention 3 (Hopper) required. SP8192 BPE tokenizer trained on FineWeb 10B (sentencepiece BPE, 8192 vocab).

```bash
pip install torch --index-url https://download.pytorch.org/whl/cu130
pip install --no-cache-dir \
"https://download.pytorch.org/whl/cu130/flash_attn_3-3.0.0-cp39-abi3-manylinux_2_28_x86_64.whl"
pip install -r requirements.txt
```

## Compliance (Track A — Fixed Predictor)

- No TTT, SLOT, n-gram cache, or eval-time adaptation
- GPTQ calibration within training budget
- Standard autoregressive sliding-window eval (stride=64)


## Credits

Learned from and inspired by [PR #1394](https://github.com/openai/parameter-golf/pull/1394) (@clarkkev) — SDClip, depth recurrence, and GPTQ embedding quantization ideas. Parallel residuals from GPT-J ([Wang & Komatsuzaki, 2021](https://github.com/kingoflolz/mesh-transformer-jax)). Additional credits: PR #1204 (@msisovic, depth recurrence), PR #1217 (@bigbag, MuonEq-R), PR #1019 (@abaybektursun, previous SOTA).
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"author": "Robby Sneiderman",
"github_id": "Robby955",
"name": "Non-record: Parallel Residuals + Hessian-Aware SDClip",
"blurb": "Three zero-cost modifications on PR #1394: GPT-J parallel residuals (layers 7+), Hessian-diagonal SDClip modulation (lambda=0.175), two-phase progressive recurrence. 3-seed mean 1.08354 BPB.",
"date": "2026-04-06T00:00:00Z",
"val_loss": 2.7995,
"val_bpb": 1.08354,
"val_bpb_std": 0.00050,
"seeds": [1337, 42, 3141],
"seed_results": {
"1337": {
"val_loss": 2.79749,
"val_bpb": 1.08301,
"artifact_bytes": 15976275,
"steps": 5178
},
"42": {
"val_loss": 2.79910,
"val_bpb": 1.08363,
"artifact_bytes": 15978439,
"steps": 5180
},
"3141": {
"val_loss": 2.80002,
"val_bpb": 1.08399,
"artifact_bytes": 15979649,
"steps": 5182
}
},
"hardware": "8xH100 80GB SXM",
"bytes_total": 15978121,
"based_on": "PR #1394 (@clarkkev)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803]
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] *****************************************
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0406 05:48:52.138000 4182076 torch/distributed/run.py:803] *****************************************
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp8192
distributed: True
ema_decay: 0.997
embed_bits: 8
embed_clip_sigmas: 20.0
embed_lr: 0.6
embed_wd: 0.085
embedding_dim: 512
enable_looping_at: 0.5
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_reserve_seconds: 12.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
hessian_clip_lambda: 0.175
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/f3971278-d577-499b-8fde-755434809ba9.txt
logit_softcap: 30.0
loop_end: 5
loop_layer_bits: 0
loop_layer_clip_sigmas: 0.0
loop_phase2_at: 0.65
loop_start: 4
matrix_bits: 6
matrix_clip_sigmas: 12.85
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_row_normalize: True
muon_wd: 0.085
num_heads: 8
num_kv_heads: 4
num_layers: 11
num_loops: 2
parallel_residual_start: 7
qk_gain_init: 4.0
quantized_model_path: final_model.int6.ptz
rank: 0
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: f3971278-d577-499b-8fde-755434809ba9
scalar_lr: 0.02
seed: 1337
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
ttt_chunk_tokens: 32768
ttt_enabled: False
ttt_entropy_high: 2.1
ttt_entropy_low: 1.75
ttt_epochs: 4
ttt_freeze_blocks: 2
ttt_lr: 0.0005
ttt_ns_steps: 3
untie_loop_mlps: False
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin
val_loss_every: 4000
vocab_size: 8192
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 128
val_tokens: 40540160
model_params:35943512
hessian_clip: lambda=0.175
parallel_residuals: ON (layers 7-10)
progressive_recurrence: phase1=0.5 phase2=0.65
gptq:reserving 12s, effective=588000ms
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
loop_warmup_phase1: encoder:[0, 1, 2, 3, 4, 5] decoder:[4, 5, 6, 7, 8, 9, 10]
loop_warmup_p1_step: 1/20
loop_warmup_p1_step: 2/20
loop_warmup_p1_step: 3/20
loop_warmup_p1_step: 4/20
loop_warmup_p1_step: 5/20
loop_warmup_p1_step: 6/20
loop_warmup_p1_step: 10/20
loop_warmup_p1_step: 20/20
loop_warmup_phase2: encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
loop_warmup_p2_step: 1/20
loop_warmup_p2_step: 2/20
loop_warmup_p2_step: 3/20
loop_warmup_p2_step: 4/20
loop_warmup_p2_step: 5/20
loop_warmup_p2_step: 6/20
loop_warmup_p2_step: 10/20
loop_warmup_p2_step: 20/20
0/20000 val_loss: 9.0052 val_bpb: 3.4862
1/20000 train_loss: 9.0086 train_time: 0.0m tok/s: 8101558 looping:False
2/20000 train_loss: 12.1911 train_time: 0.0m tok/s: 8047236 looping:False
3/20000 train_loss: 11.0242 train_time: 0.0m tok/s: 7763017 looping:False
4/20000 train_loss: 9.5010 train_time: 0.0m tok/s: 7744899 looping:False
5/20000 train_loss: 8.3911 train_time: 0.0m tok/s: 7764479 looping:False
500/20000 train_loss: 3.3106 train_time: 0.8m tok/s: 7721577 looping:False
1000/20000 train_loss: 3.2012 train_time: 1.7m tok/s: 7711891 looping:False
1500/20000 train_loss: 3.1827 train_time: 2.5m tok/s: 7711479 looping:False
2000/20000 train_loss: 2.9936 train_time: 3.4m tok/s: 7711845 looping:False
2500/20000 train_loss: 3.0679 train_time: 4.2m tok/s: 7713114 looping:False
layer_loop:phase1 step:2884 frac:0.500
3000/20000 train_loss: 3.1068 train_time: 5.1m tok/s: 7668374 looping:True
3500/20000 train_loss: 2.9483 train_time: 6.1m tok/s: 7510155 looping:True
layer_loop:phase2 step:3634 frac:0.650
4000/20000 train_loss: 2.9482 train_time: 7.2m tok/s: 7297673 looping:True
4000/20000 val_loss: 2.9279 val_bpb: 1.1335
4500/20000 train_loss: 2.8499 train_time: 8.3m tok/s: 7110716 looping:True
5000/20000 train_loss: 2.8598 train_time: 9.4m tok/s: 6967320 looping:True
5178/20000 val_loss: 2.8121 val_bpb: 1.0887
stopping_early: wallclock_cap train_time: 588103ms step: 5178/20000
peak memory allocated: 34604 MiB reserved: 34634 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.80947408 val_bpb:1.08764765 eval_time:6554ms
Serialized model: 135426937 bytes
Code size: 78688 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 67 Hessians in 11.3s
GPTQ:saved Hessian diagnostics to hessian_diagnostics.pt (67 matrices)
Quantized weights:
gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight
gptq (int8): tok_emb.weight
passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights
Serialized model quantized+brotli: 15976275 bytes
Total submission size quantized+brotli: 16054963 bytes
quantized val_loss:2.84032998 val_bpb:1.09959307 eval_time:8134ms
quantized_sliding_window val_loss:2.79749368 val_bpb:1.08300961 eval_time:82837ms
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exec(open(__file__.replace("train_gpt.py","train_gpt_decode.py")).read())
Loading