Skip to content

Commit 4d75c05

Browse files
committed
wip
1 parent 9613d65 commit 4d75c05

File tree

9 files changed

+199
-28
lines changed

9 files changed

+199
-28
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7888,7 +7888,7 @@ static void ggml_compute_forward_top_k_f32(
78887888

78897889
const int64_t nr = ggml_nrows(src0);
78907890

7891-
const int k = ggml_get_op_params_i32(dst, 0);
7891+
const int top_k = ne0;
78927892

78937893
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
78947894

@@ -7899,11 +7899,11 @@ static void ggml_compute_forward_top_k_f32(
78997899
tmp[j] = j;
79007900
}
79017901

7902-
std::partial_sort(tmp, tmp + k, tmp + ne00, cmp_top_k{src_data});
7902+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
79037903

79047904
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
79057905

7906-
std::copy(tmp, tmp + k, dst_data);
7906+
std::copy(tmp, tmp + top_k, dst_data);
79077907
}
79087908
}
79097909

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
905905
case GGML_OP_LEAKY_RELU:
906906
return op->src[0]->type == GGML_TYPE_F32;
907907
case GGML_OP_ARGSORT:
908+
case GGML_OP_TOP_K:
908909
case GGML_OP_ARANGE:
909910
return true;
910911
case GGML_OP_FLASH_ATTN_EXT:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ typedef struct {
840840
uint64_t nb01;
841841
uint64_t nb02;
842842
uint64_t nb03;
843+
int32_t top_k;
843844
} ggml_metal_kargs_argsort;
844845

