@@ -4933,7 +4933,49 @@ struct test_argsort : public test_case {
49334933 }
49344934};
49354935
4936- struct test_topk_moe : public test_case {
4936+ // GGML_OP_TOP_K
4937+ struct test_top_k : public test_case {
4938+ const ggml_type type;
4939+ const std::array<int64_t , 4 > ne;
4940+ const int k;
4941+
4942+ std::string vars () override {
4943+ return VARS_TO_STR3 (type, ne, k);
4944+ }
4945+
4946+ test_top_k (ggml_type type = GGML_TYPE_F32,
4947+ std::array<int64_t , 4 > ne = {16 , 10 , 10 , 10 },
4948+ int k = 4 )
4949+ : type(type), ne(ne), k(k) {}
4950+
4951+ ggml_tensor * build_graph (ggml_context * ctx) override {
4952+ ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
4953+ ggml_set_name (a, " a" );
4954+
4955+ ggml_tensor * out = ggml_top_k (ctx, a, k);
4956+ ggml_set_name (out, " out" );
4957+
4958+ return out;
4959+ }
4960+
4961+ void initialize_tensors (ggml_context * ctx) override {
4962+ std::random_device rd;
4963+ std::default_random_engine rng (rd ());
4964+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
4965+ // initialize with unique values to avoid ties
4966+ for (int64_t r = 0 ; r < ggml_nrows (t); r++) {
4967+ std::vector<float > data (t->ne [0 ]);
4968+ for (int i = 0 ; i < t->ne [0 ]; i++) {
4969+ data[i] = i;
4970+ }
4971+ std::shuffle (data.begin (), data.end (), rng);
4972+ ggml_backend_tensor_set (t, data.data (), r * t->nb [1 ], t->ne [0 ] * sizeof (float ));
4973+ }
4974+ }
4975+ }
4976+ };
4977+
4978+ struct test_topk_moe : public test_case {
49374979 const std::array<int64_t , 4 > ne;
49384980 const int n_expert_used;
49394981 const bool with_norm;
@@ -7514,6 +7556,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75147556 test_cases.emplace_back (new test_argsort (GGML_TYPE_F32, {2 , 8 , 8192 , 1 }, order)); // bailingmoe2 (group selection)
75157557 }
75167558
7559+ for (int k : {1 , 2 , 3 , 7 , 15 }) {
7560+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {16 , 10 , 10 , 10 }, k));
7561+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {60 , 10 , 10 , 10 }, k));
7562+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {1023 , 2 , 1 , 3 }, k));
7563+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {1024 , 2 , 1 , 3 }, k));
7564+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {1025 , 2 , 1 , 3 }, k));
7565+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {16384 , 1 , 1 , 1 }, k));
7566+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {2047 , 2 , 1 , 3 }, k));
7567+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {2048 , 2 , 1 , 3 }, k));
7568+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {2049 , 2 , 1 , 3 }, k));
7569+ }
7570+
75177571 for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
75187572 test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode));
75197573 test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode, true ));
@@ -7886,6 +7940,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
78867940 }
78877941
78887942 test_cases.emplace_back (new test_argsort (GGML_TYPE_F32, {65000 , 16 , 1 , 1 }));
7943+ test_cases.emplace_back (new test_top_k (GGML_TYPE_F32, {65000 , 16 , 1 , 1 }, 40 ));
78897944
78907945 return test_cases;
78917946}
0 commit comments