Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 289 additions & 0 deletions src/layer/vulkan/sdpa_vulkan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#include "sdpa_vulkan.h"
#include "layer_shader_type.h"
#include <cmath> // for sqrt

namespace ncnn {

SDPA_vulkan::SDPA_vulkan()
{
support_vulkan = true;
support_vulkan_packing = false;
support_vulkan_any_packing = false;

pipeline_sdpa = 0;
pipeline_sdpa_kv_concat = 0;
}

int SDPA_vulkan::load_param(const ParamDict& pd)
{
int ret = SDPA::load_param(pd);

if (int8_scale_term)
{
support_vulkan = false;
}

return ret;
}

int SDPA_vulkan::create_pipeline(const Option& opt)
{
const Mat& qshape = bottom_shapes.empty() ? Mat() : bottom_shapes[0];
const Mat& vshape = bottom_shapes.size() > 2 ? bottom_shapes[2] : Mat();

int head_dim = 0;
int out_head_dim = 0;

if (qshape.dims == 3) head_dim = qshape.w;
if (vshape.dims == 3) out_head_dim = vshape.w;

// SDPA Pipeline
// Spec constants: 0=head_dim, 1=out_head_dim.
// Scale removed from spec constants as it is passed dynamically via push constants.
std::vector<vk_specialization_type> spec_sdpa(2);
spec_sdpa[0].i = head_dim;
spec_sdpa[1].i = out_head_dim;

pipeline_sdpa = new Pipeline(vkdev);
pipeline_sdpa->set_local_size_xyz(256, 1, 1);
pipeline_sdpa->create(LayerShaderType::sdpa, opt, spec_sdpa);

// KV Concat Pipeline
std::vector<vk_specialization_type> spec_kv(2);
spec_kv[0].i = head_dim;
spec_kv[1].i = out_head_dim;

pipeline_sdpa_kv_concat = new Pipeline(vkdev);
pipeline_sdpa_kv_concat->set_local_size_xyz(64, 1, 1);
pipeline_sdpa_kv_concat->create(LayerShaderType::sdpa_kv_concat, opt, spec_kv);

return 0;
}

int SDPA_vulkan::destroy_pipeline(const Option& /*opt*/)
{
delete pipeline_sdpa;
pipeline_sdpa = 0;

delete pipeline_sdpa_kv_concat;
pipeline_sdpa_kv_concat = 0;

return 0;
}

static int sdpa_make_dispatcher(VkMat& dispatcher, int tiles_q, int heads)
{
dispatcher.w = tiles_q * 256;
dispatcher.h = heads;
dispatcher.c = 1;
return 0;
}

// sdpa_vulkan.cpp

int SDPA_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
if (bottom_blobs.size() < 3 || top_blobs.empty())
return -100;

// 0 query
// 1 cur_key
// 2 cur_value
// 3 mask (optional, if attn_mask=1)
// 3/4 past_key/value (optional, if kv_cache=1, and depends on attn_mask)

const VkMat& query = bottom_blobs[0];
const VkMat& cur_key = bottom_blobs[1];
const VkMat& cur_value = bottom_blobs[2];

// Mask (only valid if attn_mask flag is set)
VkMat mask;
if (attn_mask)
{
if ((int)bottom_blobs.size() < 4)
return -100;
mask = bottom_blobs[3];
}

// Past KV (only valid if kv_cache flag is set)
VkMat past_key;
VkMat past_value;
if (kv_cache)
{
const int pk_index = attn_mask ? 4 : 3;
const int pv_index = attn_mask ? 5 : 4;

if ((int)bottom_blobs.size() <= pv_index)
return -100;

past_key = bottom_blobs[pk_index];
past_value = bottom_blobs[pv_index];
}

VkMat key = cur_key;
VkMat value = cur_value;

// ---- KV cache concat path (only when kv_cache=1 and have non-empty past with seqlen>0) ----
const int d = query.w;
const int dv = cur_value.w;

if (d <= 0 || dv <= 0 || query.h <= 0 || query.c <= 0)
return -100;

// Only concat if past has actual length
const int past_seqlen = (kv_cache && !past_key.empty()) ? past_key.h : 0;
const int cur_seqlen = cur_key.h;

if (kv_cache && past_seqlen > 0)
{
const int num_group = cur_key.c; // expected groups for K/V

VkMat& out_key = top_blobs.size() >= 2 ? top_blobs[1] : *(VkMat*)0;
VkMat& out_value = top_blobs.size() >= 3 ? top_blobs[2] : *(VkMat*)0;

// kv_cache expects 3 outputs (top[0]=attn, top[1]=key_cache, top[2]=value_cache)
if ((int)top_blobs.size() < 3)
return -100;

out_key.create(d, past_seqlen + cur_seqlen, num_group, cur_key.elemsize, 1, opt.blob_vkallocator);
if (out_key.empty()) return -100;

out_value.create(dv, past_seqlen + cur_seqlen, num_group, cur_value.elemsize, 1, opt.blob_vkallocator);
if (out_value.empty()) return -100;

std::vector<VkMat> bindings_kv(6);
bindings_kv[0] = past_key;
bindings_kv[1] = past_value;
bindings_kv[2] = cur_key;
bindings_kv[3] = cur_value;
bindings_kv[4] = out_key;
bindings_kv[5] = out_value;

std::vector<vk_constant_type> constants_kv(11);
constants_kv[0].i = d;
constants_kv[1].i = dv;
constants_kv[2].i = past_seqlen;
constants_kv[3].i = cur_seqlen;
constants_kv[4].i = num_group;
constants_kv[5].i = past_key.cstep;
constants_kv[6].i = past_value.cstep;
constants_kv[7].i = cur_key.cstep;
constants_kv[8].i = cur_value.cstep;
constants_kv[9].i = out_key.cstep;
constants_kv[10].i = out_value.cstep;

VkMat dispatcher_kv;
const int dst_seqlen = past_seqlen + cur_seqlen;
const int maxw = d > dv ? d : dv;
dispatcher_kv.w = maxw;
dispatcher_kv.h = dst_seqlen;
dispatcher_kv.c = num_group;

cmd.record_pipeline(pipeline_sdpa_kv_concat, bindings_kv, constants_kv, dispatcher_kv);

key = out_key;
value = out_value;
}
else if (kv_cache)
{
// kv_cache enabled but no past: CPU behavior is to output current as cache
if ((int)top_blobs.size() < 3)
return -100;

top_blobs[1] = cur_key;
top_blobs[2] = cur_value;
}

// ---- Main SDPA path ----
const int src_seqlen = query.h;
const int num_heads = query.c;
const int dst_seqlen = key.h;

if (src_seqlen <= 0 || dst_seqlen <= 0 || num_heads <= 0)
return -100;

int num_heads_per_group = 1;
if (key.c > 0 && num_heads % key.c == 0)
num_heads_per_group = num_heads / key.c;

VkMat& top_blob = top_blobs[0];
top_blob.create(value.w, src_seqlen, num_heads, query.elemsize, 1, opt.blob_vkallocator);
if (top_blob.empty()) return -100;

// Mask info (keep your existing logic)
int mask_dims = 0;
int mask_w = 0;
int mask_c = 0;
int mask_cstep = 0;
if (!mask.empty())
{
mask_dims = mask.dims;
mask_w = mask.w;
mask_c = mask.c;
mask_cstep = mask.cstep;
if (mask_dims != 2 && mask_dims != 3)
{
mask_dims = 0;
mask_w = 0;
mask_c = 0;
mask_cstep = 0;
}
}

float final_scale = this->scale;
if (final_scale == 0.f)
final_scale = 1.0f / std::sqrt((float)d);

// Strides (keep your existing usage)
const int qw = query.w;
const int kw = key.w;
const int vw = value.w;
const int ow = top_blob.w;

std::vector<VkMat> bindings(5);
bindings[0] = query;
bindings[1] = key;
bindings[2] = value;
bindings[3] = mask;
bindings[4] = top_blob;

std::vector<vk_constant_type> constants(18);
constants[0].i = d;
constants[1].i = value.w; // dv
constants[2].i = src_seqlen;
constants[3].i = dst_seqlen;
constants[4].i = num_heads_per_group;
constants[5].i = query.cstep;
constants[6].i = key.cstep;
constants[7].i = value.cstep;
constants[8].i = top_blob.cstep;
constants[9].i = mask_dims;
constants[10].i = mask_w;
constants[11].i = mask_c;
constants[12].i = mask_cstep;
constants[13].f = final_scale;
constants[14].i = qw;
constants[15].i = kw;
constants[16].i = vw;
constants[17].i = ow;

VkMat dispatcher;
const int tiles_q = (src_seqlen + 16 - 1) / 16;
sdpa_make_dispatcher(dispatcher, tiles_q, num_heads);

cmd.record_pipeline(pipeline_sdpa, bindings, constants, dispatcher);

// If we concatenated, make sure outputs[1/2] are set to the concatenated cache
if (kv_cache && past_seqlen > 0)
{
top_blobs[1] = key;
top_blobs[2] = value;
}

return 0;
}

} // namespace ncnn
30 changes: 30 additions & 0 deletions src/layer/vulkan/sdpa_vulkan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2026 Futz12 <pchar.cn>
// SPDX-License-Identifier: BSD-3-Clause

#ifndef LAYER_SDPA_VULKAN_H
#define LAYER_SDPA_VULKAN_H

#include "sdpa.h"

namespace ncnn {

class SDPA_vulkan : public SDPA
{
public:
SDPA_vulkan();

virtual int create_pipeline(const Option& opt);
virtual int destroy_pipeline(const Option& opt);
virtual int load_param(const ParamDict& pd);

using SDPA::forward;
virtual int forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const;

public:
Pipeline* pipeline_sdpa;
Pipeline* pipeline_sdpa_kv_concat;
};

} // namespace ncnn

#endif // LAYER_SDPA_VULKAN_H
Loading