Record: SP8192 + GPTQ Embeddings + Depth Recurrence + MuonEq-R + SDClip — val_bpb 1.08563 (5 seed mean)#1394
Open
clarkkev wants to merge 1 commit intoopenai:mainfrom
Conversation
resouer
pushed a commit
to resouer/parameter-golf
that referenced
this pull request
Apr 6, 2026
resouer
pushed a commit
to resouer/parameter-golf
that referenced
this pull request
Apr 6, 2026
resouer
pushed a commit
to resouer/parameter-golf
that referenced
this pull request
Apr 6, 2026
vaibhav-i
added a commit
to vaibhav-i/parameter-golf
that referenced
this pull request
Apr 6, 2026
New base: PR openai#1394 (clarkkev SP8192 + SDClip + GPTQ embeddings, 1.08563 BPB) Experiments (all build on new_base_pr1394): - exp_polar_express: 4-step minimax-optimal NS (arXiv:2505.16932), ~-0.002 BPB - exp_causal_slot: per-window delta on context tokens, AdamW 16 steps, ~-0.013 BPB - exp_log_bias: streaming online log-bias (Nacrith arXiv:2602.19626), ~-0.015 BPB Research briefs: - research/2026-04-04-full-scan-brief.md - research/2026-04-05-scan-brief.md (updated: pre-quant TTT ruled illegal) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
dexhunter
added a commit
to dexhunter/parameter-golf
that referenced
this pull request
Apr 6, 2026
…(3-seed mean) On PR openai#1394 (@clarkkev): added single-knob QK_GAIN_INIT=5.0 and a legal score-first TTT eval pass (TTT_LR=0.005, epochs=3, freeze=0) on top of the clean sp8192 base. Three independent seeds (0, 42, 1234) on 8xH100 SXM, all fitting 16MB with 7-11K margin. Per-seed (post-TTT): - seed 0 : 1.08210 (val_loss 2.79517) - seed 42 : 1.08315 (val_loss 2.79788) - seed 1234: 1.08314 (val_loss 2.79785) - mean : 1.08279 (2.79697 nats per token) Improvement vs PR openai#1394 (1.08563 mean): -0.00284 bpb = -0.00731 nats/token, clearing the 0.005 nats record threshold by 0.00231 nats per seed. No SLOT, no pre-quant TTT, no ETLB, no n-gram cache, no tokenizer change. Score-first TTT matches PR openai#549 precedent: every chunk scored under inference_mode() before any parameter update.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Record: SP8192 + GPTQ Embeddings + Depth Recurrence + MuonEq-R + SDClip + Simplifications — val_bpb 1.08563
val bpb: 1.08563 (5-seed mean, std=0.0007)
Changes
This script builds on #1218. The main changes are:
Quantization–Compression Tradeoffs
Quantization and compression interact in interesting ways. The compressed size depends not just on bitwidth, but also on the clip range (also called the scale) used during quantization. An int5 quantized network can actually compress smaller than an int4 one if the int5 quantization uses a much wider clip range. The reason is that the effectiveness of compression algorithms like
brotlidepends on the entropy of the data they are compressing, and increasing the clip range can lower that entropy.An example
Neural network weights are approximately normally distributed (a). In this example, we could clip the weights to [-1, 1] and uniformly quantize them into int5 (b). But this seems a bit wasteful because many of those bins are spent modeling the tails of the distribution, where very few weights lie. Instead, we could clip to [-0.5, 0.5] and use int4 (c). Or we could go one step further and use a non-uniform quantizer such as NF4 (d) so there are approximately the same number of weights at each quantized value.
Now here is the surprising part: after compression, int4 is only slightly smaller than int5, and NF4 is quite a bit larger. Why? Because the effectiveness of compression depends on not just the raw number of bits, but also the entropy of the quantized values. When we moved from int5 to int4, we made the histogram flatter, which increases entropy. NF4 flattens it even further by design, pushing the entropy higher still.
Another view is that the int4 and int5 parameters are mostly the same. The only difference is that the weights that would have been clipped to +-7 by int4 can take on larger values in int5, but as there are very few of them, this does not substantially increase compressed size.
Mathematical explanation
Suppose our network has$n$ weights and we quantize each one to $b$ bits. The quantized model size is $s_q = n b$ . However, we also compress our network after quantizing. A useful first approximation is that the compressed size $s$ is proportional to $H(q)$ , the entropy of the quantized weights:
This is not exact: compressors can also exploit structure beyond the marginal distribution. But neural network weights usually contain much less structure than natural data, so in practice their compressed size is often very close to what their entropy would suggest. So what is$H(q)$ ? Suppose our weights are normally distributed:
The differential entropy is
Now, suppose we clip our weights between$[-c, c]$ and quantize them into $2^b$ evenly spaced bins, i.e, we uniformly quantize them into int-$b$. Each bin then has width
The entropy of the resulting quantized weights, which we call$q$ , is approximately
If we measure entropy in bits, this becomes
This approximation becomes more accurate when$c \gg \sigma$ (since in that case only a small fraction of the weights are clipped), when $b$ is large enough that the quantization bins are small, and when $n$ is large enough that we still have many weights per bin.
A natural choice is to set the clip range proportional to the standard deviation, writing$c = k\sigma$ for some hyperparameter $k$ . This makes the amount of clipping scale-invariant: if the weights become 2x larger, the clip range should also become 2x larger. Substituting $c = k\sigma$ into the expression above gives
This gives two ways to reduce compressed model size: decrease$b$ (for example, go from int5 to int4), or increase $k$ (use a wider clip range so the quantized values get more concentrated near the center, which lowers their entropy). In fact, increasing $b$ and increasing $k$ have roughly opposite effects. The histogram produced by $(b, k)$ exactly matches the middle $2^b$ bins of $(b + 1, 2k)$ . The $(b + 1, 2k)$ quantization also includes additional outer bins, but very few weights lie in those bins, so $H(q)$ may not increase by much. This is exactly what we saw in the int5 versus int4 example.
Of course our approximations do not hold exactly in practice: the derivation ignores clipping, the weight distribution is only approximately normal, and compression depends on the full byte representation, not just the marginal histogram of quantized values. However, when I examined some trained networks, I found the standard deviation of a matrix (an estimate of$\sigma$ ) correlated very strongly ($R^2=0.995$ ) with the compression ratio of that matrix under a fixed clip width, suggesting the approximations are reasonable in practice. Lastly, I should note that usually each row is quantized separately, but the same reasoning applies on a per-row basis.
Improved clipping
The previous practice was to search over multiple clip thresholds to find the one that minimized reconstruction error. In the new version, the clipping threshold for a matrix row is just set at
In practice, I used$b = 6, k = 12.85$ for matrix parameters (tuned so the artifact is close to 16MB) and $b=8, k = 20$ for embeddings (they are more sensitive to quantization). As the above analysis suggests, upping the matrix params to int7 or int8 while doubling/quadrupling $k$ produced similarly-sized models, but I stuck with int6 to keep the script consistent with the previous version. Compared with the old approach, the new standard-deviation-based clipping has several advantages: