Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions openinfer-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ pub use openinfer_kernels::ops::{
GEMM_LT_MAX_N, LoraDecodeGroupedProjection, accumulate_bf16_token_scaled_to_f32_into,
add_batch, add_batch_into, argmax, argmax_batch_bf16_into, bf16_hidden_to_f32_into,
copy_hidden_rows_into, copy_hidden_token_range_into, dflash_qk_norm_rope_into,
embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref, extract_vec_ref_into,
f32_to_bf16_hidden_into, fused_add_rms_norm_into, gather_hidden_tokens_into, gemm,
gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked, gemm_into_checked, gemm_lt_tune,
gemm_per_token, gemv, linear, lora_decode_fused_delta_group3_into,
lora_decode_fused_delta_into, pack_lora_b_rows_into,
eagle3_rope_into, embedding_decode_into, extract_vec, extract_vec_into, extract_vec_ref,
extract_vec_ref_into, f32_to_bf16_hidden_into, fused_add_rms_norm_into,
gather_hidden_tokens_into, gemm, gemm_graphsafe_into_checked, gemm_graphsafe_ref_into_checked,
gemm_into_checked, gemm_lt_tune, gemm_per_token, gemv, linear,
lora_decode_fused_delta_group3_into, lora_decode_fused_delta_into, pack_lora_b_rows_into,
qk_norm_partial_rope_batched_decode_hd256_into, rms_norm, rms_norm_batch_offset_into,
rms_norm_gated_batch_into, rms_norm_into, rms_norm_offset_into, scale_f32_in_place,
scaled_add_batch_into, scaled_add_rows_indexed_into, scaled_add_rows_into,
scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into,
single_prefill_nhd_noncausal_into, write_vec_into,
scaled_add_rows_token_range_into, silu_mul_batch, silu_mul_batch_into, single_decode_nhd_into,
single_prefill_nhd_causal_into, single_prefill_nhd_noncausal_into, write_vec_into,
};
#[cfg(not(feature = "kernel-call-trace"))]
pub use openinfer_kernels::ops::{
Expand Down
35 changes: 35 additions & 0 deletions openinfer-core/src/weight_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,41 @@ pub fn load_tensor_1d_f32(
Ok(gpu_data)
}

/// Load a 1D I64 tensor into a host `Vec<i64>`.
///
/// For small integer lookup tables that live on the host (e.g. EAGLE-3's `d2t`
/// draft→target vocab offset map), not weights destined for a GEMM.
pub fn load_tensor_i64_host(
shards: &[SafeTensors],
weight_map: &HashMap<String, usize>,
name: &str,
) -> Result<Vec<i64>> {
let tensor = find_tensor(shards, weight_map, name)?;
let data = tensor.data();
if data.len() % 8 != 0 {
return Err(anyhow::anyhow!(
"I64 tensor '{}': data length {} not multiple of 8",
name,
data.len()
));
}
Ok(data
.chunks_exact(8)
.map(|b| i64::from_le_bytes(b.try_into().unwrap()))
.collect())
}

/// Load a 1D BOOL tensor into a host `Vec<bool>` (safetensors stores BOOL as one
/// byte per element, `0`/`1`). For mask tables like EAGLE-3's `t2d`.
pub fn load_tensor_bool_host(
shards: &[SafeTensors],
weight_map: &HashMap<String, usize>,
name: &str,
) -> Result<Vec<bool>> {
let tensor = find_tensor(shards, weight_map, name)?;
Ok(tensor.data().iter().map(|&b| b != 0).collect())
}

/// Load shard info with fixup for mismatched shard filenames in index.json.
///
/// Some models (e.g., Qwen3.5) have index.json with shard filenames like
Expand Down
129 changes: 129 additions & 0 deletions openinfer-kernels/csrc/shared/paged_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,135 @@ int single_prefill_nhd_noncausal_cuda(
reinterpret_cast<cudaStream_t>(stream)));
}

