Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_get_rel_pos_f32, pipeline_get_rel_pos_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
Expand Down Expand Up @@ -3933,6 +3934,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}

ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f32, "get_rel_pos_f32", get_rel_pos_f32_len, get_rel_pos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f16, "get_rel_pos_f16", get_rel_pos_f16_len, get_rel_pos_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, { 512 }, 1);

for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
if (i <= device->max_workgroup_size_log2 &&
Expand Down Expand Up @@ -10019,6 +10023,32 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
}

static void ggml_vk_get_rel_pos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_pipeline pipeline = nullptr;
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->device->pipeline_get_rel_pos_f32; break;
case GGML_TYPE_F16: pipeline = ctx->device->pipeline_get_rel_pos_f16; break;
default: GGML_ABORT("fatal error");
}
GGML_ASSERT(pipeline != nullptr);

vk_op_unary_push_constants pc = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
init_pushconst_fastdiv(pc);

std::array<uint32_t, 3> elements;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use ggml_vk_op_f32?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because incontiguous input is not supported. The GGML_ASSERT check in ggml_vk_op_f32 will fail.
I think the behavior should be similar to GGML_OP_ARGSORT (while I'm not sure why GGML_OP_ARGSORT appears in ggml_vk_op_f32 but comes with a GGML_ASSERT(0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ARGSORT used to use ggml_vk_op_f32, it just switched away from it because it needs custom logic and shader invocations for the handling of large input tensors.

I think you should be able to use ggml_vk_op_f32 for this op. Can you be more specific which assertion failed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @0cc4m, thanks for the clarification.

Previously I encountered an assertion failure for the case GET_REL_POS(type=f32, C=1, qh=1, kh=1, v=1) in ggml_vk_op_f32, specifically on this line:

ggml/src/ggml-vulkan/ggml-vulkan.cpp:8764: GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)) failed

I reviewed the whole process and realized that ggml_is_contiguous returned true here (src0 ne={1, 1, 1, 1}, viewed from ne={2, 1, 1, 1}). Therefore, ggml_backend_vk_device_supports_op(...) returned true, and then the check ggml_vk_dim01_contiguous(src0) later in ggml_vk_op_f32 failed.

I have updated ggml_vk_get_rel_pos to use ggml_vk_op_f32, and ggml_vk_dim01_contiguous has been added to ggml_backend_vk_device_supports_op so GET_REL_POS(type=f32,C=1,qh=1,kh=1,v=1) is no longer considered supported.

To work with incontiguous inputs (v=1), we can simply use ggml_cont.

uint32_t ne = ggml_nelements(dst);
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { ggml_vk_tensor_subbuffer(ctx, src0), ggml_vk_tensor_subbuffer(ctx, dst) }, pc, elements);
}

static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t * op_params = (const uint32_t *)dst->op_params;

Expand Down Expand Up @@ -11488,6 +11518,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ROPE_BACK:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_GET_REL_POS:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
Expand Down Expand Up @@ -11817,6 +11848,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ROPE_BACK:
ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);

break;
case GGML_OP_GET_REL_POS:
ggml_vk_get_rel_pos(ctx, compute_ctx, src0, node);

break;
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
Expand Down Expand Up @@ -12006,6 +12041,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_GET_REL_POS:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
Expand Down Expand Up @@ -13964,6 +14000,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_GET_REL_POS:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ARGSORT:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/get_rel_pos.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}

const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;

const float kh = float(p.ne11);
const float qh = float(p.ne12);
const float k_scale = max(qh / kh, 1.0f);
const float q_scale = max(kh / qh, 1.0f);

// Add a small epsilon to avoid floating point precision issues
const float epsilon = 0.0001f;
const int pos = int(float(i12) * q_scale - float(i11) * k_scale + (kh - 1.0f) * k_scale + epsilon);

const uint src_idx = pos*p.nb01 + i10*p.nb00;
const uint dst_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;

data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + src_idx]);
}

3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ void process_shaders() {
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});

string_to_spv("get_rel_pos_f32", "get_rel_pos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rel_pos_f16", "get_rel_pos.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});

Expand Down