Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Approach F: Fused Triton MLP Activation Kernel

**val_bpb: TBD** (pending 8xH100 run)
**Artifact: TBD**

## Key Innovation: Fused Triton Activation Kernel

Custom Triton kernels that fuse `relu(x).square()` into a single GPU kernel, eliminating the intermediate hidden-dimension tensor write to HBM. This is a pure systems optimization -- mathematically identical output.

### What's fused

The standard MLP activation path launches two separate elementwise kernels:

```python
# Standard: 2 kernel launches, writes 1792-dim intermediate to HBM
h = torch.relu(self.fc(x)) # elementwise relu, writes to HBM
h = h.square() # elementwise square, reads+writes HBM
out = self.proj(h)
```

The fused kernel combines both operations:

```python
# Fused: 1 kernel launch, no intermediate write
h = fused_relu_sq(self.fc(x)) # relu + square in one pass
out = self.proj(h)
```

This saves one full read+write of the hidden dimension tensor (batch * seq_len * 1792 elements) per layer, per forward and backward pass. With 11 layers:
- Forward: 11 fewer HBM roundtrips
- Backward: 11 fewer HBM roundtrips (fused backward kernel too)

### Expected performance improvement

Based on PR #1072 results (87ms -> 70ms/step with a similar fused kernel), we expect ~15-20% step time reduction. Even a conservative 10% improvement yields ~10% more training steps within the 590s budget.

### Triton dependency

Requires Triton (ships with PyTorch on CUDA). Falls back to standard PyTorch ops if Triton is unavailable. The RunPod `runpod/parameter-golf:latest` image includes Triton.

### Also provides: fused LeakyReLU(0.5)^2

A `fused_leaky_relu_sq(x, neg_slope)` kernel is included for future use with LeakyReLU activation variants.

## Architecture (unchanged from Approach B)

| Component | Detail |
|-----------|--------|
| Layers | 11 |
| Dimension | 512 |
| Heads | 8 query / 8 KV |
| MLP | ReLU² with fused Triton kernel, 3.5x expansion (1792 hidden) |
| Attention | XSA on all 11 layers, Partial RoPE (16/64 dims), QK-norm |
| Embeddings | BigramHash 6144 (128-dim), Value Embeddings on layers 9-10 |
| Skip connections | U-Net with learned per-dim scaling |
| Other | SmearGate, LN depth scaling, logit softcap (30.0) |

## Training

| Parameter | Value |
|-----------|-------|
| Optimizer | Muon + AdamW |
| Batch size | 786,432 tokens |
| Warmdown | 3,500 steps (wallclock-aware) |
| QAT | Late QAT at scale < 0.5 |
| EMA | 0.997 decay |
| SWA | Every 50 steps during warmdown |
| Quantization | Int5 GPTQ + zstd/zlib |
| Pruning | 10% magnitude pruning |

## Rule Compliance

- Fused kernels are a systems-only optimization (legal, no significance test needed)
- No eval changes -- identical scoring methodology
- All assertions preserved (artifact budget, wallclock budget)
- Falls back to standard PyTorch if Triton unavailable

## Credits

Built on Approach B (Int5 GPTQ + larger model). Fused kernel pattern inspired by PR #1072 (Vilhelm Toivonen).
31 changes: 31 additions & 0 deletions records/track_10min_16mb/2026-04-01_ApproachF_FusedTriton/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash
# Approach F: Fused Triton MLP kernel for faster training
# Run on 8xH100 SXM with RunPod
set -euo pipefail

cd /workspace/parameter-golf

# Ensure data is available
if [ ! -d "data/datasets/fineweb10B_sp1024" ]; then
python3 data/cached_challenge_fineweb.py --variant sp1024
fi

export NCCL_IB_DISABLE=1

# Main training run
NCCL_IB_DISABLE=1 RUN_ID=approach_f_fused_triton \
NUM_LAYERS=11 \
MODEL_DIM=512 \
NUM_HEADS=8 \
NUM_KV_HEADS=8 \
MLP_MULT=3.5 \
BIGRAM_VOCAB_SIZE=6144 \
BIGRAM_DIM=128 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \
ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=590 \
TRAIN_SEQ_LEN=2048 EVAL_SEQ_LEN=2048 EVAL_STRIDE=64 \
PRUNE_PCT=0.10 \
torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee /workspace/run.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "Approach F: Fused Triton MLP Activation Kernel",
"val_bpb": null,
"bytes_total": null,
"blurb": "Fused Triton kernel for relu(x).square() activation eliminates intermediate HBM write in MLP forward+backward across 11 layers. Combined with Int5 GPTQ, 33.6M params, U-Net skips, XSA, BigramHash, SmearGate, value embeddings, SWA, EMA, and score-first TTT.",
"author": "Alex Ibarra",
"github_id": "elninja",
"date": "2026-04-01"
}
Loading