diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/README.md b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/README.md new file mode 100644 index 0000000000..34c8746aec --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/README.md @@ -0,0 +1,20 @@ +# GDN-Hybrid + Sliding Window Attention (3-seed mean 1.02045733 BPB) + +## Per-seed authoritative results + +| Seed | Steps | EMA BPB | Quantized BPB | XSA BPB | Artifact bytes | +|------|------:|--------:|--------------:|--------:|---------------:| +| 42 | 1864 | 1.017723 | 1.026791 | 1.031731 | 15,313,984 | +| 1337 | 2239 | 1.007375 | 1.016586 | 1.020691 | 15,830,308 | +| 2024 | 2241 | 1.008736 | 1.017995 | 1.023138 | 15,820,201 | +| **Mean** | — | **1.011278** | **1.02045733** | **1.025187** | **15,654,831.00** | +| **Std (sample)** | — | — | **0.00553017** | — | — | + +## Technique stack + +1. **SP1024 tokenizer** with a GDN-hybrid backbone (`[GDN×5] → SWA → [GDN×5] → SWA_shared`). +2. **Fixed-predictor / no-TTT Track-A lane** — no eval-time or pre-quant adaptation in the scored artifact. +3. **MuonEq-R + AdamW** training mix, EMA `0.997`, late QAT threshold `0.15`. +4. **GPTQ int6 + zstd-22** packaging. +5. **Sliding-window attention side path** present in-model, but submission authority remains the pulled `quantized_bpb` values above. + diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/architectures.py b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/architectures.py new file mode 100644 index 0000000000..06af64d660 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/architectures.py @@ -0,0 +1,637 @@ +"""GDN Hybrid Architecture — modular blocks using FLA native layers. + +Supports model variants for the Parameter Golf Direction-5 experiments. +Each model is a stack of mixed {GDN, DeltaProduct, Mamba-2, SWA} blocks +with shared MLP, RMSNorm, and residual connections. + +Key design choices: +- FLA layers handle recurrent attention (GatedDeltaNet, GatedDeltaProduct, Mamba2) +- Sliding Window Attention (SWA) uses flash attention with a causal window mask +- All blocks follow the same pre-norm residual pattern for uniform gradient flow +- Weight sharing for SWA layers in Griffin/Zamba-style models +- forward_hidden() exposes (hidden_states, logits) for RLS eval +""" +from __future__ import annotations +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +# ─── FLA backend selection ────────────────────────────────────────────────── +# Set FLA_USE_NAIVE=1 to force pure-PyTorch (naive) kernels instead of Triton. +_USE_NAIVE = os.environ.get("FLA_USE_NAIVE", "0") == "1" + +if _USE_NAIVE: + import fla.ops.gated_delta_rule.chunk as _gdr_chunk + import fla.ops.gated_delta_rule.naive as _gdr_naive + + def _patched_chunk_gated_delta_rule( + q, k, v, g, beta, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdr_naive.naive_chunk_gated_delta_rule( + q, k, v, g, beta, + chunk_size=64, scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + ) + + _gdr_chunk.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + import fla.layers.gated_deltanet as _gdn_layer + _gdn_layer.chunk_gated_delta_rule = _patched_chunk_gated_delta_rule + + import fla.ops.gated_delta_product.chunk as _gdp_chunk + import fla.ops.gated_delta_product.naive as _gdp_naive + + def _patched_chunk_gated_delta_product( + q, k, v, g, beta, num_householder=1, scale=None, initial_state=None, + output_final_state=False, use_qk_l2norm_in_kernel=False, **kwargs + ): + if use_qk_l2norm_in_kernel: + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + return _gdp_naive.naive_recurrent_gated_delta_product( + q, k, v, g, beta, + scale=scale, cu_seqlens=None, + initial_state=initial_state, + output_final_state=output_final_state, + num_householder=num_householder, + ) + + _gdp_chunk.chunk_gated_delta_product = _patched_chunk_gated_delta_product + import fla.layers.gated_deltaproduct as _gdp_layer + _gdp_layer.chunk_gated_delta_product = _patched_chunk_gated_delta_product + + print("[FLA] Using NAIVE (pure-PyTorch) kernels — set FLA_USE_NAIVE=0 for Triton", flush=True) + +# FLA imports +from fla.layers import GatedDeltaNet, GatedDeltaProduct, Mamba2 +try: + from fla.layers import RWKV7Attention +except Exception: + RWKV7Attention = None # type: ignore + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + def flash_attn_3_func(q, k, v, causal=False, window_size=(-1, -1)): + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + if k2.size(1) != q2.size(1): + rep = q2.size(1) // k2.size(1) + k2 = k2.repeat_interleave(rep, dim=1) + v2 = v2.repeat_interleave(rep, dim=1) + out = torch.nn.functional.scaled_dot_product_attention(q2, k2, v2, is_causal=causal) + return out.transpose(1, 2) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int | None = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + """Linear layer that casts input to weight dtype for mixed precision. + Supports late QAT (int6 STE) when _qat_enabled is set.""" + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(dtype=x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() # STE: forward uses quantized, backward uses full + bias = self.bias.to(dtype=x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +class Rotary(nn.Module): + """RoPE embeddings for sliding window attention.""" + def __init__(self, dim: int, base: float = 10000.0, max_len: int = 4096): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_len = max_len + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device)) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE to the input tensor.""" + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + out1 = x1 * cos[:x.shape[-2]] - x2 * sin[:x.shape[-2]] + out2 = x2 * cos[:x.shape[-2]] + x1 * sin[:x.shape[-2]] + return torch.cat([out1, out2], dim=-1) + + +class MLP(nn.Module): + """Feed-forward MLP with configurable activation.""" + def __init__(self, dim: int, mult: float = 3.0, act: str = "relu_sq", leaky_slope: float = 0.5): + super().__init__() + hidden = int(mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + nn.init.zeros_(self.proj.weight) + self.act = act + self.leaky_slope = leaky_slope + + def forward(self, x: Tensor) -> Tensor: + x = self.fc(x) + if self.act == "leaky_relu_sq": + x = F.leaky_relu(x, negative_slope=self.leaky_slope) + else: + x = F.relu(x) + return self.proj(x.square()) + + +class SlidingWindowAttention(nn.Module): + """Sliding window causal attention for hybrid models. + + Supports XSA (cross-segment attention) at eval time for extending context + across eval chunks. Window is enforced during training but can be relaxed at eval. + KV can be shared across layers (Zamba-style) by reusing the same module. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + num_kv_heads: int = 4, + window_size: int = 512, + rope_base: float = 10000.0, + qk_gain_init: float = 1.5, + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.window_size = window_size + + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + nn.init.zeros_(self.proj.weight) + + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(B, T, self.num_kv_heads, self.head_dim) + + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(T, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + + if q.is_cuda and q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) + + y = flash_attn_3_func(q, k, v, causal=True) + + if self.use_xsa: + y = self._xsa_efficient(y, v) + + y = y.reshape(B, T, D) + return self.proj(y) + + +class RecurrentBlock(nn.Module): + """Wraps any FLA recurrent layer (GDN, DeltaProduct, Mamba-2) with + pre-norm residual connection and MLP.""" + + def __init__( + self, + dim: int, + recurrent_layer: nn.Module, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.recurrent = recurrent_layer + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + recurrent_out = self.recurrent(self.attn_norm(x_in)) + if isinstance(recurrent_out, tuple): + recurrent_out = recurrent_out[0] + + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * recurrent_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class AttentionBlock(nn.Module): + """SWA block with pre-norm residual and MLP.""" + + def __init__( + self, + dim: int, + swa: SlidingWindowAttention, + mlp_mult: float = 3.0, + mlp_act: str = "relu_sq", + layer_idx: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm(dim) + self.mlp_norm = RMSNorm(dim) + self.attn = swa + self.mlp = MLP(dim, mlp_mult, act=mlp_act) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in), v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out)) + return x_out + + +class SmearGate(nn.Module): + """Weighted average of current and previous token embeddings.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram/trigram embedding for additional context.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int, + trigram: bool = False): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self._trigram = trigram + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def trigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., :2] = mod + out[..., 2:] = (36313 * t[..., 2:] ^ 27191 * t[..., 1:-1] ^ 51497 * t[..., :-2]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self._trigram: + h = h + self.embed(self.trigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +def _parse_layout(layout_str: str) -> list[tuple[str, int]]: + """Parse a layout string into a sequence of (layer_type, count) pairs. + + Examples: + "gdn_only" -> [("gdn", 11)] (count filled in by caller) + "gdn5_swa_gdn5_swa_shared" -> [("gdn", 5), ("swa", 1), ("gdn", 5), ("swa_shared", 1)] + """ + if layout_str == "gdn_only": + return [("gdn", -1)] + if layout_str == "mamba_only": + return [("mamba", -1)] + + parts = layout_str.split("_") + result = [] + i = 0 + while i < len(parts): + part = parts[i] + if part.startswith("gdn") and len(part) > 3: + count = int(part[3:]) + result.append(("gdn", count)) + elif part.startswith("mamba") and len(part) > 5: + count = int(part[5:]) + result.append(("mamba", count)) + elif part == "swa": + if i + 1 < len(parts) and parts[i + 1] == "shared": + result.append(("swa_shared", 1)) + i += 1 + else: + result.append(("swa", 1)) + elif part == "shared": + pass + i += 1 + return result + + +class HybridGDN(nn.Module): + """Hybrid GDN architecture supporting mixed recurrent/attention layers. + + Builds a stack of blocks according to the layer_layout specification: + - "gdn" blocks use GatedDeltaNet (or GatedDeltaProduct) + - "mamba" blocks use Mamba-2 + - "swa" blocks use SlidingWindowAttention + - "swa_shared" reuses the same SWA module (Griffin/Zamba-style weight sharing) + + All models share: token embedding, bigram hash, smear gate, final norm, lm_head. + """ + def __init__(self, config: dict, vocab_size: int = 1024): + super().__init__() + dim = config["model_dim"] + num_heads = config["num_heads"] + mlp_mult = config["mlp_mult"] + self.arch_name = config["arch_name"] + self.model_dim = dim + self.vocab_size = vocab_size + self.logit_softcap = 30.0 + + # Embeddings + self.tok_emb = nn.Embedding(vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=0.005) + self.bigram = BigramHashEmbedding( + config.get("bigram_vocab_size", 2048), + config.get("bigram_dim", 128), + dim, + trigram=config.get("trigram", False), + ) + self.smear = SmearGate(dim) + + # Meta tokens (Hymba-style) + n_meta = config.get("meta_tokens", 0) + if n_meta > 0: + self.meta_tokens = nn.Parameter(torch.randn(1, n_meta, dim) * 0.02) + self.n_meta = n_meta + else: + self.meta_tokens = None + self.n_meta = 0 + + # Build layer stack + layout = _parse_layout(config["layer_layout"]) + self.blocks = nn.ModuleList() + self._block_types = [] + self._shared_swa = None + + layer_idx = 0 + for layer_type, count in layout: + if count == -1: + if layer_type == "gdn": + count = config["num_gdn_layers"] + elif layer_type == "mamba": + count = config["num_mamba_layers"] + + for _ in range(count): + if layer_type == "gdn": + recurrent = self._make_recurrent_layer(config, layer_idx) + block = RecurrentBlock(dim, recurrent, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("gdn") + + elif layer_type == "mamba": + mamba_expand = config.get("mamba_expand", 2) + mamba_head_dim = config.get("gdn_head_dim", 64) + mamba_num_heads = (dim * mamba_expand) // mamba_head_dim + mamba = Mamba2( + num_heads=mamba_num_heads, + head_dim=mamba_head_dim, + hidden_size=dim, + state_size=config.get("mamba_state_size", 64), + expand=mamba_expand, + layer_idx=layer_idx, + ) + block = RecurrentBlock(dim, mamba, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("mamba") + + elif layer_type in ("swa", "swa_shared"): + if layer_type == "swa_shared" and self._shared_swa is not None: + swa = self._shared_swa + else: + swa = SlidingWindowAttention( + dim=dim, + num_heads=num_heads, + num_kv_heads=config.get("swa_num_kv_heads", 4), + window_size=config.get("swa_window", 512), + qk_gain_init=config.get("qk_gain_init", 1.5), # Direction-5: 5.0 + ) + if config.get("swa_shared", False): + self._shared_swa = swa + + block = AttentionBlock(dim, swa, mlp_mult, layer_idx=layer_idx) + self.blocks.append(block) + self._block_types.append("swa" if layer_type == "swa" else "swa_shared") + + layer_idx += 1 + + self.final_norm = RMSNorm(dim) + self.lm_head = None # tied to tok_emb + self._init_weights() + + def _make_recurrent_layer(self, config: dict, layer_idx: int) -> nn.Module: + """Create the appropriate recurrent layer based on config.""" + dim = config["model_dim"] + num_heads = config["num_heads"] + + if config.get("use_rwkv7", False): + total_layers = config.get("num_gdn_layers", 11) + return RWKV7Attention( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + layer_idx=layer_idx, + num_hidden_layers=total_layers, + mode="chunk", + ) + elif config.get("use_deltaproduct", False): + return GatedDeltaProduct( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + num_householder=config.get("dp_num_householder", 2), + allow_neg_eigval=config.get("dp_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + else: + return GatedDeltaNet( + hidden_size=dim, + head_dim=config.get("gdn_head_dim", 64), + num_heads=num_heads, + allow_neg_eigval=config.get("gdn_allow_neg_eigval", False), + use_short_conv=config.get("gdn_use_short_conv", True), + expand_v=config.get("gdn_expand_v", 1), + layer_idx=layer_idx, + mode="chunk", + ) + + def _init_weights(self) -> None: + total_layers = len(self.blocks) + for name, p in self.named_parameters(): + if ".recurrent." in name: + continue + if p.ndim == 2 and "proj" in name and "bigram" not in name: + with torch.no_grad(): + p.mul_(1.0 / math.sqrt(2 * total_layers)) + + def set_xsa(self, enable: bool = True) -> None: + """Enable/disable XSA on all attention blocks.""" + for block, btype in zip(self.blocks, self._block_types): + if btype in ("swa", "swa_shared"): + block.attn.use_xsa = enable + + def _compute_logits(self, x: Tensor) -> Tensor: + """Compute logits with tied embeddings and softcap.""" + logits = F.linear(x, self.tok_emb.weight) + return self.logit_softcap * torch.tanh(logits / self.logit_softcap) + + def _run_blocks(self, x: Tensor, x0: Tensor) -> Tensor: + """Run all blocks on x with residual anchor x0.""" + for block in self.blocks: + x = block(x, x0) + return x + + def _embed(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Shared embedding + smear, returns (x, x0).""" + x = self.tok_emb(input_ids) + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + if self.meta_tokens is not None: + B = x.shape[0] + meta = self.meta_tokens.expand(B, -1, -1).to(dtype=x.dtype) + x = torch.cat([meta, x], dim=1) + x0 = torch.cat([meta, x0], dim=1) + return x, x0 + + def _strip_meta(self, x: Tensor) -> Tensor: + if self.meta_tokens is not None: + x = x[:, self.n_meta:] + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Forward pass returning cross-entropy loss.""" + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0) + x = self._strip_meta(x) + x = self.final_norm(x) + logits = self._compute_logits(x.reshape(-1, x.size(-1))) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Forward pass returning softcapped logits (for evaluation).""" + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0) + x = self._strip_meta(x) + x = self.final_norm(x) + return self._compute_logits(x) + + def forward_hidden(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass returning (hidden_states, softcapped_logits) for RLS eval. + + hidden_states: [B, T, dim] — final norm output before lm_head + softcapped_logits: [B, T, vocab] — softcap * tanh(linear / softcap) + + Compliance note: called in inference_mode during eval. + The hidden states are purely causal (each h[t] depends only on x[0:t]). + """ + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0) + x = self._strip_meta(x) + x = self.final_norm(x) + logits = self._compute_logits(x) + return x, logits + + def get_diagnostics(self) -> dict: + """Collect per-layer weight statistics for checkpoint diagnostics.""" + diag = {} + for i, (block, btype) in enumerate(zip(self.blocks, self._block_types)): + prefix = f"layer_{i}_{btype}" + for name, param in block.named_parameters(): + if param.ndim >= 2: + w = param.data.float() + diag[f"{prefix}/{name}/std"] = w.std().item() + diag[f"{prefix}/{name}/kurtosis"] = (((w - w.mean()) / (w.std() + 1e-8)) ** 4).mean().item() - 3.0 + return diag + + def count_params(self) -> dict: + """Count parameters by category.""" + cats = {"embedding": 0, "recurrent": 0, "attention": 0, "mlp": 0, "other": 0} + for name, p in self.named_parameters(): + n = p.numel() + if "tok_emb" in name or "bigram" in name: + cats["embedding"] += n + elif any(k in name for k in ["recurrent", "gdn", "mamba", "rwkv", "delta"]): + cats["recurrent"] += n + elif "attn" in name or "c_q" in name or "c_k" in name or "c_v" in name: + cats["attention"] += n + elif "mlp" in name or "fc" in name: + cats["mlp"] += n + else: + cats["other"] += n + cats["total"] = sum(cats.values()) + return cats diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/configs.py b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/configs.py new file mode 100644 index 0000000000..71a2b6c3cd --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/configs.py @@ -0,0 +1,137 @@ +"""Model architecture configurations for Direction-5 (GDN-Hybrid + RLS). + +Model D is our primary target: + GDN×5 → shared SWA → GDN×5 → shared SWA (Griffin-style) + dim=512, qk_gain_init=5.0, bigram 3072×112 + trigram + +Model J (Phase 4 if needed): + 12-layer GDN, dim=480, KV-sharing stride=2 (GDN-native depth recurrence) + +All models sized to fit ~16MB at int6+zstd-22. +""" +from __future__ import annotations + + +def model_a_pure_gdn() -> dict: + """Model A: Pure GDN (Baseline) — 10 layers Gated DeltaNet.""" + return dict( + arch_name="A_PureGDN", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + qk_gain_init=1.5, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + ) + + +def model_d_gdn_hybrid() -> dict: + """Model D: GDN + Shared SWA — our Direction-5 primary target. + + Architecture: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + This is Griffin-style: interleaved recurrent + local attention with weight sharing. + + Key changes vs PR #1370 Model D: + - qk_gain_init=5.0 (stronger initial attention, following competition SOTA) + - bigram_vocab_size=3072, bigram_dim=112, trigram=True (matches Model A best setting) + - swa_window=512 (standard; can increase for Phase 4) + """ + return dict( + arch_name="D_GDN_Hybrid", + num_gdn_layers=10, + num_mamba_layers=0, + num_swa_layers=1, + swa_shared=True, + model_dim=512, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=64, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + qk_gain_init=5.0, # Direction-5: strong initial attention gain + meta_tokens=0, + # Layout: [GDN×5] → [SWA] → [GDN×5] → [SWA_shared] + layer_layout="gdn5_swa_gdn5_swa_shared", + bigram_vocab_size=3072, # Direction-5: full bigram table (vs 2048 default) + bigram_dim=112, # Direction-5: matches Model A best setting + trigram=True, # Direction-5: add trigram for extra context + ) + + +def model_d_smoke() -> dict: + """Model D Smoke: Same architecture, smaller for CPU sanity checks.""" + cfg = model_d_gdn_hybrid() + cfg["arch_name"] = "D_Smoke" + cfg["model_dim"] = 128 + cfg["num_heads"] = 4 + cfg["gdn_head_dim"] = 32 + cfg["swa_num_kv_heads"] = 2 + cfg["bigram_vocab_size"] = 512 + cfg["bigram_dim"] = 64 + return cfg + + +def model_j_kv_share() -> dict: + """Model J: 12-layer GDN + KV-sharing (Phase 4 depth ablation). + + GDN-native equivalent of transformer depth recurrence. + Adjacent GDN layer pairs share K/V projections, freeing ~528K params per pair. + Those params → more layers (12 vs 10) at narrower dim (480 vs 512). + + Note: KV-sharing in GDN requires architectures.py to support it. + This config is defined here for planning; implement in Phase 4 if needed. + """ + return dict( + arch_name="J_GDN_KVShare", + num_gdn_layers=12, + num_mamba_layers=0, + num_swa_layers=0, + swa_shared=False, + model_dim=480, + num_heads=8, + mlp_mult=3.0, + gdn_expand_v=1, + gdn_head_dim=60, + gdn_allow_neg_eigval=False, + gdn_use_short_conv=True, + swa_window=512, + swa_num_kv_heads=4, + qk_gain_init=1.5, + meta_tokens=0, + layer_layout="gdn_only", + bigram_vocab_size=3072, + bigram_dim=112, + trigram=True, + kv_share_stride=2, # share K/V every 2 GDN layers (not yet implemented) + ) + + +ALL_CONFIGS = { + "A": model_a_pure_gdn, + "D": model_d_gdn_hybrid, + "D_smoke": model_d_smoke, + "J": model_j_kv_share, +} + + +def get_config(model_id: str) -> dict: + """Get config by model ID (A, D, D_smoke, J).""" + if model_id not in ALL_CONFIGS: + raise ValueError(f"Unknown model ID '{model_id}'. Choose from {list(ALL_CONFIGS.keys())}") + return ALL_CONFIGS[model_id]() diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/requirements.txt b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/requirements.txt new file mode 100644 index 0000000000..9068ce0f70 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/requirements.txt @@ -0,0 +1,5 @@ +flash-linear-attention +zstandard +sentencepiece +# flash_attn_3 is pre-installed in runpod/parameter-golf:latest +# torch, numpy, and other standard deps are pre-installed in the image diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/submission.json b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/submission.json new file mode 100644 index 0000000000..f1dc3f6711 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/submission.json @@ -0,0 +1,45 @@ +{ + "author": "Joshua Martinez", + "github_id": "joshkmartinez", + "name": "GDN-Hybrid: Gated DeltaNet + Sliding Window Attention", + "blurb": "Joshua-owned SAFE_SUBMISSION reproduction of PR #1545. Pulled TensorPool artifacts for run037-safe017 / j-jvquftkrwd show a 3-seed mean quantized_bpb of 1.02045733 (std 0.00553017) with all artifact sizes under 16,000,000 bytes.", + "date": "2026-04-11T17:34:30Z", + "baseline_pr": 1545, + "lane": "SAFE_SUBMISSION", + "evaluation": "quantized_bpb", + "val_bpb": 1.02045733, + "val_bpb_std": 0.00553017, + "best_seed_bpb": 1.016586, + "seed": 1337, + "three_seed_mean_bpb": 1.02045733, + "three_seed_std_bpb": 0.00553017, + "artifact_bytes_max": 15830308, + "artifact_bytes_mean": 15654831.0, + "seed_results": [ + { + "seed": "42", + "ema_bpb": 1.017723, + "quantized_bpb": 1.026791, + "xsa_bpb": 1.031731, + "artifact_bytes": 15313984, + "steps": 1864 + }, + { + "seed": "1337", + "ema_bpb": 1.007375, + "quantized_bpb": 1.016586, + "xsa_bpb": 1.020691, + "artifact_bytes": 15830308, + "steps": 2239 + }, + { + "seed": "2024", + "ema_bpb": 1.008736, + "quantized_bpb": 1.017995, + "xsa_bpb": 1.023138, + "artifact_bytes": 15820201, + "steps": 2241 + } + ], + "artifact_authority": "state/tp-pulls/run037-safe017/artifacts/train_seed42.log,state/tp-pulls/run037-safe017/artifacts/train_seed1337.log,state/tp-pulls/run037-safe017/artifacts/train_seed2024.log" +} diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_gpt.py b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_gpt.py new file mode 100644 index 0000000000..7e01fee198 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_gpt.py @@ -0,0 +1,1141 @@ +#!/usr/bin/env python3 +"""Direction-5 Training Script — GDN Hybrid, wallclock-limited. + +Trains the GDN-Hybrid backbone (Model D: GDN×5 → SWA → GDN×5 → SWA_shared) +within the competition's 10-minute training budget on 8×H100 SXM. + +Key differences from PR #1370 train_gdn_7k.py: + - TRAIN_SEQ_LEN=2048 (longer context forces better GDN recurrence) + - MAX_WALLCLOCK_SECONDS=590 (10-min budget minus 10s safety margin) + - ITERATIONS=9999 (wallclock is the real limit) + - WARMDOWN_ITERS=3000 (30% of expected ~9000 steps) + - MuonEq-R: row-normalize before Newton-Schulz for better equivariance + - ARCH_MODE=D (Model D GDN Hybrid) + - No TTT in post-training eval (use eval_rls.py separately) + +Environment variables: + ARCH_MODE: Model config key (default: D) + TRAIN_SEQ_LEN: Training context length (default: 2048) + MAX_WALLCLOCK_SECONDS: Hard stop time (default: 590) + ITERATIONS: Max steps (default: 9999, wallclock-limited) + WARMDOWN_ITERS: Steps in LR warmdown (default: 3000) + QK_GAIN_INIT: SWA Q-gain init override (default: use config value) + SEED: Random seed (default: 42) + DATA_PATH: Dataset directory + CKPT_DIR: Checkpoint output directory (default: checkpoints) +""" +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch._dynamo +import torch.distributed as dist +import torch.nn.functional as F +import zstandard +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Safety guard: if dynamo is ever invoked on code paths containing GDN layers +# (e.g. FLA internal usage), each unique `layer_idx` integer attribute would be +# treated as a static guard and trigger a separate recompilation. The default +# limit=8 would cause layers 8-9 to permanently fall back to eager mode. +# We no longer call torch.compile on the eval forward (see evaluate_sliding_window), +# so this guard is mainly defensive. 64 > 10 GDN layers, so it's always safe. +torch._dynamo.config.recompile_limit = 64 + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from architectures import HybridGDN, CastedLinear +from configs import get_config + + +# ─── Hyperparameters ────────────────────────────────────────────────────────── + +class Hyperparameters: + arch_mode = os.environ.get("ARCH_MODE", "D") + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + + # Training length — wallclock-limited + iterations = int(os.environ.get("ITERATIONS", 9999)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) # Direction-5: 2048 + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 590.0)) # 9m50s + + # Validation + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + save_every = int(os.environ.get("SAVE_EVERY", 1000)) + + # Optimizer + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + + # Eval + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + xsa_eval = bool(int(os.environ.get("XSA_EVAL", "0"))) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Checkpoint + ckpt_dir = os.environ.get("CKPT_DIR", "checkpoints") + + # Compile + compile_enabled = bool(int(os.environ.get("COMPILE_ENABLED", "1"))) + + # Resume + resume_ckpt = os.environ.get("RESUME_CKPT", "") + + # EMA / SWA + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + + # Chained job support + auto_save_seconds = float(os.environ.get("AUTO_SAVE_SECONDS", "0")) + total_iterations = int(os.environ.get("TOTAL_ITERATIONS", "0")) + + # Direction-5: QK gain override (set in config; this overrides config value if set) + qk_gain_init_override = os.environ.get("QK_GAIN_INIT", "") + + +# ─── Data Loading ───────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + header = np.fromfile(file, dtype=np.uint32, count=256) + assert header[0] == 20240520, f"Bad magic: {header[0]}" + assert header[1] in (1, 7), f"Bad version: {header[1]}" + ntok = int(header[2]) + return torch.from_numpy(np.fromfile(file, dtype=np.uint16, offset=256 * 4)[:ntok].astype(np.int64)) + + +class TokenStream: + """Reads shards sequentially, supports coprime ordering via SHARD_ORDER_FILE.""" + def __init__(self, pattern: str): + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if shard_order_file and os.path.exists(shard_order_file): + with open(shard_order_file) as f: + self.files = [Path(line.strip()) for line in f if line.strip()] + else: + self.files = [Path(p) for p in sorted(glob.glob(pattern))] + assert self.files, f"No files matching {pattern}" + self.idx = 0 + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def _advance_file(self) -> None: + self.idx = (self.idx + 1) % len(self.files) + self.buf = load_data_shard(self.files[self.idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + parts = [] + remaining = n + while remaining > 0: + avail = self.buf.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + take_n = min(avail, remaining) + parts.append(self.buf[self.pos:self.pos + take_n]) + self.pos += take_n + remaining -= take_n + return torch.cat(parts) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.stream = TokenStream(pattern) + self.rank = rank + self.world_size = world_size + self.device = device + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + tokens_per_rank = global_tokens // self.world_size + seqs_per_rank = tokens_per_rank // seq_len + total_seqs = seqs_per_rank * self.world_size + total_needed = total_seqs * seq_len + 1 + all_tokens = self.stream.take(total_needed) + start = self.rank * seqs_per_rank * seq_len + chunk = all_tokens[start:start + seqs_per_rank * seq_len + 1] + x = chunk[:-1].reshape(seqs_per_rank, seq_len) + y = chunk[1:].reshape(seqs_per_rank, seq_len) + return x.to(self.device), y.to(self.device) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = sorted(glob.glob(pattern)) + parts = [load_data_shard(Path(f)) for f in files] + combined = torch.cat(parts) + return combined[:((combined.numel() - 1) // seq_len) * seq_len + 1] + + +def build_sentencepiece_luts(sp, vocab_size, device): + base_bytes = torch.zeros(vocab_size, dtype=torch.float32, device=device) + has_space = torch.zeros(vocab_size, dtype=torch.bool, device=device) + is_boundary = torch.zeros(vocab_size, dtype=torch.bool, device=device) + for i in range(vocab_size): + piece = sp.id_to_piece(i) + raw = piece.encode("utf-8") + base_bytes[i] = len(raw) + if piece.startswith("\u2581"): + has_space[i] = True + base_bytes[i] = len(piece[1:].encode("utf-8")) + 1 + if sp.is_control(i) or sp.is_unknown(i): + is_boundary[i] = True + return base_bytes, has_space, is_boundary + + +def generate_coprime_shard_order(shard_files: list, seed: int = 42) -> list: + n = len(shard_files) + if n <= 1: + return shard_files + target = max(1, int(n / 1.618)) + stride = target + while math.gcd(stride, n) != 1: + stride += 1 + rng = random.Random(seed) + start = rng.randint(0, n - 1) + order = [] + pos = start + for _ in range(n): + order.append(shard_files[pos]) + pos = (pos + stride) % n + return order + + +# ─── Muon Optimizer (MuonEq-R) ─────────────────────────────────────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + """Newton-Schulz 5th-order iteration with MuonEq-R row normalization. + + MuonEq-R: row-normalize each gradient row before the Frobenius normalization. + This makes the update equivariant to row-wise rescaling (~0.001 BPB gain + observed in transformer competition experiments). + """ + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + # MuonEq-R: row-normalize before NS + if X.ndim == 2: + row_norms = X.norm(dim=1, keepdim=True).clamp_min(eps) + X = X / row_norms + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transposed: + X = X.T + return X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + super().__init__(params, defaults) + + def step(self, closure=None): + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + for p in group["params"]: + if p.grad is None: + continue + g = p.grad + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g + momentum * buf + else: + g = buf + if g.ndim == 2 and min(g.shape) >= 2: + g = zeropower_via_newtonschulz5(g, steps=group["backend_steps"]) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.data.add_(g, alpha=-lr) + + +# ─── Evaluation ────────────────────────────────────────────────────────────── + +def eval_val_sliding( + model: nn.Module, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + rank: int, + world_size: int, + device: torch.device, + seq_len: int = 2048, + stride: int = 64, + batch_seqs: int = 128, + xsa_eval: bool = False, +) -> tuple[float, float]: + """Standard sliding window evaluation.""" + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + base_model = model.module if hasattr(model, 'module') else model + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(True) + + # Do NOT torch.compile here. FLA's GatedDeltaNet has integer `layer_idx` + # attributes; dynamo treats each as a unique static guard and recompiles once + # per layer (10 layers = 10 compilations). On a warm Triton cache this is + # ~3s total. On a cold cache (fresh pod) it is ~107s — eating 18% of the + # 590s budget and causing ~314 fewer training steps. FLA's Triton kernels + # are already hand-optimized; there is nothing for dynamo to gain here. + compiled_logits = base_model.forward_logits + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + + if xsa_eval and hasattr(base_model, 'set_xsa'): + base_model.set_xsa(False) + + model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ─── Quantization ──────────────────────────────────────────────────────────── + +CONTROL_PATTERNS = ( + "resid_mix", "q_gain", "smear", "skip_weight", "attn_scale", "mlp_scale", +) + + +def generate_autoregressive_calib(model, device, num_seqs=64, seq_len=2048, + vocab_size=1024, temperature=0.8, batch_size=8, seed=42): + # RoPE bug workaround: apply_rotary_emb uses x.shape[-2] (= num_heads=8) to slice cos. + # When T < num_heads, cos[:num_heads] clips to [T, D//2] which fails to broadcast with + # the head dimension. Fix: start with init_len >= num_heads+1 tokens to skip T in [2,8). + init_len = 16 + model.eval() + rng = torch.Generator(device=device) + rng.manual_seed(seed) + all_tokens = [] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for batch_start in range(0, num_seqs, batch_size): + bs = min(batch_size, num_seqs - batch_start) + tokens = torch.randint(0, vocab_size, (bs, init_len), device=device, generator=rng) + for pos in range(seq_len - init_len): + logits = model.forward_logits(tokens) + next_logit = logits[:, -1, :] + probs = torch.softmax(next_logit / temperature, dim=-1) + next_tok = torch.multinomial(probs, 1, generator=rng) + tokens = torch.cat([tokens, next_tok], dim=1) + for i in range(bs): + all_tokens.append(tokens[i:i+1]) + return all_tokens + + +def collect_hessians_from_tokens(hessian_model, token_seqs, device): + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device='cpu') + def make_hook(pname): + def hook_fn(module, input, output): + x = input[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).cpu() + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for seq in token_seqs: + x = seq[:, :-1].to(device) + y = seq[:, 1:].to(device) + hessian_model(x, y) + for h in hooks: + h.remove() + num_batches = len(token_seqs) + for name in hessians: + hessians[name] /= num_batches + return hessians + + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128): + t32 = weight.float() + if t32.ndim != 2 or hessian is None: + return quantize_int6_per_row(t32) + rows, cols = t32.shape + H = hessian.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = 0.01 * torch.mean(torch.diag(H)) + H[torch.arange(cols), torch.arange(cols)] += damp + perm = torch.argsort(torch.diag(H), descending=True) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8) + Err1 = torch.zeros(rows, count) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), -clip_range, clip_range).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q, best_scale + + +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def quantize_int8_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + clip_q = 0.9999984 + if t32.ndim == 2: + clip_abs = torch.quantile(t32.abs(), clip_q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0).to(torch.float16) + q = torch.clamp(torch.round(clipped / scale.float()[:, None]), -127, 127).to(torch.int8) + return q, scale + clip_abs = float(torch.quantile(t32.abs().flatten(), clip_q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale.float()), -127, 127).to(torch.int8) + return q, scale + + +def mixed_quantize(state_dict: dict[str, Tensor], hessians: dict[str, Tensor] | None = None) -> tuple[dict[str, Tensor], dict[str, object]]: + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + if any(p in name for p in CONTROL_PATTERNS): + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if not t.is_floating_point(): + result[name] = t + meta[name] = "passthrough" + continue + if t.numel() <= 65536: + result[name] = t.to(torch.float16) + meta[name] = "passthrough" + continue + if t.ndim == 2 and t.numel() > 65536: + H = hessians.get(name) if hessians else None + q, s = quantize_int6_gptq(t, hessian=H) if H is not None else quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + + +def dequantize_mixed(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info == "passthrough": + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ─── Checkpoint Saving ─────────────────────────────────────────────────────── + +def save_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": base.state_dict(), + } + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"{arch_name}_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def save_full_checkpoint(model, step, val_bpb, ckpt_dir, arch_name, seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + qat_enabled, rng_states=None, stream_state=None): + base = model.module if hasattr(model, 'module') else model + ckpt = { + "step": step, "val_bpb": val_bpb, + "arch_name": arch_name, "seed": seed, + "model_state_dict": {k: v.cpu() for k, v in base.state_dict().items()}, + "muon_opt_state": muon_opt.state_dict(), + "adam_opt_state": adam_opt.state_dict(), + "ema_state": {k: v.cpu() for k, v in ema_state.items()}, + "swa_state": {k: v.cpu() for k, v in swa_state.items()} if swa_state is not None else None, + "swa_count": swa_count, + "qat_enabled": qat_enabled, + } + if rng_states is not None: + ckpt["rng_states"] = rng_states + if stream_state is not None: + ckpt["stream_state"] = stream_state + os.makedirs(ckpt_dir, exist_ok=True) + path = os.path.join(ckpt_dir, f"full_ckpt_step{step}_seed{seed}.pt") + torch.save(ckpt, path) + return path + + +def _find_latest_full_ckpt(ckpt_dir): + import re + pattern = os.path.join(ckpt_dir, "full_ckpt_step*_seed*.pt") + files = glob.glob(pattern) + if not files: + return None + step_re = re.compile(r"full_ckpt_step(\d+)_seed") + best_step, best_path = -1, None + for f in files: + m = step_re.search(os.path.basename(f)) + if m: + s = int(m.group(1)) + if s > best_step: + best_step, best_path = s, f + return best_path + + +# ─── Main Training Loop ───────────────────────────────────────────────────── + +def main(): + global zeropower_via_newtonschulz5 + args = Hyperparameters() + config = get_config(args.arch_mode) + + # Direction-5: optional QK_GAIN_INIT override from env + if args.qk_gain_init_override: + config["qk_gain_init"] = float(args.qk_gain_init_override) + + # Distributed setup + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = max(1, 8 // world_size) + master_process = rank == 0 + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if args.compile_enabled: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master_process else None + + def log0(msg: str, console: bool = True): + if not master_process: + return + if console: + print(msg, flush=True) + if logfile: + with open(logfile, "a") as f: + print(msg, file=f) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + log0(f"=== Direction-5: GDN Hybrid Training ===") + log0(f"Arch: {config['arch_name']} (ARCH_MODE={args.arch_mode})") + log0(f"Seed: {args.seed}, Max steps: {args.iterations}, Warmdown: {args.warmdown_iters}") + log0(f"Train seq_len: {args.train_seq_len}, Wallclock budget: {args.max_wallclock_seconds}s") + log0(f"QK_GAIN_INIT: {config.get('qk_gain_init', 1.5)}") + log0(f"World size: {world_size}, Grad accum: {grad_accum_steps}") + log0(f"EMA decay: {args.ema_decay}, SWA: {args.swa_enabled} (every {args.swa_every})") + log0(f"Late QAT threshold: {args.late_qat_threshold}") + log0(f"MuonEq-R: enabled") + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + assert int(sp.vocab_size()) == args.vocab_size + + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"Validation tokens: {val_tokens.numel()-1:,}") + + _t0 = time.time() + model = HybridGDN(config, args.vocab_size) + model = model.to(device).bfloat16() + log0(f"Model built in {time.time()-_t0:.1f}s") + + for module in model.modules(): + if isinstance(module, CastedLinear): + module.float() + for name, p in model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + + param_counts = model.count_params() + log0(f"Parameters: {param_counts}") + log0(f"Total params: {param_counts['total']:,}") + + start_step = 0 + resume_state = None + resume_ckpt_path = args.resume_ckpt + if resume_ckpt_path == "auto": + resume_ckpt_path = _find_latest_full_ckpt(args.ckpt_dir) or "" + if resume_ckpt_path: + log0(f"Auto-detected resume checkpoint: {resume_ckpt_path}") + else: + log0("Auto-resume: no full checkpoint found, starting fresh") + if resume_ckpt_path and os.path.exists(resume_ckpt_path): + log0(f"Resuming from checkpoint: {resume_ckpt_path}") + ckpt = torch.load(resume_ckpt_path, map_location="cpu", weights_only=False) + base_sd = ckpt["model_state_dict"] + model.load_state_dict({k: v.to(device) for k, v in base_sd.items()}, strict=True) + start_step = ckpt.get("step", 0) + log0(f"Resumed model at step {start_step}, val_bpb={ckpt.get('val_bpb', 'N/A')}") + if "muon_opt_state" in ckpt: + resume_state = ckpt + log0(" Full checkpoint detected — will restore optimizers, EMA, SWA, RNG") + else: + log0(" Lightweight checkpoint — model only") + del ckpt + + base_model = model + if distributed: + model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) + + matrix_params = [] + scalar_params = [] + embed_params = [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "tok_emb" in name: + embed_params.append(p) + elif p.ndim == 2 and min(p.shape) >= 2: + matrix_params.append(p) + else: + scalar_params.append(p) + + log0(f"Matrix params: {sum(p.numel() for p in matrix_params):,}") + log0(f"Scalar params: {sum(p.numel() for p in scalar_params):,}") + log0(f"Embed params: {sum(p.numel() for p in embed_params):,}") + + muon_opt = Muon( + matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + adam_opt = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr}, + {"params": embed_params, "lr": args.tied_embed_lr}], + betas=(args.beta1, args.beta2), + weight_decay=args.adam_wd, + fused=True, + ) + + if resume_state is not None: + muon_opt.load_state_dict(resume_state["muon_opt_state"]) + adam_opt.load_state_dict(resume_state["adam_opt_state"]) + log0(" Restored optimizer states (Muon + Adam)") + + shard_order_file = os.environ.get("SHARD_ORDER_FILE", "") + if not shard_order_file: + shard_files = sorted(glob.glob(args.train_files)) + if shard_files: + ordered = generate_coprime_shard_order(shard_files, seed=args.seed) + # Use rank-specific path to avoid concurrent write race across 8 processes + shard_order_path = f"/tmp/shard_order_{args.run_id}_rank{rank}.txt" + with open(shard_order_path, "w") as f: + for sf in ordered: + f.write(str(sf) + "\n") + os.environ["SHARD_ORDER_FILE"] = shard_order_path + log0(f"Generated coprime shard order: stride across {len(shard_files)} shards") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def lr_schedule(step: int) -> float: + warmdown_start = args.iterations - args.warmdown_iters + if step < args.warmup_steps: + return step / max(1, args.warmup_steps) + elif step >= warmdown_start: + progress = (step - warmdown_start) / args.warmdown_iters + return 0.5 * (1.0 + math.cos(math.pi * progress)) + return 1.0 + + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + if resume_state is not None: + saved_ema = resume_state.get("ema_state") + if saved_ema is not None: + ema_state = {k: v.to(device).float() for k, v in saved_ema.items()} + log0(" Restored EMA state") + saved_swa = resume_state.get("swa_state") + if saved_swa is not None: + swa_state = {k: v.cpu() for k, v in saved_swa.items()} + swa_count = resume_state.get("swa_count", 0) + log0(f" Restored SWA state (count={swa_count})") + else: + swa_count = resume_state.get("swa_count", 0) + if resume_state.get("qat_enabled", False): + CastedLinear._qat_enabled = True + log0(" Restored QAT enabled state") + saved_rng = resume_state.get("rng_states") + if saved_rng is not None: + torch.set_rng_state(saved_rng["torch_cpu"]) + torch.cuda.set_rng_state(saved_rng["torch_cuda"]) + np.random.set_state(saved_rng["numpy"]) + random.setstate(saved_rng["python"]) + log0(" Restored RNG states") + saved_stream = resume_state.get("stream_state") + if saved_stream is not None: + s_idx, s_pos = saved_stream + stream = train_loader.stream + while stream.idx != s_idx: + stream._advance_file() + stream.pos = s_pos + log0(f" Restored stream state (shard={s_idx}, pos={s_pos})") + else: + if start_step > 0: + log0(f" Fast-forwarding data loader by {start_step} steps...") + for _ in range(start_step): + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + log0(f" Data loader advanced to step {start_step}") + del resume_state + log0(" Full checkpoint restore complete") + + # ─── Training Loop ─────────────────────────────────────────────────── + stale_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(stale_marker): + os.remove(stale_marker) + + log0(f"\n{'='*80}") + log0(f"Starting training: max {args.iterations} steps (from step {start_step})") + log0(f"Wallclock budget: {args.max_wallclock_seconds}s") + log0(f"{'='*80}\n") + + t0 = time.time() + running_loss = 0.0 + loss_count = 0 + stop_after_step = None + step = start_step + + for step in range(start_step + 1, args.iterations + 1): + if stop_after_step is not None and step > stop_after_step: + log0(f"Stopping early at step {step} (wallclock limit)") + break + + lr_mul = lr_schedule(step) + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + current_muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in muon_opt.param_groups: + group["lr"] = args.matrix_lr * lr_mul + group["momentum"] = current_muon_momentum + for i, pg in enumerate(adam_opt.param_groups): + if i == 0: + pg["lr"] = args.scalar_lr * lr_mul + else: + pg["lr"] = args.tied_embed_lr * lr_mul + + warmdown_start = args.iterations - args.warmdown_iters + if (args.late_qat_threshold > 0 and step >= warmdown_start + and lr_mul < args.late_qat_threshold and not CastedLinear._qat_enabled): + CastedLinear._qat_enabled = True + log0(f"Late QAT enabled at step {step} (lr_mul={lr_mul:.4f})") + + model.train() + total_loss = 0.0 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + micro_batch = x.shape[0] // grad_accum_steps + for micro_step in range(grad_accum_steps): + x_micro = x[micro_step * micro_batch:(micro_step + 1) * micro_batch] + y_micro = y[micro_step * micro_batch:(micro_step + 1) * micro_batch] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x_micro, y_micro) + loss = loss / grad_accum_steps + loss.backward() + total_loss += loss.item() + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm) + + muon_opt.step() + adam_opt.step() + muon_opt.zero_grad(set_to_none=True) + adam_opt.zero_grad(set_to_none=True) + + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(args.ema_decay).add_(t.detach().float(), alpha=1.0 - args.ema_decay) + + if args.swa_enabled and lr_mul < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"SWA started at step {step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + running_loss += total_loss + loss_count += 1 + + if step % args.train_log_every == 0 or step <= 10: + avg_loss = running_loss / max(loss_count, 1) + elapsed = time.time() - t0 + steps_per_sec = step / elapsed + log0(f"step {step:5d}/{args.iterations} | loss {avg_loss:.4f} | lr_mul {lr_mul:.4f} | " + f"mom {current_muon_momentum:.3f} | {steps_per_sec:.2f} steps/s | {elapsed:.0f}s") + running_loss = 0.0 + loss_count = 0 + + if step % args.val_loss_every == 0 or step == args.iterations: + val_loss, val_bpb = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=args.xsa_eval, + ) + log0(f"step {step:5d} | val_loss {val_loss:.4f} | val_bpb {val_bpb:.4f}") + + if master_process and args.save_every > 0 and (step % args.save_every == 0 or step == args.iterations): + ckpt_path = save_checkpoint( + model, step, val_bpb, args.ckpt_dir, config["arch_name"], args.seed, + ) + log0(f" Saved: {ckpt_path}") + + if args.max_wallclock_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.max_wallclock_seconds and stop_after_step is None: + stop_after_step = step + log0(f"Wallclock limit reached ({elapsed:.0f}s), will stop after this step") + + if args.auto_save_seconds > 0: + elapsed = time.time() - t0 + if elapsed > args.auto_save_seconds: + log0(f"Auto-save triggered at step {step} ({elapsed:.0f}s elapsed)") + if master_process: + rng_states = { + "torch_cpu": torch.get_rng_state(), + "torch_cuda": torch.cuda.get_rng_state(), + "numpy": np.random.get_state(), + "python": random.getstate(), + } + stream = train_loader.stream + stream_state = (stream.idx, stream.pos) + ckpt_path = save_full_checkpoint( + model, step, 0.0, args.ckpt_dir, config["arch_name"], args.seed, + muon_opt, adam_opt, ema_state, swa_state, swa_count, + CastedLinear._qat_enabled, + rng_states=rng_states, stream_state=stream_state, + ) + marker_path = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + with open(marker_path, "w") as f: + f.write(ckpt_path + "\n") + log0(f" Full checkpoint saved: {ckpt_path}") + break + + # ─── Check if exited due to auto-save ──────────────────────────────── + chain_marker = os.path.join(args.ckpt_dir, f"CHAIN_RESUME_FROM_seed{args.seed}") + if os.path.exists(chain_marker): + log0("\nExiting for chained job resume (skipping post-training)") + if distributed: + dist.destroy_process_group() + return + + effective_total = args.total_iterations if args.total_iterations > 0 else args.iterations + if master_process and step >= effective_total: + complete_marker = os.path.join(args.ckpt_dir, f"TRAINING_COMPLETE_seed{args.seed}") + with open(complete_marker, "w") as f: + f.write(f"step={step}\n") + + # ─── Post-Training: Apply EMA ──────────────────────────────────────── + elapsed_total = time.time() - t0 + log0(f"\nTraining complete in {elapsed_total:.0f}s ({step} steps)") + log0(f"Peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + log0(f"Steps/sec: {step / elapsed_total:.2f}") + + log0("\n=== Applying EMA weights ===") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + if swa_state is not None and swa_count > 0: + log0(f"SWA: averaging {swa_count} checkpoints with EMA") + swa_avg = {k: v / swa_count for k, v in swa_state.items()} + for name in avg_state: + if name in swa_avg: + dtype = avg_state[name].dtype + avg_state[name] = (0.5 * avg_state[name].float() + 0.5 * swa_avg[name].float()).to(dtype) + + base_model.load_state_dict(avg_state, strict=True) + + val_loss_ema, val_bpb_ema = eval_val_sliding( + model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + ) + log0(f"EMA BPB (no XSA): {val_bpb_ema:.6f}") + + if master_process: + torch.save(base_model.state_dict(), os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.pt")) + log0("Saved raw EMA model") + + # ─── GPTQ Calibration (optional) ───────────────────────────────────── + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + hessians = None + if gptq_enabled: + log0("\n=== GPTQ: generating autoregressive calibration data ===") + calib_seqs = generate_autoregressive_calib( + base_model, device, num_seqs=64, seq_len=args.train_seq_len, + vocab_size=args.vocab_size, temperature=0.8, batch_size=8, seed=args.seed, + ) + log0(f"GPTQ: generated {len(calib_seqs)} sequences, collecting hessians...") + hessians = collect_hessians_from_tokens(base_model, calib_seqs, device) + log0(f"GPTQ: collected hessians for {len(hessians)} layers") + + # ─── Quantization + Artifact Creation ──────────────────────────────── + log0("\n=== Quantizing to int6 + zstd-22 ===") + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, hessians=hessians) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + + artifact_path = os.path.join(args.ckpt_dir, f"final_model_{config['arch_name']}_seed{args.seed}.int6.ptz") + if master_process: + with open(artifact_path, "wb") as f: + f.write(quant_blob) + artifact_bytes = len(quant_blob) + log0(f"Artifact: {artifact_bytes:,} bytes ({artifact_bytes / 1024 / 1024:.2f} MB)") + if artifact_bytes > 16 * 1024 * 1024: + log0(f"WARNING: Artifact exceeds 16MB budget by {(artifact_bytes - 16*1024*1024) / 1024:.1f} KB") + + # ─── Roundtrip Validation ──────────────────────────────────────────── + log0("\n=== Roundtrip Validation (quantized model) ===") + if distributed: + dist.barrier() + + with open(artifact_path, "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = HybridGDN(config, args.vocab_size).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + for name, p in eval_model.named_parameters(): + if p.ndim <= 1: + p.data = p.data.float() + eval_model.load_state_dict(deq_state, strict=True) + + val_loss_q, val_bpb_q = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=False, + ) + log0(f"Quantized BPB (no XSA): {val_bpb_q:.6f}") + log0(f"Quantization degradation: {val_bpb_q - val_bpb_ema:+.6f}") + + block_types = eval_model._block_types + if any(bt in ("swa", "swa_shared") for bt in block_types): + val_loss_qx, val_bpb_qx = eval_val_sliding( + eval_model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + rank, world_size, device, + seq_len=args.eval_seq_len, stride=args.eval_stride, + xsa_eval=True, + ) + log0(f"Quantized BPB (XSA-all): {val_bpb_qx:.6f}") + + log0(f"\n{'='*80}") + log0(f"FINAL RESULTS — {config['arch_name']} seed={args.seed}") + log0(f" Training: {step} steps, {elapsed_total:.0f}s") + log0(f" EMA BPB: {val_bpb_ema:.6f}") + log0(f" Quantized BPB: {val_bpb_q:.6f}") + if any(bt in ("swa", "swa_shared") for bt in block_types): + log0(f" XSA BPB: {val_bpb_qx:.6f}") + log0(f" Artifact: {artifact_path}") + log0(f"{'='*80}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed1337.log b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed1337.log new file mode 100644 index 0000000000..17b78403a1 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed1337.log @@ -0,0 +1,90 @@ +W0411 17:07:52.727000 707551 torch/distributed/run.py:851] +W0411 17:07:52.727000 707551 torch/distributed/run.py:851] ***************************************** +W0411 17:07:52.727000 707551 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0411 17:07:52.727000 707551 torch/distributed/run.py:851] ***************************************** +=== Direction-5: GDN Hybrid Training === +Arch: D_GDN_Hybrid (ARCH_MODE=D) +Seed: 1337, Max steps: 9999, Warmdown: 3000 +Train seq_len: 2048, Wallclock budget: 590.0s +QK_GAIN_INIT: 5.0 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +MuonEq-R: enabled +Validation tokens: 62,021,632 +Model built in 0.2s +Parameters: {'embedding': 925697, 'recurrent': 13251360, 'attention': 792584, 'mlp': 18880512, 'other': 12800, 'total': 33862953} +Total params: 33,862,953 +Matrix params: 33,263,616 +Scalar params: 75,049 +Embed params: 524,288 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: max 9999 steps (from step 0) +Wallclock budget: 590.0s +================================================================================ + +step 1/9999 | loss 6.9316 | lr_mul 0.0500 | mom 0.850 | 0.36 steps/s | 3s +step 2/9999 | loss 6.7037 | lr_mul 0.1000 | mom 0.850 | 0.66 steps/s | 3s +step 3/9999 | loss 6.1494 | lr_mul 0.1500 | mom 0.851 | 0.91 steps/s | 3s +step 4/9999 | loss 5.8673 | lr_mul 0.2000 | mom 0.851 | 1.12 steps/s | 4s +step 5/9999 | loss 5.8454 | lr_mul 0.2500 | mom 0.851 | 1.30 steps/s | 4s +step 6/9999 | loss 5.7918 | lr_mul 0.3000 | mom 0.851 | 1.46 steps/s | 4s +step 7/9999 | loss 5.7141 | lr_mul 0.3500 | mom 0.851 | 1.60 steps/s | 4s +step 8/9999 | loss 5.7341 | lr_mul 0.4000 | mom 0.852 | 1.73 steps/s | 5s +step 9/9999 | loss 5.6277 | lr_mul 0.4500 | mom 0.852 | 1.84 steps/s | 5s +step 10/9999 | loss 5.5282 | lr_mul 0.5000 | mom 0.852 | 1.94 steps/s | 5s +step 100/9999 | loss 3.6789 | lr_mul 1.0000 | mom 0.870 | 3.48 steps/s | 29s +step 200/9999 | loss 2.7071 | lr_mul 1.0000 | mom 0.890 | 3.61 steps/s | 55s +step 300/9999 | loss 2.4812 | lr_mul 1.0000 | mom 0.910 | 3.68 steps/s | 82s +step 400/9999 | loss 2.3836 | lr_mul 1.0000 | mom 0.930 | 3.71 steps/s | 108s +step 500/9999 | loss 2.3140 | lr_mul 1.0000 | mom 0.950 | 3.73 steps/s | 134s +step 600/9999 | loss 2.2772 | lr_mul 1.0000 | mom 0.950 | 3.75 steps/s | 160s +step 700/9999 | loss 2.2463 | lr_mul 1.0000 | mom 0.950 | 3.76 steps/s | 186s +step 800/9999 | loss 2.2272 | lr_mul 1.0000 | mom 0.950 | 3.76 steps/s | 213s +step 900/9999 | loss 2.2331 | lr_mul 1.0000 | mom 0.950 | 3.77 steps/s | 239s +step 1000/9999 | loss 2.1977 | lr_mul 1.0000 | mom 0.950 | 3.77 steps/s | 265s +step 1100/9999 | loss 2.1808 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 291s +step 1200/9999 | loss 2.1808 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 317s +step 1300/9999 | loss 2.1580 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 344s +step 1400/9999 | loss 2.1307 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 370s +step 1500/9999 | loss 2.1267 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 396s +step 1600/9999 | loss 2.1344 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 422s +step 1700/9999 | loss 2.1253 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 448s +step 1800/9999 | loss 2.1133 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 475s +step 1900/9999 | loss 2.1171 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 501s +step 2000/9999 | loss 2.1097 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 527s +step 2100/9999 | loss 2.0967 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 554s +step 2200/9999 | loss 2.0996 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 580s +Wallclock limit reached (590s), will stop after this step +Stopping early at step 2239 (wallclock limit) + +Training complete in 590s (2239 steps) +Peak memory: 35748 MiB +Steps/sec: 3.79 + +=== Applying EMA weights === +EMA BPB (no XSA): 1.007375 +Saved raw EMA model + +=== GPTQ: generating autoregressive calibration data === +GPTQ: generated 64 sequences, collecting hessians... +GPTQ: collected hessians for 29 layers + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,830,308 bytes (15.10 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.016586 +Quantization degradation: +0.009211 +Quantized BPB (XSA-all): 1.020691 + +================================================================================ +FINAL RESULTS — D_GDN_Hybrid seed=1337 + Training: 2239 steps, 590s + EMA BPB: 1.007375 + Quantized BPB: 1.016586 + XSA BPB: 1.020691 + Artifact: /root/pg-repo/records/track_10min_16mb/2026-04-11_JM_GDN_Hybrid_DeltaRule/checkpoints/final_model_D_GDN_Hybrid_seed1337.int6.ptz +================================================================================ diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed2024.log b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed2024.log new file mode 100644 index 0000000000..e9efe8f4ef --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed2024.log @@ -0,0 +1,90 @@ +W0411 17:34:30.068000 738543 torch/distributed/run.py:851] +W0411 17:34:30.068000 738543 torch/distributed/run.py:851] ***************************************** +W0411 17:34:30.068000 738543 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0411 17:34:30.068000 738543 torch/distributed/run.py:851] ***************************************** +=== Direction-5: GDN Hybrid Training === +Arch: D_GDN_Hybrid (ARCH_MODE=D) +Seed: 2024, Max steps: 9999, Warmdown: 3000 +Train seq_len: 2048, Wallclock budget: 590.0s +QK_GAIN_INIT: 5.0 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +MuonEq-R: enabled +Validation tokens: 62,021,632 +Model built in 0.2s +Parameters: {'embedding': 925697, 'recurrent': 13251360, 'attention': 792584, 'mlp': 18880512, 'other': 12800, 'total': 33862953} +Total params: 33,862,953 +Matrix params: 33,263,616 +Scalar params: 75,049 +Embed params: 524,288 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: max 9999 steps (from step 0) +Wallclock budget: 590.0s +================================================================================ + +step 1/9999 | loss 6.9320 | lr_mul 0.0500 | mom 0.850 | 0.36 steps/s | 3s +step 2/9999 | loss 6.7120 | lr_mul 0.1000 | mom 0.850 | 0.66 steps/s | 3s +step 3/9999 | loss 6.1764 | lr_mul 0.1500 | mom 0.851 | 0.91 steps/s | 3s +step 4/9999 | loss 5.8613 | lr_mul 0.2000 | mom 0.851 | 1.12 steps/s | 4s +step 5/9999 | loss 5.8159 | lr_mul 0.2500 | mom 0.851 | 1.31 steps/s | 4s +step 6/9999 | loss 5.7956 | lr_mul 0.3000 | mom 0.851 | 1.47 steps/s | 4s +step 7/9999 | loss 5.7609 | lr_mul 0.3500 | mom 0.851 | 1.61 steps/s | 4s +step 8/9999 | loss 5.7081 | lr_mul 0.4000 | mom 0.852 | 1.74 steps/s | 5s +step 9/9999 | loss 5.6192 | lr_mul 0.4500 | mom 0.852 | 1.85 steps/s | 5s +step 10/9999 | loss 5.5308 | lr_mul 0.5000 | mom 0.852 | 1.95 steps/s | 5s +step 100/9999 | loss 3.6523 | lr_mul 1.0000 | mom 0.870 | 3.49 steps/s | 29s +step 200/9999 | loss 2.7056 | lr_mul 1.0000 | mom 0.890 | 3.62 steps/s | 55s +step 300/9999 | loss 2.4972 | lr_mul 1.0000 | mom 0.910 | 3.67 steps/s | 82s +step 400/9999 | loss 2.4008 | lr_mul 1.0000 | mom 0.930 | 3.70 steps/s | 108s +step 500/9999 | loss 2.3438 | lr_mul 1.0000 | mom 0.950 | 3.73 steps/s | 134s +step 600/9999 | loss 2.3031 | lr_mul 1.0000 | mom 0.950 | 3.74 steps/s | 160s +step 700/9999 | loss 2.2505 | lr_mul 1.0000 | mom 0.950 | 3.75 steps/s | 187s +step 800/9999 | loss 2.2285 | lr_mul 1.0000 | mom 0.950 | 3.76 steps/s | 213s +step 900/9999 | loss 2.2177 | lr_mul 1.0000 | mom 0.950 | 3.77 steps/s | 239s +step 1000/9999 | loss 2.1924 | lr_mul 1.0000 | mom 0.950 | 3.77 steps/s | 265s +step 1100/9999 | loss 2.1775 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 291s +step 1200/9999 | loss 2.1650 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 318s +step 1300/9999 | loss 2.1552 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 344s +step 1400/9999 | loss 2.1446 | lr_mul 1.0000 | mom 0.950 | 3.78 steps/s | 370s +step 1500/9999 | loss 2.1311 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 396s +step 1600/9999 | loss 2.1357 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 422s +step 1700/9999 | loss 2.1499 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 448s +step 1800/9999 | loss 2.1116 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 475s +step 1900/9999 | loss 2.1249 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 501s +step 2000/9999 | loss 2.1299 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 527s +step 2100/9999 | loss 2.1175 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 554s +step 2200/9999 | loss 2.1043 | lr_mul 1.0000 | mom 0.950 | 3.79 steps/s | 580s +Wallclock limit reached (590s), will stop after this step +Stopping early at step 2241 (wallclock limit) + +Training complete in 590s (2241 steps) +Peak memory: 35748 MiB +Steps/sec: 3.80 + +=== Applying EMA weights === +EMA BPB (no XSA): 1.008736 +Saved raw EMA model + +=== GPTQ: generating autoregressive calibration data === +GPTQ: generated 64 sequences, collecting hessians... +GPTQ: collected hessians for 29 layers + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,820,201 bytes (15.09 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.017995 +Quantization degradation: +0.009259 +Quantized BPB (XSA-all): 1.023138 + +================================================================================ +FINAL RESULTS — D_GDN_Hybrid seed=2024 + Training: 2241 steps, 590s + EMA BPB: 1.008736 + Quantized BPB: 1.017995 + XSA BPB: 1.023138 + Artifact: /root/pg-repo/records/track_10min_16mb/2026-04-11_JM_GDN_Hybrid_DeltaRule/checkpoints/final_model_D_GDN_Hybrid_seed2024.int6.ptz +================================================================================ diff --git a/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed42.log b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed42.log new file mode 100644 index 0000000000..a2a0681cd7 --- /dev/null +++ b/records/track_10min_16mb/2026-04-11_GDN_Hybrid_DeltaRule_1.0205/train_seed42.log @@ -0,0 +1,86 @@ +W0411 16:41:03.568000 660711 torch/distributed/run.py:851] +W0411 16:41:03.568000 660711 torch/distributed/run.py:851] ***************************************** +W0411 16:41:03.568000 660711 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0411 16:41:03.568000 660711 torch/distributed/run.py:851] ***************************************** +=== Direction-5: GDN Hybrid Training === +Arch: D_GDN_Hybrid (ARCH_MODE=D) +Seed: 42, Max steps: 9999, Warmdown: 3000 +Train seq_len: 2048, Wallclock budget: 590.0s +QK_GAIN_INIT: 5.0 +World size: 8, Grad accum: 1 +EMA decay: 0.997, SWA: True (every 50) +Late QAT threshold: 0.15 +MuonEq-R: enabled +Validation tokens: 62,021,632 +Model built in 0.2s +Parameters: {'embedding': 925697, 'recurrent': 13251360, 'attention': 792584, 'mlp': 18880512, 'other': 12800, 'total': 33862953} +Total params: 33,862,953 +Matrix params: 33,263,616 +Scalar params: 75,049 +Embed params: 524,288 +Generated coprime shard order: stride across 80 shards + +================================================================================ +Starting training: max 9999 steps (from step 0) +Wallclock budget: 590.0s +================================================================================ + +step 1/9999 | loss 6.9368 | lr_mul 0.0500 | mom 0.850 | 0.01 steps/s | 97s +step 2/9999 | loss 6.7108 | lr_mul 0.1000 | mom 0.850 | 0.02 steps/s | 98s +step 3/9999 | loss 6.1586 | lr_mul 0.1500 | mom 0.851 | 0.03 steps/s | 98s +step 4/9999 | loss 5.8826 | lr_mul 0.2000 | mom 0.851 | 0.04 steps/s | 98s +step 5/9999 | loss 5.8142 | lr_mul 0.2500 | mom 0.851 | 0.05 steps/s | 98s +step 6/9999 | loss 5.7684 | lr_mul 0.3000 | mom 0.851 | 0.06 steps/s | 99s +step 7/9999 | loss 5.7302 | lr_mul 0.3500 | mom 0.851 | 0.07 steps/s | 99s +step 8/9999 | loss 5.7047 | lr_mul 0.4000 | mom 0.852 | 0.08 steps/s | 99s +step 9/9999 | loss 5.6385 | lr_mul 0.4500 | mom 0.852 | 0.09 steps/s | 100s +step 10/9999 | loss 5.5581 | lr_mul 0.5000 | mom 0.852 | 0.10 steps/s | 100s +step 100/9999 | loss 3.6623 | lr_mul 1.0000 | mom 0.870 | 0.81 steps/s | 123s +step 200/9999 | loss 2.7004 | lr_mul 1.0000 | mom 0.890 | 1.33 steps/s | 150s +step 300/9999 | loss 2.5017 | lr_mul 1.0000 | mom 0.910 | 1.70 steps/s | 176s +step 400/9999 | loss 2.4038 | lr_mul 1.0000 | mom 0.930 | 1.97 steps/s | 203s +step 500/9999 | loss 2.3427 | lr_mul 1.0000 | mom 0.950 | 2.18 steps/s | 230s +step 600/9999 | loss 2.3079 | lr_mul 1.0000 | mom 0.950 | 2.34 steps/s | 256s +step 700/9999 | loss 2.2690 | lr_mul 1.0000 | mom 0.950 | 2.47 steps/s | 283s +step 800/9999 | loss 2.2461 | lr_mul 1.0000 | mom 0.950 | 2.59 steps/s | 309s +step 900/9999 | loss 2.2007 | lr_mul 1.0000 | mom 0.950 | 2.68 steps/s | 336s +step 1000/9999 | loss 2.1896 | lr_mul 1.0000 | mom 0.950 | 2.76 steps/s | 362s +step 1100/9999 | loss 2.1650 | lr_mul 1.0000 | mom 0.950 | 2.83 steps/s | 389s +step 1200/9999 | loss 2.1624 | lr_mul 1.0000 | mom 0.950 | 2.89 steps/s | 415s +step 1300/9999 | loss 2.1447 | lr_mul 1.0000 | mom 0.950 | 2.94 steps/s | 442s +step 1400/9999 | loss 2.1363 | lr_mul 1.0000 | mom 0.950 | 2.99 steps/s | 468s +step 1500/9999 | loss 2.1254 | lr_mul 1.0000 | mom 0.950 | 3.03 steps/s | 494s +step 1600/9999 | loss 2.1337 | lr_mul 1.0000 | mom 0.950 | 3.07 steps/s | 521s +step 1700/9999 | loss 2.1216 | lr_mul 1.0000 | mom 0.950 | 3.11 steps/s | 547s +step 1800/9999 | loss 2.1112 | lr_mul 1.0000 | mom 0.950 | 3.14 steps/s | 574s +Wallclock limit reached (590s), will stop after this step +Stopping early at step 1864 (wallclock limit) + +Training complete in 590s (1864 steps) +Peak memory: 35750 MiB +Steps/sec: 3.16 + +=== Applying EMA weights === +EMA BPB (no XSA): 1.017723 +Saved raw EMA model + +=== GPTQ: generating autoregressive calibration data === +GPTQ: generated 64 sequences, collecting hessians... +GPTQ: collected hessians for 29 layers + +=== Quantizing to int6 + zstd-22 === +Artifact: 15,313,984 bytes (14.60 MB) + +=== Roundtrip Validation (quantized model) === +Quantized BPB (no XSA): 1.026791 +Quantization degradation: +0.009067 +Quantized BPB (XSA-all): 1.031731 + +================================================================================ +FINAL RESULTS — D_GDN_Hybrid seed=42 + Training: 1864 steps, 590s + EMA BPB: 1.017723 + Quantized BPB: 1.026791 + XSA BPB: 1.031731 + Artifact: /root/pg-repo/records/track_10min_16mb/2026-04-11_JM_GDN_Hybrid_DeltaRule/checkpoints/final_model_D_GDN_Hybrid_seed42.int6.ptz +================================================================================