From 4ba466313f2ccadf47b07abf8bcacda4581135c7 Mon Sep 17 00:00:00 2001 From: Himanshu Dongre Date: Thu, 2 Apr 2026 00:54:25 +0530 Subject: [PATCH] =?UTF-8?q?Non-record:=2028=20Experiments=20in=205=20Days?= =?UTF-8?q?=20=E2=80=94=20What=20Works,=20What=20Fails,=20and=20Why=20Smal?= =?UTF-8?q?l-Scale=20Tests=20Lie?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Systematic exploration of architecture, training, quantization, and eval-time techniques. 14 dead techniques documented. Key finding: small-scale tests can be 180 degrees wrong (SSM: -18% local → +2.7% at scale). --- .../README.md | 290 ++++++ .../exp_complementary_training.py | 693 ++++++++++++++ .../exp_fixed_share_logistic.py | 880 ++++++++++++++++++ .../exp_local_distill_qat.py | 524 +++++++++++ .../exp_qat_learned_centroids.py | 525 +++++++++++ .../results_distill_qat.json | 11 + .../results_vocab4096_mlp4x.json | 19 + .../submission.json | 9 + 8 files changed, 2951 insertions(+) create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/README.md create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_complementary_training.py create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_fixed_share_logistic.py create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_local_distill_qat.py create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_qat_learned_centroids.py create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_distill_qat.json create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_vocab4096_mlp4x.json create mode 100644 records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/submission.json diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/README.md b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/README.md new file mode 100644 index 0000000000..d47a2e0b0c --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/README.md @@ -0,0 +1,290 @@ +# Non-Record: 28 Experiments in 5 Days — What Works, What Fails, and Why Small-Scale Tests Lie + +## TL;DR + +5 days, 28 controlled experiments, $12 in GPU spend, Mac Mini M4 + single H100. Systematic exploration of architecture, training, quantization, and eval-time techniques for the 16MB language model competition. + +**The most important finding:** Small-scale local experiments can be **180 degrees wrong.** Our SSM hybrid showed -18% CE improvement at dim=192 but was +2.7% BPB *worse* at dim=512 on H100. This isn't noise — it's a systematic bias that likely affects many competition submissions testing locally before committing GPU budget. + +**What works (confirmed at scale):** RoPE-16 partial embeddings, GEGLU 2.0x MLP, Full MHA over GQA at small scale, QAT as regularizer (quantized model beats float32), knowledge distillation (-0.75%), Dirichlet CTW-6 n-gram mixing (properly normalized), eval-time augmentation via scored-token statistics. + +**What definitively fails:** 14 techniques killed with specific numbers. + +--- + +## Experiment Methodology + +All local experiments use identical controlled conditions: +- **Hardware:** Mac Mini M4, 16GB unified memory, MPS backend +- **Architecture:** 6L, dim=192, 6 heads, RoPE-16, tied embeddings +- **Data:** ~1.7M tokens from Project Gutenberg (4 books), byte-level tokenization mod 1024 +- **Training:** AdamW, lr=3e-4, cosine schedule, batch 32, weight decay 0.1 +- **Eval:** 100-368 sequences (51K-188K tokens), BPC metric + +GPU experiments (where noted): +- **Hardware:** Single NVIDIA H100 80GB HBM3, RunPod +- **Architecture:** 11L, dim=512, 8 heads, competition `merged_leader` preset +- **Data:** FineWeb 10B, sentencepiece 1024 BPE tokenizer + +Each experiment changes ONE variable against a common baseline. All random seeds fixed at 42. + +--- + +## The Scale Deception Finding + +**This is the most important result in this report.** It should inform how every competitor interprets local testing. + +### The Setup + +We tested S4D-Lin (diagonal linear state-space model) as a drop-in replacement for attention in the bottom 2 layers: + +| Scale | SSM Result | Interpretation | +|-------|-----------|----------------| +| dim=192, 6L, local | -18.2% CE, 6% faster | "SSM wins! Ship it!" | +| dim=512, 11L, H100 | +2.7% BPB, same speed | "SSM loses. Attention is better." | + +**The local result pointed in the opposite direction from reality.** + +### Why This Happens + +At dim=192 (local), each attention head has dimension 32. This is below the threshold where attention can form sharp, useful patterns. SSM's convolutional approach actually works better with tiny head dimensions because it doesn't need to learn attention patterns at all. + +At dim=512 (competition), head dimension is 64. Attention heads are large enough to learn rich, multi-modal distributions. The quality advantage of attention over SSM's fixed convolutional structure becomes dominant — far outweighing SSM's marginal throughput benefit (+150 steps in 600s). + +### Implication for Competitors + +**If you're testing architecture changes on a small local model and seeing improvements, you might be measuring an artifact of small scale.** The only reliable signal is a GPU-scale experiment. We recommend: + +1. Classify failures as "dead" vs "inconclusive at scale" +2. Be skeptical of any architecture change that shows >5% improvement locally +3. Budget at least one GPU validation run before committing to a technique + +--- + +## Full Results Table + +### Proven Wins + +| # | Technique | Metric | Result | Hardware | Notes | +|---|-----------|--------|--------|----------|-------| +| 1 | **Partial RoPE (16 dims)** | BPC | **-23.6%** | Local+GPU | Largest single win. Replaces learned positional embeddings. | +| 2 | **GEGLU 2.0x MLP** | CE | -0.03% vs 3x, **1.4% faster** | Local | Sweet spot: ties quality, saves 25% MLP params | +| 3 | **Full MHA (6h) > GQA (8Q/4KV)** | CE | **-0.2%, 23% faster** | Local | Counter-intuitive: full MHA wins at small scale (32-dim heads) | +| 4 | **Width > Depth** | CE | **-2%, 31% faster** | Local | 6L×2.0x beats 9L×1.0x at <30M params | +| 5 | **QAT NF5 (int5)** | CE | **-0.66% vs float32** | Local | QAT acts as regularizer. Eliminates post-hoc degradation. | +| 6 | **QAT uniform int5** | CE | **-0.34% vs float32** | Local | Even uniform QAT beats float32 | +| 7 | **Knowledge Distillation** | CE | **-0.75%** | Local | T=2.0, alpha=0.5. Sequential distill-then-quantize works. | +| 8 | **Dirichlet CTW-6 n-gram** | BPC | **-5.76%** | Local | Properly normalized Bayesian n-gram. Order 6 > 4 > 3 > 2. | +| 9 | **Entropy-adaptive mixing** | BPC | **-2.57%** | Local | Sigmoid gating: high neural entropy → lean on n-grams | +| 10 | **N-gram Refiner** | CE | **-2.6%** | Local | +2.3% overhead. Dedicated n-gram refinement head. | +| 11 | **Score-first TTT** | BPB | **-0.12%** | H100 | LoRA rank 8, SGD, 10 epochs, cosine LR. Small but real. | +| 12 | **Eval-time augmentation** | BPC | **varies** | Local+H100 | Multiple methods tested (details in eval section below) | + +### Proven Failures (Dead) + +| # | Technique | Result | Why It Failed | +|---|-----------|--------|---------------| +| 13 | **SSM S4D-Lin Hybrid** | +2.7% BPB (H100) | Scale deception. Attention quality > SSM at dim=512. | +| 14 | **JEPA-LM** | -0.24% real text | Synthetic Markov success (-19.5%) didn't transfer. +39.8% overhead. | +| 15 | **Mixture of Softmaxes (MoS)** | +1.7% | Output rank bottleneck not binding at vocab=1024. Model ignores aux heads. | +| 16 | **Monarch Matrices** | Dense beats all | 8-12x compression per MLP too aggressive. Fragile across configs. | +| 17 | **DenseFormer (DWA)** | +0% | Weights stay at init (96.6% on previous layer). Needs 48+ layers. | +| 18 | **Complementary Training** | +2.6% worse | Mean bigram entropy 9.21/10 at vocab=1024. No easy/hard token separation. | +| 19 | **Residual Lambdas** | +0.28% | Model learns residual weighting implicitly. Extra LR complexity = zero benefit. | +| 20 | **QK-Norm** | -3.5 to -4.5% | Needs longer training. At 1500 steps, actively harmful. | +| 21 | **Softcapping** | +2.1% | Unnecessary at this scale. Doesn't help quantization. | +| 22 | **LN Scaling** | +11.4% | Catastrophic. RMSNorm alone is strictly better. | +| 23 | **Fixed-Share Mixer** | +0% | Ties entropy-adaptive. Global weights can't beat per-token gating. | +| 24 | **PAQ Logistic Mixing** | BPC=19 (broken) | **Fundamentally broken for multi-class.** Sigmoid(0)=0.5 for all tokens. Only works for binary prediction. Important negative result. | +| 25 | **Product Quantization** | +292% (random) | K-means with 256 centroids per group can't preserve weight structure. | +| 26 | **Local Linear Prediction** | +9.84% | Local weighted regression from nearest neighbors worse than global LM head. Noise dominates with few neighbors. | + +### Inconclusive (Scale-Dependent) + +| # | Technique | Local Result | Status | +|---|-----------|-------------|--------| +| 27 | **In-context N-gram (cumsum)** | gate=0.18 (barely used) | Model can't learn complex signal in 1500 steps. Needs 50K+. | +| 28 | **4096 Vocabulary** | +2.2% (bad tokenizer) | Our byte-level tokenizer doesn't simulate real BPE. Needs real tokenizer. | + +--- + +## Deep Dives + +### QAT as Regularizer (Experiments 5-6) + +This was unexpected. Training with simulated int5 quantization (STE gradients) from step 0 produces a model that's **better than float32 training:** + +| Method | Best CE | vs Float32 | +|--------|---------|-----------| +| Float32 baseline | 1.3237 | — | +| Post-hoc int5 (simulate GPTQ) | 1.3447 | +1.59% (degradation) | +| QAT uniform int5 | 1.3192 | **-0.34%** (better!) | +| QAT NF5 centroids | 1.3177 | **-0.45%** | +| QAT NF5 + learned centroids | 1.3150 | **-0.66%** | + +**The quantization noise acts as implicit regularization**, similar to dropout or weight noise. NormalFloat-5 centroids (Gaussian-optimal placement) outperform uniform grids. However, learned centroids don't actually learn to move — STE gradients are too weak to shift centroid positions at this scale. + +**Practical takeaway:** If using GPTQ post-hoc, you're leaving 1-2% on the table. QAT eliminates the degradation AND improves over float32. + +### Knowledge Distillation (Experiment 7) + +Train a larger teacher, distill to smaller student: + +| Method | CE | vs Direct Training | +|--------|-----|-------------------| +| Direct small (4L 192d) | 1.3194 | — | +| Teacher (6L 256d) | 1.3158 | -0.27% | +| Distilled small (T=2.0, alpha=0.5) | 1.3094 | **-0.75%** | +| Distilled + QAT int4 | 1.3313 | +0.90% (doesn't stack!) | + +Distillation works but **does NOT stack with QAT** when done simultaneously. The recommended approach: distill in FP32 first, then quantize post-hoc. + +**Competition implication:** Nobody in the competition does in-run distillation. Training a 50M teacher for 70% of the budget, then distilling to 27M for 30%, should produce a better model than direct 27M training. Untested at 8xH100 scale. + +### Dirichlet CTW N-gram: The Right Way to Do Eval-Time N-grams (Experiment 8) + +93% of n-gram cache submissions were closed for invalid normalization. Our Bayesian approach is properly normalized by construction: + +**Method:** Dirichlet-multinomial posterior predictive distribution with recursive Bayesian updates. At each order k, the prior is the (k-1)-order posterior. The concentration parameter scales as 0.5k (order-dependent smoothing). + +| Order | BPC | vs Neural | +|-------|-----|-----------| +| Unigram | 1.8942 | — (no change, too uniform) | +| Bigram | 1.8719 | -1.18% | +| Trigram | 1.8214 | -3.84% | +| 4-gram | 1.7988 | -5.04% | +| **6-gram** | **1.7851** | **-5.76%** | + +**Why this works when hash-based caches don't:** Hash-based caches only score P(correct_token) and ignore the other 1023 tokens. The resulting "distribution" sums to ~410, not 1.0. Our Bayesian approach maintains a full posterior over ALL tokens at ALL times. + +**Limitation:** Per-token sequential update is slow. At 62M eval tokens, estimated 200-300s on CPU. Needs C/CUDA vectorization for competition tractability. + +### PAQ Logistic Mixing: Why It's Fundamentally Broken for Multi-Class (Experiment 24) + +PAQ-style logistic mixing is the gold standard in data compression. We attempted to apply it to mix n-gram experts with the neural model. **Result: BPC=19 (vs 1.9 expected).** A catastrophic 10x blowup. + +**Root cause:** PAQ's logistic mixing works on BINARY predictions (probability of each bit). With 1024-class prediction: +- Logistic mixing operates on `log(p/(1-p))` — the log-odds +- At initialization (weights=0), sigmoid(0)=0.5 for ALL tokens +- The mixture assigns 0.5 probability to every token → BPC ≈ log2(1024) ≈ 10 +- Weight updates try to fix this but diverge because the gradient landscape is pathological for 1024 classes + +**This is a fundamental incompatibility**, not an implementation bug. PAQ works because it predicts ONE BIT at a time. Multi-class logistic mixing requires a different formulation (softmax-space mixing, which we tested as "Fixed-Share" — ties entropy-adaptive). + +**Implication:** Anyone trying to port PAQ-style compression to neural LM mixing will hit this wall. The fix is either binary decomposition (predict 10 bits sequentially) or linear mixing in probability space (which is what works). + +### Score-First TTT at Scale (Experiment 11) + +Tested on single H100 with competition's `merged_leader` architecture: + +| Config | BPB | vs Baseline | +|--------|-----|-------------| +| No TTT (2000 steps) | 1.4859 | — | +| TTT 5ep SGD | 1.4850 | -0.06% | +| **TTT 10ep cosine** | **1.4841** | **-0.12%** | +| TTT 20ep | 1.4858 | -0.01% | +| TTT AdamW | 1.4842 | -0.11% | + +**Key finding:** TTT works but gains are modest at 2000 training steps. 10 epochs with cosine LR schedule is optimal. 20 epochs overfits to each chunk. At full scale (7000+ steps, stronger model), gains may be larger. + +--- + +## Eval-Time Techniques Explored + +We systematically tested 7 eval-time augmentation methods using hidden states from the neural model. All methods follow the score-first protocol: predict → score → update. + +| Method | BPC | vs Neural | Artifact Cost | +|--------|-----|-----------|---------------| +| Neural only | 1.8942 | — | — | +| Dirichlet CTW-6 | 1.7851 | -5.76% | Zero | +| Linear Attention Memory | 1.8456 | +0% (ties CTW) | Zero | +| Delta Rule (DeltaNet) | overflow | Failed | Zero | +| EMA Hidden State | 1.8456 | +0% | Zero | +| JEPA Surprise-Adaptive | 1.8456 | +0% | ~200KB | +| Local Linear Prediction | 2.0809 | +9.84% | Zero | + +We have additional eval-time techniques under development showing promising results, pending validation at scale for a record-track submission. + +--- + +## Quantization Deep Dive + +### GPTQ int5 vs int6 (H100) + +| Quantization | BPB | Degradation | +|-------------|-----|-------------| +| Float32 (pre-export) | 1.3774 | — | +| GPTQ int6 | (export broken on 1GPU) | — | +| GPTQ int5 | 1.5247 | +2.61% | + +Int5 loses 2.61% — too aggressive for competition. Int6 is the standard for a reason. + +### The INT6 Scale Clamp Bug + +Community reports (and the matotezitanka competition analysis) indicate 93% of submissions use a minimum scale clamp of 0.032 in int6 quantization. This wastes resolution on small-magnitude weight rows. We tested lowering the clamp but didn't find significant improvement on our model — likely because our model was only 2000 steps (undertrained weights have different distributions than 7000-step SOTA weights). + +--- + +## Budget Analysis + +| Resource | Cost | Experiments | +|----------|------|-------------| +| Mac Mini M4 (owned) | $0 | 25 local experiments | +| RunPod H100 single (~4h) | ~$12 | 8 GPU experiments (TTT sweep, quantization, eval) | +| **Total** | **~$12** | **28 experiments** | + +For comparison, the depth recurrence PR (#363) reports "4 days, ~35 runs." Our 28 experiments in 5 days at $12 total demonstrates that significant research contributions are possible on a constrained budget, especially when: +1. Local experiments are used to filter dead ideas before GPU spend +2. GPU time is focused on scale-dependent validations only +3. Results are cached aggressively (model checkpoints, precomputed probabilities) + +--- + +## What We'd Do With More Compute + +We applied for the development grant on March 27 and are still awaiting approval. With GPU access, the highest-priority experiments are: + +1. **Full-scale validation of eval-time augmentation** — our redacted technique needs 8xH100 validation +2. **In-run distillation** — train 50M teacher → distill to 27M student within 600s +3. **SSM interleaved placement** — reviewer suggestion from PR #1013, untested at scale +4. **Multi-resolution training** — seq_len=256 phase then seq_len=2048 fine-tune +5. **Self-distillation (born-again networks)** — train → freeze → retrain with soft targets + +--- + +## Key Takeaways for Other Competitors + +1. **Test at scale or don't trust the result.** Local dim=192 experiments are useful for filtering dead ideas but should NEVER be used to declare a technique "works." + +2. **QAT beats post-hoc quantization.** If you're using GPTQ as an afterthought, you're losing 1-2%. Train with simulated quantization from step 0. + +3. **PAQ logistic mixing is broken for multi-class.** Don't try to port PAQ to neural LM mixing. Use linear mixing in probability space. + +4. **Bayesian n-grams are properly normalized by construction.** If your n-gram cache sums to anything other than 1.0, your BPB numbers are meaningless. + +5. **Knowledge distillation is free improvement.** Nobody in the competition does in-run distillation. Train bigger, distill smaller, quantize. + +6. **The eval budget is massively underutilized.** 600 seconds on 8xH100 is enormous. Most submissions spend <30s on eval. There's a lot of room for eval-time techniques. + +7. **TTT at 10 epochs with cosine LR** is the sweet spot. More than 20 epochs overfits. + +--- + +## Code and Reproduction + +All experiment scripts are included in this submission. Each is self-contained with data loading, model definition, training, and evaluation. + +| File | Description | +|------|-------------| +| `exp_qat_learned_centroids.py` | QAT with NF5 centroids — quantization experiments | +| `exp_local_distill_qat.py` | Knowledge distillation + QAT int4 | +| `exp_fixed_share_logistic.py` | Fixed-Share mixer + logistic mixing failure analysis | +| `exp_complementary_training.py` | Complementary training (dead technique, documented) | +| `results_distill_qat.json` | Raw results for distillation experiments | +| `results_vocab4096_mlp4x.json` | Raw results for vocab/MLP sweep | + +--- + +*Self-funded research on Mac Mini M4 + RunPod single H100. Total GPU spend: ~$12.* + +*Author: Himanshu Dongre (@himanshudongre) — also author of PR #1013 (SSM Hybrid) and PR #1012 (JEPA-LM).* diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_complementary_training.py b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_complementary_training.py new file mode 100644 index 0000000000..05e6f84478 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_complementary_training.py @@ -0,0 +1,693 @@ +""" +Complementary Training: Divide & Conquer Between Train and Eval +================================================================ + +KEY INSIGHT (from Himanshu): + Training budget: 10 min on 8×H100 → 16MB artifact + Eval budget: 10 min on 8×H100 → online n-gram caches + + If the neural model wastes gradients learning what n-grams can learn + for free at eval time, we're throwing away training compute. + + SOLUTION: Train neural model to focus on what n-grams CAN'T learn. + Then at eval time, n-grams handle the easy patterns, neural handles hard ones. + Effective learning time = 20 minutes instead of 10. + +METHODS TESTED: + A: Standard training (baseline) + B: Entropy-weighted loss — upweight tokens where n-gram is uncertain + C: Residual training — train neural on (target - ngram_prediction) effectively + D: KL divergence penalty — push neural AWAY from n-gram where n-gram is good + E: Hard token mining — only train on tokens where n-gram entropy > threshold + F: Curriculum — start standard, gradually increase complementary weight + +All methods combined with entropy-adaptive CTW-6 at eval time. +""" +import sys +sys.stdout.reconfigure(line_buffering=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +import json +import os +import urllib.request +from collections import defaultdict + +VOCAB_SIZE = 1024 +SEQ_LEN = 512 +DIM = 192 +N_HEADS = 6 +N_LAYERS = 6 +MLP_EXP = 2.0 +TRAIN_STEPS = 1500 +BATCH_SIZE = 32 +LR = 3e-4 +DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" + +print(f"Device: {DEVICE}") +print(f"Complementary Training: Divide & Conquer") +print() + +# ============================================================ +# Data Loading (reuse from other experiments) +# ============================================================ +def download_text_corpus(): + cache_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/text_corpus.txt" + if os.path.exists(cache_path): + with open(cache_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + urls = [ + "https://www.gutenberg.org/cache/epub/1342/pg1342.txt", + "https://www.gutenberg.org/cache/epub/11/pg11.txt", + "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "https://www.gutenberg.org/cache/epub/1661/pg1661.txt", + ] + all_text = [] + for url in urls: + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(req, timeout=30) + text = response.read().decode('utf-8', errors='ignore') + all_text.append(text) + except Exception as e: + print(f" Failed to download {url}: {e}") + combined = "\n\n".join(all_text) + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(combined) + return combined + +def tokenize_text(text, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN): + raw_bytes = text.encode('utf-8') + tokens = [b % vocab_size for b in raw_bytes] + n_seq = len(tokens) // (seq_len + 1) + tokens = tokens[:n_seq * (seq_len + 1)] + return torch.tensor(tokens, dtype=torch.long).view(n_seq, seq_len + 1) + +# ============================================================ +# Model (same RoPE 16 architecture) +# ============================================================ +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + +class GEGLU_MLP(nn.Module): + def __init__(self, dim, expansion=2.0): + super().__init__() + hidden = int(dim * expansion) + self.w1 = nn.Linear(dim, hidden, bias=False) + self.w2 = nn.Linear(dim, hidden, bias=False) + self.out = nn.Linear(hidden, dim, bias=False) + def forward(self, x): + return self.out(F.gelu(self.w1(x)) * self.w2(x)) + +class FullMHA(nn.Module): + def __init__(self, dim, n_heads, rope_dims=16): + super().__init__() + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.qkv = nn.Linear(dim, 3 * dim, bias=False) + self.out = nn.Linear(dim, dim, bias=False) + self.rope_dims = rope_dims + freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dims, 2).float() / rope_dims)) + t = torch.arange(SEQ_LEN).float() + freqs = torch.outer(t, freqs) + self.register_buffer('cos_cache', freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False) + self.register_buffer('sin_cache', freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False) + + def _apply_rope(self, x): + rd = self.rope_dims + x_rope, x_pass = x[..., :rd], x[..., rd:] + x1, x2 = x_rope[..., :rd//2], x_rope[..., rd//2:] + cos = self.cos_cache[:, :, :x.size(2), :] + sin = self.sin_cache[:, :, :x.size(2), :] + x_rope_out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return torch.cat([x_rope_out, x_pass], dim=-1) + + def forward(self, x): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = qkv.unbind(2) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + q, k = self._apply_rope(q), self._apply_rope(k) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out(y.transpose(1, 2).reshape(B, T, C)) + +class TransformerBlock(nn.Module): + def __init__(self, dim, n_heads, mlp_expansion=2.0): + super().__init__() + self.ln1 = RMSNorm(dim) + self.attn = FullMHA(dim, n_heads) + self.ln2 = RMSNorm(dim) + self.mlp = GEGLU_MLP(dim, expansion=mlp_expansion) + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class Transformer(nn.Module): + def __init__(self): + super().__init__() + self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM) + self.blocks = nn.ModuleList([ + TransformerBlock(DIM, N_HEADS, MLP_EXP) for _ in range(N_LAYERS) + ]) + self.ln_f = RMSNorm(DIM) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + def forward(self, idx): + x = self.tok_emb(idx) + for block in self.blocks: + x = block(x) + return F.linear(self.ln_f(x), self.tok_emb.weight) + +# ============================================================ +# N-gram Statistics from Training Data +# ============================================================ +class TrainingNgramStats: + """Pre-compute n-gram statistics from training data. + Used to create per-token weights for complementary training.""" + + def __init__(self, train_sequences, vocab_size=VOCAB_SIZE): + self.V = vocab_size + print(" Building training n-gram statistics...", flush=True) + t0 = time.time() + + # Bigram counts from training data + self.bigram_counts = np.zeros((vocab_size, vocab_size), dtype=np.float32) + self.bigram_totals = np.zeros(vocab_size, dtype=np.float32) + + for i in range(len(train_sequences)): + seq = train_sequences[i].numpy() + for t in range(len(seq) - 1): + prev, curr = seq[t], seq[t + 1] + self.bigram_counts[prev, curr] += 1 + self.bigram_totals[prev] += 1 + + # Pre-compute bigram entropy for each context token + # H(next | prev) = -Σ p(next|prev) * log2(p(next|prev)) + self.bigram_entropy = np.zeros(vocab_size, dtype=np.float32) + for prev in range(vocab_size): + total = self.bigram_totals[prev] + if total > 0: + probs = self.bigram_counts[prev] / total + probs = probs[probs > 0] # filter zeros + self.bigram_entropy[prev] = -np.sum(probs * np.log2(probs)) + else: + self.bigram_entropy[prev] = np.log2(vocab_size) # max entropy + + # Trigram counts (sparse, for contexts that appear enough) + self.trigram_entropy = {} # (prev2, prev1) -> entropy + trigram_counts = defaultdict(lambda: defaultdict(int)) + trigram_totals = defaultdict(int) + + for i in range(len(train_sequences)): + seq = train_sequences[i].numpy() + for t in range(1, len(seq) - 1): + ctx = (int(seq[t-1]), int(seq[t])) + trigram_counts[ctx][int(seq[t+1])] += 1 + trigram_totals[ctx] += 1 + + for ctx, total in trigram_totals.items(): + if total >= 5: # only compute for frequent contexts + counts = np.array(list(trigram_counts[ctx].values()), dtype=np.float32) + probs = counts / total + self.trigram_entropy[ctx] = -np.sum(probs * np.log2(probs)) + + print(f" N-gram stats built in {time.time()-t0:.1f}s", flush=True) + print(f" Bigram entropy range: [{self.bigram_entropy.min():.2f}, {self.bigram_entropy.max():.2f}]") + print(f" Mean bigram entropy: {self.bigram_entropy.mean():.2f}") + print(f" Trigram contexts with stats: {len(self.trigram_entropy)}") + + def get_token_weights(self, sequences, method='entropy', **kwargs): + """Compute per-token training weights based on n-gram difficulty. + + Args: + sequences: (B, T+1) tensor of token sequences + method: weight computation method + + Returns: + weights: (B, T) tensor of per-token loss weights + """ + B, L = sequences.shape + T = L - 1 + weights = torch.ones(B, T, dtype=torch.float32) + + if method == 'entropy': + # Weight = bigram_entropy(prev_token) / max_entropy + # High entropy = n-gram uncertain = neural should focus here + max_H = np.log2(self.V) + for b in range(B): + for t in range(T): + prev = sequences[b, t].item() + H = self.bigram_entropy[prev] + weights[b, t] = H / max_H # normalize to [0, 1] + + elif method == 'entropy_trigram': + # Use trigram entropy when available, fall back to bigram + max_H = np.log2(self.V) + for b in range(B): + for t in range(T): + prev = sequences[b, t].item() + if t > 0: + prev2 = sequences[b, t-1].item() + ctx = (prev2, prev) + H = self.trigram_entropy.get(ctx, self.bigram_entropy[prev]) + else: + H = self.bigram_entropy[prev] + weights[b, t] = H / max_H + + elif method == 'hard_only': + # Binary: train only on tokens where n-gram entropy > threshold + threshold = kwargs.get('threshold', 5.0) # bits + for b in range(B): + for t in range(T): + prev = sequences[b, t].item() + H = self.bigram_entropy[prev] + weights[b, t] = 1.0 if H > threshold else 0.0 + + elif method == 'inverse_confidence': + # Weight = 1 - max_bigram_prob(prev) + # If n-gram is very confident (one dominant next token), downweight + for b in range(B): + for t in range(T): + prev = sequences[b, t].item() + total = self.bigram_totals[prev] + if total > 0: + max_prob = self.bigram_counts[prev].max() / total + weights[b, t] = 1.0 - max_prob + else: + weights[b, t] = 1.0 + + elif method == 'softmax_temp': + # Soft version: w = sigmoid(scale * (H - threshold)) + threshold = kwargs.get('threshold', 5.0) + scale = kwargs.get('scale', 2.0) + for b in range(B): + for t in range(T): + prev = sequences[b, t].item() + H = self.bigram_entropy[prev] + weights[b, t] = 1.0 / (1.0 + math.exp(-scale * (H - threshold))) + + # Ensure mean weight ≈ 1 so effective learning rate is preserved + w_mean = weights.mean() + if w_mean > 0: + weights = weights / w_mean + + return weights + +# ============================================================ +# Pre-compute weights for speed (vectorized) +# ============================================================ +def precompute_all_weights(train_seq, ngram_stats, method, **kwargs): + """Pre-compute all token weights to avoid per-step overhead.""" + print(f" Pre-computing weights (method={method})...", flush=True) + t0 = time.time() + + B, L = train_seq.shape + T = L - 1 + + if method == 'entropy': + max_H = np.log2(VOCAB_SIZE) + # Vectorized: map prev tokens to their bigram entropies + prev_tokens = train_seq[:, :-1].numpy() # (B, T) + entropies = ngram_stats.bigram_entropy[prev_tokens] # fancy indexing + weights = torch.tensor(entropies / max_H, dtype=torch.float32) + + elif method == 'inverse_confidence': + prev_tokens = train_seq[:, :-1].numpy() + # For each prev token, get max bigram probability + max_probs = np.zeros_like(prev_tokens, dtype=np.float32) + for prev in range(VOCAB_SIZE): + total = ngram_stats.bigram_totals[prev] + if total > 0: + max_probs[prev_tokens == prev] = ngram_stats.bigram_counts[prev].max() / total + else: + max_probs[prev_tokens == prev] = 0.0 + weights = torch.tensor(1.0 - max_probs, dtype=torch.float32) + + elif method == 'softmax_temp': + max_H = np.log2(VOCAB_SIZE) + threshold = kwargs.get('threshold', 5.0) + scale = kwargs.get('scale', 2.0) + prev_tokens = train_seq[:, :-1].numpy() + entropies = ngram_stats.bigram_entropy[prev_tokens] + sigmoid_weights = 1.0 / (1.0 + np.exp(-scale * (entropies - threshold))) + weights = torch.tensor(sigmoid_weights, dtype=torch.float32) + + elif method == 'hard_only': + threshold = kwargs.get('threshold', 5.0) + prev_tokens = train_seq[:, :-1].numpy() + entropies = ngram_stats.bigram_entropy[prev_tokens] + weights = torch.tensor((entropies > threshold).astype(np.float32)) + + elif method == 'standard': + weights = torch.ones(B, T, dtype=torch.float32) + + else: + weights = torch.ones(B, T, dtype=torch.float32) + + # Normalize so mean = 1 + w_mean = weights.mean() + if w_mean > 0: + weights = weights / w_mean + + # Stats + print(f" Weights computed in {time.time()-t0:.1f}s", flush=True) + print(f" Weight stats: min={weights.min():.3f}, max={weights.max():.3f}, " + f"mean={weights.mean():.3f}, std={weights.std():.3f}", flush=True) + frac_low = (weights < 0.5).float().mean().item() + frac_high = (weights > 1.5).float().mean().item() + print(f" Fraction w<0.5: {frac_low:.1%}, w>1.5: {frac_high:.1%}", flush=True) + + return weights + +# ============================================================ +# Training with per-token weights +# ============================================================ +def train_model_weighted(train_seq, eval_seq, all_weights, label=""): + """Train model with per-token loss weights.""" + model = Transformer().to(DEVICE) + n_params = sum(p.numel() for p in model.parameters()) + print(f" [{label}] Training: {n_params:,} params", flush=True) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TRAIN_STEPS) + + t0 = time.time() + for step in range(TRAIN_STEPS + 1): + if step % 500 == 0: + model.eval() + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + ce = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), eb[:, 1:].reshape(-1)) + print(f" Step {step:4d} | CE: {ce:.4f} | {(time.time()-t0)/max(step,1)*1000:.0f}ms/step", flush=True) + model.train() + if step >= TRAIN_STEPS: + break + + bi = torch.randint(0, train_seq.size(0), (BATCH_SIZE,)) + batch = train_seq[bi].to(DEVICE) + weights = all_weights[bi].to(DEVICE) # (B, T) + + logits = model(batch[:, :-1]) # (B, T, V) + + # Per-token weighted cross-entropy + # Standard: F.cross_entropy averages over all tokens equally + # Complementary: weight each token by n-gram difficulty + per_token_ce = F.cross_entropy( + logits.reshape(-1, VOCAB_SIZE), + batch[:, 1:].reshape(-1), + reduction='none' + ).reshape(BATCH_SIZE, -1) # (B, T) + + loss = (per_token_ce * weights).mean() + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + elapsed = time.time() - t0 + print(f" [{label}] Training done in {elapsed:.1f}s", flush=True) + return model + +# ============================================================ +# Eval: Entropy-adaptive CTW-6 (our best eval method) +# ============================================================ +class DirichletCTWExpert: + def __init__(self, vocab_size=VOCAB_SIZE, max_order=6, concentrations=None): + self.V = vocab_size + self.max_order = max_order + if concentrations is None: + self.concentrations = {k: 0.5 * k for k in range(1, max_order + 1)} + else: + self.concentrations = concentrations + self.unigram_counts = np.zeros(vocab_size, dtype=np.uint32) + self.unigram_total = 0 + self.bigram_counts = np.zeros((vocab_size, vocab_size), dtype=np.uint32) + self.bigram_totals = np.zeros(vocab_size, dtype=np.uint32) + self.higher_counts = {} + self.higher_totals = {} + for k in range(3, max_order + 1): + self.higher_counts[k] = defaultdict(lambda: defaultdict(int)) + self.higher_totals[k] = defaultdict(int) + self.history = [] + + def update(self, token): + self.unigram_counts[token] += 1 + self.unigram_total += 1 + if len(self.history) >= 1: + prev = self.history[-1] + self.bigram_counts[prev, token] += 1 + self.bigram_totals[prev] += 1 + for k in range(3, self.max_order + 1): + if len(self.history) >= k - 1: + ctx = tuple(self.history[-(k-1):]) + self.higher_counts[k][ctx][token] += 1 + self.higher_totals[k][ctx] += 1 + self.history.append(token) + + def get_distribution(self, context_tokens): + V = self.V + # Start with uniform + p = np.ones(V, dtype=np.float64) / V + + # Unigram + if self.unigram_total > 0: + c1 = self.concentrations.get(1, 0.5) + p = (self.unigram_counts.astype(np.float64) + c1 * p) / (self.unigram_total + c1) + + # Bigram + if len(context_tokens) >= 1: + prev = context_tokens[-1] + total = self.bigram_totals[prev] + if total > 0: + c2 = self.concentrations.get(2, 1.0) + p = (self.bigram_counts[prev].astype(np.float64) + c2 * p) / (total + c2) + + # Higher order + for k in range(3, min(self.max_order + 1, len(context_tokens) + 2)): + if len(context_tokens) >= k - 1: + ctx = tuple(context_tokens[-(k-1):]) + total = self.higher_totals[k].get(ctx, 0) + if total > 0: + ck = self.concentrations.get(k, 0.5 * k) + counts_dict = self.higher_counts[k][ctx] + counts = np.zeros(V, dtype=np.float64) + for tok, cnt in counts_dict.items(): + counts[tok] = cnt + p = (counts + ck * p) / (total + ck) + + return p + +def eval_with_entropy_ctw(model, eval_seq, label=""): + """Evaluate model with entropy-adaptive CTW-6 mixing.""" + model.eval() + + # Get neural probabilities + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + probs = F.softmax(logits, dim=-1).cpu().numpy() + sequences = eval_seq[:100].numpy() + + # Also compute neural-only BPC for comparison + neural_bits = 0.0 + scored = 0 + for i in range(len(sequences)): + for t in range(sequences.shape[1] - 1): + target = sequences[i, t + 1] + p = max(probs[i, t, target], 1e-30) + neural_bits += -math.log2(p) + scored += 1 + neural_bpc = neural_bits / scored + + # Now eval with entropy-adaptive CTW-6 + ctw = DirichletCTWExpert(max_order=6) + + total_bits = 0.0 + scored = 0 + + for i in range(len(sequences)): + context_tokens = [] + for t in range(sequences.shape[1] - 1): + target = sequences[i, t + 1] + + # Neural prediction + neural_p = probs[i, t].astype(np.float64) + neural_p = np.clip(neural_p, 1e-10, None) + neural_p = neural_p / neural_p.sum() + + # CTW prediction (context = tokens scored so far in this doc) + ctw_p = ctw.get_distribution(context_tokens) + + # Entropy-adaptive mixing + H = -np.sum(neural_p * np.log2(np.maximum(neural_p, 1e-30))) + alpha = 0.05 + 0.55 / (1.0 + math.exp(-2.0 * (H - 4.0))) + + mixed = (1 - alpha) * neural_p + alpha * ctw_p + mixed = mixed / mixed.sum() + + p = max(mixed[target], 1e-30) + total_bits += -math.log2(p) + scored += 1 + + # Update CTW (AFTER scoring) and context + ctw.update(int(target)) + context_tokens.append(int(target)) + if len(context_tokens) > 20: + context_tokens = context_tokens[-20:] + + mixed_bpc = total_bits / scored + improvement = (mixed_bpc - neural_bpc) / neural_bpc * 100 + + print(f" [{label}] Neural BPC: {neural_bpc:.4f} | Mixed BPC: {mixed_bpc:.4f} | " + f"CTW helps: {improvement:.2f}%", flush=True) + + return neural_bpc, mixed_bpc + +# ============================================================ +# Main Experiment +# ============================================================ +if __name__ == "__main__": + print("=" * 70) + print("Loading data...", flush=True) + corpus = download_text_corpus() + all_sequences = tokenize_text(corpus) + n_train = int(len(all_sequences) * 0.9) + train_seq = all_sequences[:n_train] + eval_seq = all_sequences[n_train:] + print(f" Train: {train_seq.shape}, Eval: {eval_seq.shape}") + + # Build n-gram statistics from training data + print("\n" + "=" * 70) + print("Building N-gram Statistics from Training Data") + print("=" * 70) + ngram_stats = TrainingNgramStats(train_seq) + + results = {} + + # ============================================================ + # A: Standard Training (baseline) + # ============================================================ + print("\n" + "=" * 70) + print("A: Standard Training (baseline)") + print("=" * 70) + weights_standard = precompute_all_weights(train_seq, ngram_stats, 'standard') + model_a = train_model_weighted(train_seq, eval_seq, weights_standard, label="Standard") + neural_a, mixed_a = eval_with_entropy_ctw(model_a, eval_seq, label="Standard") + results["standard"] = {"neural": neural_a, "mixed": mixed_a} + del model_a + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # B: Entropy-Weighted Training + # ============================================================ + print("\n" + "=" * 70) + print("B: Entropy-Weighted Training (upweight hard tokens)") + print("=" * 70) + weights_entropy = precompute_all_weights(train_seq, ngram_stats, 'entropy') + model_b = train_model_weighted(train_seq, eval_seq, weights_entropy, label="Entropy") + neural_b, mixed_b = eval_with_entropy_ctw(model_b, eval_seq, label="Entropy") + results["entropy_weighted"] = {"neural": neural_b, "mixed": mixed_b} + del model_b + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # C: Inverse Confidence Training + # ============================================================ + print("\n" + "=" * 70) + print("C: Inverse Confidence (downweight where n-gram is confident)") + print("=" * 70) + weights_inv = precompute_all_weights(train_seq, ngram_stats, 'inverse_confidence') + model_c = train_model_weighted(train_seq, eval_seq, weights_inv, label="InvConf") + neural_c, mixed_c = eval_with_entropy_ctw(model_c, eval_seq, label="InvConf") + results["inverse_confidence"] = {"neural": neural_c, "mixed": mixed_c} + del model_c + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # D: Sigmoid Threshold (soft version of hard mining) + # ============================================================ + print("\n" + "=" * 70) + print("D: Sigmoid Threshold (smooth transition at entropy=5.0)") + print("=" * 70) + weights_sig = precompute_all_weights(train_seq, ngram_stats, 'softmax_temp', + threshold=5.0, scale=2.0) + model_d = train_model_weighted(train_seq, eval_seq, weights_sig, label="Sigmoid") + neural_d, mixed_d = eval_with_entropy_ctw(model_d, eval_seq, label="Sigmoid") + results["sigmoid_threshold"] = {"neural": neural_d, "mixed": mixed_d} + del model_d + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # E: Hard Token Mining (only train on hard tokens) + # ============================================================ + print("\n" + "=" * 70) + print("E: Hard Token Mining (only tokens with bigram entropy > 5.0)") + print("=" * 70) + weights_hard = precompute_all_weights(train_seq, ngram_stats, 'hard_only', threshold=5.0) + model_e = train_model_weighted(train_seq, eval_seq, weights_hard, label="HardOnly") + neural_e, mixed_e = eval_with_entropy_ctw(model_e, eval_seq, label="HardOnly") + results["hard_only"] = {"neural": neural_e, "mixed": mixed_e} + del model_e + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # F: Lower Threshold Sigmoid + # ============================================================ + print("\n" + "=" * 70) + print("F: Sigmoid Threshold (threshold=3.0, gentler)") + print("=" * 70) + weights_sig2 = precompute_all_weights(train_seq, ngram_stats, 'softmax_temp', + threshold=3.0, scale=1.5) + model_f = train_model_weighted(train_seq, eval_seq, weights_sig2, label="Sigmoid-3.0") + neural_f, mixed_f = eval_with_entropy_ctw(model_f, eval_seq, label="Sigmoid-3.0") + results["sigmoid_3.0"] = {"neural": neural_f, "mixed": mixed_f} + del model_f + torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================ + # Summary + # ============================================================ + print("\n" + "=" * 70) + print("SUMMARY: Complementary Training Results") + print("=" * 70) + + baseline_neural = results["standard"]["neural"] + baseline_mixed = results["standard"]["mixed"] + + print(f"\n{'Method':<30} {'Neural BPC':>12} {'Mixed BPC':>12} {'vs Std Neural':>14} {'vs Std Mixed':>14}") + print("-" * 86) + for name, r in results.items(): + vs_neural = (r["neural"] - baseline_neural) / baseline_neural * 100 + vs_mixed = (r["mixed"] - baseline_mixed) / baseline_mixed * 100 + tag = " *** BEST ***" if r["mixed"] < baseline_mixed - 0.001 else "" + print(f" {name:<28} {r['neural']:>12.4f} {r['mixed']:>12.4f} {vs_neural:>+13.2f}% {vs_mixed:>+13.2f}%{tag}") + + print(f"\nKEY QUESTION: Does complementary training make the neural model") + print(f" better when COMBINED with eval-time n-grams?") + print(f" Standard neural+CTW: {baseline_mixed:.4f}") + best_method = min(results.items(), key=lambda x: x[1]["mixed"]) + print(f" Best complementary: {best_method[0]} = {best_method[1]['mixed']:.4f}") + delta = (best_method[1]["mixed"] - baseline_mixed) / baseline_mixed * 100 + print(f" Delta: {delta:+.2f}%") + + # Save results + with open("complementary_training_results.json", "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to complementary_training_results.json") diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_fixed_share_logistic.py b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_fixed_share_logistic.py new file mode 100644 index 0000000000..8395f486d2 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_fixed_share_logistic.py @@ -0,0 +1,880 @@ +""" +Fixed-Share Logistic Mixer + Dense Trigram Cache +=================================================== +The FOUNDATION of our compression system. + +KEY INSIGHT: Eval-time memory is UNLIMITED (640GB GPU HBM available). +The 16MB limit is on the ARTIFACT only. At eval time, we can build: + - Dense trigram table: 1024³ × 4B = 4GB (zero hash collisions!) + - Sparse higher-order tables: ~1-5GB + - Total: ~10GB of 640GB = 1.5% utilization + +LEGALITY PROTOCOL (every step annotated): + 1. PREDICT: P_mix computed from expert distributions + current mixer weights + Uses ONLY past information (already-scored tokens) ✓ + 2. SCORE: -log2(P_mix(x_true)) is the official score ✓ + 3. OBSERVE: see x_true ✓ + 4. UPDATE: mixer weights (Fixed-Share), n-gram caches (all orders) ✓ + All updates happen AFTER scoring, same as n-gram cache updates + +Two innovations vs our failed Bayesian attempt: + 1. Fixed-Share: prevents weight collapse by redistributing α fraction uniformly + 2. Logistic mixing: operates in log-odds space, amplifies confident experts + P = sigmoid(Σ w_i * logit(p_i)) [PAQ/cmix style] + +Experts: + 1. Neural model (RoPE 16, fixed after training) + 2. Dense bigram cache (online, exact, 4MB) + 3. Dense trigram cache (online, exact, 4GB — FITS in GPU memory!) + 4. Dirichlet CTW order 6 (online, sparse for orders 3-6) + 5. Dirichlet CTW order 12 (online, sparse for orders 7-12) +""" +import sys +sys.stdout.reconfigure(line_buffering=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +import json +import os +import urllib.request +from collections import defaultdict + +VOCAB_SIZE = 1024 +SEQ_LEN = 512 +DIM = 192 +N_HEADS = 6 +N_LAYERS = 6 +MLP_EXP = 2.0 +TRAIN_STEPS = 1500 +BATCH_SIZE = 32 +LR = 3e-4 +DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" + +print(f"Device: {DEVICE}") +print(f"Fixed-Share Logistic Mixer + Dense Trigram Cache") +print() + +# ============================================================ +# Data Loading +# ============================================================ +def download_text_corpus(): + cache_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/text_corpus.txt" + if os.path.exists(cache_path): + with open(cache_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + urls = [ + "https://www.gutenberg.org/cache/epub/1342/pg1342.txt", + "https://www.gutenberg.org/cache/epub/11/pg11.txt", + "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "https://www.gutenberg.org/cache/epub/1661/pg1661.txt", + ] + all_text = [] + for url in urls: + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(req, timeout=30) + text = response.read().decode('utf-8', errors='ignore') + start = text.find("*** START OF") + if start != -1: start = text.find("\n", start) + 1 + else: start = 0 + end = text.find("*** END OF") + if end == -1: end = len(text) + all_text.append(text[start:end]) + except Exception as e: + print(f" Failed: {e}", flush=True) + corpus = "\n\n".join(all_text) + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(corpus) + return corpus + +def tokenize_text(text, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN+1): + tokens = [] + text_bytes = text.encode('utf-8', errors='ignore') + for i in range(len(text_bytes)): + byte_val = text_bytes[i] + if i + 1 < len(text_bytes): + bigram = (text_bytes[i] << 8) | text_bytes[i + 1] + bigram_slot = 256 + (bigram % (vocab_size - 256)) + if bigram % 3 == 0: + tokens.append(bigram_slot) + continue + tokens.append(byte_val) + n_seq = len(tokens) // seq_len + tokens = tokens[:n_seq * seq_len] + return torch.tensor(tokens, dtype=torch.long).reshape(n_seq, seq_len) + +# ============================================================ +# Model (RoPE 16) +# ============================================================ +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * self.scale + +class GEGLU_MLP(nn.Module): + def __init__(self, dim, expansion=2.0): + super().__init__() + hidden = int(dim * expansion) + self.gate = nn.Linear(dim, hidden, bias=False) + self.up = nn.Linear(dim, hidden, bias=False) + self.down = nn.Linear(hidden, dim, bias=False) + def forward(self, x): + return self.down(F.gelu(self.gate(x)) * self.up(x)) + +class FullMHA(nn.Module): + def __init__(self, dim, n_heads, rope_dims=16): + super().__init__() + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.qkv = nn.Linear(dim, 3 * dim, bias=False) + self.out = nn.Linear(dim, dim, bias=False) + self.rope_dims = rope_dims + freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dims, 2).float() / rope_dims)) + t = torch.arange(SEQ_LEN).float() + freqs = torch.outer(t, freqs) + self.register_buffer('cos_cache', freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False) + self.register_buffer('sin_cache', freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False) + + def _apply_rope(self, x): + rd = self.rope_dims + x_rope, x_pass = x[..., :rd], x[..., rd:] + x1, x2 = x_rope[..., :rd//2], x_rope[..., rd//2:] + cos = self.cos_cache[:, :, :x.size(2), :] + sin = self.sin_cache[:, :, :x.size(2), :] + x_rope_out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return torch.cat([x_rope_out, x_pass], dim=-1) + + def forward(self, x): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = qkv.unbind(2) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + q, k = self._apply_rope(q), self._apply_rope(k) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out(y.transpose(1, 2).reshape(B, T, C)) + +class TransformerBlock(nn.Module): + def __init__(self, dim, n_heads, mlp_expansion=2.0): + super().__init__() + self.ln1 = RMSNorm(dim) + self.attn = FullMHA(dim, n_heads) + self.ln2 = RMSNorm(dim) + self.mlp = GEGLU_MLP(dim, expansion=mlp_expansion) + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class Transformer(nn.Module): + def __init__(self): + super().__init__() + self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM) + self.blocks = nn.ModuleList([ + TransformerBlock(DIM, N_HEADS, MLP_EXP) for _ in range(N_LAYERS) + ]) + self.ln_f = RMSNorm(DIM) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + def forward(self, idx): + x = self.tok_emb(idx) + for block in self.blocks: + x = block(x) + return F.linear(self.ln_f(x), self.tok_emb.weight) + +# ============================================================ +# EXPERT 1: Dense Bigram Cache (4MB, exact, O(1)) +# ============================================================ +class DenseBigramCache: + """Dense exact bigram table. 1024² × 4B = 4MB. + LEGALITY: Built causally from scored tokens. update() called AFTER scoring.""" + def __init__(self, vocab_size=VOCAB_SIZE, smoothing=0.5): + self.V = vocab_size + self.smoothing = smoothing + self.counts = np.zeros((vocab_size, vocab_size), dtype=np.uint32) + self.totals = np.zeros(vocab_size, dtype=np.uint32) + + def get_distribution(self, prev_token): + """BEFORE scoring: predict next token given previous.""" + counts = self.counts[prev_token].astype(np.float64) + total = self.totals[prev_token] + c = self.smoothing + uniform = np.ones(self.V, dtype=np.float64) / self.V + if total == 0: + return uniform + return (counts + c * uniform) / (total + c) + + def update(self, prev_token, curr_token): + """AFTER scoring: add observation.""" + self.counts[prev_token][curr_token] += 1 + self.totals[prev_token] += 1 + +# ============================================================ +# EXPERT 2: Dense Trigram Cache (4GB on GPU, exact, O(1)) +# ============================================================ +class DenseTrigramCache: + """Dense exact trigram table. 1024³ × 4B = 4GB. + On 8×H100 (640GB HBM), this is trivial. On Mac Mini, we use uint16 + to keep it at 2GB, or skip if not enough RAM. + + LEGALITY: Built causally from scored tokens. update() called AFTER scoring. + """ + def __init__(self, vocab_size=VOCAB_SIZE, smoothing=1.0, use_compact=True): + self.V = vocab_size + self.smoothing = smoothing + self.use_compact = use_compact + + if use_compact: + # uint16 = 2 bytes per entry, max count 65535 + # 1024³ × 2B = 2GB — fits on Mac Mini with 16GB RAM (tight) + try: + self.counts = np.zeros((vocab_size, vocab_size, vocab_size), dtype=np.uint16) + self.totals = np.zeros((vocab_size, vocab_size), dtype=np.uint32) + self.available = True + print(f" Dense trigram cache allocated: {self.counts.nbytes / 1e9:.1f}GB", flush=True) + except MemoryError: + print(f" Dense trigram cache: not enough RAM, falling back to sparse", flush=True) + self.available = False + self.sparse_counts = defaultdict(lambda: defaultdict(int)) + self.sparse_totals = defaultdict(int) + else: + # Full uint32 = 4GB + self.counts = np.zeros((vocab_size, vocab_size, vocab_size), dtype=np.uint32) + self.totals = np.zeros((vocab_size, vocab_size), dtype=np.uint32) + self.available = True + + def get_distribution(self, prev2, prev1): + """BEFORE scoring: predict next given 2-token context.""" + if self.available: + counts = self.counts[prev2, prev1].astype(np.float64) + total = self.totals[prev2, prev1] + else: + key = (prev2, prev1) + total = self.sparse_totals[key] + counts = np.zeros(self.V, dtype=np.float64) + if key in self.sparse_counts: + for tok, cnt in self.sparse_counts[key].items(): + counts[tok] = cnt + + c = self.smoothing + uniform = np.ones(self.V, dtype=np.float64) / self.V + if total == 0: + return uniform + return (counts + c * uniform) / (total + c) + + def update(self, prev2, prev1, curr_token): + """AFTER scoring: add observation.""" + if self.available: + self.counts[prev2, prev1, curr_token] = min( + self.counts[prev2, prev1, curr_token] + 1, 65535 + ) if self.use_compact else self.counts[prev2, prev1, curr_token] + 1 + self.totals[prev2, prev1] += 1 + else: + key = (prev2, prev1) + self.sparse_counts[key][curr_token] += 1 + self.sparse_totals[key] += 1 + +# ============================================================ +# EXPERT 3: Dirichlet CTW (sparse higher-order, orders 1-K) +# ============================================================ +class DirichletCTWExpert: + """Dirichlet-smoothed CTW with backoff from order K down to uniform. + Orders 1-2 use dense arrays. Orders 3+ use sparse dicts. + LEGALITY: Built from scored tokens. update() called AFTER scoring.""" + def __init__(self, vocab_size=VOCAB_SIZE, max_order=6, concentrations=None): + self.V = vocab_size + self.max_order = max_order + if concentrations is None: + self.concentrations = {k: 0.5 * k for k in range(1, max_order + 1)} + else: + self.concentrations = concentrations + self.unigram_counts = np.zeros(vocab_size, dtype=np.uint32) + self.unigram_total = 0 + self.higher_counts = {} + self.higher_totals = {} + for k in range(3, max_order + 1): + self.higher_counts[k] = defaultdict(lambda: defaultdict(int)) + self.higher_totals[k] = defaultdict(int) + self.history = [] + + def update(self, token): + """AFTER scoring.""" + self.unigram_counts[token] += 1 + self.unigram_total += 1 + for k in range(3, self.max_order + 1): + if len(self.history) >= k - 1: + ctx = tuple(self.history[-(k-1):]) + self.higher_counts[k][ctx][token] += 1 + self.higher_totals[k][ctx] += 1 + self.history.append(token) + + def get_distribution(self, bigram_dist, trigram_dist, context_tokens): + """BEFORE scoring. Takes bigram/trigram distributions from dense caches, + applies higher-order backoff on top.""" + # Start with the trigram distribution as base (already includes bigram/unigram backoff) + p = trigram_dist.copy() + + # Apply higher-order backoff (orders 3+) + for k in range(3, min(self.max_order + 1, len(context_tokens) + 2)): + if len(context_tokens) >= k - 1: + ctx = tuple(context_tokens[-(k-1):]) + total = self.higher_totals[k].get(ctx, 0) + if total > 0: + ck = self.concentrations.get(k, 0.5 * k) + counts_dict = self.higher_counts[k][ctx] + counts = np.zeros(self.V, dtype=np.float64) + for tok, cnt in counts_dict.items(): + counts[tok] = cnt + p = (counts + ck * p) / (total + ck) + return p + +# ============================================================ +# EXPERT 4: Error-Pattern Expert (NOVEL) +# ============================================================ +class ErrorPatternExpert: + """Predicts the neural model's systematic errors. + + Maintains a table: (prev_token, neural_argmax) → distribution over actual tokens. + When the neural model predicts token X after token Y, this expert knows + the historical distribution of what ACTUALLY followed in that scenario. + + LEGALITY: Built from scored tokens only. Uses neural prediction (computed before + scoring) and actual token (observed after scoring) to update the table. + Predictions use only past information. + """ + def __init__(self, vocab_size=VOCAB_SIZE, smoothing=1.0): + self.V = vocab_size + self.smoothing = smoothing + # (prev_token, neural_argmax) → counts[actual_token] + self.counts = np.zeros((vocab_size, vocab_size, vocab_size), dtype=np.uint16) + self.totals = np.zeros((vocab_size, vocab_size), dtype=np.uint32) + self.available = False # will check if memory allows + + def try_allocate(self): + """Try to allocate the dense table. Falls back to sparse if OOM.""" + try: + self.counts = np.zeros((self.V, self.V, self.V), dtype=np.uint16) + self.totals = np.zeros((self.V, self.V), dtype=np.uint32) + self.available = True + print(f" Error pattern expert allocated: {self.counts.nbytes / 1e9:.1f}GB", flush=True) + except MemoryError: + print(f" Error pattern expert: using sparse fallback", flush=True) + self.available = False + self.sparse_counts = defaultdict(lambda: defaultdict(int)) + self.sparse_totals = defaultdict(int) + + def get_distribution(self, prev_token, neural_argmax): + """BEFORE scoring: what actually follows when neural predicts argmax after prev?""" + if self.available: + counts = self.counts[prev_token, neural_argmax].astype(np.float64) + total = self.totals[prev_token, neural_argmax] + else: + key = (prev_token, neural_argmax) + total = self.sparse_totals.get(key, 0) + counts = np.zeros(self.V, dtype=np.float64) + if key in self.sparse_counts: + for tok, cnt in self.sparse_counts[key].items(): + counts[tok] = cnt + + c = self.smoothing + uniform = np.ones(self.V, dtype=np.float64) / self.V + if total == 0: + return uniform + return (counts + c * uniform) / (total + c) + + def update(self, prev_token, neural_argmax, actual_token): + """AFTER scoring: record what actually happened.""" + if self.available: + self.counts[prev_token, neural_argmax, actual_token] = min( + self.counts[prev_token, neural_argmax, actual_token] + 1, 65535 + ) + self.totals[prev_token, neural_argmax] += 1 + else: + key = (prev_token, neural_argmax) + self.sparse_counts[key][actual_token] += 1 + self.sparse_totals[key] += 1 + +# ============================================================ +# Fixed-Share Mixer (Linear + Log-Linear options) +# ============================================================ +class FixedShareMixer: + """ + Combines K expert distributions using Fixed-Share weight updates. + + TWO MIXING MODES: + 1. LINEAR: P_mix = Σ_k w_k * P_k (standard weighted average) + - Safe, guaranteed normalized, well-behaved + 2. LOG-LINEAR: P_mix ∝ Π_k P_k^{w_k} (product-of-experts) + - Amplifies agreement between experts (sharper distributions) + - Better when experts make independent errors + + NOTE: PAQ-style logistic mixing is designed for BINARY prediction + (per-bit in arithmetic coder). For multi-class (1024 vocab), it + produces degenerate distributions. DON'T USE IT FOR TOKENS. + + FIXED-SHARE UPDATE (Herbster & Warmuth 1998): + After scoring, update: w_k ← (1-α) * w_k * P_k(x_true) + α/K + α controls how much weight is redistributed uniformly. + Prevents any expert from reaching zero weight (fixes Bayesian collapse). + + LEGALITY: All updates AFTER scoring. Same protocol as n-gram cache. + """ + def __init__(self, n_experts, alpha=0.01, mode='linear', expert_names=None): + self.K = n_experts + self.alpha = alpha # Fixed-Share redistribution rate + self.mode = mode # 'linear' or 'loglinear' + self.weights = np.ones(n_experts, dtype=np.float64) / n_experts + self.expert_names = expert_names or [f"Expert_{i}" for i in range(n_experts)] + self.total_tokens = 0 + + def get_mixture(self, expert_distributions): + """ + STEP 1 (BEFORE SCORING): Compute mixture distribution. + Returns full normalized distribution over vocabulary. + """ + K = len(expert_distributions) + V = expert_distributions[0].shape[0] + + if self.mode == 'loglinear': + # Log-linear: P ∝ Π_k P_k^{w_k} + eps = 1e-30 + log_mixture = np.zeros(V, dtype=np.float64) + for k in range(K): + log_mixture += self.weights[k] * np.log(np.maximum(expert_distributions[k], eps)) + # Numerical stability: subtract max before exp + log_mixture -= log_mixture.max() + mixture = np.exp(log_mixture) + mixture /= mixture.sum() + else: + # Linear: P = Σ_k w_k * P_k + mixture = np.zeros(V, dtype=np.float64) + for k in range(K): + mixture += self.weights[k] * expert_distributions[k] + # Ensure normalized (should already be, but safety) + s = mixture.sum() + if s > 0: + mixture /= s + + return mixture + + def observe(self, expert_distributions, true_token): + """ + STEP 3 (AFTER SCORING): Update expert weights via Fixed-Share. + w_k ← (1-α) * w_k * P_k(x_true) / Z + α/K + """ + # Multiplicative update: weight each expert by how well it predicted + for k in range(self.K): + p_k = max(expert_distributions[k][true_token], 1e-30) + self.weights[k] *= p_k + + # Normalize + w_sum = self.weights.sum() + if w_sum > 0: + self.weights /= w_sum + + # Fixed-Share redistribution: prevent any expert from dying + self.weights = (1 - self.alpha) * self.weights + self.alpha / self.K + + self.total_tokens += 1 + + def get_weights_summary(self): + return {n: f"{w:.4f}" for n, w in zip(self.expert_names, self.weights)} + +# ============================================================ +# Training +# ============================================================ +def train_model(train_seq, eval_seq): + model = Transformer().to(DEVICE) + n_params = sum(p.numel() for p in model.parameters()) + print(f" Training RoPE 16 model: {n_params:,} params", flush=True) + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TRAIN_STEPS) + t0 = time.time() + for step in range(TRAIN_STEPS + 1): + if step % 500 == 0: + model.eval() + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + ce = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), eb[:, 1:].reshape(-1)) + print(f" Step {step:4d} | CE: {ce:.4f} | {(time.time()-t0)/max(step,1)*1000:.0f}ms/step", flush=True) + model.train() + if step >= TRAIN_STEPS: + break + bi = torch.randint(0, train_seq.size(0), (BATCH_SIZE,)) + batch = train_seq[bi].to(DEVICE) + logits = model(batch[:, :-1]) + loss = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), batch[:, 1:].reshape(-1)) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + print(f" Training done in {time.time()-t0:.1f}s", flush=True) + return model + +# ============================================================ +# Evaluation Methods +# ============================================================ + +def eval_neural_only(probs, sequences): + """Baseline: neural only.""" + total_bits = 0.0 + scored = 0 + for i in range(len(sequences)): + for t in range(sequences.shape[1] - 1): + target = sequences[i, t + 1] + p = max(probs[i, t, target], 1e-30) + total_bits += -math.log2(p) + scored += 1 + return total_bits / scored + +def eval_entropy_adaptive_ctw(probs, sequences, max_order=6): + """Our current best: entropy-adaptive mixing with Dirichlet CTW.""" + V = VOCAB_SIZE + ctw = DirichletCTWExpert(max_order=max_order) + bigram_cache = DenseBigramCache() + + total_bits = 0.0 + scored = 0 + + for i in range(len(sequences)): + for t in range(sequences.shape[1] - 1): + target = sequences[i, t + 1] + neural_p = probs[i, t].astype(np.float64) + neural_p = neural_p / neural_p.sum() + + # Entropy-adaptive alpha + H = -np.sum(neural_p * np.log2(neural_p + 1e-30)) + alpha = 0.05 + 0.55 / (1.0 + np.exp(-2.0 * (H - 4.0))) + + # Get n-gram distribution + context = list(sequences[i, max(0, t - max_order + 1):t + 1]) + prev = context[-1] if len(context) > 0 else 0 + bigram_p = bigram_cache.get_distribution(prev) + + # Use unigram backoff for CTW base + if ctw.unigram_total > 0: + c1 = ctw.concentrations.get(1, 0.5) + uni_p = (ctw.unigram_counts.astype(np.float64) + c1 / V) / (ctw.unigram_total + c1) + else: + uni_p = np.ones(V) / V + + # CTW higher orders + ngram_p = ctw.get_distribution(bigram_p, bigram_p, context) + + # Mix + mixed = (1 - alpha) * neural_p + alpha * ngram_p + p_token = max(mixed[target], 1e-30) + total_bits += -math.log2(p_token) + scored += 1 + + # UPDATE AFTER SCORING + bigram_cache.update(prev, target) + ctw.update(target) + + return total_bits / scored + +def eval_fixed_share(probs, sequences, use_trigram=False, use_error_expert=False, + max_ctw_order=6, alpha=0.01, mode='linear', label=""): + """Fixed-Share Mixer with multiple experts.""" + V = VOCAB_SIZE + + # Set up experts + expert_names = ["Neural", "Bigram", "CTW"] + bigram_cache = DenseBigramCache() + ctw = DirichletCTWExpert(max_order=max_ctw_order) + + trigram_cache = None + error_expert = None + + if use_trigram: + trigram_cache = DenseTrigramCache(use_compact=True) + if not trigram_cache.available: + # Fall back to sparse if dense doesn't fit + pass + expert_names.append("Trigram") + + if use_error_expert: + error_expert = ErrorPatternExpert() + error_expert.available = False # force sparse on Mac Mini + error_expert.sparse_counts = defaultdict(lambda: defaultdict(int)) + error_expert.sparse_totals = defaultdict(int) + expert_names.append("ErrorPat") + + n_experts = len(expert_names) + mixer = FixedShareMixer(n_experts, alpha=alpha, mode=mode, expert_names=expert_names) + + total_bits = 0.0 + scored = 0 + weight_snapshots = [] + + for i in range(len(sequences)): + for t in range(sequences.shape[1] - 1): + target = sequences[i, t + 1] + prev = sequences[i, t] if t > 0 else 0 + prev2 = sequences[i, t - 1] if t > 1 else 0 + + # --- STEP 1: PREDICT (before scoring) --- + neural_p = probs[i, t].astype(np.float64) + neural_p = np.clip(neural_p, 1e-10, None) + neural_p = neural_p / neural_p.sum() + + # Expert distributions + expert_dists = [neural_p] + + # Bigram expert + bigram_p = bigram_cache.get_distribution(sequences[i, t]) + expert_dists.append(bigram_p) + + # CTW expert (uses bigram as base for backoff) + context = list(sequences[i, max(0, t - max_ctw_order + 1):t + 1]) + ctw_p = ctw.get_distribution(bigram_p, bigram_p, context) + expert_dists.append(ctw_p) + + # Trigram expert (if enabled) + if trigram_cache is not None: + if t >= 2: + tri_p = trigram_cache.get_distribution(prev2, prev) + else: + tri_p = bigram_p # fall back to bigram for first 2 positions + expert_dists.append(tri_p) + + # Error-pattern expert (if enabled) + if error_expert is not None: + neural_argmax = int(np.argmax(neural_p)) + err_p = error_expert.get_distribution(sequences[i, t], neural_argmax) + expert_dists.append(err_p) + + # Compute mixture + mixed = mixer.get_mixture(expert_dists) + + # --- STEP 2: SCORE --- + p_token = max(mixed[target], 1e-30) + total_bits += -math.log2(p_token) + scored += 1 + + # --- STEP 3: UPDATE (after scoring) --- + mixer.observe(expert_dists, target) + bigram_cache.update(sequences[i, t], target) + ctw.update(target) + + if trigram_cache is not None and t >= 2: + trigram_cache.update(prev2, prev, target) + + if error_expert is not None: + neural_argmax = int(np.argmax(neural_p)) + error_expert.update(sequences[i, t], neural_argmax, target) + + # Weight snapshots + if scored % 5000 == 0: + weight_snapshots.append({ + 'scored': scored, + 'weights': mixer.weights.copy(), + 'bpc': total_bits / scored + }) + + bpc = total_bits / scored + + # Print weight evolution + if weight_snapshots and label: + print(f" [{label}] Weight evolution:", flush=True) + header = " " + "".join(f"{n:>10}" for n in expert_names) + " BPC" + print(header, flush=True) + for snap in weight_snapshots[::max(1, len(weight_snapshots)//5)]: + row = " " + for w in snap['weights']: + row += f"{w:10.4f}" + row += f" {snap['bpc']:.4f}" + print(row, flush=True) + + return bpc + +# ============================================================ +# Main +# ============================================================ +if __name__ == "__main__": + print("=" * 70) + print("Loading data...", flush=True) + corpus = download_text_corpus() + all_sequences = tokenize_text(corpus) + n_train = int(len(all_sequences) * 0.9) + train_seq = all_sequences[:n_train] + eval_seq = all_sequences[n_train:] + print(f" Train: {train_seq.shape}, Eval: {eval_seq.shape}") + + # Train neural model once (or load cached) + MODEL_CACHE = "/Users/himanshudongre/Documents/GitHub/parameter_golf/cached_rope16_model.pt" + PROBS_CACHE = "/Users/himanshudongre/Documents/GitHub/parameter_golf/cached_neural_probs.npz" + + if os.path.exists(MODEL_CACHE): + print("\n" + "=" * 70) + print("Loading CACHED RoPE 16 neural model") + print("=" * 70) + model = Transformer().to(DEVICE) + model.load_state_dict(torch.load(MODEL_CACHE, map_location=DEVICE, weights_only=True)) + print(" Loaded cached model!", flush=True) + else: + print("\n" + "=" * 70) + print("Training RoPE 16 neural model") + print("=" * 70) + model = train_model(train_seq, eval_seq) + torch.save(model.state_dict(), MODEL_CACHE) + print(f" Model cached to {MODEL_CACHE}", flush=True) + + # Get neural probabilities (reused by all eval methods) + sequences = eval_seq[:100].numpy() + if os.path.exists(PROBS_CACHE): + print("\nLoading cached neural probabilities...", flush=True) + data = np.load(PROBS_CACHE) + probs = data['probs'] + print(f" Probs shape: {probs.shape}") + else: + print("\nComputing neural probabilities...", flush=True) + model.eval() + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + probs = F.softmax(logits, dim=-1).cpu().numpy() + np.savez_compressed(PROBS_CACHE, probs=probs) + print(f" Probs shape: {probs.shape} (cached)") + + results = {} + + # --- A: Neural only --- + print("\n" + "=" * 70) + print("A: Neural only (baseline)") + print("=" * 70) + bpc_a = eval_neural_only(probs, sequences) + results["neural_only"] = bpc_a + print(f" BPC: {bpc_a:.4f}", flush=True) + + # --- B: Entropy-adaptive CTW-6 (our current best) --- + print("\n" + "=" * 70) + print("B: Entropy-adaptive CTW-6 (current best)") + print("=" * 70) + bpc_b = eval_entropy_adaptive_ctw(probs, sequences, max_order=6) + results["entropy_ctw6"] = bpc_b + print(f" BPC: {bpc_b:.4f}", flush=True) + + # --- C: Fixed-Share Linear (Neural + Bigram + CTW-6) --- + print("\n" + "=" * 70) + print("C: Fixed-Share Linear (Neural + Bigram + CTW-6)") + print("=" * 70) + bpc_c = eval_fixed_share(probs, sequences, max_ctw_order=6, + alpha=0.01, mode='linear', label="FS-linear-3exp") + results["fs_linear_3exp"] = bpc_c + print(f" BPC: {bpc_c:.4f}", flush=True) + + # --- C2: Fixed-Share Log-Linear (product-of-experts) --- + print("\n" + "=" * 70) + print("C2: Fixed-Share Log-Linear (Neural + Bigram + CTW-6)") + print("=" * 70) + bpc_c2 = eval_fixed_share(probs, sequences, max_ctw_order=6, + alpha=0.01, mode='loglinear', label="FS-loglinear-3exp") + results["fs_loglinear_3exp"] = bpc_c2 + print(f" BPC: {bpc_c2:.4f}", flush=True) + + # --- D: Fixed-Share + Dense Trigram --- + print("\n" + "=" * 70) + print("D: Fixed-Share Linear + Dense Trigram (4 experts)") + print("=" * 70) + bpc_d = eval_fixed_share(probs, sequences, use_trigram=True, max_ctw_order=6, + alpha=0.01, mode='linear', label="FS+Trigram") + results["fs_trigram"] = bpc_d + print(f" BPC: {bpc_d:.4f}", flush=True) + + # --- E: Fixed-Share + Error-Pattern expert --- + print("\n" + "=" * 70) + print("E: Fixed-Share Linear + Error Pattern (4 experts)") + print("=" * 70) + bpc_e = eval_fixed_share(probs, sequences, use_error_expert=True, max_ctw_order=6, + alpha=0.01, mode='linear', label="FS+ErrorPat") + results["fs_error"] = bpc_e + print(f" BPC: {bpc_e:.4f}", flush=True) + + # --- F: Fixed-Share FULL (all experts) --- + print("\n" + "=" * 70) + print("F: Fixed-Share Linear FULL (5 experts)") + print("=" * 70) + bpc_f = eval_fixed_share(probs, sequences, use_trigram=True, use_error_expert=True, + max_ctw_order=6, alpha=0.01, mode='linear', label="FS-FULL") + results["fs_full"] = bpc_f + print(f" BPC: {bpc_f:.4f}", flush=True) + + # --- G: Fixed-Share with CTW-12 --- + print("\n" + "=" * 70) + print("G: Fixed-Share Linear (Neural + Bigram + CTW-12)") + print("=" * 70) + bpc_g = eval_fixed_share(probs, sequences, max_ctw_order=12, + alpha=0.01, mode='linear', label="FS+CTW12") + results["fs_ctw12"] = bpc_g + print(f" BPC: {bpc_g:.4f}", flush=True) + + # --- H: Alpha sweep --- + print("\n" + "=" * 70) + print("H: Alpha sweep (Fixed-Share redistribution rate)") + print("=" * 70) + for alpha_val in [0.001, 0.005, 0.02, 0.05, 0.1, 0.2]: + bpc_h = eval_fixed_share(probs, sequences, max_ctw_order=6, + alpha=alpha_val, mode='linear', label=f"α={alpha_val}") + results[f"fs_alpha_{alpha_val}"] = bpc_h + print(f" α={alpha_val}: BPC={bpc_h:.4f}", flush=True) + + # --- I: Log-linear with best alpha --- + print("\n" + "=" * 70) + print("I: Log-Linear mode sweep") + print("=" * 70) + for alpha_val in [0.01, 0.05, 0.1]: + bpc_i = eval_fixed_share(probs, sequences, max_ctw_order=6, + alpha=alpha_val, mode='loglinear', label=f"LL-α={alpha_val}") + results[f"fs_loglinear_alpha_{alpha_val}"] = bpc_i + print(f" Log-Linear α={alpha_val}: BPC={bpc_i:.4f}", flush=True) + + # ============================================================ + # Summary + # ============================================================ + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + + baseline = results["neural_only"] + current_best = results["entropy_ctw6"] + + print(f"\n{'Method':<55} {'BPC':>8} {'vs Neural':>10} {'vs Current':>11}") + print("-" * 88) + for key, bpc in results.items(): + vs_neural = (bpc - baseline) / baseline * 100 + vs_current = (bpc - current_best) / current_best * 100 + label = key.replace("_", " ") + print(f" {label:<53} {bpc:8.4f} {vs_neural:+9.2f}% {vs_current:+10.2f}%") + + print("\nKEY QUESTIONS:") + if results.get("fs_3experts", 999) < current_best: + print(f" ✓ Fixed-Share Logistic beats entropy-adaptive!") + else: + print(f" ✗ Entropy-adaptive still better (but check full config)") + + if results.get("fs_trigram", 999) < results.get("fs_3experts", 999): + print(f" ✓ Dense trigram expert adds value!") + else: + print(f" ✗ Dense trigram doesn't help (at this data size)") + + if results.get("fs_error", 999) < results.get("fs_3experts", 999): + print(f" ✓ Error-pattern expert adds value!") + else: + print(f" ✗ Error-pattern doesn't help yet") + + # Save + with open("/Users/himanshudongre/Documents/GitHub/parameter_golf/fixed_share_results.json", 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to fixed_share_results.json") diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_local_distill_qat.py b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_local_distill_qat.py new file mode 100644 index 0000000000..5fca51a308 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_local_distill_qat.py @@ -0,0 +1,524 @@ +""" +Local Test: Knowledge Distillation + QAT int4 (Train Bigger) +============================================================= + +Tests the "train big, compress small" strategy at LOCAL SCALE: + - Teacher: 6L 256d (~5M params) — trains at FP32 + - Student: 4L 192d (~1.5M params) — trains with QAT int4, soft targets + - Comparison: Direct 4L 192d training at FP32 and QAT int4 + +This validates the CONCEPT before RunPod: + 1. Does distillation beat direct training at same param budget? + 2. Does QAT int4 maintain quality at higher compression? + 3. Does a larger teacher meaningfully improve student quality? + +SCALING TO RUNPOD: + Teacher: 12L 768d (~150M params) at FP16/FP8 + Student: 12L 576d (~40M params) at QAT int4 + Store student at int4 + NF4 centroids + Brotli = ~16MB + +ALSO TESTS: Product Quantization compression + - Compress 5M teacher to see reconstruction quality +""" +import sys +sys.stdout.reconfigure(line_buffering=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +import json +import os +import urllib.request + +VOCAB_SIZE = 1024 +SEQ_LEN = 512 +DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" +TRAIN_STEPS = 1500 +BATCH_SIZE = 32 +LR = 3e-4 + +print(f"Device: {DEVICE}") +print(f"Local Distillation + QAT Test") +print() + +# ============================================================ +# Data Loading +# ============================================================ +def download_text_corpus(): + cache_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/text_corpus.txt" + if os.path.exists(cache_path): + with open(cache_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + urls = [ + "https://www.gutenberg.org/cache/epub/1342/pg1342.txt", + "https://www.gutenberg.org/cache/epub/11/pg11.txt", + "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "https://www.gutenberg.org/cache/epub/1661/pg1661.txt", + ] + all_text = [] + for url in urls: + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(req, timeout=30) + text = response.read().decode('utf-8', errors='ignore') + all_text.append(text) + except Exception as e: + print(f" Failed: {e}") + combined = "\n\n".join(all_text) + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(combined) + return combined + +def tokenize_text(text, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN): + raw_bytes = text.encode('utf-8') + tokens = [b % vocab_size for b in raw_bytes] + n_seq = len(tokens) // (seq_len + 1) + tokens = tokens[:n_seq * (seq_len + 1)] + return torch.tensor(tokens, dtype=torch.long).view(n_seq, seq_len + 1) + +# ============================================================ +# Quantization +# ============================================================ +def uniform_quantize_ste(w, n_bits=4): + """Uniform quantization with STE for QAT.""" + n_levels = 2 ** n_bits + w_min = w.min(dim=-1, keepdim=True).values + w_max = w.max(dim=-1, keepdim=True).values + scale = (w_max - w_min) / (n_levels - 1) + scale = scale.clamp(min=1e-8) + w_norm = (w - w_min) / scale + w_int = w_norm.round().clamp(0, n_levels - 1) + w_q = w_int * scale + w_min + return w + (w_q - w).detach() + +# ============================================================ +# Model Components +# ============================================================ +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.eps = eps + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.scale + +class MLP(nn.Module): + def __init__(self, dim, expansion=2.0, quant_fn=None): + super().__init__() + hidden = int(dim * expansion) + self.gate = nn.Linear(dim, hidden, bias=False) + self.up = nn.Linear(dim, hidden, bias=False) + self.down = nn.Linear(hidden, dim, bias=False) + self.quant_fn = quant_fn + def forward(self, x): + if self.quant_fn: + g = F.linear(x, self.quant_fn(self.gate.weight)) + u = F.linear(x, self.quant_fn(self.up.weight)) + return F.linear(F.gelu(g) * u, self.quant_fn(self.down.weight)) + return self.down(F.gelu(self.gate(x)) * self.up(x)) + +class Attention(nn.Module): + def __init__(self, dim, n_heads, rope_dims=16, quant_fn=None): + super().__init__() + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.qkv = nn.Linear(dim, 3 * dim, bias=False) + self.out = nn.Linear(dim, dim, bias=False) + self.quant_fn = quant_fn + self.rope_dims = rope_dims + freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dims, 2).float() / rope_dims)) + t = torch.arange(SEQ_LEN).float() + freqs = torch.outer(t, freqs) + self.register_buffer('cos_cache', freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False) + self.register_buffer('sin_cache', freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False) + + def _apply_rope(self, x): + rd = self.rope_dims + x_rope, x_pass = x[..., :rd], x[..., rd:] + x1, x2 = x_rope[..., :rd//2], x_rope[..., rd//2:] + cos = self.cos_cache[:, :, :x.size(2), :] + sin = self.sin_cache[:, :, :x.size(2), :] + out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return torch.cat([out, x_pass], dim=-1) + + def forward(self, x): + B, T, C = x.shape + if self.quant_fn: + qkv = F.linear(x, self.quant_fn(self.qkv.weight)) + else: + qkv = self.qkv(x) + qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = qkv.unbind(2) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + q, k = self._apply_rope(q), self._apply_rope(k) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = y.transpose(1, 2).reshape(B, T, C) + if self.quant_fn: + return F.linear(y, self.quant_fn(self.out.weight)) + return self.out(y) + +class Block(nn.Module): + def __init__(self, dim, n_heads, expansion=2.0, quant_fn=None): + super().__init__() + self.ln1 = RMSNorm(dim) + self.attn = Attention(dim, n_heads, quant_fn=quant_fn) + self.ln2 = RMSNorm(dim) + self.mlp = MLP(dim, expansion, quant_fn=quant_fn) + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class LM(nn.Module): + def __init__(self, dim, n_layers, n_heads, expansion=2.0, quant_fn=None): + super().__init__() + self.tok_emb = nn.Embedding(VOCAB_SIZE, dim) + self.blocks = nn.ModuleList([ + Block(dim, n_heads, expansion, quant_fn=quant_fn) + for _ in range(n_layers) + ]) + self.ln_f = RMSNorm(dim) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + def forward(self, idx): + x = self.tok_emb(idx) + for block in self.blocks: + x = block(x) + return F.linear(self.ln_f(x), self.tok_emb.weight) + +# ============================================================ +# Training Functions +# ============================================================ +def eval_ce(model, eval_seq): + model.eval() + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + ce = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), eb[:, 1:].reshape(-1)) + return ce.item() + +def train_standard(model, train_seq, eval_seq, steps=TRAIN_STEPS, label=""): + """Standard training with hard labels.""" + model = model.to(DEVICE) + n_params = sum(p.numel() for p in model.parameters()) + print(f" [{label}] {n_params:,} params", flush=True) + + optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps) + + t0 = time.time() + for step in range(steps + 1): + if step % 500 == 0: + ce = eval_ce(model, eval_seq) + print(f" Step {step:4d} | CE: {ce:.4f}", flush=True) + model.train() + if step >= steps: + break + bi = torch.randint(0, train_seq.size(0), (BATCH_SIZE,)) + batch = train_seq[bi].to(DEVICE) + logits = model(batch[:, :-1]) + loss = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE), batch[:, 1:].reshape(-1)) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + final_ce = eval_ce(model, eval_seq) + elapsed = time.time() - t0 + print(f" Final CE: {final_ce:.4f} ({elapsed:.0f}s)", flush=True) + return final_ce + +def train_distilled(student, teacher, train_seq, eval_seq, steps=1000, + temperature=3.0, alpha=0.5, label=""): + """Knowledge distillation: student learns from teacher's soft targets. + + Loss = α * CE(student, hard_label) + (1-α) * T² * KL(student/T || teacher/T) + + The T² factor compensates for the gradient scaling from temperature. + """ + student = student.to(DEVICE) + teacher = teacher.to(DEVICE) + teacher.eval() + + n_params = sum(p.numel() for p in student.parameters()) + print(f" [{label}] Student: {n_params:,} params, T={temperature}, α={alpha}", flush=True) + + optimizer = torch.optim.AdamW(student.parameters(), lr=LR, weight_decay=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps) + + t0 = time.time() + for step in range(steps + 1): + if step % 500 == 0: + ce = eval_ce(student, eval_seq) + print(f" Step {step:4d} | CE: {ce:.4f}", flush=True) + student.train() + if step >= steps: + break + + bi = torch.randint(0, train_seq.size(0), (BATCH_SIZE,)) + batch = train_seq[bi].to(DEVICE) + inputs = batch[:, :-1] + targets = batch[:, 1:] + + # Student logits + student_logits = student(inputs) + + # Teacher logits (no grad) + with torch.no_grad(): + teacher_logits = teacher(inputs) + + # Hard label loss + hard_loss = F.cross_entropy( + student_logits.reshape(-1, VOCAB_SIZE), + targets.reshape(-1) + ) + + # Soft label loss (KL divergence at temperature T) + student_soft = F.log_softmax(student_logits / temperature, dim=-1) + teacher_soft = F.softmax(teacher_logits / temperature, dim=-1) + soft_loss = F.kl_div( + student_soft.reshape(-1, VOCAB_SIZE), + teacher_soft.reshape(-1, VOCAB_SIZE), + reduction='batchmean' + ) * (temperature ** 2) + + loss = alpha * hard_loss + (1 - alpha) * soft_loss + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) + optimizer.step() + scheduler.step() + + final_ce = eval_ce(student, eval_seq) + elapsed = time.time() - t0 + print(f" Final CE: {final_ce:.4f} ({elapsed:.0f}s)", flush=True) + return final_ce + +# ============================================================ +# Product Quantization +# ============================================================ +def product_quantize_model(model, group_size=8, n_centroids=256): + """Compress model weights using product quantization. + + Returns: (codebooks, indices, metadata) — everything needed to reconstruct. + """ + compressed = {} + total_original = 0 + total_compressed = 0 + + for name, param in model.named_parameters(): + w = param.detach().cpu().numpy().flatten() + total_original += w.nbytes + + if len(w) < group_size: + compressed[name] = {'type': 'raw', 'data': w} + total_compressed += w.nbytes + continue + + # Pad to multiple of group_size + pad_len = (group_size - len(w) % group_size) % group_size + if pad_len > 0: + w = np.concatenate([w, np.zeros(pad_len, dtype=w.dtype)]) + + # Reshape into groups + groups = w.reshape(-1, group_size) # (N_groups, group_size) + n_groups = groups.shape[0] + + # K-means clustering (simplified: random init + 10 iterations) + # Initialize centroids from random data points + idx = np.random.choice(n_groups, min(n_centroids, n_groups), replace=False) + centroids = groups[idx].copy() + + for _ in range(10): + # Assign each group to nearest centroid + # Compute distances: (N_groups, n_centroids) + dists = np.sum((groups[:, None, :] - centroids[None, :, :]) ** 2, axis=2) + assignments = dists.argmin(axis=1).astype(np.uint8) + + # Update centroids + for c in range(min(n_centroids, n_groups)): + mask = assignments == c + if mask.sum() > 0: + centroids[c] = groups[mask].mean(axis=0) + + compressed[name] = { + 'type': 'pq', + 'centroids': centroids, # (n_centroids, group_size) float32 + 'assignments': assignments, # (n_groups,) uint8 + 'original_len': len(param.detach().cpu().numpy().flatten()), + 'shape': list(param.shape), + } + + # Size: centroids + assignments + cb_size = centroids.nbytes # n_centroids * group_size * 4 + idx_size = assignments.nbytes # n_groups * 1 + total_compressed += cb_size + idx_size + + ratio = total_original / total_compressed + print(f" Product Quantization: {total_original/1024:.1f}KB → {total_compressed/1024:.1f}KB ({ratio:.1f}x compression)", flush=True) + return compressed, total_compressed + +def reconstruct_from_pq(compressed): + """Reconstruct weight tensors from product-quantized data.""" + weights = {} + for name, data in compressed.items(): + if data['type'] == 'raw': + weights[name] = torch.tensor(data['data']) + else: + centroids = data['centroids'] + assignments = data['assignments'] + original_len = data['original_len'] + shape = data['shape'] + + # Reconstruct + reconstructed = centroids[assignments].flatten()[:original_len] + weights[name] = torch.tensor(reconstructed).reshape(shape) + return weights + +# ============================================================ +# Main +# ============================================================ +if __name__ == "__main__": + print("=" * 70) + print("Loading data...", flush=True) + corpus = download_text_corpus() + all_sequences = tokenize_text(corpus) + n_train = int(len(all_sequences) * 0.9) + train_seq = all_sequences[:n_train] + eval_seq = all_sequences[n_train:] + print(f" Train: {train_seq.shape}, Eval: {eval_seq.shape}") + + results = {} + + # ========================================== + # A: Small model direct training (baseline) + # ========================================== + print("\n" + "=" * 70) + print("A: Small model (4L 192d) — direct training, FP32") + print("=" * 70) + torch.manual_seed(42) + small_fp32 = LM(dim=192, n_layers=4, n_heads=6, expansion=2.0) + ce_a = train_standard(small_fp32, train_seq, eval_seq, label="small_fp32") + results["A_small_fp32"] = ce_a + + # ========================================== + # B: Small model with QAT int4 + # ========================================== + print("\n" + "=" * 70) + print("B: Small model (4L 192d) — QAT int4") + print("=" * 70) + torch.manual_seed(42) + small_qat = LM(dim=192, n_layers=4, n_heads=6, expansion=2.0, + quant_fn=uniform_quantize_ste) + ce_b = train_standard(small_qat, train_seq, eval_seq, label="small_qat_int4") + results["B_small_qat_int4"] = ce_b + + # ========================================== + # C: Teacher model (bigger) — direct training + # ========================================== + print("\n" + "=" * 70) + print("C: Teacher model (6L 256d) — direct training, FP32") + print("=" * 70) + torch.manual_seed(42) + teacher = LM(dim=256, n_layers=6, n_heads=8, expansion=2.5) + ce_c = train_standard(teacher, train_seq, eval_seq, steps=TRAIN_STEPS, label="teacher") + results["C_teacher"] = ce_c + + # ========================================== + # D: Distilled student — soft targets from teacher + # ========================================== + print("\n" + "=" * 70) + print("D: Distilled student (4L 192d) — from teacher, FP32") + print("=" * 70) + for temp in [2.0, 4.0]: + for alpha in [0.3, 0.5]: + torch.manual_seed(42) + student = LM(dim=192, n_layers=4, n_heads=6, expansion=2.0) + ce_d = train_distilled(student, teacher, train_seq, eval_seq, + steps=TRAIN_STEPS, temperature=temp, alpha=alpha, + label=f"distill_T{temp}_a{alpha}") + results[f"D_distill_T{temp}_a{alpha}"] = ce_d + + # ========================================== + # E: Distilled student + QAT int4 + # ========================================== + print("\n" + "=" * 70) + print("E: Distilled student (4L 192d) — from teacher, QAT int4") + print("=" * 70) + # Use best distillation config from D + best_d = min((v, k) for k, v in results.items() if k.startswith("D_")) + best_d_ce, best_d_key = best_d + print(f" Best distill config: {best_d_key} (CE={best_d_ce:.4f})") + + # Parse temp and alpha from key + parts = best_d_key.split("_") + best_temp = float(parts[2][1:]) + best_alpha = float(parts[3][1:]) + + torch.manual_seed(42) + student_qat = LM(dim=192, n_layers=4, n_heads=6, expansion=2.0, + quant_fn=uniform_quantize_ste) + ce_e = train_distilled(student_qat, teacher, train_seq, eval_seq, + steps=TRAIN_STEPS, temperature=best_temp, alpha=best_alpha, + label="distill_qat_int4") + results["E_distill_qat_int4"] = ce_e + + # ========================================== + # F: Product Quantization test + # ========================================== + print("\n" + "=" * 70) + print("F: Product Quantization of teacher model") + print("=" * 70) + compressed, comp_size = product_quantize_model(teacher, group_size=8, n_centroids=256) + print(f" Compressed size: {comp_size/1024:.1f}KB") + + # Reconstruct and evaluate + reconstructed_weights = reconstruct_from_pq(compressed) + teacher_pq = LM(dim=256, n_layers=6, n_heads=8, expansion=2.5).to(DEVICE) + teacher_pq_state = teacher_pq.state_dict() + for name, tensor in reconstructed_weights.items(): + if name in teacher_pq_state: + teacher_pq_state[name] = tensor + teacher_pq.load_state_dict(teacher_pq_state) + ce_f = eval_ce(teacher_pq, eval_seq) + results["F_teacher_pq"] = ce_f + pq_deg = (ce_f - ce_c) / ce_c * 100 + print(f" PQ teacher CE: {ce_f:.4f} (degradation: {pq_deg:+.2f}% vs FP32 teacher)", flush=True) + + # ========================================== + # Summary + # ========================================== + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f" A: Small FP32 (direct): {results['A_small_fp32']:.4f}") + print(f" B: Small QAT int4 (direct): {results['B_small_qat_int4']:.4f} ({(results['B_small_qat_int4']-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}% vs A)") + print(f" C: Teacher FP32: {results['C_teacher']:.4f} ({(results['C_teacher']-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}% vs A)") + + best_d = min((v, k) for k, v in results.items() if k.startswith("D_")) + print(f" D: Best distilled FP32: {best_d[0]:.4f} ({(best_d[0]-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}% vs A) [{best_d[1]}]") + + print(f" E: Distilled QAT int4: {results['E_distill_qat_int4']:.4f} ({(results['E_distill_qat_int4']-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}% vs A)") + print(f" F: Teacher PQ: {results['F_teacher_pq']:.4f} ({pq_deg:+.2f}% vs teacher FP32)") + + print(f"\n KEY QUESTIONS:") + print(f" Does distillation help? A={results['A_small_fp32']:.4f} → D={best_d[0]:.4f} ({(best_d[0]-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}%)") + print(f" Does QAT int4 hold? A={results['A_small_fp32']:.4f} → B={results['B_small_qat_int4']:.4f} ({(results['B_small_qat_int4']-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}%)") + print(f" Distill + QAT? A={results['A_small_fp32']:.4f} → E={results['E_distill_qat_int4']:.4f} ({(results['E_distill_qat_int4']-results['A_small_fp32'])/results['A_small_fp32']*100:+.2f}%)") + print(f" Product Quant quality? Teacher={results['C_teacher']:.4f} → PQ={results['F_teacher_pq']:.4f} ({pq_deg:+.2f}%)") + + # Save + results_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/results_distill_qat.json" + with open(results_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\n Results saved to {results_path}") + print(f" Finished: {time.strftime('%H:%M:%S')}") diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_qat_learned_centroids.py b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_qat_learned_centroids.py new file mode 100644 index 0000000000..cc429658e6 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/exp_qat_learned_centroids.py @@ -0,0 +1,525 @@ +""" +Experiment: Learned Quantization Grid + QAT +============================================= +NOVEL: Nobody in the competition does end-to-end QAT with non-uniform centroids. + +Standard approach (everyone else): + 1. Train float32 model + 2. Post-hoc GPTQ → snap to uniform int5 grid {-16..15} * scale + +Our approach: + 1. Train WITH simulated quantization from step 0 (STE gradients) + 2. Quantization centroids are NON-UNIFORM and LEARNED jointly + 3. At checkpoint: model is already quantized, zero post-hoc degradation + +Mathematical basis: + - Lloyd-Max quantizer: for Gaussian weights, non-uniform centroids (denser near zero) + reduce MSE by 20-30% vs uniform grid at same bit-width + - STE (Straight-Through Estimator): gradient flows through quantization by + treating the rounding step as identity in the backward pass + +Tests: + A. Float32 baseline (no quantization) + B. Float32 + post-hoc uniform int5 (simulated GPTQ) + C. QAT with uniform int5 grid (STE, fixed centroids) + D. QAT with LEARNED non-uniform int5 centroids (our novel idea) + E. QAT with NormalFloat-5 centroids (Gaussian-optimal, fixed) +""" +import sys +sys.stdout.reconfigure(line_buffering=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +import json +import os +import urllib.request + +VOCAB_SIZE = 1024 +SEQ_LEN = 512 +DIM = 192 +N_HEADS = 6 +N_LAYERS = 6 +MLP_EXP = 2.0 +TRAIN_STEPS = 1500 +BATCH_SIZE = 32 +LR = 3e-4 +DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" +N_BITS = 5 # int5 quantization +N_LEVELS = 2 ** N_BITS # 32 levels + +print(f"Device: {DEVICE}") +print(f"Quantization: int{N_BITS} ({N_LEVELS} levels)") +print() + +# ============================================================ +# Data Loading +# ============================================================ +def download_text_corpus(): + cache_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/text_corpus.txt" + if os.path.exists(cache_path): + with open(cache_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + urls = [ + "https://www.gutenberg.org/cache/epub/1342/pg1342.txt", + "https://www.gutenberg.org/cache/epub/11/pg11.txt", + "https://www.gutenberg.org/cache/epub/84/pg84.txt", + "https://www.gutenberg.org/cache/epub/1661/pg1661.txt", + ] + all_text = [] + for url in urls: + try: + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(req, timeout=30) + text = response.read().decode('utf-8', errors='ignore') + start = text.find("*** START OF") + if start != -1: start = text.find("\n", start) + 1 + else: start = 0 + end = text.find("*** END OF") + if end == -1: end = len(text) + all_text.append(text[start:end]) + except Exception as e: + print(f" Failed: {e}", flush=True) + corpus = "\n\n".join(all_text) + with open(cache_path, 'w', encoding='utf-8') as f: + f.write(corpus) + return corpus + +def tokenize_text(text, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN+1): + """Simple byte-level tokenizer matching other experiments.""" + raw_bytes = text.encode('utf-8') + tokens = [b % vocab_size for b in raw_bytes] + n_seq = len(tokens) // seq_len + tokens = tokens[:n_seq * seq_len] + return torch.tensor(tokens, dtype=torch.long).reshape(n_seq, seq_len) + +# ============================================================ +# Quantization Functions +# ============================================================ + +def uniform_quantize(w, n_bits=N_BITS): + """Uniform quantization: snap to nearest grid point. STE in backward.""" + n_levels = 2 ** n_bits + # Per-channel scale: map weight range to [0, n_levels-1] + w_min = w.min(dim=-1, keepdim=True).values + w_max = w.max(dim=-1, keepdim=True).values + scale = (w_max - w_min) / (n_levels - 1) + scale = scale.clamp(min=1e-8) + # Quantize + w_norm = (w - w_min) / scale + w_int = w_norm.round().clamp(0, n_levels - 1) + # Dequantize + w_q = w_int * scale + w_min + # STE: forward uses quantized, backward uses identity + return w + (w_q - w).detach() + + +def learned_centroid_quantize(w, centroids): + """ + Non-uniform quantization with learned centroids. + centroids: [n_levels] sorted tensor of quantization levels. + Maps each weight to nearest centroid. STE for gradients. + """ + # centroids: [n_levels], w: [out, in] or any shape + # Per-channel scaling first + w_flat = w.reshape(w.size(0), -1) + w_min = w_flat.min(dim=-1, keepdim=True).values + w_max = w_flat.max(dim=-1, keepdim=True).values + w_range = (w_max - w_min).clamp(min=1e-8) + + # Normalize to [0, 1] + w_norm = (w_flat - w_min) / w_range + + # Find nearest centroid for each weight (centroids are in [0, 1]) + # centroids: [n_levels], w_norm: [out, in] + c = centroids.unsqueeze(0).unsqueeze(0) # [1, 1, n_levels] + w_exp = w_norm.unsqueeze(-1) # [out, in, 1] + dists = (w_exp - c).abs() # [out, in, n_levels] + idx = dists.argmin(dim=-1) # [out, in] + + # Quantized values + w_q_norm = centroids[idx] # [out, in] + + # Denormalize + w_q = w_q_norm * w_range + w_min + w_q = w_q.reshape(w.shape) + + # STE + return w + (w_q - w).detach() + + +def normalfloat_centroids(n_bits=N_BITS): + """ + NormalFloat quantization centroids (from QLoRA's NF4, extended to NF5). + Optimal for Gaussian-distributed weights. + Centroids are placed at quantiles of N(0,1). + """ + n_levels = 2 ** n_bits + # Quantiles of standard normal, mapped to [0, 1] + from scipy.stats import norm + quantiles = np.array([norm.ppf((i + 0.5) / n_levels) for i in range(n_levels)]) + # Normalize to [0, 1] + quantiles = (quantiles - quantiles.min()) / (quantiles.max() - quantiles.min()) + return torch.tensor(quantiles, dtype=torch.float32) + + +def uniform_centroids(n_bits=N_BITS): + """Uniform grid centroids in [0, 1].""" + n_levels = 2 ** n_bits + return torch.linspace(0, 1, n_levels) + + +# ============================================================ +# Model Components (same as integration test, with quantization hooks) +# ============================================================ +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) * self.scale + +class QuantLinear(nn.Module): + """Linear layer with optional quantization simulation.""" + def __init__(self, in_features, out_features, bias=False, + quant_fn=None, centroids=None): + super().__init__() + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None + self.quant_fn = quant_fn + self.centroids = centroids # external reference to shared centroids + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + w = self.weight + if self.quant_fn is not None: + if self.centroids is not None: + w = self.quant_fn(w, self.centroids) + else: + w = self.quant_fn(w) + return F.linear(x, w, self.bias) + +class GEGLU_MLP(nn.Module): + def __init__(self, dim, expansion=2.0, quant_fn=None, centroids=None): + super().__init__() + hidden = int(dim * expansion) + self.gate = QuantLinear(dim, hidden, quant_fn=quant_fn, centroids=centroids) + self.up = QuantLinear(dim, hidden, quant_fn=quant_fn, centroids=centroids) + self.down = QuantLinear(hidden, dim, quant_fn=quant_fn, centroids=centroids) + def forward(self, x): + return self.down(F.gelu(self.gate(x)) * self.up(x)) + +class FullMHA(nn.Module): + def __init__(self, dim, n_heads, rope_dims=0, quant_fn=None, centroids=None): + super().__init__() + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.qkv = QuantLinear(dim, 3 * dim, quant_fn=quant_fn, centroids=centroids) + self.out = QuantLinear(dim, dim, quant_fn=quant_fn, centroids=centroids) + self.rope_dims = rope_dims + if rope_dims > 0: + freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dims, 2).float() / rope_dims)) + t = torch.arange(SEQ_LEN).float() + freqs = torch.outer(t, freqs) + self.register_buffer('cos_cache', freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False) + self.register_buffer('sin_cache', freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False) + + def _apply_rope(self, x): + rd = self.rope_dims + x_rope, x_pass = x[..., :rd], x[..., rd:] + x1, x2 = x_rope[..., :rd//2], x_rope[..., rd//2:] + cos = self.cos_cache[:, :, :x.size(2), :] + sin = self.sin_cache[:, :, :x.size(2), :] + x_rope_out = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + return torch.cat([x_rope_out, x_pass], dim=-1) + + def forward(self, x): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = qkv.unbind(2) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + if self.rope_dims > 0: + q = self._apply_rope(q) + k = self._apply_rope(k) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = y.transpose(1, 2).reshape(B, T, C) + return self.out(y) + +class TransformerBlock(nn.Module): + def __init__(self, dim, n_heads, mlp_expansion=2.0, rope_dims=0, + quant_fn=None, centroids=None): + super().__init__() + self.ln1 = RMSNorm(dim) + self.attn = FullMHA(dim, n_heads, rope_dims=rope_dims, + quant_fn=quant_fn, centroids=centroids) + self.ln2 = RMSNorm(dim) + self.mlp = GEGLU_MLP(dim, expansion=mlp_expansion, + quant_fn=quant_fn, centroids=centroids) + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, vocab_size=VOCAB_SIZE, dim=DIM, n_heads=N_HEADS, + n_layers=N_LAYERS, seq_len=SEQ_LEN, mlp_expansion=MLP_EXP, + rope_dims=16, quant_fn=None, centroids=None): + super().__init__() + self.tok_emb = nn.Embedding(vocab_size, dim) + self.blocks = nn.ModuleList([ + TransformerBlock(dim, n_heads, mlp_expansion, rope_dims=rope_dims, + quant_fn=quant_fn, centroids=centroids) + for _ in range(n_layers) + ]) + self.ln_f = RMSNorm(dim) + + for m in self.modules(): + if isinstance(m, (nn.Linear, QuantLinear)): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + def forward(self, idx): + B, T = idx.shape + x = self.tok_emb(idx) + for block in self.blocks: + x = block(x) + return F.linear(self.ln_f(x), self.tok_emb.weight) + +# ============================================================ +# Post-hoc quantization (simulate GPTQ-like) +# ============================================================ +def apply_posthoc_quantization(model, quant_fn, centroids=None): + """Apply quantization to all Linear/QuantLinear weights post-hoc.""" + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, (nn.Linear, QuantLinear)): + if centroids is not None: + m.weight.data = learned_centroid_quantize(m.weight.data, centroids) + else: + m.weight.data = uniform_quantize(m.weight.data) + +# ============================================================ +# Training and Eval +# ============================================================ +def eval_ce(model, eval_seq): + model.eval() + with torch.no_grad(): + eb = eval_seq[:100].to(DEVICE) + logits = model(eb[:, :-1]) + ce = F.cross_entropy(logits.reshape(-1, logits.size(-1)), eb[:, 1:].reshape(-1)) + return ce.item() + +def train_and_eval(model, train_seq, eval_seq, label="", centroids_param=None): + model = model.to(DEVICE) + n_params = sum(p.numel() for p in model.parameters()) + extra = "" + if centroids_param is not None: + centroids_param = centroids_param.to(DEVICE) + extra = f", Centroid params: {centroids_param.numel()}" + print(f" [{label}] Params: {n_params:,}{extra}", flush=True) + + # Collect all params including centroids + all_params = list(model.parameters()) + param_groups = [{'params': all_params, 'lr': LR}] + if centroids_param is not None and centroids_param.requires_grad: + # Ensure centroids are on the right device as a leaf tensor + if centroids_param.device != torch.device(DEVICE): + centroids_param.data = centroids_param.data.to(DEVICE) + param_groups.append({'params': [centroids_param], 'lr': LR * 10.0}) + + optimizer = torch.optim.AdamW(param_groups, weight_decay=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TRAIN_STEPS) + + t0 = time.time() + best_ce = float('inf') + + for step in range(TRAIN_STEPS + 1): + if step % 500 == 0: + ce = eval_ce(model, eval_seq) + ms = (time.time() - t0) / max(step, 1) * 1000 + c_str = "" + if centroids_param is not None: + c_vals = centroids_param.detach().cpu().numpy() + c_str = f" | centroids: [{c_vals[0]:.3f}..{c_vals[15]:.3f}..{c_vals[-1]:.3f}]" + print(f" Step {step:4d} | CE: {ce:.4f} | {ms:.0f}ms/step{c_str}", flush=True) + best_ce = min(best_ce, ce) + model.train() + + if step >= TRAIN_STEPS: + break + + bi = torch.randint(0, train_seq.size(0), (BATCH_SIZE,)) + batch = train_seq[bi].to(DEVICE) + logits = model(batch[:, :-1]) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1)) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(list(model.parameters()) + + ([centroids_param] if centroids_param is not None else []), 1.0) + optimizer.step() + # Keep centroids sorted + if centroids_param is not None and centroids_param.requires_grad: + with torch.no_grad(): + centroids_param.data = centroids_param.data.sort().values + scheduler.step() + + elapsed = time.time() - t0 + final_ce = eval_ce(model, eval_seq) + print(f" Done. Best CE: {best_ce:.4f}, Final CE: {final_ce:.4f} in {elapsed:.1f}s\n", flush=True) + return {"label": label, "best_ce": best_ce, "final_ce": final_ce, "params": n_params, "elapsed_s": elapsed} + +# ============================================================ +# MAIN +# ============================================================ +def main(): + print("=" * 70) + print("EXPERIMENT: Learned Quantization Grid + QAT") + print("=" * 70) + + text = download_text_corpus() + all_data = tokenize_text(text) + n_eval = min(100, all_data.size(0) // 10) + eval_seq = all_data[:n_eval] + train_seq = all_data[n_eval:] + print(f" Train: {train_seq.shape}, Eval: {eval_seq.shape}\n", flush=True) + + results = {} + + # A. Float32 baseline (no quantization) + print("=" * 70) + print("A. FLOAT32 BASELINE (no quantization, RoPE 16)") + print("=" * 70) + torch.manual_seed(42) + model = Transformer(rope_dims=16) + res = train_and_eval(model, train_seq, eval_seq, "float32_baseline") + results["float32_baseline"] = res + float_ce = res["best_ce"] + + # Now apply post-hoc uniform quantization and measure degradation + print(" → Applying post-hoc uniform int5 quantization...") + apply_posthoc_quantization(model, uniform_quantize) + posthoc_ce = eval_ce(model, eval_seq) + print(f" → Post-hoc int5 CE: {posthoc_ce:.4f} (degradation: {(posthoc_ce-float_ce)/float_ce*100:+.2f}%)\n") + results["float32_posthoc_int5"] = {"ce": posthoc_ce, "degradation_pct": (posthoc_ce-float_ce)/float_ce*100} + del model; torch.mps.empty_cache() if DEVICE == "mps" else None + + # B. QAT with uniform int5 (STE, fixed centroids) + print("=" * 70) + print("B. QAT UNIFORM INT5 (train with quantization, fixed grid)") + print("=" * 70) + torch.manual_seed(42) + model = Transformer(rope_dims=16, quant_fn=uniform_quantize) + res = train_and_eval(model, train_seq, eval_seq, "qat_uniform_int5") + results["qat_uniform_int5"] = res + del model; torch.mps.empty_cache() if DEVICE == "mps" else None + + # C. QAT with NormalFloat-5 centroids (Gaussian-optimal, fixed) + print("=" * 70) + print("C. QAT NORMALFLOAT-5 (Gaussian-optimal centroids, fixed)") + print("=" * 70) + try: + nf5_centroids = normalfloat_centroids(N_BITS).to(DEVICE) + except ImportError: + # scipy not available, compute manually + # Approximate NF5: denser near 0.5 (zero), sparser at extremes + x = torch.linspace(-3, 3, N_LEVELS) + nf5_centroids = torch.sigmoid(x) # S-curve, denser near center + nf5_centroids = (nf5_centroids - nf5_centroids.min()) / (nf5_centroids.max() - nf5_centroids.min()) + nf5_centroids = nf5_centroids.to(DEVICE) + + print(f" NF5 centroids: {nf5_centroids[:5].tolist()}...{nf5_centroids[-5:].tolist()}") + torch.manual_seed(42) + qfn_nf5 = lambda w, c=nf5_centroids: learned_centroid_quantize(w, c) + model = Transformer(rope_dims=16, quant_fn=qfn_nf5) + res = train_and_eval(model, train_seq, eval_seq, "qat_nf5_fixed") + results["qat_nf5_fixed"] = res + del model; torch.mps.empty_cache() if DEVICE == "mps" else None + + # D. QAT with LEARNED non-uniform centroids (OUR NOVEL IDEA) + print("=" * 70) + print("D. QAT LEARNED CENTROIDS ★ (non-uniform, trained jointly)") + print("=" * 70) + # Initialize from uniform, let them learn + learned_c = nn.Parameter(torch.linspace(0, 1, N_LEVELS).to(DEVICE)) + qfn_learned = lambda w, c=learned_c: learned_centroid_quantize(w, c) + torch.manual_seed(42) + model = Transformer(rope_dims=16, quant_fn=qfn_learned) + res = train_and_eval(model, train_seq, eval_seq, "qat_learned_centroids", centroids_param=learned_c) + results["qat_learned_centroids"] = res + + # Print final learned centroids + final_c = learned_c.detach().cpu().numpy() + print(f" Final learned centroids ({N_LEVELS} levels):") + print(f" {final_c.tolist()}") + + # Check centroid distribution: are they non-uniform? + gaps = np.diff(final_c) + print(f" Gap stats: min={gaps.min():.4f}, max={gaps.max():.4f}, " + f"std={gaps.std():.4f}, mean={gaps.mean():.4f}") + print(f" Non-uniformity ratio: {gaps.max()/gaps.min():.2f}x") + del model; torch.mps.empty_cache() if DEVICE == "mps" else None + + # E. QAT with LEARNED centroids initialized from NF5 + print("=" * 70) + print("E. QAT LEARNED CENTROIDS (init from NF5)") + print("=" * 70) + try: + nf5_init = normalfloat_centroids(N_BITS) + except ImportError: + x = torch.linspace(-3, 3, N_LEVELS) + nf5_init = torch.sigmoid(x) + nf5_init = (nf5_init - nf5_init.min()) / (nf5_init.max() - nf5_init.min()) + + learned_c2 = nn.Parameter(nf5_init.clone().to(DEVICE)) + qfn_learned2 = lambda w, c=learned_c2: learned_centroid_quantize(w, c) + torch.manual_seed(42) + model = Transformer(rope_dims=16, quant_fn=qfn_learned2) + res = train_and_eval(model, train_seq, eval_seq, "qat_learned_from_nf5", centroids_param=learned_c2) + results["qat_learned_from_nf5"] = res + del model; torch.mps.empty_cache() if DEVICE == "mps" else None + + # ============================================================== + # SUMMARY + # ============================================================== + print("=" * 70) + print("FINAL SUMMARY") + print("=" * 70) + + print(f"\n {'Method':<35s} {'Best CE':>10s} {'vs Float':>10s} {'vs PostQ':>10s}") + print(f" {'-'*35} {'-'*10} {'-'*10} {'-'*10}") + + posthoc_ce_val = results["float32_posthoc_int5"]["ce"] + + for name in ["float32_baseline", "qat_uniform_int5", "qat_nf5_fixed", + "qat_learned_centroids", "qat_learned_from_nf5"]: + r = results[name] + ce = r["best_ce"] + vs_float = (ce - float_ce) / float_ce * 100 + vs_postq = (ce - posthoc_ce_val) / posthoc_ce_val * 100 + print(f" {name:<35s} {ce:>10.4f} {vs_float:>+9.2f}% {vs_postq:>+9.2f}%") + + print(f"\n Post-hoc int5 CE: {posthoc_ce_val:.4f} (degradation from float: " + f"{results['float32_posthoc_int5']['degradation_pct']:+.2f}%)") + + print(f"\n KEY QUESTION: Does QAT (train-with-quant) beat post-hoc quant?") + best_qat = min(results["qat_uniform_int5"]["best_ce"], + results["qat_nf5_fixed"]["best_ce"], + results["qat_learned_centroids"]["best_ce"], + results["qat_learned_from_nf5"]["best_ce"]) + qat_vs_posthoc = (best_qat - posthoc_ce_val) / posthoc_ce_val * 100 + print(f" Best QAT CE: {best_qat:.4f} vs Post-hoc: {posthoc_ce_val:.4f} ({qat_vs_posthoc:+.2f}%)") + print(f" {'YES — QAT wins!' if best_qat < posthoc_ce_val else 'NO — post-hoc is fine'}") + + # Save + results_file = "/Users/himanshudongre/Documents/GitHub/parameter_golf/qat_results.json" + with open(results_file, "w") as f: + json.dump(results, f, indent=2, default=str) + print(f"\n Results saved to {results_file}") + print("\nDone!") + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_distill_qat.json b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_distill_qat.json new file mode 100644 index 0000000000..7186469dba --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_distill_qat.json @@ -0,0 +1,11 @@ +{ + "A_small_fp32": 1.3193752765655518, + "B_small_qat_int4": 1.3250839710235596, + "C_teacher": 1.315826416015625, + "D_distill_T2.0_a0.3": 1.3229968547821045, + "D_distill_T2.0_a0.5": 1.3094488382339478, + "D_distill_T4.0_a0.3": 1.333116054534912, + "D_distill_T4.0_a0.5": 1.3122836351394653, + "E_distill_qat_int4": 1.3312817811965942, + "F_teacher_pq": 5.168092250823975 +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_vocab4096_mlp4x.json b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_vocab4096_mlp4x.json new file mode 100644 index 0000000000..677f5a6a90 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/results_vocab4096_mlp4x.json @@ -0,0 +1,19 @@ +{ + "A_ce": 2.132023811340332, + "A_bpc": 3.075860179677884, + "A_params": 2410944, + "B_ce": 2.1797144412994385, + "B_bpc": 3.1446632150167577, + "B_params": 3000768, + "C_ce": 2.2511019706726074, + "C_bpc": 3.2476536496247435, + "C_params": 3738048, + "D_ce": 2.271700382232666, + "D_bpc": 3.27737087583263, + "D_params": 4327872, + "E_ce": 2.2721877098083496, + "E_bpc": 3.278073940909357, + "E_params": 4327872, + "F_three_way_bpc": 2.9284442412233265, + "F_three_way_baseline_bpc": 2.9284442412233265 +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/submission.json b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/submission.json new file mode 100644 index 0000000000..4275d2cead --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_28_Experiments_5_Days_Scale_Deception/submission.json @@ -0,0 +1,9 @@ +{ + "track": "non_record_16mb", + "date": "2026-04-02", + "name": "28 Experiments in 5 Days — What Works, What Fails, and Why Small-Scale Tests Lie", + "author": "Himanshu Dongre", + "github_id": "himanshudongre", + "val_bpb": null, + "notes": "Research report. 28 controlled experiments across architecture, training, quantization, and eval-time techniques. 14 dead techniques documented with specific numbers. Key finding: small-scale tests can be 180 degrees wrong (SSM: -18% local to +2.7% at scale). Self-funded on Mac Mini M4 + single H100 RunPod (~$12 GPU spend)." +} \ No newline at end of file