目标
实现 Lightning Linear Attention 层 (BailingMoeV2_5LinearAttention),集成 tops.ops.simple_gla.chunk_simple_gla_fwd Pallas kernel。
架构
- Fused QKV 投影:
hidden_size(8192) → 3 * num_heads * head_dim
- QK Norm: 每 head 独立 RMSNorm
- Partial RoPE:
rotary_dim=64(head_dim 的前半部分),NeoX style
- Simple GLA:
chunk_simple_gla_fwd(q, k, v, g_gamma=slopes, h0=state, use_ht=True)
- ALiBi 衰减:
g_gamma = -base_slopes * (1 - (layer_idx-1)/(num_layers-1) + 1e-5)
- 输出 gating:
GroupRMSNorm(o) * sigmoid(g_proj(x))
- Recurrent state: 常量大小
[B, H, K, V],跨 decode step 传递
约束
K 和 V 必须是 128 的倍数
seq_len 必须是 chunk_size(64) 的倍数(不足时 pad)
- Prefill 和 decode 均使用 chunk 模式(fused recurrent 为后续优化)
验收标准
目标
实现 Lightning Linear Attention 层 (
BailingMoeV2_5LinearAttention),集成tops.ops.simple_gla.chunk_simple_gla_fwdPallas kernel。架构
hidden_size(8192) → 3 * num_heads * head_dimrotary_dim=64(head_dim 的前半部分),NeoX stylechunk_simple_gla_fwd(q, k, v, g_gamma=slopes, h0=state, use_ht=True)g_gamma = -base_slopes * (1 - (layer_idx-1)/(num_layers-1) + 1e-5)GroupRMSNorm(o) * sigmoid(g_proj(x))[B, H, K, V],跨 decode step 传递约束
K和V必须是 128 的倍数seq_len必须是chunk_size(64)的倍数(不足时 pad)验收标准
BailingMoeV2_5LinearAttention实现并可构造[batch_seq, hidden_size]