Skip to content

Commit 3f750f8

Browse files
authored
metal: add support for opt_step_sgd (ggml-org#16539)
* metal: add support for opt_step_sgd * add newline to pass EditorConfig check
1 parent c515fc5 commit 3f750f8

File tree

7 files changed

+78
-0
lines changed

7 files changed

+78
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,3 +1519,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_
15191519

15201520
return res;
15211521
}
1522+
1523+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1524+
assert(op->op == GGML_OP_OPT_STEP_SGD);
1525+
1526+
char base[256];
1527+
char name[256];
1528+
1529+
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1530+
snprintf(name, 256, "%s", base);
1531+
1532+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1533+
if (res) {
1534+
return res;
1535+
}
1536+
1537+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1538+
1539+
return res;
1540+
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
137137
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
138138
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
139+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
139140

140141
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
141142
ggml_metal_library_t lib,

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
800800
};
801801
}
802802
case GGML_OP_OPT_STEP_ADAMW:
803+
case GGML_OP_OPT_STEP_SGD:
803804
return has_simdgroup_reduction;
804805
default:
805806
return false;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,4 +781,8 @@ typedef struct {
781781
int64_t np;
782782
} ggml_metal_kargs_opt_step_adamw;
783783

784+
typedef struct {
785+
int64_t np;
786+
} ggml_metal_kargs_opt_step_sgd;
787+
784788
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
418418
{
419419
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
420420
} break;
421+
case GGML_OP_OPT_STEP_SGD:
422+
{
423+
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
424+
} break;
421425
default:
422426
{
423427
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
34693473

34703474
return 1;
34713475
}
3476+
3477+
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
3478+
ggml_tensor * op = ctx->node(idx);
3479+
3480+
ggml_metal_library_t lib = ctx->lib;
3481+
ggml_metal_encoder_t enc = ctx->enc;
3482+
3483+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3484+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3485+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3486+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3487+
3488+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3489+
3490+
const int64_t np = ggml_nelements(op->src[0]);
3491+
ggml_metal_kargs_opt_step_sgd args = {
3492+
/*.np =*/ np,
3493+
};
3494+
3495+
int ida = 0;
3496+
3497+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3498+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3499+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
3500+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
3501+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
3502+
3503+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3504+
const int64_t n = (np + nth - 1) / nth;
3505+
3506+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3507+
3508+
return 1;
3509+
}

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
8080
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
8181
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
8282
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
83+
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
8384

8485
#ifdef __cplusplus
8586
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8806,3 +8806,17 @@ kernel void kernel_opt_step_adamw_f32(
88068806

88078807
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
88088808
}
8809+
8810+
kernel void kernel_opt_step_sgd_f32(
8811+
constant ggml_metal_kargs_opt_step_sgd & args,
8812+
device float * x,
8813+
device const float * g,
8814+
device const float * pars,
8815+
uint gid[[thread_position_in_grid]]) {
8816+
8817+
if (gid >= args.np) {
8818+
return;
8819+
}
8820+
8821+
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
8822+
}

0 commit comments

Comments
 (0)