diff --git a/metal_infer/infer.m b/metal_infer/infer.m index 5d2a946..448aab4 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -64,6 +64,8 @@ #include #include #include +#include +#include // ============================================================================ // Model constants @@ -79,16 +81,20 @@ #define NUM_EXPERTS 512 #define NUM_EXPERTS_PER_TOK 10 #define MOE_INTERMEDIATE 1024 +_Static_assert(MOE_INTERMEDIATE <= 1024, "v3_small kernel requires MOE_INTERMEDIATE <= 1024"); #define SHARED_INTERMEDIATE 1024 #define FULL_ATTN_INTERVAL 4 #define GROUP_SIZE 64 #define BITS 4 +#define MAX_K 8 // Maximum active experts per layer // Linear attention (GatedDeltaNet) constants #define LINEAR_NUM_V_HEADS 64 #define LINEAR_NUM_K_HEADS 16 #define LINEAR_KEY_DIM 128 // head_k_dim +_Static_assert(LINEAR_KEY_DIM == 128, "SIMD reductions in rms_norm_qk assume 4 groups of 32"); #define LINEAR_VALUE_DIM 128 // head_v_dim +_Static_assert(LINEAR_VALUE_DIM == 128, "SIMD reductions in rms_norm_qk assume 4 groups of 32"); #define LINEAR_TOTAL_KEY (LINEAR_NUM_K_HEADS * LINEAR_KEY_DIM) // 2048 #define LINEAR_TOTAL_VALUE (LINEAR_NUM_V_HEADS * LINEAR_VALUE_DIM) // 8192 #define LINEAR_CONV_DIM (LINEAR_TOTAL_KEY * 2 + LINEAR_TOTAL_VALUE) // 12288 @@ -196,6 +202,67 @@ static double now_ms(void) { static int g_cache_telemetry_enabled = 0; // enabled by --cache-telemetry flag static int g_think_budget = 2048; // max thinking tokens before force-emitting +// ============================================================================ +// Cache-Aware Routing (--cache-aware) +// ============================================================================ +// Modifies expert selection to PREFER experts already in the OS page cache. +// Instead of pure top-K by gate score, uses: +// adjusted_score[e] = original_score[e] + cache_bonus * is_likely_cached(layer, e) +// with bounded quality degradation via a tolerance threshold. +// +// Cache state inference: experts accessed within the last N layer-reads are +// likely still in the ~42GB OS page cache. At 6.75MB/expert (4-bit) and +// K=4 experts per layer, each token touches 60*4=240 experts = ~1.6GB. +// With ~42GB page cache, the last ~26 tokens' experts fit (~6240 layer-expert reads). +// We use a per-layer LRU token timestamp. +// +// Quality bound: only substitute a cached expert for an uncached one if the +// cached expert's raw gate score is within `tolerance` of the evicted expert's +// score: score[cached] >= score[original] * (1 - tolerance). +// With tolerance=0.10, we only pick experts at least 90% as strong as the +// original choice -- the 4th expert typically has weight ~0.08 after softmax, +// so the output perturbation is bounded by ~0.008 * ||expert_diff||. +// ============================================================================ + +static int g_cache_aware_enabled = 0; // enabled by --cache-aware flag +static float g_cache_bonus = 0.5f; // (reserved) additive bonus for future scoring experiments +static float g_cache_tolerance = 0.10f; // max score degradation as fraction of top-K range +static int g_cache_aware_window = 25; // tokens within which an expert is "likely cached" + +// Per-layer, per-expert: token number when last accessed. +// 0 = never accessed. Updated after routing, read before topk. +static uint64_t g_car_last_access[NUM_LAYERS][NUM_EXPERTS]; +static uint64_t g_car_token_clock = 0; // incremented each token + +// Stats +static uint64_t g_car_substitutions = 0; // times a cached expert replaced an uncached one +static uint64_t g_car_total_selections = 0; // total expert selections (tokens * K * layers) +static uint64_t g_car_estimated_hits = 0; // experts selected that were likely cached +static uint64_t g_car_estimated_misses = 0; // experts selected that were likely NOT cached + +static void cache_aware_reset(void) { + memset(g_car_last_access, 0, sizeof(g_car_last_access)); + g_car_token_clock = 0; + g_car_substitutions = 0; + g_car_total_selections = 0; + g_car_estimated_hits = 0; + g_car_estimated_misses = 0; +} + +static inline int car_is_likely_cached(int layer_idx, int expert_idx) { + uint64_t last = g_car_last_access[layer_idx][expert_idx]; + if (last == 0) return 0; // never accessed + return (g_car_token_clock - last) <= (uint64_t)g_cache_aware_window; +} + +static void car_touch(int layer_idx, const int *expert_indices, int K) { + for (int k = 0; k < K; k++) { + g_car_last_access[layer_idx][expert_indices[k]] = g_car_token_clock; + } +} + +// cpu_topk_cache_aware and car_print_stats defined after cpu_topk (see below) + // Tiered I/O: cold fds (F_NOCACHE) for first reads, warm fds (page cached) for repeats static int *g_layer_fds_cold = NULL; // [NUM_LAYERS] cold fds (set in main) static uint8_t g_expert_seen[NUM_LAYERS][NUM_EXPERTS / 8]; // bitset: seen before? @@ -250,6 +317,8 @@ static void cache_telemetry_reset(void) { } static void cache_telemetry_note_token(void) { + // Cache-aware routing token clock (always active when --cache-aware) + if (g_cache_aware_enabled) g_car_token_clock++; if (!g_cache_telemetry_enabled) return; g_cache_telemetry.token_clock++; } @@ -811,6 +880,154 @@ static void cpu_normalize_weights(float *weights, int K) { } } +// ============================================================================ +// Cache-Aware Top-K (implementation — state variables declared near line 223) +// ============================================================================ + +// Cache-aware top-K selection. +// 1. Compute standard top-K from raw gate scores. +// 2. Scan the remaining experts for cached ones whose score is within tolerance. +// 3. For each such cached expert, replace the lowest-scoring uncached expert +// in the current top-K set (if any uncached expert exists). +// This guarantees: +// - All K experts have score >= min_topk_score * (1 - tolerance) +// - We maximize cache hits without unbounded quality loss +// - If all top-K are already cached, no changes are made +static void cpu_topk_cache_aware( + const float *scores, int dim, int K, + int *indices, float *values, + int layer_idx +) { + // Step 1: Standard top-K + cpu_topk(scores, dim, K, indices, values); + + // Step 2: Identify uncached experts in the top-K (candidates for replacement) + int uncached_slots[MAX_K]; // indices into the topK arrays + int num_uncached = 0; + int num_cached_in_topk = 0; + + for (int k = 0; k < K; k++) { + if (car_is_likely_cached(layer_idx, indices[k])) { + num_cached_in_topk++; + } else { + uncached_slots[num_uncached++] = k; + } + } + + // If all top-K are already cached, nothing to do + if (num_uncached == 0) { + g_car_estimated_hits += K; + g_car_total_selections += K; + return; + } + + // Step 3: Compute the score range of the top-K. + // Tolerance is defined as a fraction of the top-K RANGE (max - min logit). + // A cached substitute must have score >= evicted_score - tolerance * range. + // This works correctly regardless of logit sign. + float max_topk_score = values[0], min_topk_score = values[0]; + for (int k = 1; k < K; k++) { + if (values[k] > max_topk_score) max_topk_score = values[k]; + if (values[k] < min_topk_score) min_topk_score = values[k]; + } + float score_range = max_topk_score - min_topk_score; + if (score_range < 1e-6f) score_range = 1e-6f; // avoid zero range + float abs_tolerance = g_cache_tolerance * score_range; + + // Global floor: no substitute below this score + float global_floor = min_topk_score - abs_tolerance; + + // Build a set of current top-K indices for O(1) membership check + uint8_t in_topk[NUM_EXPERTS]; + memset(in_topk, 0, sizeof(in_topk)); // stack allocation, 512 bytes + for (int k = 0; k < K; k++) in_topk[indices[k]] = 1; + + // Step 4: Scan all non-top-K experts for cached ones above the floor + // Collect at most `num_uncached` candidates (we only need that many) + typedef struct { int idx; float score; } CacheCandidate; + CacheCandidate candidates[MAX_K * 4]; // generous buffer (we only need num_uncached) + int num_candidates = 0; + + for (int e = 0; e < dim && num_candidates < num_uncached * 4; e++) { + if (in_topk[e]) continue; // already in top-K + if (!car_is_likely_cached(layer_idx, e)) continue; // not cached + if (scores[e] < global_floor) continue; // too weak + candidates[num_candidates].idx = e; + candidates[num_candidates].score = scores[e]; + num_candidates++; + } + + // Sort candidates descending by score (simple insertion sort for small N) + for (int i = 1; i < num_candidates; i++) { + CacheCandidate tmp = candidates[i]; + int j = i - 1; + while (j >= 0 && candidates[j].score < tmp.score) { + candidates[j + 1] = candidates[j]; + j--; + } + candidates[j + 1] = tmp; + } + + // Step 5: Replace uncached experts with cached candidates + // Sort uncached_slots by ascending score (replace weakest first) + for (int i = 1; i < num_uncached; i++) { + int tmp = uncached_slots[i]; + float tmp_score = values[tmp]; + int j = i - 1; + while (j >= 0 && values[uncached_slots[j]] > tmp_score) { + uncached_slots[j + 1] = uncached_slots[j]; + j--; + } + uncached_slots[j + 1] = tmp; + } + + int subs = 0; + int ci = 0; // candidate index + for (int u = 0; u < num_uncached && ci < num_candidates; u++) { + int slot = uncached_slots[u]; + float evicted_score = values[slot]; + + // Find best candidate within abs_tolerance of the evicted expert's score. + // candidate.score >= evicted_score - abs_tolerance + while (ci < num_candidates) { + if (candidates[ci].score >= evicted_score - abs_tolerance) { + // Substitute! + indices[slot] = candidates[ci].idx; + values[slot] = candidates[ci].score; + subs++; + ci++; + break; + } + ci++; + } + } + + g_car_substitutions += subs; + g_car_total_selections += K; + g_car_estimated_hits += num_cached_in_topk + subs; + g_car_estimated_misses += num_uncached - subs; +} + +static void car_print_stats(void) { + if (!g_cache_aware_enabled || g_car_total_selections == 0) return; + uint64_t total = g_car_estimated_hits + g_car_estimated_misses; + fprintf(stderr, "\n=== Cache-Aware Routing Stats ===\n"); + fprintf(stderr, "Tokens: %llu\n", g_car_token_clock); + fprintf(stderr, "Total selections:%llu\n", g_car_total_selections); + fprintf(stderr, "Substitutions: %llu (%.2f%%)\n", + g_car_substitutions, + 100.0 * g_car_substitutions / g_car_total_selections); + fprintf(stderr, "Est. cache hits: %llu / %llu (%.1f%%)\n", + g_car_estimated_hits, total, + total > 0 ? 100.0 * g_car_estimated_hits / total : 0.0); + fprintf(stderr, "Est. hit rate: %.1f%% -> %.1f%% (delta: +%.1f%%)\n", + total > 0 ? 100.0 * (g_car_estimated_hits - g_car_substitutions) / total : 0.0, + total > 0 ? 100.0 * g_car_estimated_hits / total : 0.0, + total > 0 ? 100.0 * g_car_substitutions / total : 0.0); + fprintf(stderr, "Config: bonus=%.2f, tolerance=%.2f, window=%d\n", + g_cache_bonus, g_cache_tolerance, g_cache_aware_window); +} + // Element-wise add: dst += src __attribute__((unused)) static void cpu_vec_add(float *dst, const float *src, int dim) { @@ -902,6 +1119,7 @@ static void cpu_conv1d_step( id queue; id library; id matvec_v3; + id matvec_v3_small; // 4KB x_shared for down_proj (in_dim<=1024) id matvec_v5; // LUT dequant variant id matvec_fast; // for in_dim > 4096 id matvec_2bit; // 2-bit expert dequant kernel @@ -934,7 +1152,7 @@ static void cpu_conv1d_step( // Each expert k uses slot [k]. // Double-buffered: set A (data) for GPU compute, set B (data_B) for background pread. // Gate/up/act/out only need one set (GPU uses them after pread completes). - #define MAX_K 8 + // MAX_K defined in model constants at top of file id buf_multi_expert_data[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set A id buf_multi_expert_data_B[MAX_K]; // [EXPERT_SIZE bytes] each — buffer set B (prefetch) id buf_multi_expert_gate[MAX_K]; // [MOE_INTERMEDIATE floats] @@ -1042,6 +1260,7 @@ static void cpu_conv1d_step( }; ctx->matvec_v3 = makePipe(@"dequant_matvec_4bit_v3"); + ctx->matvec_v3_small = makePipe(@"dequant_matvec_4bit_v3_small"); ctx->matvec_v5 = makePipe(@"dequant_matvec_4bit_v5"); // LUT variant (no uint→float conversions) ctx->matvec_fast = makePipe(@"dequant_matvec_4bit_fast"); ctx->matvec_2bit = makePipe(@"dequant_matvec_2bit"); @@ -1571,10 +1790,11 @@ static void gpu_encode_expert_forward_slot( threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; [enc endEncoding]; } - // down_proj: act[k] -> out[k] + // down_proj: act[k] -> out[k] — use v3_small for better GPU occupancy { id enc = [cmdbuf computeCommandEncoder]; - [enc setComputePipelineState:expert_pipe]; + id down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe; + [enc setComputePipelineState:down_pipe]; [enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_w_off atIndex:0]; [enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_s_off atIndex:1]; [enc setBuffer:ctx->buf_multi_expert_data[k] offset:down_b_off atIndex:2]; @@ -1667,10 +1887,11 @@ static void gpu_encode_expert_forward_slot_buf( threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; [enc endEncoding]; } - // down_proj + // down_proj — v3_small for better GPU occupancy { id enc = [cmdbuf computeCommandEncoder]; - [enc setComputePipelineState:expert_pipe]; + id down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe; + [enc setComputePipelineState:down_pipe]; [enc setBuffer:data_buf offset:down_w_off atIndex:0]; [enc setBuffer:data_buf offset:down_s_off atIndex:1]; [enc setBuffer:data_buf offset:down_b_off atIndex:2]; @@ -1768,8 +1989,11 @@ static void gpu_encode_experts_batched( [enc setBytes:&gate_up_out length:4 atIndex:3]; [enc dispatchThreadgroups:MTLSizeMake(swiglu_tgs, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; - // down_proj (same encoder, serialized after SwiGLU) - [enc setComputePipelineState:expert_pipe]; + // down_proj (same encoder, serialized after SwiGLU) — v3_small for occupancy + { + id down_pipe = (!g_use_2bit && ctx->matvec_v3_small) ? ctx->matvec_v3_small : expert_pipe; + [enc setComputePipelineState:down_pipe]; + } [enc setBuffer:expert_bufs[k] offset:down_w_off atIndex:0]; [enc setBuffer:expert_bufs[k] offset:down_s_off atIndex:1]; [enc setBuffer:expert_bufs[k] offset:down_b_off atIndex:2]; @@ -1856,10 +2080,10 @@ static void gpu_encode_expert_forward( threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; [enc endEncoding]; } - // down_proj + // down_proj — v3_small for better GPU occupancy { id enc = [cmdbuf computeCommandEncoder]; - [enc setComputePipelineState:ctx->matvec_v3]; + [enc setComputePipelineState:ctx->matvec_v3_small ? ctx->matvec_v3_small : ctx->matvec_v3]; [enc setBuffer:ctx->buf_expert_data offset:down_w_off atIndex:0]; [enc setBuffer:ctx->buf_expert_data offset:down_s_off atIndex:1]; [enc setBuffer:ctx->buf_expert_data offset:down_b_off atIndex:2]; @@ -2707,14 +2931,12 @@ static void moe_forward( fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4); } - // Softmax routing scores - cpu_softmax(gate_scores, NUM_EXPERTS); - - // Top-K expert selection + // Top-K on raw logits + partial softmax (only K values instead of 512 exp() calls). + // Softmax is monotonic so topK ordering is identical on raw vs softmax'd values. int expert_indices[64]; float expert_weights[64]; cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights); - cpu_normalize_weights(expert_weights, K); + cpu_softmax(expert_weights, K); // softmax only K values (not all 512) if (moe_dump) { fprintf(stderr, "[MOE-DUMP] routing: K=%d experts=[", K); @@ -2978,6 +3200,7 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits) InferPreadTask *tasks; int num_tasks; int tasks_completed; + _Atomic int tasks_done; // atomic completion counter for WFE-based wait int generation; // incremented each dispatch — workers wait for new gen volatile int shutdown; } IOThreadPool; @@ -3019,6 +3242,12 @@ static void lm_head_forward(WeightFile *wf, const float *hidden, float *logits) } } + // Signal completion via atomic (pairs with WFE spin in dispatch) + atomic_fetch_add_explicit(&g_io_pool.tasks_done, 1, memory_order_release); +#if defined(__aarch64__) || defined(__arm64__) + __asm__ volatile("sev" ::: "memory"); // wake WFE spinner +#endif + pthread_mutex_lock(&g_io_pool.mutex); g_io_pool.tasks_completed++; if (g_io_pool.tasks_completed == NUM_IO_THREADS) @@ -3035,6 +3264,7 @@ static void io_pool_init(void) { pthread_cond_init(&g_io_pool.work_done, NULL); g_io_pool.shutdown = 0; g_io_pool.generation = 0; + atomic_store(&g_io_pool.tasks_done, 0); g_io_pool.tasks = NULL; for (int i = 0; i < NUM_IO_THREADS; i++) pthread_create(&g_io_pool.threads[i], NULL, io_pool_worker, (void*)(intptr_t)i); @@ -3045,16 +3275,22 @@ static void io_pool_init(void) { static void io_pool_dispatch(InferPreadTask *tasks, int num_tasks) { if (num_tasks == 0) return; + atomic_store_explicit(&g_io_pool.tasks_done, 0, memory_order_relaxed); pthread_mutex_lock(&g_io_pool.mutex); g_io_pool.tasks = tasks; g_io_pool.num_tasks = num_tasks; g_io_pool.tasks_completed = 0; g_io_pool.generation++; pthread_cond_broadcast(&g_io_pool.work_ready); - while (g_io_pool.tasks_completed < NUM_IO_THREADS) { - pthread_cond_wait(&g_io_pool.work_done, &g_io_pool.mutex); - } pthread_mutex_unlock(&g_io_pool.mutex); + // Wait via atomic + WFE (avoids pthread_cond_wait kernel transition) + while (atomic_load_explicit(&g_io_pool.tasks_done, memory_order_acquire) < NUM_IO_THREADS) { +#if defined(__aarch64__) || defined(__arm64__) + __asm__ volatile("wfe" ::: "memory"); +#else + sched_yield(); +#endif + } } // ---- Async expert pread pipeline ---- @@ -5027,13 +5263,26 @@ static void fused_layer_forward( if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd2_encode += t1 - t0; } } - // ---- Softmax + top-K (CPU) ---- + // ---- Top-K + partial softmax (CPU) ---- + // TopK on raw logits (softmax is monotonic — preserves ordering), then + // softmax only the K selected values: 4 exp() instead of 512. + // Cache-aware routing: if enabled, prefer experts likely in page cache. if (g_timing_enabled) { t0 = now_ms(); } - cpu_softmax(gate_scores, NUM_EXPERTS); int expert_indices[64]; float expert_weights[64]; - cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights); - cpu_normalize_weights(expert_weights, K); + if (g_cache_aware_enabled) { + cpu_topk_cache_aware(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights, layer_idx); + } else { + cpu_topk(gate_scores, NUM_EXPERTS, K, expert_indices, expert_weights); + } + { + float max_val = expert_weights[0]; + for (int k = 1; k < K; k++) if (expert_weights[k] > max_val) max_val = expert_weights[k]; + float sum = 0.0f; + for (int k = 0; k < K; k++) { expert_weights[k] = expf(expert_weights[k] - max_val); sum += expert_weights[k]; } + float inv = 1.0f / sum; + for (int k = 0; k < K; k++) expert_weights[k] *= inv; + } if (g_freq_tracking) { for (int k = 0; k < K; k++) { g_expert_freq[layer_idx][expert_indices[k]]++; @@ -5054,6 +5303,11 @@ static void fused_layer_forward( } } + // Cache-aware routing: update access timestamps after routing + if (g_cache_aware_enabled) { + car_touch(layer_idx, expert_indices, (K > MAX_K) ? MAX_K : K); + } + if (g_timing_enabled) { t1 = now_ms(); g_timing.routing_cpu += t1 - t0; } // Log routing data for predictor training @@ -6299,6 +6553,7 @@ static void serve_loop( } } if (g_cache_telemetry_enabled) cache_telemetry_reset(); + if (g_cache_aware_enabled) cache_aware_reset(); // ---- Send SSE headers ---- http_write_str(client_fd, SSE_HEADERS); @@ -6472,6 +6727,7 @@ static void serve_loop( } else if (g_malloc_cache) { cache_telemetry_print(g_malloc_cache->hits, g_malloc_cache->misses); } + car_print_stats(); free(pt->ids); free(pt); @@ -6512,6 +6768,10 @@ static void print_usage(const char *prog) { printf(" --timing Enable per-layer timing breakdown\n"); printf(" --freq Enable expert frequency tracking + analysis\n"); printf(" --cache-telemetry Report cold vs eviction misses and reuse distance\n"); + printf(" --cache-aware Enable cache-aware routing (prefer cached experts)\n"); + printf(" --cache-bonus F Cache-aware bonus on raw logits (default: 0.5)\n"); + printf(" --cache-tolerance F Max relative score degradation (default: 0.10 = 10%%)\n"); + printf(" --cache-window N Tokens within which expert is 'likely cached' (default: 25)\n"); printf(" --2bit Use 2-bit quantized experts (packed_experts_2bit/)\n"); printf(" --gpu-linear Alias for the fused GPU delta-net path (default)\n"); printf(" --predict Enable temporal expert prediction (prefetch during CMD1_wait)\n"); @@ -6535,6 +6795,9 @@ int main(int argc, char **argv) { int malloc_cache_entries = 0; // 0 = disabled (override with --malloc-cache) int serve_port = 0; // 0 = disabled, >0 = HTTP serve mode + // Long-option-only codes (above 256 to avoid single-char conflicts) + enum { OPT_CACHE_AWARE = 300, OPT_CACHE_BONUS, OPT_CACHE_TOLERANCE, OPT_CACHE_WINDOW }; + static struct option long_options[] = { {"model", required_argument, 0, 'm'}, {"weights", required_argument, 0, 'w'}, @@ -6557,6 +6820,10 @@ int main(int argc, char **argv) { {"serve", required_argument, 0, 'R'}, {"predict", no_argument, 0, 'D'}, {"collect-routing", required_argument, 0, 'Z'}, + {"cache-aware", no_argument, 0, OPT_CACHE_AWARE}, + {"cache-bonus", required_argument, 0, OPT_CACHE_BONUS}, + {"cache-tolerance", required_argument, 0, OPT_CACHE_TOLERANCE}, + {"cache-window", required_argument, 0, OPT_CACHE_WINDOW}, {"help", no_argument, 0, 'h'}, {0, 0, 0, 0} }; @@ -6591,6 +6858,10 @@ int main(int argc, char **argv) { break; case 'B': g_think_budget = atoi(optarg); break; case 'R': serve_port = atoi(optarg); break; + case OPT_CACHE_AWARE: g_cache_aware_enabled = 1; break; + case OPT_CACHE_BONUS: g_cache_bonus = atof(optarg); break; + case OPT_CACHE_TOLERANCE: g_cache_tolerance = atof(optarg); break; + case OPT_CACHE_WINDOW: g_cache_aware_window = atoi(optarg); break; case 'h': print_usage(argv[0]); return 0; default: print_usage(argv[0]); return 1; } @@ -6664,6 +6935,10 @@ int main(int argc, char **argv) { printf("Cache: %d entries%s\n", cache_entries, cache_entries > 0 ? "" : " (disabled)"); } + if (g_cache_aware_enabled) { + printf("CacheAware: ON (bonus=%.2f, tolerance=%.2f, window=%d tokens)\n", + g_cache_bonus, g_cache_tolerance, g_cache_aware_window); + } double t0 = now_ms(); @@ -6872,6 +7147,7 @@ int main(int argc, char **argv) { // ---- Generate tokens ---- reset_delta_net_state(); // zero GPU delta-net state before generation if (g_cache_telemetry_enabled) cache_telemetry_reset(); + if (g_cache_aware_enabled) cache_aware_reset(); printf("--- Generating %d tokens ---\n", max_tokens); int pos = 0; // position counter for RoPE @@ -7115,6 +7391,8 @@ int main(int argc, char **argv) { ? 100.0 * g_spec_route_hits / g_spec_route_attempts : 0.0); } + car_print_stats(); + if (g_freq_tracking) freq_print_analysis(K); if (g_routing_log) { fclose(g_routing_log); diff --git a/metal_infer/shaders.metal b/metal_infer/shaders.metal index 80a3be6..22d1766 100644 --- a/metal_infer/shaders.metal +++ b/metal_infer/shaders.metal @@ -1047,25 +1047,34 @@ kernel void gated_delta_net_step( uint k_base = kh * 128; uint v_base = head_id * 128; - // Step 1+2: Decay state row and compute kv_mem = dot(S[vi][:], k[:]) + // Load entire state row into registers (1 device memory read) + float S[128]; + for (uint ki = 0; ki < 128; ki++) { + S[ki] = state[state_base + ki]; + } + + // Fused loop 1: decay + kv_mem dot product float kv_mem = 0.0f; for (uint ki = 0; ki < 128; ki++) { - float s = state[state_base + ki] * g; - state[state_base + ki] = s; - kv_mem += s * k[k_base + ki]; + S[ki] *= g; + kv_mem += S[ki] * k[k_base + ki]; } - // Step 3+4: Delta update — S[vi][ki] += k[ki] * delta + // Compute delta scalar float delta = (v[v_base + vi] - kv_mem) * beta; + + // Fused loop 2: state update + output dot product + float out_val = 0.0f; for (uint ki = 0; ki < 128; ki++) { - state[state_base + ki] += k[k_base + ki] * delta; + S[ki] += k[k_base + ki] * delta; + out_val += S[ki] * q[k_base + ki]; } - // Step 5: Output — out[vi] = dot(S[vi][:], q[:]) - float out_val = 0.0f; + // Write state row back (1 device memory write) for (uint ki = 0; ki < 128; ki++) { - out_val += state[state_base + ki] * q[k_base + ki]; + state[state_base + ki] = S[ki]; } + output[v_base + vi] = out_val; } @@ -1136,23 +1145,18 @@ kernel void rms_norm_qk( uint tid [[thread_position_in_threadgroup]] ) { uint base = head * key_dim; + uint simd_lane = tid % 32; + uint simd_group = tid / 32; - // RMS norm for q - threadgroup float q_sum_sq; - if (tid == 0) q_sum_sq = 0; - threadgroup_barrier(mem_flags::mem_threadgroup); - + // RMS norm for q — SIMD parallel reduction (4 groups of 32) float qval = (tid < key_dim) ? q[base + tid] : 0; - // Use threadgroup atomic add for sum of squares - float q_sq_local = qval * qval; - // Simple reduction: thread 0 accumulates (key_dim=128, fits in one pass) - threadgroup float q_partial[128]; - q_partial[tid] = q_sq_local; + float q_simd_val = simd_sum(qval * qval); + threadgroup float q_shared[4]; + if (simd_lane == 0) q_shared[simd_group] = q_simd_val; threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float q_sum_sq; if (tid == 0) { - float s = 0; - for (uint i = 0; i < key_dim; i++) s += q_partial[i]; - q_sum_sq = s; + q_sum_sq = q_shared[0] + q_shared[1] + q_shared[2] + q_shared[3]; } threadgroup_barrier(mem_flags::mem_threadgroup); float q_inv_rms = rsqrt(q_sum_sq / float(key_dim) + 1e-6f); @@ -1160,16 +1164,15 @@ kernel void rms_norm_qk( q[base + tid] = qval * q_inv_rms * inv_scale * inv_scale; // q gets extra scale } - // RMS norm for k - threadgroup float k_sum_sq; + // RMS norm for k — SIMD parallel reduction float kval = (tid < key_dim) ? k[base + tid] : 0; - threadgroup float k_partial[128]; - k_partial[tid] = kval * kval; + float k_simd_val = simd_sum(kval * kval); + threadgroup float k_shared[4]; + if (simd_lane == 0) k_shared[simd_group] = k_simd_val; threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float k_sum_sq; if (tid == 0) { - float s = 0; - for (uint i = 0; i < key_dim; i++) s += k_partial[i]; - k_sum_sq = s; + k_sum_sq = k_shared[0] + k_shared[1] + k_shared[2] + k_shared[3]; } threadgroup_barrier(mem_flags::mem_threadgroup); float k_inv_rms = rsqrt(k_sum_sq / float(key_dim) + 1e-6f); @@ -1221,20 +1224,22 @@ kernel void gated_rms_norm( uint tid [[thread_position_in_threadgroup]] ) { uint base = head * value_dim; + uint simd_lane = tid % 32; + uint simd_group = tid / 32; float val = (tid < value_dim) ? values[base + tid] : 0; - // RMS norm reduction - threadgroup float partial[128]; - partial[tid] = val * val; + // RMS norm — SIMD parallel reduction (4 groups of 32) + float simd_val = simd_sum(val * val); + threadgroup float shared[4]; + if (simd_lane == 0) shared[simd_group] = simd_val; threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float total_sq; if (tid == 0) { - float s = 0; - for (uint i = 0; i < value_dim; i++) s += partial[i]; - partial[0] = s; + total_sq = shared[0] + shared[1] + shared[2] + shared[3]; } threadgroup_barrier(mem_flags::mem_threadgroup); - float inv_rms = rsqrt(partial[0] / float(value_dim) + eps); + float inv_rms = rsqrt(total_sq / float(value_dim) + eps); if (tid < value_dim) { float normed = val * inv_rms; @@ -1294,3 +1299,79 @@ kernel void moe_combine_residual( hidden_out[tid] = h_mid[tid] + moe + shared_gate * shared_out[tid]; } + + +// ============================================================================ +// Kernel 1c-small: down_proj variant with 4KB threadgroup memory (vs 16KB) +// ============================================================================ +// +// Identical to dequant_matvec_4bit_v3 except x_shared is [1024] instead of +// [4096]. For down_proj (in_dim=1024), this reduces threadgroup memory from +// 16KB to 4KB, allowing ~4x more concurrent threadgroups per GPU core. +// On M1 (8 cores) this significantly improves occupancy and latency hiding. + +kernel void dequant_matvec_4bit_v3_small( + device const uint32_t* W_packed [[buffer(0)]], + device const uint16_t* scales [[buffer(1)]], + device const uint16_t* biases [[buffer(2)]], + device const float* x [[buffer(3)]], + device float* out [[buffer(4)]], + constant uint& out_dim [[buffer(5)]], + constant uint& in_dim [[buffer(6)]], + constant uint& group_size [[buffer(7)]], + uint tgid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_group [[simdgroup_index_in_threadgroup]] +) { + uint row = tgid * ROWS_PER_TG + simd_group; + uint packed_cols = in_dim / 8; + uint num_groups = in_dim / group_size; + + threadgroup float x_shared[1024]; // 4KB vs 16KB in v3 + + for (uint i = lid; i < in_dim; i += 256) { + x_shared[i] = x[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (row >= out_dim) return; + + device const uint32_t* w_row = W_packed + row * packed_cols; + device const uint16_t* s_row = scales + row * num_groups; + device const uint16_t* b_row = biases + row * num_groups; + + float acc = 0.0f; + + for (uint col = simd_lane; col < packed_cols; col += 32) { + uint g = col / (group_size / 8); + float scale = bf16_to_f32(s_row[g]); + float bias = bf16_to_f32(b_row[g]); + + uint32_t packed = w_row[col]; + uint x_base = col * 8; + + float sx0 = scale * x_shared[x_base + 0]; float bx0 = bias * x_shared[x_base + 0]; + float sx1 = scale * x_shared[x_base + 1]; float bx1 = bias * x_shared[x_base + 1]; + float sx2 = scale * x_shared[x_base + 2]; float bx2 = bias * x_shared[x_base + 2]; + float sx3 = scale * x_shared[x_base + 3]; float bx3 = bias * x_shared[x_base + 3]; + float sx4 = scale * x_shared[x_base + 4]; float bx4 = bias * x_shared[x_base + 4]; + float sx5 = scale * x_shared[x_base + 5]; float bx5 = bias * x_shared[x_base + 5]; + float sx6 = scale * x_shared[x_base + 6]; float bx6 = bias * x_shared[x_base + 6]; + float sx7 = scale * x_shared[x_base + 7]; float bx7 = bias * x_shared[x_base + 7]; + + acc += fma(float((packed >> 0) & 0xF), sx0, bx0); + acc += fma(float((packed >> 4) & 0xF), sx1, bx1); + acc += fma(float((packed >> 8) & 0xF), sx2, bx2); + acc += fma(float((packed >> 12) & 0xF), sx3, bx3); + acc += fma(float((packed >> 16) & 0xF), sx4, bx4); + acc += fma(float((packed >> 20) & 0xF), sx5, bx5); + acc += fma(float((packed >> 24) & 0xF), sx6, bx6); + acc += fma(float((packed >> 28) & 0xF), sx7, bx7); + } + + float sum = simd_sum(acc); + if (simd_lane == 0) { + out[row] = sum; + } +}