845846
typedef struct {
@@ -851,6 +852,7 @@ typedef struct {
851852
uint64_t nb01;
852853
uint64_t nb02;
853854
uint64_t nb03;
855+
int32_t top_k;
854856
int32_t len;
855857
} ggml_metal_kargs_argsort_merge;
856858

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
405405
{
406406
n_fuse = ggml_metal_op_argsort(ctx, idx);
407407
} break;
408+
case GGML_OP_TOP_K:
409+
{
410+
n_fuse = ggml_metal_op_top_k(ctx, idx);
411+
} break;
408412
case GGML_OP_LEAKY_RELU:
409413
{
410414
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
@@ -3677,14 +3681,15 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
36773681
}
36783682

36793683
ggml_metal_kargs_argsort args = {
3680-
/*.ne00 =*/ ne00,
3681-
/*.ne01 =*/ ne01,
3682-
/*.ne02 =*/ ne02,
3683-
/*.ne03 =*/ ne03,
3684-
/*.nb00 =*/ nb00,
3685-
/*.nb01 =*/ nb01,
3686-
/*.nb02 =*/ nb02,
3687-
/*.nb03 =*/ nb03,
3684+
/*.ne00 =*/ ne00,
3685+
/*.ne01 =*/ ne01,
3686+
/*.ne02 =*/ ne02,
3687+
/*.ne03 =*/ ne03,
3688+
/*.nb00 =*/ nb00,
3689+
/*.nb01 =*/ nb01,
3690+
/*.nb02 =*/ nb02,
3691+
/*.nb03 =*/ nb03,
3692+
/*.top_k =*/ ne00,
36883693
};
36893694

36903695
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -3704,15 +3709,117 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37043709
ggml_metal_op_concurrency_reset(ctx);
37053710

37063711
ggml_metal_kargs_argsort_merge args_merge = {
3707-
.ne00 = ne00,
3708-
.ne01 = ne01,
3709-
.ne02 = ne02,
3710-
.ne03 = ne03,
3711-
.nb00 = nb00,
3712-
.nb01 = nb01,
3713-
.nb02 = nb02,
3714-
.nb03 = nb03,
3715-
.len = len,
3712+
/*.ne00 =*/ ne00,
3713+
/*.ne01 =*/ ne01,
3714+
/*.ne02 =*/ ne02,
3715+
/*.ne03 =*/ ne03,
3716+
/*.nb00 =*/ nb00,
3717+
/*.nb01 =*/ nb01,
3718+
/*.nb02 =*/ nb02,
3719+
/*.nb03 =*/ nb03,
3720+
/*.top_k =*/ ne00,
3721+
/*.len =*/ len,
3722+
};
3723+
3724+
// merges per row
3725+
const int nm = (ne00 + 2*len - 1) / (2*len);
3726+
3727+
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3728+
3729+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3730+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3731+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3732+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3733+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3734+
3735+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3736+
3737+
std::swap(bid_dst, bid_tmp);
3738+
3739+
len <<= 1;
3740+
}
3741+
3742+
return 1;
3743+
}
3744+
3745+
int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3746+
ggml_tensor * op = ctx->node(idx);
3747+
3748+
ggml_metal_library_t lib = ctx->lib;
3749+
ggml_metal_encoder_t enc = ctx->enc;
3750+
3751+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3752+
3753+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3754+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3755+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3756+
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3757+
3758+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3759+
3760+
// bitonic sort requires the number of elements to be power of 2
3761+
int nth = 1;
3762+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3763+
nth *= 2;
3764+
}
3765+
3766+
const int npr = (ne00 + nth - 1)/nth;
3767+
3768+
// Metal kernels require the buffer size to be multiple of 16 bytes
3769+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3770+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3771+
3772+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3773+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3774+
3775+
ggml_metal_buffer_id bid_tmp = bid_dst;
3776+
bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3777+
3778+
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3779+
std::swap(bid_dst, bid_tmp);
3780+
}
3781+
3782+
const int top_k = ne0;
3783+
3784+
ggml_metal_kargs_argsort args = {
3785+
/*.ne00 =*/ ne00,
3786+
/*.ne01 =*/ ne01,
3787+
/*.ne02 =*/ ne02,
3788+
/*.ne03 =*/ ne03,
3789+
/*.nb00 =*/ nb00,
3790+
/*.nb01 =*/ nb01,
3791+
/*.nb02 =*/ nb02,
3792+
/*.nb03 =*/ nb03,
3793+
/*.top_k =*/ nth < ne00 ? ne00 : top_k,
3794+
};
3795+
3796+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3797+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3798+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3799+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3800+
3801+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3802+
3803+
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3804+
3805+
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3806+
3807+
int len = nth;
3808+
3809+
while (len < ne00) {
3810+
ggml_metal_op_concurrency_reset(ctx);
3811+
3812+
ggml_metal_kargs_argsort_merge args_merge = {
3813+
/*.ne00 =*/ ne00,
3814+
/*.ne01 =*/ ne01,
3815+
/*.ne02 =*/ ne02,
3816+
/*.ne03 =*/ ne03,
3817+
/*.nb00 =*/ nb00,
3818+
/*.nb01 =*/ nb01,
3819+
/*.nb02 =*/ nb02,
3820+
/*.nb03 =*/ nb03,
3821+
/*.top_k =*/ 2*len >= ne00 ? top_k : ne00,
3822+
/*.len =*/ len,
37163823
};
37173824

37183825
// merges per row

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
8181
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
8282
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
8383
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
84+
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
8485
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
8586
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
8687
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
202202
{
203203
res *= 2;
204204
} break;
205+
case GGML_OP_TOP_K:
206+
{
207+
res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
208+
} break;
205209
default:
206210
break;
207211
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,8 +4711,8 @@ kernel void kernel_argsort_f32_i32(
47114711
}
47124712

47134713
// copy the result to dst without the padding
4714-
if (i00 + col < args.ne00) {
4715-
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
4714+
if (i00 + col < args.ne00 && col < args.top_k) {
4715+
dst += i00 + args.top_k*i01 + args.top_k*args.ne01*i02 + args.top_k*args.ne01*args.ne02*i03;
47164716

47174717
dst[col] = shmem_i32[col];
47184718
}
@@ -4760,9 +4760,9 @@ kernel void kernel_argsort_merge_f32_i32(
47604760
device const int32_t * tmp1 = tmp0 + args.len;
47614761

47624762
dst += start
4763-
+ i01*args.ne00
4764-
+ i02*args.ne00*args.ne01
4765-
+ i03*args.ne00*args.ne01*args.ne02;
4763+
+ i01*args.top_k
4764+
+ i02*args.top_k*args.ne01
4765+
+ i03*args.top_k*args.ne01*args.ne02;
47664766

47674767
device const float * src0_row = (device const float *)(src0
47684768
+ args.nb01*i01
@@ -4827,7 +4827,7 @@ kernel void kernel_argsort_merge_f32_i32(
48274827
val1 = src0_row[idx1];
48284828
}
48294829

4830-
for (int k = k0; k < k1; ++k) {
4830+
for (int k = k0; k < k1 && k < args.top_k; ++k) {
48314831
int32_t out_idx;
48324832

48334833
if (i >= len0) {

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5158,7 +5158,8 @@ struct ggml_tensor * ggml_top_k(
51585158

51595159
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
51605160

5161-
ggml_set_op_params_i32(result, 0, (int32_t) k);
5161+
// TODO: tmp
5162+
ggml_set_op_params_i32(result, 0, (int32_t) GGML_SORT_ORDER_DESC);
51625163

51635164
result->op = GGML_OP_TOP_K;
51645165
result->src[0] = a;

tests/test-backend-ops.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4933,7 +4933,49 @@ struct test_argsort : public test_case {
49334933
}
49344934
};
49354935

4936-
struct test_topk_moe: public test_case {
4936+
// GGML_OP_TOP_K
4937+
struct test_top_k : public test_case {
4938+
const ggml_type type;
4939+
const std::array<int64_t, 4> ne;
4940+
const int k;
4941+
4942+
std::string vars() override {
4943+
return VARS_TO_STR3(type, ne, k);
4944+
}
4945+
4946+
test_top_k(ggml_type type = GGML_TYPE_F32,
4947+
std::array<int64_t, 4> ne = {16, 10, 10, 10},
4948+
int k = 4)
4949+
: type(type), ne(ne), k(k) {}
4950+
4951+
ggml_tensor * build_graph(ggml_context * ctx) override {
4952+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4953+
ggml_set_name(a, "a");
4954+
4955+
ggml_tensor * out = ggml_top_k(ctx, a, k);
4956+
ggml_set_name(out, "out");
4957+
4958+
return out;
4959+
}
4960+
4961+
void initialize_tensors(ggml_context * ctx) override {
4962+
std::random_device rd;
4963+
std::default_random_engine rng(rd());
4964+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4965+
// initialize with unique values to avoid ties
4966+
for (int64_t r = 0; r < ggml_nrows(t); r++) {
4967+
std::vector<float> data(t->ne[0]);
4968+
for (int i = 0; i < t->ne[0]; i++) {
4969+
data[i] = i;
4970+
}
4971+
std::shuffle(data.begin(), data.end(), rng);
4972+
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
4973+
}
4974+
}
4975+
}
4976+
};
4977+
4978+
struct test_topk_moe : public test_case {
49374979
const std::array<int64_t, 4> ne;
49384980
const int n_expert_used;
49394981
const bool with_norm;
@@ -7514,6 +7556,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75147556
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
75157557
}
75167558

7559+
for (int k : {1, 2, 3, 7, 15}) {
7560+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
7561+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
7562+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
7563+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
7564+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
7565+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
7566+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
7567+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
7568+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
7569+
}
7570+
75177571
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
75187572
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
75197573
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
@@ -7886,6 +7940,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
78867940
}
78877941

78887942
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
7943+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
78897944

78907945
return test_cases;
78917946
}

0 commit comments

Comments
 (0)