// Causal variant of single_prefill_nhd_noncausal_cuda: identical NHD token-major
// layout (q/output [seq, q_dim], k/v [max_seq, kv_dim]) but a causal mask, so the
// N query tokens at the tail of the cache attend only positions <= their own.
// Used for EAGLE-3's teacher-forced prefill in one batched forward (query i at
// absolute position kv_len - seq_len + i, FlashInfer's causal alignment).
int single_prefill_nhd_causal_cuda(
void* q,
void* output,
void* k_cache,
void* v_cache,
int32_t num_qo_heads,
int32_t num_kv_heads,
int32_t head_dim,
int32_t seq_len,
int32_t kv_len,
int32_t max_seq_len,
float sm_scale,
void* stream)
{
if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr ||
num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 ||
seq_len <= 0 || kv_len <= 0 || max_seq_len < kv_len || seq_len > kv_len) {
return static_cast<int>(cudaErrorInvalidValue);
}

uint32_t q_stride_n = num_qo_heads * head_dim;
uint32_t q_stride_h = head_dim;
uint32_t kv_stride_n = num_kv_heads * head_dim;
uint32_t kv_stride_h = head_dim;

PrefillParamsT params(
reinterpret_cast<DType*>(q),
reinterpret_cast<DType*>(k_cache),
reinterpret_cast<DType*>(v_cache),
/*maybe_custom_mask=*/nullptr,
reinterpret_cast<DType*>(output),
/*lse=*/nullptr,
/*maybe_alibi_slopes=*/nullptr,
num_qo_heads,
num_kv_heads,
static_cast<uint32_t>(seq_len),
static_cast<uint32_t>(kv_len),
q_stride_n,
q_stride_h,
kv_stride_n,
kv_stride_h,
static_cast<uint32_t>(head_dim),
/*window_left=*/-1,
/*logits_soft_cap=*/0.0f,
sm_scale,
/*rope_scale=*/1.0f,
/*rope_theta=*/1e6f);

return static_cast<int>(
SinglePrefillWithKVCacheDispatched<
/*HEAD_DIM_QK=*/128,
/*HEAD_DIM_VO=*/128,
PosEncodingMode::kNone,
/*USE_FP16_QK_REDUCTION=*/false,
MaskMode::kCausal,
Variant,
PrefillParamsT>(
params,
/*tmp=*/nullptr,
reinterpret_cast<cudaStream_t>(stream)));
}

// ---------------------------------------------------------------------------
// Single-query DECODE over the EAGLE-3 draft's contiguous NHD KV cache.
//
// The chain drafter advances one token per step, so each draft attention is a
// pure decode: exactly ONE query attends the whole [0, kv_len) prefix. This uses
// FlashInfer's dedicated single-query decode path (GEMV-style over the KV),
// which is *structurally* single-query — SingleDecodeParams::get_qo_len() is
// hard-wired to 1 — so, unlike single_prefill_nhd_noncausal_cuda (a prefill
// template forced to qo_len==1), it cannot be silently misused for a multi-query
// batch (a footgun once the draft chain is batched). Same NHD token-major layout
// as the *_nhd_* prefill pair: q/output [1, q_dim], k/v [max_seq_len, kv_dim].
// No RoPE inside (the caller applies eagle3_rope first).
// ---------------------------------------------------------------------------
using DecodeParamsT = SingleDecodeParams<DType, DType, DType>;

int single_decode_nhd_cuda(
void* q, // [1, q_dim] token-major — the single decode query
void* output, // [1, q_dim]
void* k_cache, // [max_seq_len, kv_dim] NHD (k[pos, head, dim])
void* v_cache,
int32_t num_qo_heads,
int32_t num_kv_heads,
int32_t head_dim,
int32_t kv_len, // positions to attend: [0, kv_len)
int32_t max_seq_len, // allocated cache rows (validation parity with the *_nhd_* pair)
float sm_scale,
void* stream)
{
if (q == nullptr || output == nullptr || k_cache == nullptr || v_cache == nullptr ||
num_qo_heads <= 0 || num_kv_heads <= 0 || head_dim != 128 ||
kv_len <= 0 || max_seq_len < kv_len) {
return static_cast<int>(cudaErrorInvalidValue);
}

DecodeParamsT params(
reinterpret_cast<DType*>(q),
reinterpret_cast<DType*>(k_cache),
reinterpret_cast<DType*>(v_cache),
reinterpret_cast<DType*>(output),
/*maybe_alibi_slopes=*/nullptr,
/*seq_len(=kv_len)=*/static_cast<uint32_t>(kv_len),
static_cast<uint32_t>(num_qo_heads),
static_cast<uint32_t>(num_kv_heads),
QKVLayout::kNHD,
static_cast<uint32_t>(head_dim),
/*window_left=*/-1,
/*logits_soft_cap=*/0.0f,
sm_scale,
/*rope_scale=*/1.0f,
/*rope_theta=*/1e6f);

return static_cast<int>(
SingleDecodeWithKVCacheDispatched<
/*HEAD_DIM=*/128,
PosEncodingMode::kNone,
Variant,
DecodeParamsT>(
params,
/*tmp=*/nullptr,
reinterpret_cast<cudaStream_t>(stream)));
}

