Skip to content

Commit ca1a938

Browse files
abaybektursunclaude
andcommitted
Non-record: Training data ordering by model perplexity (-0.0033 BPB)
Case study: reordering training shards by model difficulty (hardest first) gives -0.0033 BPB improvement over sequential ordering. Zero architecture changes, zero compute cost, ten lines of code. Key finding: token-level statistics (KL divergence) find 0.0009 range across shards. Model perplexity finds 0.0475 range -- 100x more variation. The two metrics are uncorrelated (r = -0.056). 3-seed validated on PR #549 (merged #1): Seed 1337: 1.1217 -> 1.1183 (-0.0034) Seed 42: 1.1222 -> 1.1181 (-0.0041) Seed 2025: 1.1221 -> 1.1198 (-0.0023) Mean: 1.1220 -> 1.1187 (-0.0033) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 630bb5e commit ca1a938

File tree

13 files changed

+3033
-0
lines changed

13 files changed

+3033
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# You're Training on 57% of Your Data. Does It Matter Which 57%?
2+
3+
**Non-record submission: a case study on training data ordering.**
4+
5+
Everyone in this competition trains sequentially through the shards. Nobody questions the order. We did — and found a free **-0.0033 BPB** improvement by reordering shards based on model perplexity. Zero architecture changes, zero hyperparameter tuning, ten lines of code.
6+
7+
## The Setup
8+
9+
At 83ms/step we get ~7,200 steps in 600s. That's 5.66B tokens out of 10B — **57% of the dataset**. Shards 0 through ~45 get seen. Shards 46-79 never do.
10+
11+
## Token Statistics Say It Doesn't Matter
12+
13+
We computed KL(val || shard) for all 80 training shards. Every shard has essentially the same token distribution. Range: 0.0009. Translates to ~0.00005 BPB. Dead end.
14+
15+
![Token frequencies vs model perplexity](figures/fig1_kl_vs_perplexity.png)
16+
17+
Left panel. All shards look identical by token frequency.
18+
19+
## The Model Disagrees
20+
21+
Trained a model 500 steps on one shard, then scored all 80 shards by cross-entropy loss. Right panel above. Range: **0.0475** — 100× larger than the KL signal. The shards are NOT all the same. Token statistics just can't see the difference.
22+
23+
As expected — KL counts tokens, the model scores sequences. But the magnitude of the gap is what matters: the model finds **100× more variation** than token statistics do.
24+
25+
![KL rank vs perplexity rank](figures/fig4_the_insight.png)
26+
27+
**r = -0.056.** The two metrics are uncorrelated. The shard most similar to val by token frequency is middling by model difficulty.
28+
29+
## Where the Hard Shards Are
30+
31+
The difficulty is about content, not position. Hard and easy shards are scattered randomly across the dataset.
32+
33+
![Shard difficulty heatmap](figures/fig2_shard_heatmap.png)
34+
35+
Shard 44 (hardest, rank #1) sits next to shard 43 (rank #48). Sequential ordering — the default in every training framework — isn't optimized for anything except simplicity.
36+
37+
![Sequential vs optimal ordering](figures/fig3_sequential_vs_optimal.png)
38+
39+
## Results: 3-Seed Validated
40+
41+
Reran our merged #1 submission (PR #549) with shards reordered hardest-first. Same code, same hyperparameters, same compute budget.
42+
43+
| Seed | Sequential (PR #549) | Hardest-first | Delta |
44+
|------|---------------------|--------------|-------|
45+
| 1337 | 1.1217 | **1.1183** | **-0.0034** |
46+
| 42 | 1.1222 | **1.1181** | **-0.0041** |
47+
| 2025 | 1.1221 | **1.1198** | **-0.0023** |
48+
| **Mean** | **1.1220** | **1.1187** | **-0.0033** |
49+
50+
Every seed improves. Mean improvement: **-0.0033 BPB.**
51+
52+
![Three-seed comparison](figures/fig5_three_seed_comparison.png)
53+
54+
![Improvement consistency and cost](figures/fig6_delta_consistency.png)
55+
56+
For context: our last three PRs each took days of architecture and quantization work to gain 0.001-0.003 BPB. This took ten lines of code.
57+
58+
## The Change
59+
60+
```python
61+
class TokenStream:
62+
def __init__(self, pattern: str):
63+
self.files = [Path(p) for p in sorted(glob.glob(pattern))]
64+
# NEW: reorder shards by model difficulty (hardest first)
65+
shard_order = os.environ.get("SHARD_ORDER", "")
66+
if shard_order:
67+
order = [int(x) for x in shard_order.split(",")]
68+
reordered = [self.files[i] for i in order if i < len(self.files)]
69+
remaining = [f for i, f in enumerate(self.files) if i not in set(order)]
70+
self.files = reordered + remaining
71+
```
72+
73+
```bash
74+
SHARD_ORDER=44,63,65,42,18,67,30,69,61,3,13,19,50,49,56,45,73,79,57,32,\
75+
28,68,66,34,46,38,17,77,0,14,26,74,59,62,41,9,58,22,78,4,48,8,12,27,75,\
76+
36,16,43,52,15,33,47,25,55,54,23,37,51,31,21,60,1,20,72,24,53,39,35,71,\
77+
76,40,5,10,2,7,6,70,11,64,29
78+
```
79+
80+
## Method: How We Ranked Shards
81+
82+
1. Train a 6-layer, 512d model for 500 steps on shard 0 (single GPU, ~40 seconds)
83+
2. Score all 80 shards by cross-entropy loss with this partially-trained model
84+
3. Sort shards by loss descending (hardest first)
85+
4. Pass the ordering via `SHARD_ORDER` environment variable
86+
87+
The ranking model is deliberately small and undertrained — it captures which shards have patterns the model hasn't learned yet. A fully-trained model would rank differently (everything is "easy" by then).
88+
89+
## Open Questions
90+
91+
- **Adaptive curriculum**: Re-rank shards every 1,000 steps as the model learns. The optimal ordering probably changes during training.
92+
- **Anti-curriculum**: We haven't tested easiest-first. It might help build foundations before tackling hard patterns. *(Experiment running.)*
93+
- **Transfer across architectures**: Our ranking was done with a 6-layer model. Does it transfer to 11-layer? To different hyperparameters?
94+
- **Interaction with SWA/EMA**: Does the ordering effect survive weight averaging?
95+
96+
## Credits
97+
98+
- Original idea (train on similar data): Lucas Fievet
99+
- Base model: [PR #549](https://github.com/openai/parameter-golf/pull/549) by @abaybektursun (merged #1)
100+
- Analysis and implementation: @abaybektursun
69.7 KB
Loading
168 KB
Loading
94.7 KB
Loading
48.1 KB
Loading
63.6 KB
Loading
65.5 KB
Loading
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""
2+
Score each training shard by model perplexity using a simple approach.
3+
4+
1. Score all shards with random model (inherent difficulty baseline)
5+
2. Train 500 steps on shard 0
6+
3. Score all shards again (what's still hard after partial training)
7+
4. Rank by remaining loss
8+
9+
Usage (single GPU is fine):
10+
python3 analysis/score_shards_simple.py --data-dir ./data/datasets/fineweb10B_sp1024
11+
"""
12+
13+
import argparse
14+
import glob
15+
import math
16+
import time
17+
import numpy as np
18+
import torch
19+
import torch.nn as nn
20+
import torch.nn.functional as F
21+
from pathlib import Path
22+
23+
24+
class MiniGPT(nn.Module):
25+
"""Minimal GPT for shard scoring. Same architecture shape as competition model."""
26+
def __init__(self, vocab=1024, dim=512, layers=6, heads=8):
27+
super().__init__()
28+
self.tok_emb = nn.Embedding(vocab, dim)
29+
self.blocks = nn.ModuleList([
30+
nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*3,
31+
batch_first=True, norm_first=True, dropout=0.0)
32+
for _ in range(layers)
33+
])
34+
self.norm = nn.LayerNorm(dim)
35+
self.head = nn.Linear(dim, vocab, bias=False)
36+
self.head.weight = self.tok_emb.weight # tie embeddings
37+
38+
def forward(self, x):
39+
B, T = x.shape
40+
h = self.tok_emb(x)
41+
mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
42+
for block in self.blocks:
43+
h = block(h, src_mask=mask, is_causal=True)
44+
return self.head(self.norm(h))
45+
46+
47+
def load_shard(path, vocab_size=1024):
48+
tokens = np.fromfile(path, dtype=np.uint16).astype(np.int64)
49+
tokens = np.clip(tokens, 0, vocab_size - 1)
50+
return torch.from_numpy(tokens)
51+
52+
53+
def score_shard(model, tokens, device, seq_len=1024, max_batches=50, batch_size=16):
54+
model.eval()
55+
n_seqs = len(tokens) // (seq_len + 1)
56+
if n_seqs == 0:
57+
return float('inf')
58+
59+
step = max(1, n_seqs // (max_batches * batch_size))
60+
total_loss = 0.0
61+
total_tokens = 0
62+
63+
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
64+
for bi in range(0, min(n_seqs, max_batches * batch_size), batch_size):
65+
batch_starts = [((bi + b) * step) * (seq_len + 1) for b in range(batch_size)
66+
if (bi + b) * step < n_seqs]
67+
if not batch_starts:
68+
break
69+
x = torch.stack([tokens[s:s+seq_len].to(device) for s in batch_starts])
70+
y = torch.stack([tokens[s+1:s+seq_len+1].to(device) for s in batch_starts])
71+
logits = model(x)
72+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(),
73+
y.reshape(-1), reduction="sum")
74+
total_loss += loss.item()
75+
total_tokens += y.numel()
76+
77+
return total_loss / max(total_tokens, 1)
78+
79+
80+
def train_steps(model, tokens, device, steps=500, seq_len=1024, batch_size=16, lr=0.001):
81+
model.train()
82+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
83+
n_seqs = len(tokens) // (seq_len + 1)
84+
85+
for step in range(steps):
86+
idx = (step * batch_size) % n_seqs
87+
batch_starts = [(idx + b) * (seq_len + 1) for b in range(batch_size) if idx + b < n_seqs]
88+
if not batch_starts:
89+
continue
90+
x = torch.stack([tokens[s:s+seq_len].to(device) for s in batch_starts])
91+
y = torch.stack([tokens[s+1:s+seq_len+1].to(device) for s in batch_starts])
92+
93+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
94+
logits = model(x)
95+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y.reshape(-1))
96+
97+
optimizer.zero_grad()
98+
loss.backward()
99+
optimizer.step()
100+
101+
if (step + 1) % 100 == 0:
102+
print(f" Train step {step+1}/{steps}, loss: {loss.item():.4f}")
103+
104+
105+
def main():
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument("--data-dir", required=True)
108+
parser.add_argument("--train-steps", type=int, default=500)
109+
parser.add_argument("--device", default="cuda:0")
110+
args = parser.parse_args()
111+
112+
device = torch.device(args.device)
113+
train_files = sorted(glob.glob(str(Path(args.data_dir) / "fineweb_train_*.bin")))
114+
val_files = sorted(glob.glob(str(Path(args.data_dir) / "fineweb_val_*.bin")))
115+
print(f"Found {len(train_files)} train shards, {len(val_files)} val shards")
116+
117+
model = MiniGPT(vocab=1024, dim=512, layers=6, heads=8).to(device)
118+
print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
119+
120+
# Phase 1: Score with random model
121+
print(f"\n{'='*60}")
122+
print("PHASE 1: Random model scoring")
123+
print(f"{'='*60}")
124+
random_scores = {}
125+
for i, f in enumerate(train_files):
126+
tokens = load_shard(f)
127+
loss = score_shard(model, tokens, device)
128+
random_scores[i] = loss
129+
if (i + 1) % 10 == 0 or i == len(train_files) - 1:
130+
print(f" [{i+1}/{len(train_files)}] shard {i}: loss={loss:.4f}")
131+
132+
val_tokens = torch.cat([load_shard(f) for f in val_files])
133+
val_random = score_shard(model, val_tokens, device)
134+
print(f" Val loss (random): {val_random:.4f}")
135+
136+
# Phase 2: Train on shard 0
137+
print(f"\n{'='*60}")
138+
print(f"PHASE 2: Training {args.train_steps} steps on shard 0")
139+
print(f"{'='*60}")
140+
train_tokens = load_shard(train_files[0])
141+
train_steps(model, train_tokens, device, steps=args.train_steps)
142+
143+
# Phase 3: Score with trained model
144+
print(f"\n{'='*60}")
145+
print("PHASE 3: Trained model scoring")
146+
print(f"{'='*60}")
147+
trained_scores = {}
148+
for i, f in enumerate(train_files):
149+
tokens = load_shard(f)
150+
loss = score_shard(model, tokens, device)
151+
trained_scores[i] = loss
152+
if (i + 1) % 10 == 0 or i == len(train_files) - 1:
153+
print(f" [{i+1}/{len(train_files)}] shard {i}: loss={loss:.4f}")
154+
155+
val_trained = score_shard(model, val_tokens, device)
156+
print(f" Val loss (trained): {val_trained:.4f}")
157+
158+
# Results
159+
print(f"\n{'='*60}")
160+
print("RESULTS: Shard ranking by remaining loss (highest = most to learn)")
161+
print(f"{'='*60}")
162+
print(f"{'Rank':>4} {'Shard':>6} {'Random':>10} {'Trained':>10} {'Remaining':>10} {'Learned':>10}")
163+
print("-" * 60)
164+
165+
shards = [(i, random_scores[i], trained_scores[i],
166+
trained_scores[i], random_scores[i] - trained_scores[i])
167+
for i in range(len(train_files))]
168+
shards.sort(key=lambda x: -x[3]) # sort by remaining loss descending
169+
170+
for rank, (idx, rand, trained, remaining, learned) in enumerate(shards):
171+
print(f"{rank+1:>4} {idx:>6} {rand:>10.4f} {trained:>10.4f} {remaining:>10.4f} {learned:>10.4f}")
172+
173+
# Key metrics
174+
losses = [s[3] for s in shards]
175+
loss_range = max(losses) - min(losses)
176+
loss_std = np.std(losses)
177+
178+
print(f"\n{'='*60}")
179+
print("SUMMARY")
180+
print(f"{'='*60}")
181+
print(f"Remaining loss range: {min(losses):.4f}{max(losses):.4f} (range: {loss_range:.4f})")
182+
print(f"Remaining loss std: {loss_std:.4f}")
183+
print(f"Val loss: {val_trained:.4f}")
184+
print(f"")
185+
if loss_range < 0.01:
186+
print("VERDICT: Small range — shards are similar even at model-perplexity level.")
187+
print("Shard reordering unlikely to help significantly.")
188+
else:
189+
print(f"VERDICT: Range of {loss_range:.4f} — meaningful variation!")
190+
print("Shard reordering could improve BPB.")
191+
192+
recommended = [s[0] for s in shards] # already sorted by remaining loss desc
193+
print(f"\nRecommended order (hardest first): {recommended[:20]}...")
194+
print(f"Skip (easiest): {recommended[-10:]}")
195+
196+
197+
if __name__ == "__main__":
198+
main()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"name": "Case Study: Training Data Ordering by Model Perplexity",
3+
"val_bpb": 1.1187,
4+
"bytes_total": 15921633,
5+
"blurb": "Non-record case study: reordering training shards by model perplexity (hardest first) gives -0.0033 BPB for free. Token-level statistics (KL divergence) miss 100x of the variation that model-level scoring reveals. 3-seed validated. Ten lines of code, zero compute cost.",
6+
"author": "abaybektursun",
7+
"github_id": "abaybektursun",
8+
"date": "2026-03-24"
9+
}

0 commit comments

Comments
 (0)