@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
418418            {
419419                n_fuse = ggml_metal_op_opt_step_adamw (ctx, idx);
420420            } break ;
421+         case  GGML_OP_OPT_STEP_SGD:
422+             {
423+                 n_fuse = ggml_metal_op_opt_step_sgd (ctx, idx);
424+             } break ;
421425       default :
422426            {
423427                GGML_LOG_ERROR (" %s: error: node %3d, op = %8s not implemented\n " ggml_op_name (node->op ));
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
34693473
34703474    return  1 ;
34713475}
3476+ 
3477+ int  ggml_metal_op_opt_step_sgd (ggml_metal_op_t  ctx, int  idx) {
3478+     ggml_tensor * op = ctx->node (idx);
3479+ 
3480+     ggml_metal_library_t  lib = ctx->lib ;
3481+     ggml_metal_encoder_t  enc = ctx->enc ;
3482+ 
3483+     GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
3484+     GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3485+     GGML_TENSOR_LOCALS ( int32_t , ne,  op,         ne);
3486+     GGML_TENSOR_LOCALS (uint32_t , nb,  op,         nb);
3487+ 
3488+     ggml_metal_pipeline_t  pipeline = ggml_metal_library_get_pipeline_opt_step_sgd (lib, op);
3489+ 
3490+     const  int64_t  np = ggml_nelements (op->src [0 ]);
3491+     ggml_metal_kargs_opt_step_sgd args = {
3492+         /* .np =*/ 
3493+     };
3494+ 
3495+     int  ida = 0 ;
3496+ 
3497+     ggml_metal_encoder_set_pipeline (enc, pipeline);
3498+     ggml_metal_encoder_set_bytes    (enc, &args, sizeof (args), ida++);
3499+     ggml_metal_encoder_set_buffer   (enc, ggml_metal_get_buffer_id (op->src [0 ]), ida++);
3500+     ggml_metal_encoder_set_buffer   (enc, ggml_metal_get_buffer_id (op->src [1 ]), ida++);
3501+     ggml_metal_encoder_set_buffer   (enc, ggml_metal_get_buffer_id (op->src [2 ]), ida++);
3502+ 
3503+     const  int  nth = std::min (ggml_metal_pipeline_max_theads_per_threadgroup (pipeline), ne0);
3504+     const  int64_t  n = (np + nth - 1 ) / nth;
3505+ 
3506+     ggml_metal_encoder_dispatch_threadgroups (enc, n, 1 , 1 , nth, 1 , 1 );
3507+ 
3508+     return  1 ;
3509+ }
0 commit comments