// ---------------------------------------------------------------------------
// Single-request prefill for HEAD_DIM=256 — wraps FlashInfer SinglePrefillWithKVCache.
//
Expand Down
94 changes: 94 additions & 0 deletions openinfer-kernels/csrc/shared/prefill_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,69 @@ __global__ void dflash_qk_norm_rope_kernel(
data[offset] = result;
}

// Plain RoPE (no QK-norm) for EAGLE-3, whose attention has no per-head q/k norm.
__global__ void eagle3_rope_kernel(
__nv_bfloat16* __restrict__ q,
__nv_bfloat16* __restrict__ k,
const __nv_bfloat16* __restrict__ cos_cache,
const __nv_bfloat16* __restrict__ sin_cache,
int num_q_heads,
int num_kv_heads,
int head_dim,
int q_len,
int k_len,
int q_start_pos,
int k_start_pos,
int cos_max_pos
) {
int head_global = blockIdx.x;
int token = blockIdx.y;
int d = threadIdx.x;

bool is_q = (head_global < num_q_heads);
int local_heads = is_q ? num_q_heads : num_kv_heads;
int seq_len = is_q ? q_len : k_len;
if (token >= seq_len) return;

int head_local = is_q ? head_global : (head_global - num_q_heads);
if (head_local >= local_heads) return;

__nv_bfloat16* data = is_q ? q : k;
int dim_stride = local_heads * head_dim;
int pos = (is_q ? q_start_pos : k_start_pos) + token;
if (pos < 0 || pos >= cos_max_pos) __trap();

int offset = token * dim_stride + head_local * head_dim + d;

// No QK-norm: the raw value goes straight into the RoPE rotation.
__shared__ __nv_bfloat16 smem[HEAD_DIM];
smem[d] = data[offset];
__syncthreads();

int half = head_dim / 2;
__nv_bfloat16 result;
if (d < half) {
float lo = __bfloat162float(smem[d]);
float hi = __bfloat162float(smem[d + half]);
float c = __bfloat162float(cos_cache[pos * head_dim + d]);
float s = __bfloat162float(sin_cache[pos * head_dim + d]);
float lo_cos = __bfloat162float(__float2bfloat16(lo * c));
float hi_sin = __bfloat162float(__float2bfloat16(hi * s));
result = __float2bfloat16(lo_cos - hi_sin);
} else {
int pair_d = d - half;
float lo = __bfloat162float(smem[pair_d]);
float hi = __bfloat162float(smem[d]);
float c = __bfloat162float(cos_cache[pos * head_dim + pair_d]);
float s = __bfloat162float(sin_cache[pos * head_dim + pair_d]);
float lo_sin = __bfloat162float(__float2bfloat16(lo * s));
float hi_cos = __bfloat162float(__float2bfloat16(hi * c));
result = __float2bfloat16(lo_sin + hi_cos);
}

data[offset] = result;
}

