Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# PR703 + Shard-Order Curriculum + GPTQ Cache-Backout

This is a `non-record-16mb` submission built from the PR703-style quant/cache-backout branch and improved with a score-ranked shard curriculum.

This run is being submitted as **non-record** for two reasons:

- it is a single-seed result, not a multi-seed statistically significant record package
- the improvement over the current accepted README leaderboard is well below the `0.005`-nat record threshold described in the root `README.md`

## Result

- `final_int6_sliding_window_exact`: `1.11709895`
- `final_int6_roundtrip_exact`: `1.14068680`
- `post_ema`: `1.1368`
- `step_stop`: `6918`
- `step_avg`: `86.75ms`
- `total submission size`: `15,909,560` bytes
- `bytes under 16MB`: `90,440`

## Core Change Relative to the Forked PR703 Base

The base PR703 carryover result was:

- `1.11748714`
- `15,963,300` bytes

This submission improves that branch mainly by:

1. `Shard-order curriculum`
Training shards are reordered by a lightweight scorer so the run sees harder shards earlier.

2. `Tighter final compression`
Final int6 payload uses a stronger `lzma` preset, preserving the same core model family while giving more artifact headroom.

The winning object is still the same general PR703-style branch:

- 11-layer trunk
- cache/backout path
- full-Hessian GPTQ over the banked-attn/MLP surface
- `BIGRAM_VOCAB_SIZE=1536`
- no TTT

## Reproduction

This submission depends on a generated `shard_order.json`. The run used the same shard scorer included here as `score_shards.py`.

First generate shard order:

```bash
python score_shards.py --data-dir ./data/datasets/fineweb10B_sp1024 --device cuda:0 --seq-len 1024 --train-steps 500 --max-batches 50 --batch-size 16 --output shard_order.json
```

Then launch training:

