Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# - fla/ops/gated_delta_rule/wy_fast.py
#
# openinfer-specific changes:
# - fixed Qwen3.5 shapes (batch=1, H=32, K=128, V=128, chunk_size=64)
# - fixed Qwen3.5 dims (batch=1, K=128, V=128, chunk_size=64); head count is a runtime arg
# - Triton AOT-friendly surface and wrapper contracts
# - no backward / varlen / generic autotune surface
# - decode-compatible final-state layout contract [H, V, K]
# - fused prepare stage for q/k expansion, q/k normalization, and g/beta generation


QWEN35_GDR_HEADS = 32
QWEN35_GDR_CHUNK_SIZE = 64
QWEN35_GDR_KEY_DIM = 128
QWEN35_GDR_VALUE_DIM = 128
Expand Down
61 changes: 54 additions & 7 deletions openinfer-qwen35-4b/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ pub(crate) struct Config35 {
pub(crate) layer_types: Vec<LayerType>,
}

/// GDN dims the Triton-AOT kernels are built for; a mismatched model is rejected at load.
/// Head dims baked into the kernels; head counts are runtime parameters.
const GDN_AOT_KEY_HEAD_DIM: usize = 128;
const GDN_AOT_VALUE_HEAD_DIM: usize = 128;
const GDN_AOT_NUM_VALUE_HEADS: usize = 32;
const FULL_ATTN_HEAD_DIM: usize = 256;

