Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions openinfer-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ pub use openinfer_kernels::ops::{
GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into,
add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into,
copy_hidden_rows_into, copy_hidden_token_range_into, dflash_qk_norm_rope_into,
embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into,
f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm,
gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune,
gemm_per_token, gemv, linear, lora_decode_fused_delta_group3_into,
lora_decode_fused_delta_into, pack_lora_b_rows_into,
eagle3_rope_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref,
extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into,
gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked,
gemm_into_checked, gemm_lt_tune, gemm_per_token, gemv, linear,
lora_decode_fused_delta_group3_into, lora_decode_fused_delta_into, pack_lora_b_rows_into,
qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into,
rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place,
scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into,
scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into,
single_prefill_nhd_noncausal_into, write_vec_into,
scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, single_decode_nhd_into,
single_prefill_nhd_causal_into, single_prefill_nhd_noncausal_into, write_vec_into,
};
#[cfg(not(feature = "kernel-call-trace"))]
pub use openinfer_kernels::ops::{
Expand Down
35 changes: 35 additions & 0 deletions openinfer-core/src/weight_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,41 @@ pub fn load_tensor_1d_f32(
Ok(gpu_data)
}

/// Load a 1D I64 tensor into a host `Vec<i64>`.
///
/// For small integer lookup tables that live on the host (e.g. EAGLE-3's `d2t`
/// draft→target vocab offset map), not weights destined for a GEMM.
pub fn load_tensor_i64_host(
shards: &[SafeTensors],
weight_map: &HashMap<String, usize>,
name: &str,
) -> Result<Vec<i64>> {
let tensor = find_tensor(shards, weight_map, name)?;
let data = tensor.data();
if data.len() % 8 != 0 {
return Err(anyhow::anyhow!(
"I64 tensor '{}': data length {} not multiple of 8",
name,
data.len()
));
}
Ok(data
.chunks_exact(8)
.map(|b| i64::from_le_bytes(b.try_into().unwrap()))
.collect())
}

/// Load a 1D BOOL tensor into a host `Vec<bool>` (safetensors stores BOOL as one
/// byte per element, `0`/`1`). For mask tables like EAGLE-3's `t2d`.
pub fn load_tensor_bool_host(
shards: &[SafeTensors],
weight_map: &HashMap<String, usize>,
name: &str,
) -> Result<Vec<bool>> {
let tensor = find_tensor(shards, weight_map, name)?;
Ok(tensor.data().iter().map(|&b| b != 0).collect())
}

/// Load shard info with fixup for mismatched shard filenames in index.json.
///
/// Some models (e.g., Qwen3.5) have index.json with shard filenames like
Expand Down
129 changes: 129 additions & 0 deletions openinfer-kernels/csrc/shared/paged_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,135 @@ int single_prefill_nhd_noncausal_cuda(
reinterpret_cast<cudaStream_t>(stream)));
}

// Causal variant of single_prefill_nhd_noncausal_cuda: identical NHD token-major
// layout (q/output [seq, q_dim], k/v [max_seq, kv_dim]) but a causal mask, so the
// N query tokens at the tail of the cache attend only positions <= their own.
// Used for EAGLE-3's teacher-forced prefill in one batched forward (query i at
// absolute position kv_len - seq_len + i, FlashInfer's causal alignment).
int single_prefill_nhd_causal_cuda(
void* q,
void* output,
void* k_cache,
void* v_cache,
int32_t num_qo_heads,
int32_t num_kv_heads,
int32_t head_dim,
int32_t seq_len,
int32_t kv_len,
int32_t max_seq_len,
float sm_scale,
void* stream)
{
if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr ||
num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 ||
seq_len <= 0 || kv_len <= 0 || max_seq_len < kv_len || seq_len > kv_len) {
return static_cast<int>(cudaErrorInvalidValue);
}

uint32_t q_stride_n = num_qo_heads * head_dim;
uint32_t q_stride_h = head_dim;
uint32_t kv_stride_n = num_kv_heads * head_dim;
uint32_t kv_stride_h = head_dim;

PrefillParamsT params(
reinterpret_cast<DType*>(q),
reinterpret_cast<DType*>(k_cache),
reinterpret_cast<DType*>(v_cache),
/*maybe_custom_mask=*/nullptr,
reinterpret_cast<DType*>(output),
/*lse=*/nullptr,
/*maybe_alibi_slopes=*/nullptr,
num_qo_heads,
num_kv_heads,
static_cast<uint32_t>(seq_len),
static_cast<uint32_t>(kv_len),
q_stride_n,
q_stride_h,
kv_stride_n,
kv_stride_h,
static_cast<uint32_t>(head_dim),
/*window_left=*/-1,
/*logits_soft_cap=*/0.0f,
sm_scale,
/*rope_scale=*/1.0f,
/*rope_theta=*/1e6f);