extern "C" {

// ============================================================================
Expand Down Expand Up @@ -291,4 +354,35 @@ int dflash_qk_norm_rope_cuda(
return static_cast<int>(cudaGetLastError());
}

// Plain RoPE (no QK-norm) for EAGLE-3. Launches eagle3_rope_kernel.
int eagle3_rope_cuda(
__nv_bfloat16* q,
__nv_bfloat16* k,
const __nv_bfloat16* cos_cache,
const __nv_bfloat16* sin_cache,
int num_q_heads,
int num_kv_heads,
int head_dim,
int q_len,
int k_len,
int q_start_pos,
int k_start_pos,
int cos_max_pos,
cudaStream_t stream
) {
if (q == nullptr || k == nullptr || cos_cache == nullptr || sin_cache == nullptr ||
num_q_heads <= 0 || num_kv_heads <= 0 || head_dim != HEAD_DIM ||
q_len <= 0 || k_len <= 0 || q_start_pos < 0 || k_start_pos < 0 ||
q_start_pos + q_len > cos_max_pos || k_start_pos + k_len > cos_max_pos) {
return static_cast<int>(cudaErrorInvalidValue);
}

dim3 grid(num_q_heads + num_kv_heads, q_len > k_len ? q_len : k_len);
eagle3_rope_kernel<<<grid, head_dim, 0, stream>>>(
q, k, cos_cache, sin_cache,
num_q_heads, num_kv_heads, head_dim, q_len, k_len,
q_start_pos, k_start_pos, cos_max_pos);
return static_cast<int>(cudaGetLastError());
}

} // extern "C"
49 changes: 49 additions & 0 deletions openinfer-kernels/src/ffi/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,23 @@ unsafe extern "C" {
stream: CUstream,
) -> i32;

// Plain RoPE (no QK-norm) for EAGLE-3 — no norm-weight / eps params.
pub fn eagle3_rope_cuda(
q: *mut Half,
k: *mut Half,
cos_cache: *const Half,
sin_cache: *const Half,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
q_len: i32,
k_len: i32,
q_start_pos: i32,
k_start_pos: i32,
cos_max_pos: i32,
stream: CUstream,
) -> i32;

// Scatter contiguous KV → paged layout (one layer, FlashInfer prefill append).
pub fn paged_kv_scatter_cuda(
kv_data: *const Half,
Expand Down Expand Up @@ -463,6 +480,38 @@ unsafe extern "C" {
stream: CUstream,
) -> i32;

// Causal NHD single-sequence prefill (same layout, causal mask).
pub fn single_prefill_nhd_causal_cuda(
q: *const Half,
output: *mut Half,
k_cache: *const Half,
v_cache: *const Half,
num_qo_heads: i32,
num_kv_heads: i32,
head_dim: i32,
seq_len: i32,
kv_len: i32,
max_seq_len: i32,
sm_scale: f32,
stream: CUstream,
) -> i32;

// Single-query NHD decode over a contiguous KV cache (FlashInfer SingleDecode,
// no partition-KV). Structurally one query, so there is no `seq_len` parameter.
pub fn single_decode_nhd_cuda(
q: *const Half,
output: *mut Half,
k_cache: *const Half,
v_cache: *const Half,
num_qo_heads: i32,
num_kv_heads: i32,
head_dim: i32,
kv_len: i32,
max_seq_len: i32,
sm_scale: f32,
stream: CUstream,
) -> i32;

// Paged attention decode (FlashInfer BatchDecode, no partition-KV).
pub fn paged_attention_decode_cuda(
q: *const Half,
Expand Down
9 changes: 5 additions & 4 deletions openinfer-kernels/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ mod norm;
mod sampling;

pub use attention::{
PrefillPagedPlan, dflash_qk_norm_rope_into, paged_attention_batch_decode_hd256_into,
paged_attention_batch_decode_into, paged_attention_batch_decode_split_kv_into,
prefill_attention_paged_into, qk_norm_partial_rope_batched_decode_hd256_into,
qk_norm_rope_batch_decode_into, single_prefill_nhd_noncausal_into,
PrefillPagedPlan, dflash_qk_norm_rope_into, eagle3_rope_into,
paged_attention_batch_decode_hd256_into, paged_attention_batch_decode_into,
paged_attention_batch_decode_split_kv_into, prefill_attention_paged_into,
qk_norm_partial_rope_batched_decode_hd256_into, qk_norm_rope_batch_decode_into,
single_decode_nhd_into, single_prefill_nhd_causal_into, single_prefill_nhd_noncausal_into,
};
#[cfg(feature = "moe")]
pub use deepep::{
Expand Down
Loading