impl Config35 {
pub(crate) fn from_file(model_path: &str) -> Result<Self> {
Expand Down Expand Up @@ -116,16 +116,28 @@ impl Config35 {

anyhow::ensure!(
t.linear_key_head_dim == GDN_AOT_KEY_HEAD_DIM
&& t.linear_value_head_dim == GDN_AOT_VALUE_HEAD_DIM
&& t.linear_num_value_heads == GDN_AOT_NUM_VALUE_HEADS,
"Qwen3.5 GDN Triton-AOT kernels are baked for key/value head dim {}/{}, \
{} value heads; config has {}/{}, {}. Rebuild openinfer-kernels to match.",
&& t.linear_value_head_dim == GDN_AOT_VALUE_HEAD_DIM,
"Qwen3.5 GDN Triton-AOT kernels are baked for key/value head dim {}/{}; \
config has {}/{} (dims are baked into the AOT signatures in openinfer-kernels/build.rs).",
GDN_AOT_KEY_HEAD_DIM,
GDN_AOT_VALUE_HEAD_DIM,
GDN_AOT_NUM_VALUE_HEADS,
t.linear_key_head_dim,
t.linear_value_head_dim,
);
anyhow::ensure!(
t.head_dim == FULL_ATTN_HEAD_DIM,
"Qwen3.5 full-attention kernels are baked for head_dim {}; config has {}.",
FULL_ATTN_HEAD_DIM,
t.head_dim,
);
anyhow::ensure!(
t.linear_num_key_heads > 0
&& t.linear_num_value_heads
.is_multiple_of(t.linear_num_key_heads),
"Qwen3.5 GDN kernels require linear_num_value_heads ({}) divisible by \
linear_num_key_heads ({})",
t.linear_num_value_heads,
t.linear_num_key_heads,
);

Ok(Self {
Expand Down Expand Up @@ -186,3 +198,38 @@ impl Config35 {
self.linear_num_value_heads * self.linear_value_head_dim
}
}

#[cfg(test)]
mod tests {
use super::Config35;

#[test]
fn guard_accepts_48_value_heads() {
let dir = std::env::temp_dir().join(format!("qwen35-config-guard-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let json = r#"{
"max_position_embeddings": 4096,
"tie_word_embeddings": true,
"text_config": {
"hidden_size": 512,
"intermediate_size": 1024,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 256,
"vocab_size": 1000,
"rms_norm_eps": 1e-6,
"layer_types": ["linear_attention", "full_attention"],
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 48,
"linear_value_head_dim": 128,
"rope_parameters": { "rope_theta": 10000.0, "partial_rotary_factor": 0.25 },
"eos_token_id": 0
}
}"#;
std::fs::write(dir.join("config.json"), json).unwrap();
Config35::from_file(dir.to_str().unwrap()).expect("48 value heads must load");
}
}
139 changes: 138 additions & 1 deletion openinfer-qwen35-4b/src/recurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,11 @@ mod tests {
use anyhow::Result;
use half::bf16;

use super::conv1d_prefill_batch_into;
use super::{
conv1d_prefill_batch_into, gated_delta_rule_decode_vec_into,
gated_delta_rule_prefill_chunkwise_into,
};
use crate::prefill_buffers::GdrChunkwiseScratch35;
use openinfer_core::tensor::{DeviceContext, DeviceVec, HiddenStates};

fn bf16_vec(data: &[f32]) -> Vec<bf16> {
Expand Down Expand Up @@ -623,4 +627,137 @@ mod tests {
assert!(max_state_diff < 0.02, "state diff {max_state_diff}");
Ok(())
}

#[test]
fn gdn_chunkwise_prefill_matches_stepwise_decode_at_48_value_heads() -> Result<()> {
let ctx = DeviceContext::new()?;
let num_key_heads = 16usize;
let num_value_heads = 48usize;
let key_dim = 128usize;
let val_dim = 128usize;
let seq_len = 96usize;

let qkv_dim = 2 * num_key_heads * key_dim + num_value_heads * val_dim;
let out_dim = num_value_heads * val_dim;
let state_len = num_value_heads * key_dim * val_dim;

let qkv_host = bf16_vec(
&(0..seq_len * qkv_dim)
.map(|i| ((i % 73) as f32 - 36.0) * 0.01)
.collect::<Vec<_>>(),
);
let b_host = bf16_vec(
&(0..seq_len * num_value_heads)
.map(|i| ((i % 13) as f32 - 6.0) * 0.05)
.collect::<Vec<_>>(),
);
let a_host = bf16_vec(
&(0..seq_len * num_value_heads)
.map(|i| ((i % 17) as f32 - 8.0) * 0.05)
.collect::<Vec<_>>(),
);
let dt_host = bf16_vec(
&(0..num_value_heads)
.map(|i| ((i % 7) as f32 - 3.0) * 0.1)
.collect::<Vec<_>>(),
);
let alog_host: Vec<f32> = (0..num_value_heads)
.map(|i| ((i % 5) as f32 - 2.0) * 0.2)
.collect();

let dt_bias = DeviceVec::from_host(&ctx, &dt_host)?;
let a_log = ctx.stream.clone_htod(&alog_host)?;

let qkv_all = HiddenStates {
data: ctx.stream.clone_htod(&qkv_host)?,
hidden_dim: qkv_dim,
seq_len,
};
let b_all = HiddenStates {
data: ctx.stream.clone_htod(&b_host)?,
hidden_dim: num_value_heads,
seq_len,
};
let a_all = HiddenStates {
data: ctx.stream.clone_htod(&a_host)?,
hidden_dim: num_value_heads,
seq_len,
};
let mut state_chunk: cudarc::driver::CudaSlice<f32> = ctx.stream.alloc_zeros(state_len)?;
let mut scratch =
GdrChunkwiseScratch35::from_dims(&ctx, num_value_heads, key_dim, val_dim, seq_len)?;
let mut out_chunk = HiddenStates::zeros(&ctx, out_dim, seq_len)?;
gated_delta_rule_prefill_chunkwise_into(
&ctx,
&qkv_all,
&b_all,
&a_all,
&dt_bias,
&a_log,
&mut state_chunk,
&mut scratch,
&mut out_chunk,
num_key_heads,
num_value_heads,
key_dim,
val_dim,
)?;

let mut state_step: cudarc::driver::CudaSlice<f32> = ctx.stream.alloc_zeros(state_len)?;
let mut out_step_rows: Vec<f32> = Vec::with_capacity(seq_len * out_dim);
for t in 0..seq_len {
let qkv_t = DeviceVec::from_host(&ctx, &qkv_host[t * qkv_dim..(t + 1) * qkv_dim])?;
let b_t = DeviceVec::from_host(
&ctx,
&b_host[t * num_value_heads..(t + 1) * num_value_heads],
)?;
let a_t = DeviceVec::from_host(
&ctx,
&a_host[t * num_value_heads..(t + 1) * num_value_heads],
)?;
let mut out_t = DeviceVec::from_host(&ctx, &vec![bf16::ZERO; out_dim])?;
gated_delta_rule_decode_vec_into(
&ctx,
&qkv_t,
&b_t,
&a_t,
&dt_bias,
&a_log,
&mut state_step,
&mut out_t,
num_key_heads,
num_value_heads,
key_dim,
val_dim,
);
let row = out_t.to_host(&ctx)?;
out_step_rows.extend_from_slice(&row);
}

let out_chunk_host = ctx.stream.clone_dtoh(&out_chunk.data)?;
let state_chunk_host = ctx.stream.clone_dtoh(&state_chunk)?;
let state_step_host = ctx.stream.clone_dtoh(&state_step)?;
ctx.sync()?;
let out_chunk_host: Vec<f32> = out_chunk_host.iter().map(|x| x.to_f32()).collect();

let max_out_diff = out_chunk_host
.iter()
.zip(out_step_rows.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
let max_state_diff = state_chunk_host
.iter()
.zip(state_step_host.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);

assert!(
out_chunk_host.iter().all(|x| x.is_finite())
&& state_chunk_host.iter().all(|x| x.is_finite()),
"chunkwise outputs must be finite"
);
assert!(max_out_diff < 0.05, "output diff {max_out_diff}");
assert!(max_state_diff < 0.05, "state diff {max_state_diff}");
Ok(())
}
}