return static_cast<int>(
SinglePrefillWithKVCacheDispatched<
/*HEAD_DIM_QK=*/128,
/*HEAD_DIM_VO=*/128,
PosEncodingMode::kNone,
/*USE_FP16_QK_REDUCTION=*/false,
MaskMode::kCausal,
Variant,
PrefillParamsT>(
params,
/*tmp=*/nullptr,
reinterpret_cast<cudaStream_t>(stream)));
}

// ---------------------------------------------------------------------------
// Single-query DECODE over the EAGLE-3 draft's contiguous NHD KV cache.
//
// The chain drafter advances one token per step, so each draft attention is a
// pure decode: exactly ONE query attends the whole [0, kv_len) prefix. This uses
// FlashInfer's dedicated single-query decode path (GEMV-style over the KV),
// which is *structurally* single-query — SingleDecodeParams::get_qo_len() is
// hard-wired to 1 — so, unlike single_prefill_nhd_noncausal_cuda (a prefill
// template forced to qo_len==1), it cannot be silently misused for a multi-query
// batch (a footgun once the draft chain is batched). Same NHD token-major layout
// as the *_nhd_* prefill pair: q/output [1, q_dim], k/v [max_seq_len, kv_dim].
// No RoPE inside (the caller applies eagle3_rope first).
// ---------------------------------------------------------------------------
using DecodeParamsT = SingleDecodeParams<DType, DType, DType>;

int single_decode_nhd_cuda(
void* q, // [1, q_dim] token-major — the single decode query
void* output, // [1, q_dim]
void* k_cache, // [max_seq_len, kv_dim] NHD (k[pos, head, dim])
void* v_cache,
int32_t num_qo_heads,
int32_t num_kv_heads,
int32_t head_dim,
int32_t kv_len, // positions to attend: [0, kv_len)
int32_t max_seq_len, // allocated cache rows (validation parity with the *_nhd_* pair)
float sm_scale,
void* stream)
{
if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr ||
num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 ||
kv_len <= 0 || max_seq_len < kv_len) {
return static_cast<int>(cudaErrorInvalidValue);
}

DecodeParamsT params(
reinterpret_cast<DType*>(q),
reinterpret_cast<DType*>(k_cache),
reinterpret_cast<DType*>(v_cache),
reinterpret_cast<DType*>(output),
/*maybe_alibi_slopes=*/nullptr,
/*seq_len(=kv_len)=*/static_cast<uint32_t>(kv_len),
static_cast<uint32_t>(num_qo_heads),
static_cast<uint32_t>(num_kv_heads),
QKVLayout::kNHD,
static_cast<uint32_t>(head_dim),
/*window_left=*/-1,
/*logits_soft_cap=*/0.0f,
sm_scale,
/*rope_scale=*/1.0f,
/*rope_theta=*/1e6f);

return static_cast<int>(
SingleDecodeWithKVCacheDispatched<
/*HEAD_DIM=*/128,
PosEncodingMode::kNone,
Variant,
DecodeParamsT>(
params,
/*tmp=*/nullptr,
reinterpret_cast<cudaStream_t>(stream)));
}

// ---------------------------------------------------------------------------
// Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache.
//
Expand Down
94 changes: 94 additions & 0 deletions openinfer-kernels/csrc/shared/prefill_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,69 @@ __global__ void dflash_qk_norm_rope_kernel(
data[offset] = result;
}

// Plain RoPE (no QK-norm) for EAGLE-3, whose attention has no per-head q/k norm.
__global__ void eagle3_rope_kernel(
__nv_bfloat16* __restrict__ q,
__nv_bfloat16* __restrict__ k,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
int num_q_heads,
int num_kv_heads,
int head_dim,
int q_len,
int k_len,
int q_start_pos,
int k_start_pos,
int cos_max_pos
) {
int head_global = blockIdx.x;
int token = blockIdx.y;
int d = threadIdx.x;

bool is_q = (head_global < num_q_heads);
int local_heads = is_q ? num_q_heads : num_kv_heads;
int seq_len = is_q ? q_len : k_len;
if (token >= seq_len) return;

int head_local = is_q ? head_global : (head_global - num_q_heads);
if (head_local >= local_heads) return;

__nv_bfloat16* data = is_q ? q : k;
int dim_stride = local_heads * head_dim;
int pos = (is_q ? q_start_pos : k_start_pos) + token;
if (pos < 0 || pos >= cos_max_pos) __trap();

int offset = token * dim_stride + head_local * head_dim + d;

// No QK-norm: the raw value goes straight into the RoPE rotation.
__shared__ __nv_bfloat16 smem[HEAD_DIM];
smem[d] = data[offset];
__syncthreads();

int half = head_dim / 2;
__nv_bfloat16 result;
if (d < half) {
float lo = __bfloat162float(smem[d]);
float hi = __bfloat162float(smem[d + half]);
float c = __bfloat162float(cos_cache[pos * head_dim + d]);
float s = __bfloat162float(sin_cache[pos * head_dim + d]);
float lo_cos = __bfloat162float(__float2bfloat16(lo * c));
float hi_sin = __bfloat162float(__float2bfloat16(hi * s));
result = __float2bfloat16(lo_cos - hi_sin);
} else {
int pair_d = d - half;
float lo = __bfloat162float(smem[pair_d]);
float hi = __bfloat162float(smem[d]);
float c = __bfloat162float(cos_cache[pos * head_dim + pair_d]);
float s = __bfloat162float(sin_cache[pos * head_dim + pair_d]);
float lo_sin = __bfloat162float(__float2bfloat16(lo * s));
float hi_cos = __bfloat162float(__float2bfloat16(hi * c));
result = __float2bfloat16(lo_sin + hi_cos);
}

data[offset] = result;
}