```bash
SHARD_ORDER_FILE=./shard_order.json SEED=2025 NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 ROPE_DIMS=16 LN_SCALE=1 VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 MILE_GAMMA=1.1 MUON_QUANT_MOMENTUM=1 CACHE_LAYER=7 BACKOUT_LAMBDA_INIT=0.1 MUON_WD=0.04 ADAM_WD=0.04 MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 GPTQ_CALIB_BATCHES=256 GPTQ_BLOCK_SIZE=128 torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Included Files

- `train_gpt.py`: exact code snapshot used by the winning run
- `score_shards.py`: shard-order scorer used to generate the curriculum input
- `train.log`: exact controller log for the submitted run
- `submission.json`: leaderboard metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import glob
import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


HEADER_WORDS = 256
HEADER_MAGIC = 20240520
HEADER_VERSION = 1


def load_data_shard(path: Path, vocab_size: int) -> torch.Tensor:
header = np.fromfile(path, dtype="<i4", count=HEADER_WORDS)
if header.size != HEADER_WORDS or int(header[0]) != HEADER_MAGIC or int(header[1]) != HEADER_VERSION:
raise ValueError(f"unexpected shard header: {path}")
num_tokens = int(header[2])
tokens = np.fromfile(
path,
dtype="<u2",
count=num_tokens,
offset=HEADER_WORDS * np.dtype("<i4").itemsize,
)
if tokens.size != num_tokens:
raise ValueError(f"short read for {path}")
return torch.from_numpy(np.clip(tokens.astype(np.int64, copy=False), 0, vocab_size - 1))


class MiniGPT(nn.Module):
def __init__(self, vocab_size: int, model_dim: int, num_layers: int, num_heads: int, mlp_mult: float):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, model_dim)
self.blocks = nn.ModuleList(
[
nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
dim_feedforward=int(model_dim * mlp_mult),
batch_first=True,
norm_first=True,
dropout=0.0,
activation="gelu",
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(model_dim)
self.head = nn.Linear(model_dim, vocab_size, bias=False)
self.head.weight = self.tok_emb.weight

def forward(self, x: torch.Tensor) -> torch.Tensor:
_, seq_len = x.shape
mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=x.device)
hidden = self.tok_emb(x)
for block in self.blocks:
hidden = block(hidden, src_mask=mask, is_causal=True)
return self.head(self.norm(hidden))


def score_shard(
model: nn.Module,
tokens: torch.Tensor,
device: torch.device,
seq_len: int,
max_batches: int,
batch_size: int,
) -> float:
model.eval()
num_sequences = len(tokens) // (seq_len + 1)
if num_sequences == 0:
return float("inf")
step = max(1, num_sequences // max(max_batches * batch_size, 1))
total_loss = 0.0
total_tokens = 0
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
for batch_index in range(0, min(num_sequences, max_batches * batch_size), batch_size):
starts = [
((batch_index + offset) * step) * (seq_len + 1)
for offset in range(batch_size)
if (batch_index + offset) * step < num_sequences
]
if not starts:
break
x = torch.stack([tokens[start : start + seq_len].to(device) for start in starts])
y = torch.stack([tokens[start + 1 : start + seq_len + 1].to(device) for start in starts])
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y.reshape(-1), reduction="sum")
total_loss += float(loss.item())
total_tokens += int(y.numel())
return total_loss / max(total_tokens, 1)


def train_steps(
model: nn.Module,
tokens: torch.Tensor,
device: torch.device,
*,
steps: int,
seq_len: int,
batch_size: int,
learning_rate: float,
) -> None:
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
num_sequences = len(tokens) // (seq_len + 1)
for step in range(steps):
start_index = (step * batch_size) % max(num_sequences, 1)
starts = [
(start_index + offset) * (seq_len + 1)
for offset in range(batch_size)
if start_index + offset < num_sequences
]
if not starts:
continue
x = torch.stack([tokens[start : start + seq_len].to(device) for start in starts])
y = torch.stack([tokens[start + 1 : start + seq_len + 1].to(device) for start in starts])
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(x)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), y.reshape(-1))
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if (step + 1) % 100 == 0:
print(f"train_step:{step + 1}/{steps} loss:{loss.item():.4f}")


def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Rank training shards by remaining loss after a short warmup.")
parser.add_argument("--data-dir", default="./data/datasets/fineweb10B_sp1024")
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--vocab-size", type=int, default=1024)
parser.add_argument("--model-dim", type=int, default=512)
parser.add_argument("--layers", type=int, default=6)
parser.add_argument("--heads", type=int, default=8)
parser.add_argument("--mlp-mult", type=float, default=3.0)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--max-batches", type=int, default=50)
parser.add_argument("--train-steps", type=int, default=500)
parser.add_argument("--learning-rate", type=float, default=1e-3)
parser.add_argument(
"--output",
default="shard_order.json",
help="JSON file to write in the current directory.",
)
return parser


def main() -> None:
args = build_parser().parse_args()
device = torch.device(args.device)
train_files = sorted(glob.glob(str(Path(args.data_dir) / "fineweb_train_*.bin")))
if not train_files:
raise FileNotFoundError(f"no training shards found under {args.data_dir}")

model = MiniGPT(args.vocab_size, args.model_dim, args.layers, args.heads, args.mlp_mult).to(device)
print(f"train_shards:{len(train_files)} model_params:{sum(p.numel() for p in model.parameters()):,}")

random_scores: dict[int, float] = {}
for idx, shard_path in enumerate(train_files):
tokens = load_data_shard(Path(shard_path), args.vocab_size)
loss = score_shard(model, tokens, device, args.seq_len, args.max_batches, args.batch_size)
random_scores[idx] = loss
print(f"random_score:{idx} loss:{loss:.6f}")

first_shard = load_data_shard(Path(train_files[0]), args.vocab_size)
train_steps(
model,
first_shard,
device,
steps=args.train_steps,
seq_len=args.seq_len,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
)

trained_scores: dict[int, float] = {}
for idx, shard_path in enumerate(train_files):
tokens = load_data_shard(Path(shard_path), args.vocab_size)
loss = score_shard(model, tokens, device, args.seq_len, args.max_batches, args.batch_size)
trained_scores[idx] = loss
print(f"trained_score:{idx} loss:{loss:.6f}")

ranking = [
{
"shard_index": idx,
"random_loss": random_scores[idx],
"trained_loss": trained_scores[idx],
"learned_delta": random_scores[idx] - trained_scores[idx],
}
for idx in range(len(train_files))
]
ranking.sort(key=lambda item: item["trained_loss"], reverse=True)
order = [item["shard_index"] for item in ranking]
summary = {
"order": order,
"train_steps": args.train_steps,
"seq_len": args.seq_len,
"batch_size": args.batch_size,
"max_batches": args.max_batches,
"remaining_loss_min": ranking[-1]["trained_loss"],
"remaining_loss_max": ranking[0]["trained_loss"],
"remaining_loss_std": float(np.std([item["trained_loss"] for item in ranking])),
}
output_path = Path(args.output).resolve()
output_path.write_text(json.dumps({"summary": summary, "ranking": ranking}, indent=2), encoding="utf-8")
print(f"recommended_order:{','.join(str(idx) for idx in order)}")
print(f"wrote:{output_path}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"author": "Peter",
"github_id": "petergpt",
"name": "PR703 + Shard-Order Curriculum + GPTQ Cache-Backout",
"blurb": "Non-record 8xH100 SXM submission: a PR703-style quant/cache-backout branch improved with score-ranked shard curriculum and stronger final int6+lzma packing. Single-seed result reached 1.11709895 sliding-window exact under the 16MB artifact cap.",
"date": "2026-03-25T22:44:46Z",
"track": "non-record-16mb",
"val_loss": 1.88616979,
"val_bpb": 1.11709895,
"pre_quant_val_loss": 1.9195,
"pre_quant_val_bpb": 1.1368,
"step_stop": 6918,
"wallclock_seconds": 600.12,
"bytes_total": 15909560,
"bytes_model_int6_lzma": 15798304,
"bytes_code": 111256,
"gpu": "8xH100 SXM"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
experiment:pr703_curriculum_only
W0325 22:47:14.475000 582 torch/distributed/run.py:803]
W0325 22:47:14.475000 582 torch/distributed/run.py:803] *****************************************
W0325 22:47:14.475000 582 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 22:47:14.475000 582 torch/distributed/run.py:803] *****************************************
logs/e1e9df51-7f27-4a41-9cec-f3b0bdebb012.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/parameter-golf-winning-stack/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=/workspace/parameter-golf-winning-stack/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26928221
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_4 active_layers:[7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000
seed:2025
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/9000 val_loss:6.9302 val_bpb:4.1045 train_time:0ms step_avg:0.01ms
step:1/9000 train_loss:6.9311 train_time:107ms step_avg:106.67ms
step:2/9000 train_loss:8.6119 train_time:143ms step_avg:71.72ms
step:3/9000 train_loss:7.6526 train_time:225ms step_avg:75.03ms
step:4/9000 train_loss:7.3924 train_time:307ms step_avg:76.78ms
step:5/9000 train_loss:7.2751 train_time:389ms step_avg:77.83ms
step:6/9000 train_loss:7.0734 train_time:471ms step_avg:78.52ms
step:7/9000 train_loss:6.9761 train_time:554ms step_avg:79.08ms
step:8/9000 train_loss:6.9682 train_time:635ms step_avg:79.41ms
step:9/9000 train_loss:6.6304 train_time:718ms step_avg:79.81ms
step:10/9000 train_loss:6.2255 train_time:800ms step_avg:79.98ms
step:500/9000 train_loss:2.9876 train_time:42721ms step_avg:85.44ms
step:1000/9000 train_loss:2.8630 train_time:85990ms step_avg:85.99ms
step:1500/9000 train_loss:2.8506 train_time:129254ms step_avg:86.17ms
step:2000/9000 train_loss:2.7479 train_time:172590ms step_avg:86.30ms
step:2500/9000 train_loss:2.8058 train_time:215988ms step_avg:86.40ms
step:3000/9000 train_loss:2.7985 train_time:259389ms step_avg:86.46ms
step:3500/9000 train_loss:2.8133 train_time:302803ms step_avg:86.52ms
step:4000/9000 train_loss:2.6067 train_time:346188ms step_avg:86.55ms
step:4000/9000 val_loss:2.0520 val_bpb:1.2153 train_time:346235ms step_avg:86.56ms
step:4500/9000 train_loss:2.6510 train_time:389586ms step_avg:86.57ms
step:5000/9000 train_loss:2.5632 train_time:432989ms step_avg:86.60ms
step:5500/9000 train_loss:2.3861 train_time:476351ms step_avg:86.61ms
step:6000/9000 train_loss:2.1448 train_time:519703ms step_avg:86.62ms
swa:start step:6250
late_qat:enabled step:6400 scale:0.1495
step:6500/9000 train_loss:2.0681 train_time:563373ms step_avg:86.67ms
step:6918/9000 val_loss:1.9211 val_bpb:1.1378 train_time:600120ms step_avg:86.75ms
stopping_early: wallclock_cap train_time:600120ms step:6918/9000
peak memory allocated: 22403 MiB reserved: 23086 MiB
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.9195 val_bpb:1.1368 eval_time:2017ms
Serialized model: 106027703 bytes
Code size: 111256 bytes
gptq:building non-banked model for Hessian collection...
gptq:calibrating with 256 batches...
gptq:collected hessians for 68 layers
Serialized model int6+lzma: 15798304 bytes
Total submission size int6+lzma: 15909560 bytes
final_int6_roundtrip val_loss:1.9260 val_bpb:1.1407 eval_time:15937ms
final_int6_roundtrip_exact val_loss:1.92600187 val_bpb:1.14068680
final_int6_sliding_window val_loss:1.8862 val_bpb:1.1171 stride:64 eval_time:92158ms
final_int6_sliding_window_exact val_loss:1.88616979 val_bpb:1.11709895
final_int8_zlib_roundtrip_exact val_loss:1.88616979 val_bpb:1.11709895
Loading