Skip to content

[Feature] Implement BailingMoeV2_5LinearAttention (Simple GLA) #153

@sii-xinglong

Description

@sii-xinglong

目标

实现 Lightning Linear Attention 层 (BailingMoeV2_5LinearAttention),集成 tops.ops.simple_gla.chunk_simple_gla_fwd Pallas kernel。

架构

  • Fused QKV 投影: hidden_size(8192) → 3 * num_heads * head_dim
  • QK Norm: 每 head 独立 RMSNorm
  • Partial RoPE: rotary_dim=64(head_dim 的前半部分),NeoX style
  • Simple GLA: chunk_simple_gla_fwd(q, k, v, g_gamma=slopes, h0=state, use_ht=True)
  • ALiBi 衰减: g_gamma = -base_slopes * (1 - (layer_idx-1)/(num_layers-1) + 1e-5)
  • 输出 gating: GroupRMSNorm(o) * sigmoid(g_proj(x))
  • Recurrent state: 常量大小 [B, H, K, V],跨 decode step 传递

约束

  • KV 必须是 128 的倍数
  • seq_len 必须是 chunk_size(64) 的倍数(不足时 pad)
  • Prefill 和 decode 均使用 chunk 模式(fused recurrent 为后续优化)

验收标准

  • BailingMoeV2_5LinearAttention 实现并可构造
  • 前向输出形状正确 [batch_seq, hidden_size]
  • Recurrent state 正确更新(跨调用 state 不同)
  • 单元测试通过

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions