Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions metal_infer/infer.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
#include <signal.h>
#include <sys/wait.h>
#include <compression.h>
#include <stdatomic.h>
#include <sched.h>

// ============================================================================
// Model constants
Expand Down Expand Up @@ -902,6 +904,7 @@ static void cpu_conv1d_step(
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
id<MTLComputePipelineState> matvec_v3;
id<MTLComputePipelineState> matvec_v3_small; // 4KB x_shared for down_proj (in_dim<=1024)
id<MTLComputePipelineState> matvec_v5; // LUT dequant variant
id<MTLComputePipelineState> matvec_fast; // for in_dim > 4096
id<MTLComputePipelineState> matvec_2bit; // 2-bit expert dequant kernel
Expand Down Expand Up @@ -1042,6 +1045,7 @@ static void cpu_conv1d_step(
};

ctx->matvec_v3 = makePipe(@"dequant_matvec_4bit_v3");
ctx->matvec_v3_small = makePipe(@"dequant_matvec_4bit_v3_small");
ctx->matvec_v5 = makePipe(@"dequant_matvec_4bit_v5"); // LUT variant (no uint→float conversions)
ctx->matvec_fast = makePipe(@"dequant_matvec_4bit_fast");
ctx->matvec_2bit = makePipe(@"dequant_matvec_2bit");
Expand Down Expand Up @@ -1571,10 +1575,11 @@ static void gpu_encode_expert_forward_slot(
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
[enc endEncoding];
}
// down_proj: act[k] -> out[k]
// down_proj: act[k] -> out[k] — use v3_small for better GPU occupancy
{
id<MTLComputeCommandEncoder> enc = [cmdbuf computeCommandEncoder];
[enc setComputePipelineState:expert_pipe];
id<MTLComputePipelineState> down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe;
[enc setComputePipelineState:down_pipe];
[enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_w_off atIndex:0];
[enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_s_off atIndex:1];
[enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_b_off atIndex:2];
Expand Down Expand Up @@ -1667,10 +1672,11 @@ static void gpu_encode_expert_forward_slot_buf(
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
[enc endEncoding];
}
// down_proj
// down_proj — v3_small for better GPU occupancy
{
id<MTLComputeCommandEncoder> enc = [cmdbuf computeCommandEncoder];
[enc setComputePipelineState:expert_pipe];
id<MTLComputePipelineState> down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe;
[enc setComputePipelineState:down_pipe];
[enc setBuffer:data_buf offset:down_w_off atIndex:0];
[enc setBuffer:data_buf offset:down_s_off atIndex:1];
[enc setBuffer:data_buf offset:down_b_off atIndex:2];
Expand Down Expand Up @@ -1768,8 +1774,11 @@ static void gpu_encode_experts_batched(
[enc setBytes:&gate_up_out length:4 atIndex:3];
[enc dispatchThreadgroups:MTLSizeMake(swiglu_tgs, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
// down_proj (same encoder, serialized after SwiGLU)
[enc setComputePipelineState:expert_pipe];
// down_proj (same encoder, serialized after SwiGLU) — v3_small for occupancy
{
id<MTLComputePipelineState> down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe;
[enc setComputePipelineState:down_pipe];
}
[enc setBuffer:expert_bufs[k] offset:down_w_off atIndex:0];
[enc setBuffer:expert_bufs[k] offset:down_s_off atIndex:1];
[enc setBuffer:expert_bufs[k] offset:down_b_off atIndex:2];
Expand Down Expand Up @@ -1856,10 +1865,10 @@ static void gpu_encode_expert_forward(
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
[enc endEncoding];
}
// down_proj
// down_proj — v3_small for better GPU occupancy
{
id<MTLComputeCommandEncoder> enc = [cmdbuf computeCommandEncoder];
[enc setComputePipelineState:ctx->matvec_v3];
[enc setComputePipelineState:ctx->matvec_v3_small ? ctx->matvec_v3_small : ctx->matvec_v3];
[enc setBuffer:ctx->buf_expert_data offset:down_w_off atIndex:0];
[enc setBuffer:ctx->buf_expert_data offset:down_s_off atIndex:1];
[enc setBuffer:ctx->buf_expert_data offset:down_b_off atIndex:2];
Expand Down Expand Up @@ -2707,14 +2716,19 @@ static void moe_forward(
fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4);
}

// Softmax routing scores
cpu_softmax(gate_scores, NUM_EXPERTS);

// Top-K expert selection
// Top-K on raw logits + partial softmax (only K values instead of 512 exp() calls).
// Softmax is monotonic so topK ordering is identical on raw vs softmax'd values.
int expert_indices[64];
float expert_weights[64];
cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights);
cpu_normalize_weights(expert_weights, K);
{
float max_val = expert_weights[0];
for (int k = 1; k < K; k++) if (expert_weights[k] > max_val) max_val = expert_weights[k];
float sum = 0.0f;
for (int k = 0; k < K; k++) { expert_weights[k] = expf(expert_weights[k] - max_val); sum += expert_weights[k]; }
float inv = 1.0f / sum;
for (int k = 0; k < K; k++) expert_weights[k] *= inv;
}

if (moe_dump) {
fprintf(stderr, "[MOE-DUMP] routing: K=%d experts=[", K);
Expand Down Expand Up @@ -2978,6 +2992,7 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits)
InferPreadTask *tasks;
int num_tasks;
int tasks_completed;
_Atomic int tasks_done; // atomic completion counter for WFE-based wait
int generation; // incremented each dispatch — workers wait for new gen
volatile int shutdown;
} IOThreadPool;
Expand Down Expand Up @@ -3019,6 +3034,9 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits)
}
}

// Signal completion via atomic (pairs with WFE spin in dispatch)
atomic_fetch_add_explicit(&g_io_pool.tasks_done, 1, memory_order_release);

pthread_mutex_lock(&g_io_pool.mutex);
g_io_pool.tasks_completed++;
if (g_io_pool.tasks_completed == NUM_IO_THREADS)
Expand All @@ -3035,6 +3053,7 @@ static void io_pool_init(void) {
pthread_cond_init(&g_io_pool.work_done, NULL);
g_io_pool.shutdown = 0;
g_io_pool.generation = 0;
atomic_store(&g_io_pool.tasks_done, 0);
g_io_pool.tasks = NULL;
for (int i = 0; i < NUM_IO_THREADS; i++)
pthread_create(&g_io_pool.threads[i], NULL, io_pool_worker, (void*)(intptr_t)i);
Expand All @@ -3045,16 +3064,22 @@ static void io_pool_init(void) {

static void io_pool_dispatch(InferPreadTask *tasks, int num_tasks) {
if (num_tasks == 0) return;
atomic_store_explicit(&g_io_pool.tasks_done, 0, memory_order_relaxed);
pthread_mutex_lock(&g_io_pool.mutex);
g_io_pool.tasks = tasks;
g_io_pool.num_tasks = num_tasks;
g_io_pool.tasks_completed = 0;
g_io_pool.generation++;
pthread_cond_broadcast(&g_io_pool.work_ready);
while (g_io_pool.tasks_completed < NUM_IO_THREADS) {
pthread_cond_wait(&g_io_pool.work_done, &g_io_pool.mutex);
}
pthread_mutex_unlock(&g_io_pool.mutex);
// Wait via atomic + WFE (avoids pthread_cond_wait kernel transition)
while (atomic_load_explicit(&g_io_pool.tasks_done, memory_order_acquire) < NUM_IO_THREADS) {
#if defined(__aarch64__) || defined(__arm64__)
__asm__ volatile("wfe" ::: "memory");
#else
sched_yield();
#endif
}
}

// ---- Async expert pread pipeline ----
Expand Down Expand Up @@ -5027,13 +5052,21 @@ static void fused_layer_forward(
if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_encode += t1 - t0; }
}

// ---- Softmax + top-K (CPU) ----
// ---- Top-K + partial softmax (CPU) ----
// TopK on raw logits (softmax is monotonic — preserves ordering), then
// softmax only the K selected values: 4 exp() instead of 512.
if (g_timing_enabled) { t0 = now_ms(); }
cpu_softmax(gate_scores, NUM_EXPERTS);
int expert_indices[64];
float expert_weights[64];
cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights);
cpu_normalize_weights(expert_weights, K);
{
float max_val = expert_weights[0];
for (int k = 1; k < K; k++) if (expert_weights[k] > max_val) max_val = expert_weights[k];
float sum = 0.0f;
for (int k = 0; k < K; k++) { expert_weights[k] = expf(expert_weights[k] - max_val); sum += expert_weights[k]; }
float inv = 1.0f / sum;
for (int k = 0; k < K; k++) expert_weights[k] *= inv;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the above code is same as a cpu_softmax, right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — it is a softmax, but applied to only K values (typically 4) instead of all 512. The optimization: we run cpu_topk on raw logits first. Since softmax is monotonic (preserves ordering), the top-K indices are identical whether you softmax before or after selection. Then we softmax just those K values to get normalized routing weights. Net: 4 expf() calls instead of 512 — mathematically identical result.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can replace code with cpu_softmax(expert_weights, K).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — much cleaner. Will update.

if (g_freq_tracking) {
for (int k = 0; k < K; k++) {
g_expert_freq[layer_idx][expert_indices[k]]++;
Expand Down
153 changes: 117 additions & 36 deletions metal_infer/shaders.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1047,25 +1047,34 @@ kernel void gated_delta_net_step(
uint k_base = kh * 128;
uint v_base = head_id * 128;

// Step 1+2: Decay state row and compute kv_mem = dot(S[vi][:], k[:])
// Load entire state row into registers (1 device memory read)
float S[128];
for (uint ki = 0; ki < 128; ki++) {
S[ki] = state[state_base + ki];
}

// Fused loop 1: decay + kv_mem dot product
float kv_mem = 0.0f;
for (uint ki = 0; ki < 128; ki++) {
float s = state[state_base + ki] * g;
state[state_base + ki] = s;
kv_mem += s * k[k_base + ki];
S[ki] *= g;
kv_mem += S[ki] * k[k_base + ki];
}

// Step 3+4: Delta update — S[vi][ki] += k[ki] * delta
// Compute delta scalar
float delta = (v[v_base + vi] - kv_mem) * beta;

// Fused loop 2: state update + output dot product
float out_val = 0.0f;
for (uint ki = 0; ki < 128; ki++) {
state[state_base + ki] += k[k_base + ki] * delta;
S[ki] += k[k_base + ki] * delta;
out_val += S[ki] * q[k_base + ki];
}

// Step 5: Output — out[vi] = dot(S[vi][:], q[:])
float out_val = 0.0f;
// Write state row back (1 device memory write)
for (uint ki = 0; ki < 128; ki++) {
out_val += state[state_base + ki] * q[k_base + ki];
state[state_base + ki] = S[ki];
}

output[v_base + vi] = out_val;
}

Expand Down Expand Up @@ -1136,40 +1145,34 @@ kernel void rms_norm_qk(
uint tid [[thread_position_in_threadgroup]]
) {
uint base = head * key_dim;
uint simd_lane = tid % 32;
uint simd_group = tid / 32;

// RMS norm for q
threadgroup float q_sum_sq;
if (tid == 0) q_sum_sq = 0;
threadgroup_barrier(mem_flags::mem_threadgroup);

// RMS norm for q — SIMD parallel reduction (4 groups of 32)
float qval = (tid < key_dim) ? q[base + tid] : 0;
// Use threadgroup atomic add for sum of squares
float q_sq_local = qval * qval;
// Simple reduction: thread 0 accumulates (key_dim=128, fits in one pass)
threadgroup float q_partial[128];
q_partial[tid] = q_sq_local;
float q_simd_val = simd_sum(qval * qval);
threadgroup float q_shared[4];
if (simd_lane == 0) q_shared[simd_group] = q_simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float q_sum_sq;
if (tid == 0) {
float s = 0;
for (uint i = 0; i < key_dim; i++) s += q_partial[i];
q_sum_sq = s;
q_sum_sq = q_shared[0] + q_shared[1] + q_shared[2] + q_shared[3];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float q_inv_rms = rsqrt(q_sum_sq / float(key_dim) + 1e-6f);
if (tid < key_dim) {
q[base + tid] = qval * q_inv_rms * inv_scale * inv_scale; // q gets extra scale
}

// RMS norm for k
threadgroup float k_sum_sq;
// RMS norm for k — SIMD parallel reduction
float kval = (tid < key_dim) ? k[base + tid] : 0;
threadgroup float k_partial[128];
k_partial[tid] = kval * kval;
float k_simd_val = simd_sum(kval * kval);
threadgroup float k_shared[4];
if (simd_lane == 0) k_shared[simd_group] = k_simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float k_sum_sq;
if (tid == 0) {
float s = 0;
for (uint i = 0; i < key_dim; i++) s += k_partial[i];
k_sum_sq = s;
k_sum_sq = k_shared[0] + k_shared[1] + k_shared[2] + k_shared[3];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float k_inv_rms = rsqrt(k_sum_sq / float(key_dim) + 1e-6f);
Expand Down Expand Up @@ -1221,20 +1224,22 @@ kernel void gated_rms_norm(
uint tid [[thread_position_in_threadgroup]]
) {
uint base = head * value_dim;
uint simd_lane = tid % 32;
uint simd_group = tid / 32;

float val = (tid < value_dim) ? values[base + tid] : 0;

// RMS norm reduction
threadgroup float partial[128];
partial[tid] = val * val;
// RMS norm — SIMD parallel reduction (4 groups of 32)
float simd_val = simd_sum(val * val);
threadgroup float shared[4];
if (simd_lane == 0) shared[simd_group] = simd_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float total_sq;
if (tid == 0) {
float s = 0;
for (uint i = 0; i < value_dim; i++) s += partial[i];
partial[0] = s;
total_sq = shared[0] + shared[1] + shared[2] + shared[3];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float inv_rms = rsqrt(partial[0] / float(value_dim) + eps);
float inv_rms = rsqrt(total_sq / float(value_dim) + eps);

if (tid < value_dim) {
float normed = val * inv_rms;
Expand Down Expand Up @@ -1294,3 +1299,79 @@ kernel void moe_combine_residual(

hidden_out[tid] = h_mid[tid] + moe + shared_gate * shared_out[tid];
}


// ============================================================================
// Kernel 1c-small: down_proj variant with 4KB threadgroup memory (vs 16KB)
// ============================================================================
//
// Identical to dequant_matvec_4bit_v3 except x_shared is [1024] instead of
// [4096]. For down_proj (in_dim=1024), this reduces threadgroup memory from
// 16KB to 4KB, allowing ~4x more concurrent threadgroups per GPU core.
// On M1 (8 cores) this significantly improves occupancy and latency hiding.

kernel void dequant_matvec_4bit_v3_small(
device const uint32_t* W_packed [[buffer(0)]],
device const uint16_t* scales [[buffer(1)]],
device const uint16_t* biases [[buffer(2)]],
device const float* x [[buffer(3)]],
device float* out [[buffer(4)]],
constant uint& out_dim [[buffer(5)]],
constant uint& in_dim [[buffer(6)]],
constant uint& group_size [[buffer(7)]],
uint tgid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_group [[simdgroup_index_in_threadgroup]]
) {
uint row = tgid * ROWS_PER_TG + simd_group;
uint packed_cols = in_dim / 8;
uint num_groups = in_dim / group_size;

threadgroup float x_shared[1024]; // 4KB vs 16KB in v3

for (uint i = lid; i < in_dim; i += 256) {
x_shared[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);

if (row >= out_dim) return;

device const uint32_t* w_row = W_packed + row * packed_cols;
device const uint16_t* s_row = scales + row * num_groups;
device const uint16_t* b_row = biases + row * num_groups;

float acc = 0.0f;

for (uint col = simd_lane; col < packed_cols; col += 32) {
uint g = col / (group_size / 8);
float scale = bf16_to_f32(s_row[g]);
float bias = bf16_to_f32(b_row[g]);

uint32_t packed = w_row[col];
uint x_base = col * 8;

float sx0 = scale * x_shared[x_base + 0]; float bx0 = bias * x_shared[x_base + 0];
float sx1 = scale * x_shared[x_base + 1]; float bx1 = bias * x_shared[x_base + 1];
float sx2 = scale * x_shared[x_base + 2]; float bx2 = bias * x_shared[x_base + 2];
float sx3 = scale * x_shared[x_base + 3]; float bx3 = bias * x_shared[x_base + 3];
float sx4 = scale * x_shared[x_base + 4]; float bx4 = bias * x_shared[x_base + 4];
float sx5 = scale * x_shared[x_base + 5]; float bx5 = bias * x_shared[x_base + 5];
float sx6 = scale * x_shared[x_base + 6]; float bx6 = bias * x_shared[x_base + 6];
float sx7 = scale * x_shared[x_base + 7]; float bx7 = bias * x_shared[x_base + 7];

acc += fma(float((packed >> 0) & 0xF), sx0, bx0);
acc += fma(float((packed >> 4) & 0xF), sx1, bx1);
acc += fma(float((packed >> 8) & 0xF), sx2, bx2);
acc += fma(float((packed >> 12) & 0xF), sx3, bx3);
acc += fma(float((packed >> 16) & 0xF), sx4, bx4);
acc += fma(float((packed >> 20) & 0xF), sx5, bx5);
acc += fma(float((packed >> 24) & 0xF), sx6, bx6);
acc += fma(float((packed >> 28) & 0xF), sx7, bx7);
}

float sum = simd_sum(acc);
if (simd_lane == 0) {
out[row] = sum;
}
}