From 257b492884a71b319464c5728b5b51525d6b27b4 Mon Sep 17 00:00:00 2001 From: Himanshu Dongre Date: Thu, 2 Apr 2026 21:16:44 +0530 Subject: [PATCH] =?UTF-8?q?Non-record:=20KNN=20Hidden=20State=20Retrieval?= =?UTF-8?q?=20=E2=80=94=20Scale=20Deception=20from=20Weak=20to=20Strong=20?= =?UTF-8?q?Models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Novel eval-time technique validated on 8xH100. Helps weak models (-2 to -4%), hurts competition-quality models (+1.5%). Definitive scale deception finding. --- .../README.md | 153 +++++++ .../knn_eval_patch.py | 350 ++++++++++++++++ .../knn_vectorized.py | 395 ++++++++++++++++++ .../submission.json | 9 + .../train_seed42.log | 87 ++++ 5 files changed, 994 insertions(+) create mode 100644 records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/README.md create mode 100644 records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_eval_patch.py create mode 100644 records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_vectorized.py create mode 100644 records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/submission.json create mode 100644 records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/train_seed42.log diff --git a/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/README.md b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/README.md new file mode 100644 index 0000000000..d1bca2100d --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/README.md @@ -0,0 +1,153 @@ +# Non-Record: KNN Hidden State Retrieval — When Eval-Time Augmentation Helps Weak Models but Hurts Strong Ones + +## TL;DR + +**KNN Hidden State Retrieval** is an eval-time technique that stores hidden states from scored tokens and uses nearest-neighbor retrieval to augment neural predictions. It shows strong improvements on weak models (-2% to -4% BPC) but **hurts competition-quality models (+1.5% BPB).** + +This is a definitive demonstration of **scale deception** — the same phenomenon we documented in PR #1227 with SSM hybrids. Techniques that help at small scale can hurt at competition scale, and the crossover happens silently. + +**Key numbers:** + +| Model Quality | Training | KNN Effect | Direction | +|--------------|----------|------------|-----------| +| Very weak (local, 1500 steps) | AdamW, dim=192 | **-2.34%** | Helps | +| Weak (1×H100, 2K steps) | Muon, dim=512 | **-1.57%** | Helps | +| Medium (1×H100, 2K steps, export) | Muon, GPTQ int6 | **-5.21%** | Helps (weak export) | +| **Strong (8×H100, 5665 steps, SOTA stack)** | **Muon, EMA, GPTQ int6** | **+1.47%** | **Hurts** | + +## Method + +### KNN Hidden State Retrieval + +At eval time, for each scored token, we store its final hidden state and the actual next token in a growing datastore. For each new position: + +1. Query the datastore with the current hidden state +2. Find K=8 nearest neighbors by L2 distance +3. Build an empirical distribution from neighbors' successor tokens +4. Mix with neural prediction: `P = 0.88 × P_neural + 0.12 × P_knn` +5. Score with the mixed distribution +6. AFTER scoring: add current hidden state to datastore + +### Vectorized Implementation + +Per-token Python loops are too slow for 62M tokens. We use batch `torch.cdist` to compute all distances per chunk: + +```python +dists = torch.cdist(queries.half(), stored_h[:n_stored], p=2).pow(2) # (C, N) +topk_d, topk_i = dists.topk(K, dim=1, largest=False) # (C, K) +knn_dist = torch.zeros(C, V, device=device) +knn_dist.scatter_add_(1, stored_tok[topk_i], weights) # build distribution +mixed = (1 - lam) * neural_probs + lam * knn_dist # mix +``` + +With `subsample=4` and `max_stored=8M`, KNN eval completes in **168 seconds on 8×H100** — well within the 600s eval budget. + +### Legality + +Score-first protocol, properly normalized, causal, zero artifact cost. Same protocol as TTT (explicitly legal). See PR #1227 for detailed legality analysis. + +## 8×H100 Results (Competition Scale) + +### Training +``` +Profile: full_8gpu_600s +Preset: merged_leader (11L, 512d, LeakyReLU², XSA-all, BigramHash, EMA, Muon) +Steps: 5665 in 600s +Pre-export: val_bpb = 1.1446 +Export: GPTQ int6, LZMA preset 9 +Artifact: 15,826,144 bytes (under 16MB) +``` + +### Evaluation +| Eval Method | val_loss | val_bpb | vs Neural | +|-------------|----------|---------|-----------| +| Neural (roundtrip, exported model) | 1.9473 | **1.1533** | baseline | +| **KNN (k=8, λ=0.12, subsample=4)** | **1.9758** | **1.1702** | **+1.47% worse** | + +**KNN hurts by 1.47% on the competition-quality model.** + +### KNN Eval Timing +- 8 GPUs, each processing 1/8 of 62M tokens +- `subsample=4`: store every 4th hidden state (max 8M vectors per rank) +- `chunk_size=1024`: batch distance computation +- **Total: 168 seconds** — fits in 600s eval budget + +## Scaling Analysis + +### Why KNN Helps Weak Models + +On a 2000-step model (BPC ~1.9), the neural predictions are noisy. Many positions have high entropy (model is uncertain). KNN provides an alternative signal — if the hidden state is similar to a previously seen context, the empirical distribution from that context is informative. The model's hidden states are discriminative enough for retrieval but the predictions are poor enough that KNN can improve them. + +### Why KNN Hurts Strong Models + +On a 5665-step competition model (BPB ~1.15), the neural predictions are already well-calibrated. The model correctly assigns high probability to the right tokens in most contexts. KNN introduces noise because: + +1. **Nearest neighbors aren't close enough.** With 512-dimensional hidden states, L2 distance is a crude similarity measure. Two states can be "nearest neighbors" but represent very different linguistic contexts. + +2. **The empirical distribution is sparse.** With K=8 neighbors, the KNN distribution places probability mass on only 3-8 distinct tokens. The neural distribution spreads probability more appropriately across the full vocabulary. + +3. **Mixing degrades calibration.** Even 12% KNN weight is enough to move probability mass away from the correct neural prediction toward the noisy KNN estimate. + +### The Crossover Point + +Based on our data, the crossover (where KNN goes from helping to hurting) happens approximately at: +- BPC ~2.5 (local model at ~3000 steps) +- BPB ~1.4 (H100 model at ~2000 steps with Muon) + +Below these thresholds, KNN helps. Above them, it hurts. Competition-quality models (BPB ~1.15) are well past the crossover. + +## Comparison with Other Scale Deception Findings + +| Technique | Local Result | Competition Result | Reversal | +|-----------|-------------|-------------------|----------| +| S4D-Lin SSM (PR #1013) | -18% CE | +2.7% BPB | 180° flip | +| **KNN Hidden State** | -4.6% BPC | +1.5% BPB | Sign flip | +| QAT NF5 | -0.66% CE | (untested at scale) | Unknown | +| Self-distillation | -9.24% CE | (untested at scale) | Unknown | + +This is now the **second technique** where we've demonstrated definitive scale deception in this competition. The pattern is consistent: techniques that compensate for model weakness don't help (and actively hurt) when the model is strong. + +## Implications for Other Competitors + +1. **Don't trust local eval-time improvements.** If your technique helps a weak model, it may hurt a strong one. The only reliable test is at competition scale. + +2. **Eval-time augmentation has diminishing (then negative) returns.** The stronger the base model, the less room for eval-time tricks. At BPB ~1.15, the model's predictions are hard to improve by mixing in external signals. + +3. **The eval budget IS underutilized, but not for prediction mixing.** Our KNN used 168s of the 600s budget effectively. The compute is there — the challenge is finding eval-time techniques that actually help strong models. TTT (gradient-based adaptation) may be more promising because it adapts the model itself rather than mixing in an external signal. + +## Hardware and Cost + +| Phase | Hardware | Time | Cost | +|-------|----------|------|------| +| Local experiments (28+) | Mac Mini M4 | 5 days | $0 | +| 1×H100 validation | Single H100 | ~4 hours | ~$12 | +| 8×H100 record attempt | 8×H100 SXM | ~2 hours | ~$43 | +| **Total** | | | **~$55** | + +## Code + +The KNN implementation is integrated into `train_gpt.py` as `eval_knn()`, following the same pattern as the competition's `eval_ttt()` and `eval_ngram()`. Enable with `KNN_ENABLED=1`. + +Key files: +- `knn_eval_patch.py` — standalone KNN module +- `apply_knn_patch.py` — script to patch train_gpt.py +- `h100_knn_eval_submission.py` — standalone KNN eval for checkpoints + +## Training Log (seed 42, 8×H100) + +``` +step:0/20000 val_loss:6.9301 val_bpb:4.1044 +step:4000/20000 val_loss:2.0157 val_bpb:1.1938 +stopping_early: wallclock_cap train_time:600062ms step:5665/20000 +DIAGNOSTIC post_average val_loss:1.9325 val_bpb:1.1446 +Serialized model research_export: 15826144 bytes +Total submission size research_export: 16018423 bytes +final_research_export_roundtrip val_loss:1.9473 val_bpb:1.1533 +final_knn val_loss:1.9758 val_bpb:1.1702 eval_time:168181ms k:8 lam:0.12 +``` + +--- + +*Self-funded research. Mac Mini M4 + RunPod H100. Total GPU spend: ~$55.* + +*Author: Himanshu Dongre (@himanshudongre) — also author of PR #1227 (28 Experiments), PR #1013 (SSM Hybrid), PR #1012 (JEPA-LM).* diff --git a/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_eval_patch.py b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_eval_patch.py new file mode 100644 index 0000000000..7a19b2d16a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_eval_patch.py @@ -0,0 +1,350 @@ +""" +KNN Hidden State Retrieval — Drop-in eval augmentation for train_gpt.py +========================================================================= + +This module adds vectorized KNN scoring to the competition eval pipeline. +Import it in train_gpt.py and call eval_knn() instead of eval_standard(). + +INTEGRATION: + 1. Add this file alongside train_gpt.py + 2. In train_gpt.py main(), after model training: + - Import: from knn_eval_patch import eval_with_knn + - Replace: result = eval_standard(config, model, ...) + - With: result = eval_with_knn(config, model, ...) + +ALTERNATIVELY: paste eval_with_knn() directly into train_gpt.py. + +The function signature matches eval_standard() for drop-in compatibility. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +from dataclasses import dataclass + + +@dataclass +class EvalResult: + """Matches the competition's EvalResult format.""" + val_loss: float + val_bpb: float + eval_ms: float + eval_seq_len: int + + +class VectorizedKNN: + """GPU-accelerated KNN hidden state retrieval. + + Stores hidden states from scored tokens and retrieves K nearest + neighbors for each new token. Fully vectorized via torch.cdist. + + Protocol: score-first, causal, chunk-based. + """ + + def __init__(self, dim: int, max_stored: int, k: int = 8, + lam: float = 0.12, device: str = "cuda", + dtype: torch.dtype = torch.float16, + subsample_rate: int = 1): + """ + Args: + dim: hidden state dimension (e.g., 512) + max_stored: maximum number of hidden states to store + k: number of nearest neighbors + lam: mixing weight (0.12 = 12% KNN) + device: cuda device + dtype: float16 saves memory (32GB vs 64GB for 62M vectors) + subsample_rate: store every Nth token (1=all, 4=25%) + """ + self.dim = dim + self.k = k + self.lam = lam + self.device = device + self.dtype = dtype + self.subsample_rate = subsample_rate + + # Pre-allocate store + self.stored_h = torch.zeros(max_stored, dim, device=device, dtype=dtype) + self.stored_tok = torch.zeros(max_stored, device=device, dtype=torch.long) + self.n_stored = 0 + self.n_seen = 0 # total tokens seen (before subsampling) + + def store_chunk(self, hidden_states: torch.Tensor, tokens: torch.Tensor): + """Store hidden states from a scored chunk. CALL AFTER SCORING. + + Args: + hidden_states: (chunk_len, dim) — hidden states from model + tokens: (chunk_len,) — the scored tokens (targets) + """ + if self.subsample_rate > 1: + # Subsample: keep every Nth token + indices = torch.arange(0, len(tokens), self.subsample_rate, device=tokens.device) + hidden_states = hidden_states[indices] + tokens = tokens[indices] + + n_new = len(tokens) + if self.n_stored + n_new > self.stored_h.shape[0]: + n_new = self.stored_h.shape[0] - self.n_stored + if n_new <= 0: + return + hidden_states = hidden_states[:n_new] + tokens = tokens[:n_new] + + self.stored_h[self.n_stored:self.n_stored + n_new] = hidden_states.to(self.dtype) + self.stored_tok[self.n_stored:self.n_stored + n_new] = tokens + self.n_stored += n_new + self.n_seen += len(tokens) * self.subsample_rate + + def get_knn_distribution(self, queries: torch.Tensor, + vocab_size: int = 1024) -> torch.Tensor: + """Batch KNN query. CALL BEFORE SCORING (uses only previously stored states). + + Args: + queries: (chunk_len, dim) — hidden states for current chunk + + Returns: + knn_probs: (chunk_len, vocab_size) — KNN probability distribution + Returns uniform if not enough stored states. + """ + chunk_len = queries.shape[0] + + if self.n_stored < self.k: + return torch.ones(chunk_len, vocab_size, device=self.device) / vocab_size + + # Compute squared L2 distances via cdist + # queries: (C, dim), stored: (N, dim) + q = queries.to(self.dtype) + dists = torch.cdist(q, self.stored_h[:self.n_stored], p=2).pow(2) # (C, N) + + # Top-K + actual_k = min(self.k, self.n_stored) + topk_dists, topk_idx = dists.topk(actual_k, dim=1, largest=False) # (C, K) + + # Get tokens of neighbors + topk_toks = self.stored_tok[topk_idx] # (C, K) + + # Distance-weighted softmax + weights = torch.exp(-topk_dists.float() / self.dim) # (C, K) + weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-30) + + # Build distribution via scatter + knn_dist = torch.zeros(chunk_len, vocab_size, device=self.device) + knn_dist.scatter_add_(1, topk_toks, weights) + + # Smooth + knn_dist = 0.99 * knn_dist + 0.01 / vocab_size + knn_dist = knn_dist / knn_dist.sum(dim=1, keepdim=True) + + return knn_dist + + def reset(self): + self.n_stored = 0 + self.n_seen = 0 + + +def eval_with_knn( + model: nn.Module, + val_tokens: torch.Tensor, + vocab_size: int = 1024, + seq_len: int = 512, + eval_batch_seqs: int = 64, + knn_k: int = 8, + knn_lam: float = 0.12, + knn_chunk: int = 1024, + knn_subsample: int = 4, + knn_max_stored: int = 16_000_000, + device: str = "cuda", + luts: tuple = None, +) -> EvalResult: + """Evaluate model with KNN hidden state augmentation. + + Drop-in replacement for eval_standard() with KNN mixing. + + The eval processes sequences in chunks: + 1. Forward pass → logits + hidden states + 2. Softmax → neural probabilities + 3. Query KNN → KNN probabilities + 4. Mix: (1-lam)*neural + lam*KNN + 5. Score with mixed distribution + 6. Store hidden states AFTER scoring + + Args: + model: ParameterGolfModel (must have _hidden() method) + val_tokens: 1D tensor of validation tokens + vocab_size: vocabulary size (1024) + seq_len: evaluation sequence length + eval_batch_seqs: sequences per forward pass batch + knn_k: number of nearest neighbors + knn_lam: KNN mixing weight + knn_chunk: tokens per KNN scoring chunk + knn_subsample: store every Nth token (reduces memory) + knn_max_stored: maximum stored hidden states + device: cuda device + luts: byte scoring lookup tables (for BPB computation) + + Returns: + EvalResult with val_loss and val_bpb + """ + start_time = time.perf_counter() + model.eval() + + # Determine model dimension + dim = model.config.model_dim if hasattr(model, 'config') else model.tok_emb.weight.shape[1] + + # Initialize KNN store + knn = VectorizedKNN( + dim=dim, + max_stored=knn_max_stored, + k=knn_k, + lam=knn_lam, + device=device, + dtype=torch.float16, + subsample_rate=knn_subsample, + ) + + # Prepare sequences + total_tokens = val_tokens.numel() - 1 + n_seqs = total_tokens // seq_len + batch_tokens = eval_batch_seqs * seq_len + + total_loss_sum = 0.0 + total_token_count = 0 + total_byte_count = 0.0 + + # Process sequences in batches + with torch.inference_mode(): + for batch_start in range(0, n_seqs, eval_batch_seqs): + batch_end = min(batch_start + eval_batch_seqs, n_seqs) + actual_batch = batch_end - batch_start + + # Get batch tokens + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.long) + + x = local[:-1].reshape(-1, seq_len) # (B, seq_len) — inputs + y = local[1:].reshape(-1, seq_len) # (B, seq_len) — targets + + # Forward pass: get hidden states AND logits + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + hidden = model._hidden(x) # (B, seq_len, dim) + + # Compute logits from hidden + hidden_f32 = hidden.float() + if hasattr(model, 'tie_embeddings') and model.tie_embeddings: + logits = F.linear(hidden_f32, model.tok_emb.weight.float()) + elif hasattr(model, 'lm_head') and model.lm_head is not None: + logits = model.lm_head(hidden_f32) + else: + logits = F.linear(hidden_f32, model.tok_emb.weight.float()) + + # Flatten for chunk-based KNN scoring + flat_logits = logits.reshape(-1, vocab_size) # (B*seq_len, V) + flat_hidden = hidden_f32.reshape(-1, dim) # (B*seq_len, dim) + flat_targets = y.reshape(-1) # (B*seq_len,) + n_tokens = flat_logits.shape[0] + + # Neural probabilities + neural_probs = F.softmax(flat_logits, dim=-1) # (B*seq_len, V) + + # Score in KNN chunks + for chunk_start in range(0, n_tokens, knn_chunk): + chunk_end = min(chunk_start + knn_chunk, n_tokens) + chunk_len = chunk_end - chunk_start + + c_neural = neural_probs[chunk_start:chunk_end] # (C, V) + c_hidden = flat_hidden[chunk_start:chunk_end] # (C, dim) + c_targets = flat_targets[chunk_start:chunk_end] # (C,) + + # Get KNN distribution (from previously stored states) + knn_dist = knn.get_knn_distribution(c_hidden, vocab_size) + + # Mix + if knn.n_stored >= knn_k: + mixed = (1.0 - knn_lam) * c_neural + knn_lam * knn_dist + else: + mixed = c_neural + + mixed = mixed / mixed.sum(dim=1, keepdim=True) + + # Score: cross-entropy + target_probs = mixed.gather(1, c_targets.unsqueeze(1)).squeeze(1) + log_probs = torch.log(target_probs.clamp(min=1e-30)) + chunk_loss = -log_probs.sum() + + total_loss_sum += chunk_loss.item() + total_token_count += chunk_len + + # Byte counting for BPB (if luts provided) + if luts is not None: + x_chunk = local[:-1].reshape(-1)[chunk_start:chunk_end] + y_chunk = flat_targets[chunk_start:chunk_end] + # Use competition's byte scoring + try: + from train_gpt import _score_token_bytes + token_bytes = _score_token_bytes(x_chunk, y_chunk, luts) + total_byte_count += token_bytes.float().sum().item() + except ImportError: + total_byte_count += chunk_len # fallback: 1 byte per token + + # Store AFTER scoring (causal) + knn.store_chunk(c_hidden.detach(), c_targets.detach()) + + if (batch_start // eval_batch_seqs) % 10 == 0: + elapsed = time.perf_counter() - start_time + bpc = total_loss_sum / max(total_token_count, 1) / math.log(2) + print(f" KNN eval [{batch_end}/{n_seqs}] " + f"loss={total_loss_sum/max(total_token_count,1):.4f} " + f"BPC={bpc:.4f} " + f"stored={knn.n_stored:,} " + f"({elapsed:.0f}s)", flush=True) + + # Compute final metrics + val_loss = total_loss_sum / max(total_token_count, 1) + bits_per_token = val_loss / math.log(2.0) + if total_byte_count > 0: + tokens_per_byte = total_token_count / total_byte_count + else: + tokens_per_byte = 1.0 # fallback + val_bpb = bits_per_token * tokens_per_byte + + elapsed_ms = 1000.0 * (time.perf_counter() - start_time) + + print(f"\n KNN eval complete: val_loss={val_loss:.4f} val_bpb={val_bpb:.4f} " + f"stored={knn.n_stored:,} time={elapsed_ms/1000:.0f}s", flush=True) + + return EvalResult( + val_loss=val_loss, + val_bpb=val_bpb, + eval_ms=elapsed_ms, + eval_seq_len=seq_len, + ) + + +# ============================================================ +# Integration helper: patch into existing eval pipeline +# ============================================================ +def patch_eval_with_knn(original_eval_fn, knn_lam=0.12, knn_k=8, + knn_subsample=4): + """Decorator to add KNN to any eval function. + + Usage in train_gpt.py: + original_result = eval_standard(config, model, ...) + knn_result = eval_with_knn(model, val_tokens, ...) + # Use knn_result instead + """ + pass # Not needed if we directly call eval_with_knn + + +# ============================================================ +# Quick test +# ============================================================ +if __name__ == "__main__": + print("KNN Eval Patch — ready for integration") + print("Import eval_with_knn() and call it with your trained model.") + print() + print("Example:") + print(" from knn_eval_patch import eval_with_knn") + print(" result = eval_with_knn(model, val_tokens, vocab_size=1024)") + print(" print(f'BPB: {result.val_bpb}')") diff --git a/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_vectorized.py b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_vectorized.py new file mode 100644 index 0000000000..36323150ce --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/knn_vectorized.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +""" +Vectorized KNN Hidden State Retrieval — Competition-Ready +=========================================================== + +Replaces the per-token Python loop with batched GPU operations. +Instead of 51,200 individual KNN queries, does ~50 batch queries. + +PROTOCOL (causal, score-first): + For each chunk of tokens: + 1. Compute neural probs + hidden states (one forward pass) + 2. Batch KNN: all tokens in chunk query against ALL previously stored states + 3. Mix KNN distribution with neural (vectorized) + 4. Score all tokens in chunk + 5. AFTER: add chunk's hidden states to store + 6. Next chunk (with updated store) + +Within a chunk, all queries use states from BEFORE the chunk. +This is the same causality as TTT (which also operates per-chunk). + +SPEED ESTIMATE: + Old: 51,200 Python iterations × GPU topk = 1959s + New: 50 batch cdist calls + vectorized mixing = ~10-30s + Speedup: 60-200× + +LOCAL TEST: Validates correctness against the slow per-token version. +""" +import sys; sys.stdout.reconfigure(line_buffering=True) +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import time +import os + +VOCAB_SIZE = 1024 +SEQ_LEN = 512 +DIM = 192 # local model dim +DEVICE = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") + +print(f"Device: {DEVICE}") +print(f"Vectorized KNN Test") +print() + +# ============================================================ +# Vectorized KNN Scoring (THE KEY FUNCTION) +# ============================================================ +def score_knn_vectorized(neural_probs, hidden_states, targets, + K=8, lam=0.12, chunk_size=1024, + vocab_size=VOCAB_SIZE, device=DEVICE): + """ + Vectorized KNN scoring — competition-ready speed. + + Args: + neural_probs: (N_tokens, V) float32 tensor — pre-computed softmax + hidden_states: (N_tokens, dim) float32 tensor — from model._hidden() + targets: (N_tokens,) long tensor — ground truth tokens + K: number of nearest neighbors + lam: KNN mixing weight (0.12 = 12% KNN, 88% neural) + chunk_size: tokens per batch (larger = faster but more memory) + + Returns: + bpc: bits per character (float) + total_bits: total log-loss in bits + scored: number of tokens scored + + Memory: O(N_tokens × dim) for stored hidden states + O(chunk_size × N_stored) for distance matrix (peak) + """ + N = len(targets) + dim = hidden_states.shape[1] + + # Move to device + if not isinstance(neural_probs, torch.Tensor): + neural_probs = torch.tensor(neural_probs, dtype=torch.float32) + if not isinstance(hidden_states, torch.Tensor): + hidden_states = torch.tensor(hidden_states, dtype=torch.float32) + if not isinstance(targets, torch.Tensor): + targets = torch.tensor(targets, dtype=torch.long) + + neural_probs = neural_probs.to(device) + hidden_states = hidden_states.to(device) + targets = targets.to(device) + + # Clamp and normalize neural probs + neural_probs = neural_probs.clamp(min=1e-10) + neural_probs = neural_probs / neural_probs.sum(dim=1, keepdim=True) + + # Store for growing datastore + stored_h = torch.zeros(N, dim, device=device, dtype=torch.float32) + stored_tok = torch.zeros(N, device=device, dtype=torch.long) + n_stored = 0 + + total_bits = 0.0 + scored = 0 + + for chunk_start in range(0, N, chunk_size): + chunk_end = min(chunk_start + chunk_size, N) + chunk_len = chunk_end - chunk_start + + # This chunk's data + q = hidden_states[chunk_start:chunk_end] # (C, dim) + t = targets[chunk_start:chunk_end] # (C,) + np_ = neural_probs[chunk_start:chunk_end] # (C, V) + + if n_stored >= K: + # === BATCH KNN === + # Compute squared L2 distances: (C, n_stored) + # Using cdist for efficiency + dists = torch.cdist(q, stored_h[:n_stored], p=2).pow(2) # (C, n_stored) + + # Top-K nearest for each query + topk_dists, topk_local_idx = dists.topk(K, dim=1, largest=False) # (C, K) + + # Get tokens of nearest neighbors + topk_toks = stored_tok[:n_stored][topk_local_idx] # (C, K) + + # Softmax weights over distances + weights = torch.exp(-topk_dists / dim) # (C, K) + weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-30) + + # Build KNN distribution via scatter + knn_dist = torch.zeros(chunk_len, vocab_size, device=device) + knn_dist.scatter_add_(1, topk_toks, weights) + + # Smooth + knn_dist = 0.99 * knn_dist + 0.01 / vocab_size + knn_dist = knn_dist / knn_dist.sum(dim=1, keepdim=True) + + # Mix: (1-lam)*neural + lam*knn + mixed = (1.0 - lam) * np_ + lam * knn_dist + else: + mixed = np_ + + # Normalize + mixed = mixed / mixed.sum(dim=1, keepdim=True) + + # Score: gather target probabilities + target_probs = mixed.gather(1, t.unsqueeze(1)).squeeze(1) # (C,) + bits = -torch.log2(target_probs.clamp(min=1e-30)) + total_bits += bits.sum().item() + scored += chunk_len + + # Store AFTER scoring (causal) + stored_h[n_stored:n_stored + chunk_len] = q + stored_tok[n_stored:n_stored + chunk_len] = t + n_stored += chunk_len + + bpc = total_bits / scored + return bpc, total_bits, scored + + +# ============================================================ +# Original per-token KNN (for correctness comparison) +# ============================================================ +def score_knn_pertokern(neural_probs_np, hidden_np, targets_np, + K=8, lam=0.12, vocab_size=VOCAB_SIZE): + """Original slow per-token KNN (reference implementation).""" + N = len(targets_np) + dim = hidden_np.shape[1] + stored_h = np.zeros((N, dim), np.float32) + stored_tok = np.zeros(N, np.int32) + ns = 0 + + total_bits = 0.0 + scored = 0 + + for i in range(N): + tgt = int(targets_np[i]) + np_ = neural_probs_np[i].astype(np.float64) + np_ = np.clip(np_, 1e-10, None); np_ /= np_.sum() + + if ns > K: + diff = stored_h[:ns] - hidden_np[i] + dists = np.einsum('ij,ij->i', diff, diff) + ak = min(K, ns - 1) + ki = np.argpartition(dists, ak)[:ak] + kd = dists[ki] + w = np.exp(-kd / dim); w /= w.sum() + 1e-30 + kp = np.zeros(vocab_size, np.float64) + for j in range(K): kp[stored_tok[ki[j]]] += w[j] + kp = 0.99 * kp + 0.01 / vocab_size + mx = (1-lam) * np_ + lam * kp; mx /= mx.sum() + p = max(mx[tgt], 1e-30) + else: + p = max(np_[tgt], 1e-30) + + total_bits += -math.log2(p); scored += 1 + stored_h[ns] = hidden_np[i]; stored_tok[ns] = tgt; ns += 1 + + return total_bits / scored, total_bits, scored + + +# ============================================================ +# Test: Compare vectorized vs per-token +# ============================================================ +if __name__ == "__main__": + # Load cached model + MODEL_CACHE = "/Users/himanshudongre/Documents/GitHub/parameter_golf/cached_rope16_model.pt" + + if not os.path.exists(MODEL_CACHE): + print("No cached model — run exp_three_way_stack.py first") + sys.exit(1) + + # Minimal model for loading + class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.eps = eps + def forward(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.scale + + class GEGLU_MLP(nn.Module): + def __init__(self, dim, expansion=2.0): + super().__init__() + h = int(dim * expansion) + self.gate = nn.Linear(dim, h, bias=False) + self.up = nn.Linear(dim, h, bias=False) + self.down = nn.Linear(h, dim, bias=False) + def forward(self, x): + return self.down(F.gelu(self.gate(x)) * self.up(x)) + + class FullMHA(nn.Module): + def __init__(self, dim, n_heads, rope_dims=16): + super().__init__() + self.n_heads = n_heads; self.head_dim = dim // n_heads + self.qkv = nn.Linear(dim, 3*dim, bias=False) + self.out = nn.Linear(dim, dim, bias=False) + self.rope_dims = rope_dims + freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dims, 2).float() / rope_dims)) + t = torch.arange(SEQ_LEN).float() + freqs = torch.outer(t, freqs) + self.register_buffer('cos_cache', freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False) + self.register_buffer('sin_cache', freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False) + def _apply_rope(self, x): + rd = self.rope_dims + x_rope, x_pass = x[..., :rd], x[..., rd:] + x1, x2 = x_rope[..., :rd//2], x_rope[..., rd//2:] + cos = self.cos_cache[:, :, :x.size(2), :] + sin = self.sin_cache[:, :, :x.size(2), :] + out = torch.cat([x1*cos - x2*sin, x2*cos + x1*sin], dim=-1) + return torch.cat([out, x_pass], dim=-1) + def forward(self, x): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = qkv.unbind(2) + q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2) + q, k = self._apply_rope(q), self._apply_rope(k) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.out(y.transpose(1,2).reshape(B, T, C)) + + class Block(nn.Module): + def __init__(self, dim, n_heads, expansion=2.0): + super().__init__() + self.ln1 = RMSNorm(dim); self.attn = FullMHA(dim, n_heads) + self.ln2 = RMSNorm(dim); self.mlp = GEGLU_MLP(dim, expansion) + def forward(self, x): + x = x + self.attn(self.ln1(x)); x = x + self.mlp(self.ln2(x)); return x + + class Transformer(nn.Module): + def __init__(self): + super().__init__() + self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM) + self.blocks = nn.ModuleList([Block(DIM, 6, 2.0) for _ in range(6)]) + self.ln_f = RMSNorm(DIM) + for m in self.modules(): + if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) + def forward_with_hidden(self, idx): + x = self.tok_emb(idx) + for block in self.blocks: x = block(x) + h = self.ln_f(x) + return F.linear(h, self.tok_emb.weight), h + + # Load model + print("Loading cached model...", flush=True) + model = Transformer().to(DEVICE) + model.load_state_dict(torch.load(MODEL_CACHE, map_location=DEVICE, weights_only=True)) + model.eval() + + # Load data + import urllib.request + cache_path = "/Users/himanshudongre/Documents/GitHub/parameter_golf/text_corpus.txt" + with open(cache_path, 'r', encoding='utf-8', errors='ignore') as f: + text = f.read() + tokens = [b % VOCAB_SIZE for b in text.encode('utf-8')] + n_seq = len(tokens) // (SEQ_LEN + 1) + sequences = torch.tensor(tokens[:n_seq * (SEQ_LEN + 1)], dtype=torch.long).view(n_seq, SEQ_LEN + 1) + n_train = int(n_seq * 0.9) + eval_seq = sequences[n_train:] + + N_EVAL = min(100, len(eval_seq)) + print(f"Eval: {N_EVAL} sequences, {N_EVAL * SEQ_LEN:,} tokens") + + # Get probs + hidden states + print("Computing probs + hidden...", flush=True) + with torch.no_grad(): + eb = eval_seq[:N_EVAL].to(DEVICE) + logits, hidden = model.forward_with_hidden(eb[:, :-1]) + probs = F.softmax(logits, dim=-1) # (N, T, V) + # Keep on device for vectorized version + + targets = eval_seq[:N_EVAL, 1:].contiguous() # (N, T) + + # Flatten for scoring: (N*T, V), (N*T, dim), (N*T,) + N_tokens = N_EVAL * SEQ_LEN + probs_flat = probs.reshape(N_tokens, VOCAB_SIZE) + hidden_flat = hidden.reshape(N_tokens, DIM) + targets_flat = targets.reshape(N_tokens) + + # Also numpy versions for reference + probs_np = probs_flat.cpu().numpy() + hidden_np = hidden_flat.cpu().numpy() + targets_np = targets_flat.cpu().numpy() + + # ========================================== + # Test 1: Neural only + # ========================================== + print("\n" + "=" * 60) + print("Neural only") + t0 = time.time() + tp = probs_flat.gather(1, targets_flat.to(DEVICE).unsqueeze(1)).squeeze(1) + neural_bpc = (-torch.log2(tp.clamp(min=1e-30))).mean().item() + print(f" BPC: {neural_bpc:.4f} ({time.time()-t0:.2f}s)") + + # ========================================== + # Test 2: Vectorized KNN + # ========================================== + print("\n" + "=" * 60) + print("Vectorized KNN (chunk_size=1024)") + t0 = time.time() + vec_bpc, vec_bits, vec_scored = score_knn_vectorized( + probs_flat, hidden_flat, targets_flat, + K=8, lam=0.12, chunk_size=1024 + ) + vec_time = time.time() - t0 + vec_imp = (vec_bpc - neural_bpc) / neural_bpc * 100 + print(f" BPC: {vec_bpc:.4f} ({vec_imp:+.2f}%) — {vec_time:.1f}s") + + # ========================================== + # Test 3: Per-token KNN (reference, slow) + # ========================================== + print("\n" + "=" * 60) + print("Per-token KNN (reference, slow)") + t0 = time.time() + ref_bpc, ref_bits, ref_scored = score_knn_pertokern( + probs_np, hidden_np, targets_np, K=8, lam=0.12 + ) + ref_time = time.time() - t0 + ref_imp = (ref_bpc - neural_bpc) / neural_bpc * 100 + print(f" BPC: {ref_bpc:.4f} ({ref_imp:+.2f}%) — {ref_time:.1f}s") + + # ========================================== + # Test 4: Different chunk sizes + # ========================================== + print("\n" + "=" * 60) + print("Chunk size sweep") + for cs in [256, 512, 1024, 2048, 4096]: + t0 = time.time() + bpc, _, _ = score_knn_vectorized( + probs_flat, hidden_flat, targets_flat, + K=8, lam=0.12, chunk_size=cs + ) + elapsed = time.time() - t0 + imp = (bpc - neural_bpc) / neural_bpc * 100 + print(f" chunk={cs:5d}: BPC={bpc:.4f} ({imp:+.2f}%) — {elapsed:.1f}s") + + # ========================================== + # Summary + # ========================================== + print("\n" + "=" * 60) + print("SUMMARY") + print(f" Neural: {neural_bpc:.4f}") + print(f" Vectorized KNN: {vec_bpc:.4f} ({vec_imp:+.2f}%) — {vec_time:.1f}s") + print(f" Per-token KNN: {ref_bpc:.4f} ({ref_imp:+.2f}%) — {ref_time:.1f}s") + print(f" Speedup: {ref_time/vec_time:.0f}×") + + # Correctness check + diff = abs(vec_bpc - ref_bpc) + print(f"\n BPC difference (vec vs ref): {diff:.4f}") + if diff < 0.05: + print(f" CORRECTNESS: PASS (diff < 0.05)") + else: + print(f" CORRECTNESS: WARN — vectorized differs from per-token") + print(f" (Expected: vectorized uses chunk-level causality, per-token uses token-level)") + + # Competition estimate + print(f"\n Competition estimate (62M tokens on 8×H100):") + tokens_per_sec = N_tokens / vec_time + comp_time = 62_000_000 / tokens_per_sec / 8 # 8 GPUs + print(f" Tokens/sec: {tokens_per_sec:.0f}") + print(f" Estimated: {comp_time:.0f}s (limit=600s)") + print(f" Fits: {'YES' if comp_time < 600 else 'NO — need larger chunk or GPU optimization'}") diff --git a/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/submission.json b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/submission.json new file mode 100644 index 0000000000..4b6adc5802 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/submission.json @@ -0,0 +1,9 @@ +{ + "track": "non_record_16mb", + "date": "2026-04-02", + "name": "KNN Hidden State Retrieval — Scale Deception from Weak to Strong Models", + "author": "Himanshu Dongre", + "github_id": "himanshudongre", + "val_bpb": null, + "notes": "Novel eval-time technique tested across 4 model quality levels. Helps weak models (-2 to -4%), HURTS strong competition-quality models (+1.5%). Definitive scale deception finding validated on 8xH100." +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/train_seed42.log b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/train_seed42.log new file mode 100644 index 0000000000..2fbc9d7002 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-02_KNN_Scale_Deception_8xH100/train_seed42.log @@ -0,0 +1,87 @@ +=== SEED 42 FINAL RUN === +Started: Thu Apr 2 15:22:48 UTC 2026 +W0402 15:22:50.003000 147 torch/distributed/run.py:803] +W0402 15:22:50.003000 147 torch/distributed/run.py:803] ***************************************** +W0402 15:22:50.003000 147 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0402 15:22:50.003000 147 torch/distributed/run.py:803] ***************************************** +logs/013213b0-8bde-41cc-92da-66a76d8bb4e4.txt +model_preset:merged_leader run_profile:full_8gpu_600s +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993756 +param_breakdown:{"lexical": 852481, "skip": 2560, "upper_global": 25974872, "value_embedding": 163843} +world_size:8 grad_accum_steps:1 +flash_attn_3_loaded:True +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +activation_mode:leaky_relu2 export_quantizer:rowclip_int6 ttt_optimizer:sgd +muon:banking_enabled:False bank_min_tensors:2 +moonshot lower_replace_layers:0 local_shared_blocks:4 use_unet_skips:True +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9318 train_time:168ms step_avg:167.70ms +step:2/20000 train_loss:8.6373 train_time:278ms step_avg:139.04ms +step:3/20000 train_loss:7.8688 train_time:389ms step_avg:129.55ms +step:4/20000 train_loss:7.2352 train_time:499ms step_avg:124.72ms +step:5/20000 train_loss:7.0081 train_time:609ms step_avg:121.89ms +step:6/20000 train_loss:6.8995 train_time:720ms step_avg:120.04ms +step:7/20000 train_loss:6.7618 train_time:831ms step_avg:118.66ms +step:8/20000 train_loss:6.6891 train_time:941ms step_avg:117.58ms +step:9/20000 train_loss:6.4117 train_time:1051ms step_avg:116.77ms +step:10/20000 train_loss:6.0709 train_time:1162ms step_avg:116.17ms +step:500/20000 train_loss:2.3806 train_time:52899ms step_avg:105.80ms +step:1000/20000 train_loss:2.2605 train_time:105769ms step_avg:105.77ms +step:1500/20000 train_loss:2.2076 train_time:158707ms step_avg:105.80ms +step:2000/20000 train_loss:2.0490 train_time:211582ms step_avg:105.79ms +step:2500/20000 train_loss:2.1516 train_time:264487ms step_avg:105.79ms +step:3000/20000 train_loss:2.1320 train_time:317382ms step_avg:105.79ms +step:3500/20000 train_loss:2.1375 train_time:370358ms step_avg:105.82ms +step:4000/20000 train_loss:1.9271 train_time:423266ms step_avg:105.82ms +step:4000/20000 val_loss:2.0157 val_bpb:1.1938 train_time:423274ms step_avg:105.82ms +step:4500/20000 train_loss:2.0687 train_time:476138ms step_avg:105.81ms +swa:start step:5000 +step:5000/20000 train_loss:2.0446 train_time:529043ms step_avg:105.81ms +step:5500/20000 train_loss:1.9545 train_time:582413ms step_avg:105.89ms +step:5665/20000 val_loss:1.9329 val_bpb:1.1448 train_time:600062ms step_avg:105.92ms +stopping_early: wallclock_cap train_time:600062ms step:5665/20000 +peak memory allocated: 20758 MiB reserved: 20852 MiB +ema:applying best EMA (decay=0.9970 bpb=inf) +Saved post-train checkpoint: /workspace/seed42_model.pt (106177672 bytes) +DIAGNOSTIC post_average val_loss:1.9325 val_bpb:1.1446 eval_time:1980ms +export_grid block:128 refine:3 damp:0.0100 mse:0.00582124 +export_grid block:128 refine:5 damp:0.0100 mse:0.00582124 +export_grid block:128 refine:3 damp:0.0050 mse:0.00582124 +export_grid block:128 refine:5 damp:0.0050 mse:0.00582124 +gptq_quantize: 0 GPTQ layers, 0 naive layers +mixed_precision: 0 int5 params, 25952256 int6 params +Serialized model research_export: 15826144 bytes +Code size: 192279 bytes +Total submission size research_export: 16018423 bytes +final_research_export_roundtrip val_loss:1.9473 val_bpb:1.1533 eval_time:8264ms +final_research_export_sliding skipped +final_research_export_exact val_loss:1.94730372 val_bpb:1.15330295 +final_knn val_loss:1.9758 val_bpb:1.1702 eval_time:168181ms k:8 lam:0.12 +final_knn_exact val_loss:1.97579965 val_bpb:1.17017984 +phase_timings:{"diagnostic_eval_ms": 1980.0734035670757, "quantize_ms": 4838.551789522171, "roundtrip_eval_ms": 43036.73002496362, "serialize_ms": 32567.488629370928, "skipped": {"diagnostic_eval": false, "export": false, "roundtrip_eval": false, "sliding_eval": true}, "sliding_eval_ms": 0.0} +=== SEED 42 DONE: Thu Apr 2 15:38:02 UTC 2026 ===