@@ -241,15 +241,18 @@ struct vk_device_struct {
241241 vk_pipeline pipeline_norm_f32;
242242 vk_pipeline pipeline_group_norm_f32;
243243 vk_pipeline pipeline_rms_norm_f32;
244+ vk_pipeline pipeline_rms_norm_back_f32;
244245 vk_pipeline pipeline_gelu_f32;
245246 vk_pipeline pipeline_gelu_quick_f32;
246247 vk_pipeline pipeline_silu_f32;
248+ vk_pipeline pipeline_silu_back_f32;
247249 vk_pipeline pipeline_relu_f32;
248250 vk_pipeline pipeline_leaky_relu_f32;
249251 vk_pipeline pipeline_tanh_f32;
250252 vk_pipeline pipeline_diag_mask_inf_f32;
251253 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252254 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
255+ vk_pipeline pipeline_soft_max_back_f32;
253256 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254257 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255258 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -504,6 +507,7 @@ struct vk_op_rope_push_constants {
504507 uint32_t s1;
505508 uint32_t s2;
506509 int32_t sections[4 ];
510+ uint32_t is_back;
507511};
508512
509513struct vk_op_soft_max_push_constants {
@@ -2121,6 +2125,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21212125 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
21222126 ggml_vk_create_pipeline (device, device->pipeline_group_norm_f32 , " group_norm_f32" , group_norm_f32_len, group_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
21232127 ggml_vk_create_pipeline (device, device->pipeline_rms_norm_f32 , " rms_norm_f32" , rms_norm_f32_len, rms_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
2128+ ggml_vk_create_pipeline (device, device->pipeline_rms_norm_back_f32 , " rms_norm_back_f32" , rms_norm_back_f32_len, rms_norm_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
21242129
21252130 ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_f32 , " cpy_f32_f32" , cpy_f32_f32_len, cpy_f32_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
21262131 ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_f16 , " cpy_f32_f16" , cpy_f32_f16_len, cpy_f32_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -2180,6 +2185,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21802185 ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21812186 ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21822187 ggml_vk_create_pipeline (device, device->pipeline_silu_f32 , " silu_f32" , silu_f32_len, silu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2188+ ggml_vk_create_pipeline (device, device->pipeline_silu_back_f32 , " silu_back_f32" , silu_back_f32_len, silu_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21832189 ggml_vk_create_pipeline (device, device->pipeline_relu_f32 , " relu_f32" , relu_f32_len, relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21842190 ggml_vk_create_pipeline (device, device->pipeline_leaky_relu_f32 , " leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21852191 ggml_vk_create_pipeline (device, device->pipeline_tanh_f32 , " tanh_f32" , tanh_f32_len, tanh_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -2190,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21902196 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_wg512 , " soft_max_f32_wg512" , soft_max_f32_len, soft_max_f32_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
21912197 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16 , " soft_max_f32_f16" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
21922198 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
2199+ ggml_vk_create_pipeline (device, device->pipeline_soft_max_back_f32 , " soft_max_back_f32" , soft_max_back_f32_len, soft_max_back_f32_data, " main" , 3 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
21932200
21942201 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
21952202 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
@@ -5283,6 +5290,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52835290 case GGML_OP_CONT:
52845291 case GGML_OP_DUP:
52855292 return ggml_vk_get_cpy_pipeline (ctx, src0, dst, dst->type );
5293+ case GGML_OP_SILU_BACK:
5294+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5295+ return ctx->device ->pipeline_silu_back_f32 ;
5296+ }
5297+ return nullptr ;
52865298 case GGML_OP_NORM:
52875299 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
52885300 return ctx->device ->pipeline_norm_f32 ;
@@ -5298,6 +5310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52985310 return ctx->device ->pipeline_rms_norm_f32 ;
52995311 }
53005312 return nullptr ;
5313+ case GGML_OP_RMS_NORM_BACK:
5314+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5315+ return ctx->device ->pipeline_rms_norm_back_f32 ;
5316+ }
5317+ return nullptr ;
53015318 case GGML_OP_UNARY:
53025319 switch (ggml_get_unary_op (dst)) {
53035320 case GGML_UNARY_OP_SILU:
@@ -5344,7 +5361,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53445361 return src0->ne [0 ] > 1024 ? ctx->device ->pipeline_soft_max_f32_f16_wg512 : ctx->device ->pipeline_soft_max_f32_f16 ;
53455362 }
53465363 return nullptr ;
5364+ case GGML_OP_SOFT_MAX_BACK:
5365+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5366+ return ctx->device ->pipeline_soft_max_back_f32 ;
5367+ }
5368+ return nullptr ;
53475369 case GGML_OP_ROPE:
5370+ case GGML_OP_ROPE_BACK:
53485371 {
53495372 const int mode = ((const int32_t *) dst->op_params )[2 ];
53505373 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5672,7 +5695,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56725695 switch (op) {
56735696 case GGML_OP_NORM:
56745697 case GGML_OP_RMS_NORM:
5698+ case GGML_OP_RMS_NORM_BACK:
56755699 case GGML_OP_SOFT_MAX:
5700+ case GGML_OP_SOFT_MAX_BACK:
56765701 case GGML_OP_SUM_ROWS:
56775702 case GGML_OP_ARGMAX:
56785703 {
@@ -5696,6 +5721,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56965721 } break ;
56975722 case GGML_OP_DIAG_MASK_INF:
56985723 case GGML_OP_ROPE:
5724+ case GGML_OP_ROPE_BACK:
56995725 elements = { (uint32_t )ggml_nrows (src0), (uint32_t )ne00, 1 };
57005726 break ;
57015727 case GGML_OP_GET_ROWS:
@@ -5791,7 +5817,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57915817
57925818 ggml_vk_sync_buffers (subctx);
57935819 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof (PC), &pc, elements);
5794- } else if (op == GGML_OP_ROPE) {
5820+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK ) {
57955821 // Empty src2 is possible in rope, but the shader needs a buffer
57965822 vk_subbuffer subbuf_z;
57975823 if (use_src2) {
@@ -6313,6 +6339,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
63136339 }, dryrun);
63146340}
63156341
6342+ static void ggml_vk_silu_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6343+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SILU_BACK, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6344+ }
6345+
63166346static void ggml_vk_norm (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
63176347 float * op_params = (float *)dst->op_params ;
63186348
@@ -6335,6 +6365,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
63356365 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_RMS_NORM, { (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], op_params[0 ], 0 .0f }, dryrun);
63366366}
63376367
6368+ static void ggml_vk_rms_norm_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6369+ float * op_params = (float *)dst->op_params ;
6370+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_RMS_NORM_BACK, { (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], op_params[0 ], 0 .0f }, dryrun);
6371+ }
6372+
63386373static void ggml_vk_unary (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
63396374 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_UNARY, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
63406375}
@@ -6370,7 +6405,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
63706405 }, dryrun);
63716406}
63726407
6373- static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false ) {
6408+ static void ggml_vk_soft_max_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6409+ float * op_params = (float *)dst->op_params ;
6410+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], op_params[0 ], op_params[1 ] }, dryrun);
6411+ }
6412+
6413+ static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false ) {
63746414 const int n_dims = ((int32_t *) dst->op_params )[1 ];
63756415 const int mode = ((int32_t *) dst->op_params )[2 ];
63766416 // const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6398,7 +6438,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
63986438 (uint32_t )src0->ne [0 ], (uint32_t )n_dims, freq_scale, (uint32_t )src0->ne [1 ],
63996439 freq_base, ext_factor, attn_factor, {corr_dims[0 ], corr_dims[1 ]}, theta_scale,
64006440 src2 != nullptr , (uint32_t )src0->ne [2 ], s1, s2,
6401- sections[0 ], sections[1 ], sections[2 ], sections[3 ],
6441+ sections[0 ], sections[1 ], sections[2 ], sections[3 ], backprop
64026442 }, dryrun);
64036443}
64046444
@@ -7319,12 +7359,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73197359 case GGML_OP_CPY:
73207360 case GGML_OP_CONT:
73217361 case GGML_OP_DUP:
7362+ case GGML_OP_SILU_BACK:
73227363 case GGML_OP_NORM:
73237364 case GGML_OP_GROUP_NORM:
73247365 case GGML_OP_RMS_NORM:
7366+ case GGML_OP_RMS_NORM_BACK:
73257367 case GGML_OP_DIAG_MASK_INF:
73267368 case GGML_OP_SOFT_MAX:
7369+ case GGML_OP_SOFT_MAX_BACK:
73277370 case GGML_OP_ROPE:
7371+ case GGML_OP_ROPE_BACK:
73287372 case GGML_OP_MUL_MAT:
73297373 case GGML_OP_MUL_MAT_ID:
73307374 case GGML_OP_ARGSORT:
@@ -7377,13 +7421,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73777421 case GGML_OP_CPY:
73787422 case GGML_OP_CONT:
73797423 case GGML_OP_DUP:
7424+ case GGML_OP_SILU_BACK:
73807425 case GGML_OP_NORM:
73817426 case GGML_OP_GROUP_NORM:
73827427 case GGML_OP_RMS_NORM:
7428+ case GGML_OP_RMS_NORM_BACK:
73837429 case GGML_OP_UNARY:
73847430 case GGML_OP_DIAG_MASK_INF:
73857431 case GGML_OP_SOFT_MAX:
7432+ case GGML_OP_SOFT_MAX_BACK:
73867433 case GGML_OP_ROPE:
7434+ case GGML_OP_ROPE_BACK:
73877435 case GGML_OP_ARGSORT:
73887436 case GGML_OP_SUM:
73897437 case GGML_OP_SUM_ROWS:
@@ -7475,6 +7523,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74757523 case GGML_OP_DUP:
74767524 ggml_vk_cpy (ctx, compute_ctx, src0, node, dryrun);
74777525
7526+ break ;
7527+ case GGML_OP_SILU_BACK:
7528+ ggml_vk_silu_back (ctx, compute_ctx, src0, src1, node, dryrun);
7529+
74787530 break ;
74797531 case GGML_OP_NORM:
74807532 ggml_vk_norm (ctx, compute_ctx, src0, node, dryrun);
@@ -7487,6 +7539,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74877539 case GGML_OP_RMS_NORM:
74887540 ggml_vk_rms_norm (ctx, compute_ctx, src0, node, dryrun);
74897541
7542+ break ;
7543+ case GGML_OP_RMS_NORM_BACK:
7544+ ggml_vk_rms_norm_back (ctx, compute_ctx, src0, src1, node, dryrun);
7545+
74907546 break ;
74917547 case GGML_OP_UNARY:
74927548 switch (ggml_get_unary_op (node)) {
@@ -7508,9 +7564,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
75087564 case GGML_OP_SOFT_MAX:
75097565 ggml_vk_soft_max (ctx, compute_ctx, src0, src1, node, dryrun);
75107566
7567+ break ;
7568+ case GGML_OP_SOFT_MAX_BACK:
7569+ ggml_vk_soft_max_back (ctx, compute_ctx, src0, src1, node, dryrun);
7570+
75117571 break ;
75127572 case GGML_OP_ROPE:
7513- ggml_vk_rope (ctx, compute_ctx, src0, src1, src2, node, dryrun);
7573+ ggml_vk_rope (ctx, compute_ctx, src0, src1, src2, node, false , dryrun);
7574+
7575+ break ;
7576+ case GGML_OP_ROPE_BACK:
7577+ ggml_vk_rope (ctx, compute_ctx, src0, src1, src2, node, true , dryrun);
75147578
75157579 break ;
75167580 case GGML_OP_ARGSORT:
@@ -7636,12 +7700,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76367700 case GGML_OP_CPY:
76377701 case GGML_OP_CONT:
76387702 case GGML_OP_DUP:
7703+ case GGML_OP_SILU_BACK:
76397704 case GGML_OP_NORM:
76407705 case GGML_OP_GROUP_NORM:
76417706 case GGML_OP_RMS_NORM:
7707+ case GGML_OP_RMS_NORM_BACK:
76427708 case GGML_OP_DIAG_MASK_INF:
76437709 case GGML_OP_SOFT_MAX:
7710+ case GGML_OP_SOFT_MAX_BACK:
76447711 case GGML_OP_ROPE:
7712+ case GGML_OP_ROPE_BACK:
76457713 case GGML_OP_RESHAPE:
76467714 case GGML_OP_VIEW:
76477715 case GGML_OP_PERMUTE:
@@ -8560,6 +8628,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85608628 case GGML_OP_REPEAT_BACK:
85618629 return op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
85628630 case GGML_OP_ROPE:
8631+ case GGML_OP_ROPE_BACK:
85638632 case GGML_OP_NONE:
85648633 case GGML_OP_RESHAPE:
85658634 case GGML_OP_VIEW:
@@ -8576,6 +8645,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85768645 case GGML_OP_MUL:
85778646 case GGML_OP_DIV:
85788647 case GGML_OP_CONCAT:
8648+ case GGML_OP_SILU_BACK:
8649+ case GGML_OP_RMS_NORM_BACK:
85798650 case GGML_OP_UPSCALE:
85808651 case GGML_OP_SCALE:
85818652 case GGML_OP_SQR:
@@ -8585,6 +8656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85858656 case GGML_OP_PAD:
85868657 case GGML_OP_DIAG_MASK_INF:
85878658 case GGML_OP_SOFT_MAX:
8659+ case GGML_OP_SOFT_MAX_BACK:
85888660 case GGML_OP_ARGSORT:
85898661 case GGML_OP_SUM:
85908662 case GGML_OP_SUM_ROWS:
@@ -8976,15 +9048,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89769048 tensor_clone = ggml_group_norm (ggml_ctx, src_clone[0 ], *(int *)tensor->op_params , ((float *)tensor->op_params )[1 ]);
89779049 } else if (tensor->op == GGML_OP_RMS_NORM) {
89789050 tensor_clone = ggml_rms_norm (ggml_ctx, src_clone[0 ], *(float *)tensor->op_params );
9051+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9052+ const float eps = ((float *) tensor->op_params )[0 ];
9053+ tensor_clone = ggml_rms_norm_back (ggml_ctx, src_clone[0 ], src_clone[1 ], eps);
9054+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9055+ tensor_clone = ggml_silu_back (ggml_ctx, src_clone[0 ], src_clone[1 ]);
89799056 } else if (tensor->op == GGML_OP_SOFT_MAX) {
89809057 if (src1 != nullptr ) {
89819058 tensor_clone = ggml_soft_max_ext (ggml_ctx, src_clone[0 ], src_clone[1 ], ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
89829059 } else {
89839060 tensor_clone = ggml_soft_max (ggml_ctx, src_clone[0 ]);
89849061 }
9062+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9063+ tensor_clone = ggml_soft_max_ext_back (ggml_ctx, src_clone[0 ], src_clone[1 ], ((float *)tensor->op_params )[0 ], ((float *)tensor->op_params )[1 ]);
89859064 } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
89869065 tensor_clone = ggml_diag_mask_inf (ggml_ctx, src_clone[0 ], *(int *)tensor->op_params );
8987- } else if (tensor->op == GGML_OP_ROPE) {
9066+ } else if (tensor->op == GGML_OP_ROPE || tensor-> op == GGML_OP_ROPE_BACK ) {
89889067 const int n_dims = ((int32_t *) tensor->op_params )[1 ];
89899068 const int mode = ((int32_t *) tensor->op_params )[2 ];
89909069 // const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8997,9 +9076,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89979076 const float beta_slow = ((float *) tensor->op_params )[10 ];
89989077 if (mode & GGML_ROPE_TYPE_MROPE) {
89999078 int32_t *sections = ((int32_t *) tensor->op_params ) + 11 ;
9000- tensor_clone = ggml_rope_multi (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9079+ if (tensor->op == GGML_OP_ROPE) {
9080+ tensor_clone = ggml_rope_multi (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9081+ } else {
9082+ tensor_clone = ggml_rope_multi_back (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9083+ }
90019084 } else {
9002- tensor_clone = ggml_rope_ext (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9085+ if (tensor->op == GGML_OP_ROPE) {
9086+ tensor_clone = ggml_rope_ext (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9087+ } else {
9088+ tensor_clone = ggml_rope_ext_back (ggml_ctx, src_clone[0 ], src_clone[1 ], src_clone[2 ], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9089+ }
90039090 }
90049091 } else if (tensor->op == GGML_OP_UNARY) {
90059092 switch (ggml_get_unary_op (tensor)) {
0 commit comments