diff --git a/src/layer/vulkan/sdpa_vulkan.cpp b/src/layer/vulkan/sdpa_vulkan.cpp new file mode 100644 index 000000000000..9a0ec491f51f --- /dev/null +++ b/src/layer/vulkan/sdpa_vulkan.cpp @@ -0,0 +1,289 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#include "sdpa_vulkan.h" +#include "layer_shader_type.h" +#include // for sqrt + +namespace ncnn { + +SDPA_vulkan::SDPA_vulkan() +{ + support_vulkan = true; + support_vulkan_packing = false; + support_vulkan_any_packing = false; + + pipeline_sdpa = 0; + pipeline_sdpa_kv_concat = 0; +} + +int SDPA_vulkan::load_param(const ParamDict& pd) +{ + int ret = SDPA::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + } + + return ret; +} + +int SDPA_vulkan::create_pipeline(const Option& opt) +{ + const Mat& qshape = bottom_shapes.empty() ? Mat() : bottom_shapes[0]; + const Mat& vshape = bottom_shapes.size() > 2 ? bottom_shapes[2] : Mat(); + + int head_dim = 0; + int out_head_dim = 0; + + if (qshape.dims == 3) head_dim = qshape.w; + if (vshape.dims == 3) out_head_dim = vshape.w; + + // SDPA Pipeline + // Spec constants: 0=head_dim, 1=out_head_dim. + // Scale removed from spec constants as it is passed dynamically via push constants. + std::vector spec_sdpa(2); + spec_sdpa[0].i = head_dim; + spec_sdpa[1].i = out_head_dim; + + pipeline_sdpa = new Pipeline(vkdev); + pipeline_sdpa->set_local_size_xyz(256, 1, 1); + pipeline_sdpa->create(LayerShaderType::sdpa, opt, spec_sdpa); + + // KV Concat Pipeline + std::vector spec_kv(2); + spec_kv[0].i = head_dim; + spec_kv[1].i = out_head_dim; + + pipeline_sdpa_kv_concat = new Pipeline(vkdev); + pipeline_sdpa_kv_concat->set_local_size_xyz(64, 1, 1); + pipeline_sdpa_kv_concat->create(LayerShaderType::sdpa_kv_concat, opt, spec_kv); + + return 0; +} + +int SDPA_vulkan::destroy_pipeline(const Option& /*opt*/) +{ + delete pipeline_sdpa; + pipeline_sdpa = 0; + + delete pipeline_sdpa_kv_concat; + pipeline_sdpa_kv_concat = 0; + + return 0; +} + +static int sdpa_make_dispatcher(VkMat& dispatcher, int tiles_q, int heads) +{ + dispatcher.w = tiles_q * 256; + dispatcher.h = heads; + dispatcher.c = 1; + return 0; +} + +// sdpa_vulkan.cpp + +int SDPA_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + if (bottom_blobs.size() < 3 || top_blobs.empty()) + return -100; + + // 0 query + // 1 cur_key + // 2 cur_value + // 3 mask (optional, if attn_mask=1) + // 3/4 past_key/value (optional, if kv_cache=1, and depends on attn_mask) + + const VkMat& query = bottom_blobs[0]; + const VkMat& cur_key = bottom_blobs[1]; + const VkMat& cur_value = bottom_blobs[2]; + + // Mask (only valid if attn_mask flag is set) + VkMat mask; + if (attn_mask) + { + if ((int)bottom_blobs.size() < 4) + return -100; + mask = bottom_blobs[3]; + } + + // Past KV (only valid if kv_cache flag is set) + VkMat past_key; + VkMat past_value; + if (kv_cache) + { + const int pk_index = attn_mask ? 4 : 3; + const int pv_index = attn_mask ? 5 : 4; + + if ((int)bottom_blobs.size() <= pv_index) + return -100; + + past_key = bottom_blobs[pk_index]; + past_value = bottom_blobs[pv_index]; + } + + VkMat key = cur_key; + VkMat value = cur_value; + + // ---- KV cache concat path (only when kv_cache=1 and have non-empty past with seqlen>0) ---- + const int d = query.w; + const int dv = cur_value.w; + + if (d <= 0 || dv <= 0 || query.h <= 0 || query.c <= 0) + return -100; + + // Only concat if past has actual length + const int past_seqlen = (kv_cache && !past_key.empty()) ? past_key.h : 0; + const int cur_seqlen = cur_key.h; + + if (kv_cache && past_seqlen > 0) + { + const int num_group = cur_key.c; // expected groups for K/V + + VkMat& out_key = top_blobs.size() >= 2 ? top_blobs[1] : *(VkMat*)0; + VkMat& out_value = top_blobs.size() >= 3 ? top_blobs[2] : *(VkMat*)0; + + // kv_cache expects 3 outputs (top[0]=attn, top[1]=key_cache, top[2]=value_cache) + if ((int)top_blobs.size() < 3) + return -100; + + out_key.create(d, past_seqlen + cur_seqlen, num_group, cur_key.elemsize, 1, opt.blob_vkallocator); + if (out_key.empty()) return -100; + + out_value.create(dv, past_seqlen + cur_seqlen, num_group, cur_value.elemsize, 1, opt.blob_vkallocator); + if (out_value.empty()) return -100; + + std::vector bindings_kv(6); + bindings_kv[0] = past_key; + bindings_kv[1] = past_value; + bindings_kv[2] = cur_key; + bindings_kv[3] = cur_value; + bindings_kv[4] = out_key; + bindings_kv[5] = out_value; + + std::vector constants_kv(11); + constants_kv[0].i = d; + constants_kv[1].i = dv; + constants_kv[2].i = past_seqlen; + constants_kv[3].i = cur_seqlen; + constants_kv[4].i = num_group; + constants_kv[5].i = past_key.cstep; + constants_kv[6].i = past_value.cstep; + constants_kv[7].i = cur_key.cstep; + constants_kv[8].i = cur_value.cstep; + constants_kv[9].i = out_key.cstep; + constants_kv[10].i = out_value.cstep; + + VkMat dispatcher_kv; + const int dst_seqlen = past_seqlen + cur_seqlen; + const int maxw = d > dv ? d : dv; + dispatcher_kv.w = maxw; + dispatcher_kv.h = dst_seqlen; + dispatcher_kv.c = num_group; + + cmd.record_pipeline(pipeline_sdpa_kv_concat, bindings_kv, constants_kv, dispatcher_kv); + + key = out_key; + value = out_value; + } + else if (kv_cache) + { + // kv_cache enabled but no past: CPU behavior is to output current as cache + if ((int)top_blobs.size() < 3) + return -100; + + top_blobs[1] = cur_key; + top_blobs[2] = cur_value; + } + + // ---- Main SDPA path ---- + const int src_seqlen = query.h; + const int num_heads = query.c; + const int dst_seqlen = key.h; + + if (src_seqlen <= 0 || dst_seqlen <= 0 || num_heads <= 0) + return -100; + + int num_heads_per_group = 1; + if (key.c > 0 && num_heads % key.c == 0) + num_heads_per_group = num_heads / key.c; + + VkMat& top_blob = top_blobs[0]; + top_blob.create(value.w, src_seqlen, num_heads, query.elemsize, 1, opt.blob_vkallocator); + if (top_blob.empty()) return -100; + + // Mask info (keep your existing logic) + int mask_dims = 0; + int mask_w = 0; + int mask_c = 0; + int mask_cstep = 0; + if (!mask.empty()) + { + mask_dims = mask.dims; + mask_w = mask.w; + mask_c = mask.c; + mask_cstep = mask.cstep; + if (mask_dims != 2 && mask_dims != 3) + { + mask_dims = 0; + mask_w = 0; + mask_c = 0; + mask_cstep = 0; + } + } + + float final_scale = this->scale; + if (final_scale == 0.f) + final_scale = 1.0f / std::sqrt((float)d); + + // Strides (keep your existing usage) + const int qw = query.w; + const int kw = key.w; + const int vw = value.w; + const int ow = top_blob.w; + + std::vector bindings(5); + bindings[0] = query; + bindings[1] = key; + bindings[2] = value; + bindings[3] = mask; + bindings[4] = top_blob; + + std::vector constants(18); + constants[0].i = d; + constants[1].i = value.w; // dv + constants[2].i = src_seqlen; + constants[3].i = dst_seqlen; + constants[4].i = num_heads_per_group; + constants[5].i = query.cstep; + constants[6].i = key.cstep; + constants[7].i = value.cstep; + constants[8].i = top_blob.cstep; + constants[9].i = mask_dims; + constants[10].i = mask_w; + constants[11].i = mask_c; + constants[12].i = mask_cstep; + constants[13].f = final_scale; + constants[14].i = qw; + constants[15].i = kw; + constants[16].i = vw; + constants[17].i = ow; + + VkMat dispatcher; + const int tiles_q = (src_seqlen + 16 - 1) / 16; + sdpa_make_dispatcher(dispatcher, tiles_q, num_heads); + + cmd.record_pipeline(pipeline_sdpa, bindings, constants, dispatcher); + + // If we concatenated, make sure outputs[1/2] are set to the concatenated cache + if (kv_cache && past_seqlen > 0) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; +} + +} // namespace ncnn \ No newline at end of file diff --git a/src/layer/vulkan/sdpa_vulkan.h b/src/layer/vulkan/sdpa_vulkan.h new file mode 100644 index 000000000000..b818d72444bd --- /dev/null +++ b/src/layer/vulkan/sdpa_vulkan.h @@ -0,0 +1,30 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_SDPA_VULKAN_H +#define LAYER_SDPA_VULKAN_H + +#include "sdpa.h" + +namespace ncnn { + +class SDPA_vulkan : public SDPA +{ +public: + SDPA_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + virtual int load_param(const ParamDict& pd); + + using SDPA::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + +public: + Pipeline* pipeline_sdpa; + Pipeline* pipeline_sdpa_kv_concat; +}; + +} // namespace ncnn + +#endif // LAYER_SDPA_VULKAN_H \ No newline at end of file diff --git a/src/layer/vulkan/shader/sdpa.comp b/src/layer/vulkan/shader/sdpa.comp new file mode 100644 index 000000000000..1ad7f1616d92 --- /dev/null +++ b/src/layer/vulkan/shader/sdpa.comp @@ -0,0 +1,309 @@ +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(constant_id = 0) const int head_dim = 0; +layout(constant_id = 1) const int out_head_dim = 0; + +layout(binding = 0) readonly buffer query_blob { sfp query_data[]; }; +layout(binding = 1) readonly buffer key_blob { sfp key_data[]; }; +layout(binding = 2) readonly buffer value_blob { sfp value_data[]; }; +layout(binding = 3) readonly buffer mask_blob { sfp mask_data[]; }; +layout(binding = 4) writeonly buffer top_blob { sfp top_data[]; }; + +layout(push_constant) uniform parameter +{ + int head_dim; + int out_head_dim; + int src_seqlen; + int dst_seqlen; + int num_heads_per_group; + int q_cstep; + int k_cstep; + int v_cstep; + int o_cstep; + int mask_dims; + int mask_w; + int mask_c; + int mask_cstep; + float qk_scale; // Pre-calculated scale + int qw; // Pre-calculated strides + int kw; + int vw; + int ow; +} p; + +#define br 16 +#define bc 16 +#define max_d 128 +#define max_dv 128 + +shared afp s_q[br][max_d]; +shared afp s_k[bc][max_d]; +shared afp s_v[bc][max_dv]; +shared afp s_o[br][max_dv]; +shared afp s_w[bc][br]; // softmax weight +shared afp s_row_max[br]; +shared afp s_row_sum[br]; + +void main() +{ + const int lid = int(gl_LocalInvocationIndex); + const int block_q = int(gl_WorkGroupID.x); + const int head = int(gl_WorkGroupID.y); + + const int d = psc(head_dim); + const int dv = psc(out_head_dim); + + if (d <= 0 || dv <= 0 || d > max_d || dv > max_dv) + return; + + const int qw = p.qw; + const int kw = p.kw; + const int vw = p.vw; + const int ow = p.ow; + const afp qk_scale = afp(p.qk_scale); + + const int q_base = block_q * br; + const int group = head / p.num_heads_per_group; + + // 16x16 workgroup + const int lx = lid & 15; + const int ly = lid >> 4; + + // load Q + for (int idx = lid; idx < br * d; idx += 256) + { + const int y = idx / d; + const int x = idx - y * d; + const int qi = q_base + y; + + afp v = 0.f; + if (qi < p.src_seqlen) + { + const int off = head * p.q_cstep + qi * qw + x; + v = buffer_ld1(query_data, off); + } + s_q[y][x] = v; + } + + // init O + for (int idx = lid; idx < br * dv; idx += 256) + { + const int y = idx / dv; + const int x = idx - y * dv; + s_o[y][x] = 0.f; + } + + // init row_max + if (lid < br) + s_row_max[lid] = -3.402823466e+38f; + + barrier(); + + const int kv_iters = (p.dst_seqlen + bc - 1) / bc; + + for (int it = 0; it < kv_iters; it++) + { + const int k_base = it * bc; + + // load K + for (int idx = lid; idx < bc * d; idx += 256) + { + const int y = idx / d; + const int x = idx - y * d; + const int kj = k_base + y; + + afp v = 0.f; + if (kj < p.dst_seqlen) + { + const int off = group * p.k_cstep + kj * kw + x; + v = buffer_ld1(key_data, off); + } + s_k[y][x] = v; + } + + barrier(); + + afp qk = -3.402823466e+38f; + const int qi = q_base + ly; + const int kj = k_base + lx; + + if (ly < br && lx < bc && qi < p.src_seqlen && kj < p.dst_seqlen) + { + afp acc = 0.f; + for (int k = 0; k < d; k++) + acc += s_q[ly][k] * s_k[lx][k]; + + // Use pre-calculated qk_scale + afp v = acc * qk_scale; + + if (p.mask_dims != 0) + { + const int mh = (p.mask_c > 1) ? head : 0; + afp mv = 0.f; + + if (p.mask_dims == 2) + { + const int moff = qi * p.mask_w + kj; + mv = buffer_ld1(mask_data, moff); + } + else if (p.mask_dims == 3) + { + const int moff = mh * p.mask_cstep + qi * p.mask_w + kj; + mv = buffer_ld1(mask_data, moff); + } + + v += mv; + } + + qk = v; + } + + if (ly < br && lx < bc) + s_w[lx][ly] = qk; + + barrier(); + + // reduce tile max + if (lx == 0 && ly < br) + { + afp tile_max = -3.402823466e+38f; + for (int x = 0; x < bc; x++) + tile_max = max(tile_max, s_w[x][ly]); + + s_row_max[ly] = max(s_row_max[ly], tile_max); + } + + barrier(); + } + + if (lid < br) + s_row_sum[lid] = 0.f; + + barrier(); + + // Re-compute weights based on global max + for (int it = 0; it < kv_iters; it++) + { + const int k_base = it * bc; + + // load K (again) - can't avoid easily without large shared mem + for (int idx = lid; idx < bc * d; idx += 256) + { + const int y = idx / d; + const int x = idx - y * d; + const int kj = k_base + y; + + afp v = 0.f; + if (kj < p.dst_seqlen) + { + const int off = group * p.k_cstep + kj * kw + x; + v = buffer_ld1(key_data, off); + } + s_k[y][x] = v; + } + + // load V + for (int idx = lid; idx < bc * dv; idx += 256) + { + const int y = idx / dv; + const int x = idx - y * dv; + const int vj = k_base + y; + + afp v = 0.f; + if (vj < p.dst_seqlen) + { + const int off = group * p.v_cstep + vj * vw + x; + v = buffer_ld1(value_data, off); + } + s_v[y][x] = v; + } + + barrier(); + + // compute weight = exp(qk - row_max) + afp w = 0.f; + const int qi = q_base + ly; + const int kj = k_base + lx; + + if (ly < br && lx < bc && qi < p.src_seqlen && kj < p.dst_seqlen) + { + afp acc = 0.f; + for (int k = 0; k < d; k++) + acc += s_q[ly][k] * s_k[lx][k]; + + afp v = acc * qk_scale; + + if (p.mask_dims != 0) + { + const int mh = (p.mask_c > 1) ? head : 0; + afp mv = 0.f; + + if (p.mask_dims == 2) + { + const int moff = qi * p.mask_w + kj; + mv = buffer_ld1(mask_data, moff); + } + else if (p.mask_dims == 3) + { + const int moff = mh * p.mask_cstep + qi * p.mask_w + kj; + mv = buffer_ld1(mask_data, moff); + } + v += mv; + } + + w = exp(v - s_row_max[ly]); + } + + if (ly < br && lx < bc) + s_w[lx][ly] = w; + + barrier(); + + // reduce row_sum + if (lx == 0 && ly < br) + { + afp tile_sum = 0.f; + for (int x = 0; x < bc; x++) + tile_sum += s_w[x][ly]; + + s_row_sum[ly] += tile_sum; + } + + barrier(); + + // accumulate O + for (int idx = lid; idx < br * dv; idx += 256) + { + const int y = idx / dv; + const int x = idx - y * dv; + + afp addv = 0.f; + for (int t = 0; t < bc; t++) + addv += s_w[t][y] * s_v[t][x]; + + s_o[y][x] += addv; + } + + barrier(); + } + + // store O + for (int idx = lid; idx < br * dv; idx += 256) + { + const int y = idx / dv; + const int x = idx - y * dv; + const int qi = q_base + y; + + if (qi >= p.src_seqlen) + continue; + + const afp sum = s_row_sum[y]; + const afp ov = (sum > 0.f) ? (s_o[y][x] / sum) : 0.f; + + const int out_off = head * p.o_cstep + qi * ow + x; + buffer_st1(top_data, out_off, afp(ov)); + } +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/sdpa_kv_concat.comp b/src/layer/vulkan/shader/sdpa_kv_concat.comp new file mode 100644 index 000000000000..9ac42a7caa2d --- /dev/null +++ b/src/layer/vulkan/shader/sdpa_kv_concat.comp @@ -0,0 +1,81 @@ +// sdpa_kv_concat.comp +// Copyright 2026 Futz12 +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(constant_id = 0) const int head_dim = 0; +layout(constant_id = 1) const int out_head_dim = 0; + +layout(binding = 0) readonly buffer past_key_blob { sfp past_key_data[]; }; +layout(binding = 1) readonly buffer past_value_blob { sfp past_value_data[]; }; +layout(binding = 2) readonly buffer cur_key_blob { sfp cur_key_data[]; }; +layout(binding = 3) readonly buffer cur_value_blob { sfp cur_value_data[]; }; +layout(binding = 4) writeonly buffer out_key_blob { sfp out_key_data[]; }; +layout(binding = 5) writeonly buffer out_value_blob { sfp out_value_data[]; }; + +layout(push_constant) uniform parameter +{ + int head_dim; + int out_head_dim; + int past_seqlen; + int cur_seqlen; + int num_group; + int pastk_cstep; + int pastv_cstep; + int curk_cstep; + int curv_cstep; + int outk_cstep; + int outv_cstep; +} p; + +void main() +{ + const int gx = int(gl_GlobalInvocationID.x); + const int gy = int(gl_GlobalInvocationID.y); + const int gz = int(gl_GlobalInvocationID.z); + + const int d = psc(head_dim); + const int dv = psc(out_head_dim); + const int dst_seqlen = p.past_seqlen + p.cur_seqlen; + + if (gz >= p.num_group || gy >= dst_seqlen) + return; + + const int maxw = d > dv ? d : dv; + if (gx >= maxw) + return; + + const bool from_past = gy < p.past_seqlen; + const int sy = from_past ? gy : (gy - p.past_seqlen); + + if (gx < d) + { + const int src_off = from_past ? (gz * p.pastk_cstep + sy * d + gx) : (gz * p.curk_cstep + sy * d + gx); + const int dst_off = gz * p.outk_cstep + gy * d + gx; + + if (from_past) + { + buffer_cp1(out_key_data, dst_off, past_key_data, src_off); + } + else + { + buffer_cp1(out_key_data, dst_off, cur_key_data, src_off); + } + } + + if (gx < dv) + { + const int src_off = from_past ? (gz * p.pastv_cstep + sy * dv + gx) : (gz * p.curv_cstep + sy * dv + gx); + const int dst_off = gz * p.outv_cstep + gy * dv + gx; + + if (from_past) + { + buffer_cp1(out_value_data, dst_off, past_value_data, src_off); + } + else + { + buffer_cp1(out_value_data, dst_off, cur_value_data, src_off); + } + } +}