extern "C" {

// ============================================================================
Expand Down Expand Up @@ -291,4 +354,35 @@ int dflash_qk_norm_rope_cuda(
return static_cast<int>(cudaGetLastError());
}

// Plain RoPE (no QK-norm) for EAGLE-3. Launches eagle3_rope_kernel.
int eagle3_rope_cuda(
__nv_bfloat16* q,
__nv_bfloat16* k,
const __nv_bfloat16* cos_cache,
const __nv_bfloat16* sin_cache,
int num_q_heads,
int num_kv_heads,
int head_dim,
int q_len,
int k_len,
int q_start_pos,
int k_start_pos,
int cos_max_pos,
cudaStream_t stream
) {
if (q == nullptr || k == nullptr || cos_cache == nullptr || sin_cache == nullptr ||
num_q_heads <= 0 || num_kv_heads <= 0 || head_dim != HEAD_DIM ||
q_len <= 0 || k_len <= 0 || q_start_pos < 0 || k_start_pos < 0 ||
q_start_pos + q_len > cos_max_pos || k_start_pos + k_len > cos_max_pos) {
return static_cast<int>(cudaErrorInvalidValue);
}

dim3 grid(num_q_heads + num_kv_heads, q_len > k_len ? q_len : k_len);
eagle3_rope_kernel<<<grid, head_dim, 0, stream>>>(
q, k, cos_cache, sin_cache,
num_q_heads, num_kv_heads, head_dim, q_len, k_len,
q_start_pos, k_start_pos, cos_max_pos);
return static_cast<int>(cudaGetLastError());
}

} // extern "C"
49 changes: 49 additions & 0 deletions openinfer-kernels/src/ffi/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,23 @@ unsafe extern "C" {
stream: CUstream,
) -> i32;

// Plain RoPE (no QK-norm) for EAGLE-3 — no norm-weight / eps params.
pub fn eagle3_rope_cuda(
q: *mut Half,
k: *mut Half,
cos_cache: *const Half,
sin_cache: *const Half,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
q_len: i32,
k_len: i32,
q_start_pos: i32,
k_start_pos: i32,
cos_max_pos: i32,
stream: CUstream,
) -> i32;

// Scatter contiguous KV → paged layout (one layer, FlashInfer prefill append).
pub fn paged_kv_scatter_cuda(
kv_data: *const Half,
Expand Down Expand Up @@ -463,6 +480,38 @@ unsafe extern "C" {
stream: CUstream,
) -> i32;

// Causal NHD single-sequence prefill (same layout, causal mask).
pub fn single_prefill_nhd_causal_cuda(
q: *const Half,
output: *mut Half,
k_cache: *const Half,
v_cache: *const Half,
num_qo_heads: i32,
num_kv_heads: i32,
head_dim: i32,
seq_len: i32,
kv_len: i32,
max_seq_len: i32,
sm_scale: f32,
stream: CUstream,
) -> i32;

// Single-query NHD decode over a contiguous KV cache (FlashInfer SingleDecode,
// no partition-KV). Structurally one query, so there is no `seq_len` parameter.
pub fn single_decode_nhd_cuda(
q: *const Half,
output: *mut Half,
k_cache: *const Half,
v_cache: *const Half,
num_qo_heads: i32,
num_kv_heads: i32,
head_dim: i32,
kv_len: i32,
max_seq_len: i32,
sm_scale: f32,
stream: CUstream,
) -> i32;

// Paged attention decode (FlashInfer BatchDecode, no partition-KV).
pub fn paged_attention_decode_cuda(
q: *const Half,
Expand Down
9 changes: 5 additions & 4 deletions openinfer-kernels/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ mod norm;
mod sampling;

pub use attention::{
PrefillPagedPlan, dflash_qk_norm_rope_into, paged_attention_batch_decode_hd256_into,
paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into,
prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into,
qk_norm_rope_batch_decode_into, single_prefill_nhd_noncausal_into,
PrefillPagedPlan, dflash_qk_norm_rope_into, eagle3_rope_into,
paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into,
paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into,
qk_norm_partial_rope_batched_decode_hd256_into, qk_norm_rope_batch_decode_into,
single_decode_nhd_into, single_prefill_nhd_causal_into, single_prefill_nhd_noncausal_into,
};
#[cfg(feature = "moe")]
pub use deepep::{
Expand Down
Loading