diff --git a/src/tilegym/ops/cutile/__init__.py b/src/tilegym/ops/cutile/__init__.py index 79f318b..45ba4f7 100644 --- a/src/tilegym/ops/cutile/__init__.py +++ b/src/tilegym/ops/cutile/__init__.py @@ -16,6 +16,7 @@ from . import attention from . import dropout from . import flash_decode + from . import flash_decode_fused from . import group_gemm from . import matmul from . import mla @@ -32,6 +33,7 @@ # Import specific functions for direct access from .flash_decode import fmha_decode + from .flash_decode_fused import fmha_decode_fused from .moe import fused_moe_kernel as invoke_fused_moe_kernel from .moe_align_block import moe_align_block_size from .rms_norm import get_rms_norm_module @@ -47,7 +49,9 @@ __all__ = [ # NN operations "fmha_decode", + "fmha_decode_fused", "flash_decode", + "flash_decode_fused", "splitk_reduce", "invoke_fused_moe_kernel", "moe_align_block_size", diff --git a/src/tilegym/ops/cutile/flash_decode_fused.py b/src/tilegym/ops/cutile/flash_decode_fused.py new file mode 100644 index 0000000..dc79520 --- /dev/null +++ b/src/tilegym/ops/cutile/flash_decode_fused.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +""" +Experimental fused FMHA decode (split-kv + reduction in a single kernel). + +See `flash_decode.py` for the baseline two-kernel approach. +""" + +import math + +import cuda.tile as ct +import torch +from cuda.tile._numeric_semantics import RoundingMode as RMd + +from tilegym.backend import register_impl + +from .utils import next_power_of_2 + +INV_LOG_2 = 1.0 / math.log(2) + +# Type aliases for constants +ConstInt = ct.Constant[int] + + +@ct.kernel +def attention_decode_fused_kernel( + Q, # [B, H_kv, NUM_Q_HEAD_PER_KV, HEAD_DIM] + K, # [B, H_kv, S_kv, HEAD_DIM] + V, # [B, H_kv, S_kv, HEAD_DIM] + Output, # [B, H_q, HEAD_DIM] + Partial_O, # [B, H_kv, NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV, HEAD_DIM] + Partial_LSE, # [B, H_kv, NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV] + Completion_Counter, # [B, H_kv] int32 + softmax_scale: float, + B: int, + H_q: int, + H_kv: int, + S_kv: int, + HEAD_DIM: ConstInt, + TILE_N: ConstInt, + KV_LEN_PER_SPLIT: ConstInt, + NUM_Q_HEAD_PER_KV: ConstInt, + NUM_KV_SPLITS: ConstInt, +): + batch_id = ct.bid(0) + kv_head_id = ct.bid(1) + split_id = ct.bid(2) + + qk_scale = ct.mul(softmax_scale, INV_LOG_2) + + # ========================================= + # PHASE 1: Standard attention computation (local to this split) + # ========================================= + q = ct.load( + Q, + index=(batch_id, kv_head_id, 0, 0), + shape=(1, 1, NUM_Q_HEAD_PER_KV, HEAD_DIM), + order=(0, 1, 2, 3), + allow_tma=True, + ) + q = ct.reshape(q, (NUM_Q_HEAD_PER_KV, HEAD_DIM)) + q = ct.transpose(q) # (HEAD_DIM, NUM_Q_HEAD_PER_KV) + + m_i = ct.full((NUM_Q_HEAD_PER_KV,), -math.inf, dtype=ct.float32) + l_i = ct.full((TILE_N, NUM_Q_HEAD_PER_KV), 1.0, dtype=ct.float32) + acc = ct.full((HEAD_DIM, NUM_Q_HEAD_PER_KV), 0.0, dtype=ct.float32) + + start_idx = ct.mul(split_id, KV_LEN_PER_SPLIT) + end_idx = ct.minimum(ct.add(start_idx, KV_LEN_PER_SPLIT), S_kv) + + num_tiles = ct.cdiv(KV_LEN_PER_SPLIT, TILE_N) + offs_n = ct.arange(TILE_N, dtype=ct.int32) + + for idx in range(num_tiles): + cnt = (start_idx // TILE_N) + idx + kv_pos = cnt * TILE_N + + if kv_pos >= end_idx: + continue + + k = ct.load( + K, + index=(batch_id, kv_head_id, cnt, 0), + shape=(1, 1, TILE_N, HEAD_DIM), + order=(0, 1, 2, 3), + allow_tma=True, + ) + k = ct.reshape(k, (TILE_N, HEAD_DIM)) + qk = ct.matmul(k, q) # (TILE_N, NUM_Q_HEAD_PER_KV) + + # Mask for split end + if kv_pos + TILE_N > end_idx: + mask = ct.less(ct.add(kv_pos, offs_n[:, None]), end_idx) + qk = ct.where(mask, qk, -1.0e6) + + qk_scaled = ct.mul(qk, qk_scale) + m_ij = ct.maximum(m_i, ct.max(qk_scaled, 0)) + qk_shifted = ct.sub(qk_scaled, m_ij[None, :]) + p = ct.exp2(qk_shifted) + + alpha = ct.exp2(ct.sub(m_i, m_ij)) + l_i = ct.add(ct.mul(l_i, alpha[None, :]), p) + acc = ct.mul(acc, alpha[None, :]) + + v = ct.load( + V, + index=(batch_id, kv_head_id, cnt, 0), + shape=(1, 1, TILE_N, HEAD_DIM), + order=(0, 1, 2, 3), + allow_tma=True, + ) + v = ct.reshape(v, (TILE_N, HEAD_DIM)) + v = ct.transpose(v) # (HEAD_DIM, TILE_N) + p = ct.astype(p, q.dtype) + acc = ct.mma(v, p, acc=acc) + + m_i = m_ij + + # Finalize local results + l = ct.sum(l_i, 0) # (NUM_Q_HEAD_PER_KV,) + acc = ct.truediv(acc, l[None, :], flush_to_zero=True, rounding_mode=RMd.APPROX) + acc = ct.astype(acc, ct.float32) + acc = ct.transpose(acc) # (NUM_Q_HEAD_PER_KV, HEAD_DIM) + acc = ct.astype(acc, Partial_O.dtype) + lse = ct.add(m_i, ct.log2(l)) # log2-space LSE per q-head + + # ========================================= + # PHASE 2: Write partial results + # ========================================= + ct.store( + Partial_O, + index=(batch_id, kv_head_id, split_id, 0, 0), + tile=ct.reshape(acc, (1, 1, 1, NUM_Q_HEAD_PER_KV, HEAD_DIM)), + order=(0, 1, 2, 3, 4), + # Avoid async TMA stores here: we need the data to be globally visible + # before the completion counter is incremented. + allow_tma=False, + ) + + idx_q = ct.arange(NUM_Q_HEAD_PER_KV, dtype=ct.int32) + ct.scatter( + Partial_LSE, + (batch_id, kv_head_id, split_id, idx_q), + lse, + check_bounds=True, + latency=1, + ) + + # ========================================= + # PHASE 3: Atomic counter and reduction + # ========================================= + # Publish partials, then increment completion counter. + # Use RELEASE to prevent reordering of the stores after the atomic. + old_count = ct.atomic_add( + Completion_Counter, + (batch_id, kv_head_id), + 1, + check_bounds=True, + memory_order=ct.MemoryOrder.RELEASE, + memory_scope=ct.MemoryScope.DEVICE, + ) + + if old_count == (NUM_KV_SPLITS - 1): + # Acquire fence to ensure we observe all other splits' published partials. + # (atomic_add with update=0 acts as an atomic load + acquire barrier.) + ct.atomic_add( + Completion_Counter, + (batch_id, kv_head_id), + 0, + check_bounds=True, + memory_order=ct.MemoryOrder.ACQUIRE, + memory_scope=ct.MemoryScope.DEVICE, + ) + + # Reset counter for next iteration + ct.atomic_xchg( + Completion_Counter, + (batch_id, kv_head_id), + 0, + check_bounds=True, + memory_order=ct.MemoryOrder.RELAXED, + memory_scope=ct.MemoryScope.DEVICE, + ) + + # Load all partials + all_partial_o = ct.load( + Partial_O, + index=(batch_id, kv_head_id, 0, 0, 0), + shape=(1, 1, NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV, HEAD_DIM), + order=(0, 1, 2, 3, 4), + allow_tma=False, + ) + all_partial_o = ct.reshape(all_partial_o, (NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV, HEAD_DIM)) + + all_lse = ct.load( + Partial_LSE, + index=(batch_id, kv_head_id, 0, 0), + shape=(1, 1, NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV), + order=(0, 1, 2, 3), + allow_tma=False, + ) + all_lse = ct.reshape(all_lse, (NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV)) + + # Reduce over splits for all Q heads in this kv-group (vectorized). + # Avoid dynamic tile indexing (cuda.tile requires constant subscripts). + lse_max = ct.max(all_lse, 0) # (NUM_Q_HEAD_PER_KV,) + weights = ct.exp2(ct.sub(all_lse, lse_max[None, :])) # (NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV) + weights_sum = ct.sum(weights, 0) # (NUM_Q_HEAD_PER_KV,) + + weights_3d = ct.reshape(weights, (NUM_KV_SPLITS, NUM_Q_HEAD_PER_KV, 1)) + weighted_sum = ct.sum(weights_3d * all_partial_o, axis=0) # (NUM_Q_HEAD_PER_KV, HEAD_DIM) + final_output = weighted_sum / ct.reshape(weights_sum, (NUM_Q_HEAD_PER_KV, 1)) + + # Store all query heads for this kv-head group in one go. + # IMPORTANT: `ct.store` indices are tile indices (not element indices). + # With a tile shaped (1, NUM_Q_HEAD_PER_KV, HEAD_DIM) on Output[B, H_q, D], + # the head dimension is tiled by NUM_Q_HEAD_PER_KV, so we index by `kv_head_id`. + ct.store( + Output, + index=(batch_id, kv_head_id, 0), + tile=ct.reshape(ct.astype(final_output, Output.dtype), (1, NUM_Q_HEAD_PER_KV, HEAD_DIM)), + order=(0, 1, 2), + allow_tma=True, + ) + + +class _attention_decode_fused(torch.autograd.Function): + @staticmethod + def forward(ctx, Q, K, V, softmax_scale, kv_len_per_split=None): + batch_size, num_q_heads = Q.shape[0], Q.shape[1] + num_kv_heads = K.shape[1] + seq_len, head_dim = V.shape[2], V.shape[3] + + # Reshape for processing + Q = Q.view(batch_size, num_q_heads, head_dim) + K = K.view(batch_size, num_kv_heads, seq_len, head_dim) + V = V.view(batch_size, num_kv_heads, seq_len, head_dim) + + TILE_N = 128 + if kv_len_per_split is None: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + num_kv_splits_est = NUM_SMS // (batch_size * num_kv_heads) + KV_LEN_PER_SPLIT = max( + TILE_N, + next_power_of_2((seq_len + num_kv_splits_est - 1) // num_kv_splits_est), + ) + NUM_KV_SPLITS = (seq_len + KV_LEN_PER_SPLIT - 1) // KV_LEN_PER_SPLIT + else: + KV_LEN_PER_SPLIT = kv_len_per_split + NUM_KV_SPLITS = (seq_len + KV_LEN_PER_SPLIT - 1) // KV_LEN_PER_SPLIT + + KV_LEN_PER_SPLIT = next_power_of_2(KV_LEN_PER_SPLIT) + assert KV_LEN_PER_SPLIT >= TILE_N + + # Grouped-query layout (same constraints as existing decode kernel) + assert num_q_heads % num_kv_heads == 0 + num_q_head_per_kv = num_q_heads // num_kv_heads + assert head_dim == next_power_of_2(head_dim) + assert num_q_head_per_kv == next_power_of_2(num_q_head_per_kv) + + HEAD_DIM = head_dim + Q_grouped = Q.view(batch_size, num_kv_heads, num_q_head_per_kv, head_dim) + + # Workspaces + output + Partial_O = torch.empty( + (batch_size, num_kv_heads, NUM_KV_SPLITS, num_q_head_per_kv, head_dim), + device=Q.device, + dtype=Q.dtype, + ) + Partial_LSE = torch.empty( + (batch_size, num_kv_heads, NUM_KV_SPLITS, num_q_head_per_kv), + device=Q.device, + dtype=torch.float32, + ) + Completion_Counter = torch.zeros( + (batch_size, num_kv_heads), + device=Q.device, + dtype=torch.int32, + ) + O = torch.empty((batch_size, num_q_heads, head_dim), device=Q.device, dtype=Q.dtype) + + grid = (batch_size, num_kv_heads, NUM_KV_SPLITS) + ct.launch( + torch.cuda.current_stream(), + grid, + attention_decode_fused_kernel, + ( + Q_grouped, + K, + V, + O, + Partial_O, + Partial_LSE, + Completion_Counter, + softmax_scale, + batch_size, + num_q_heads, + num_kv_heads, + seq_len, + HEAD_DIM, + TILE_N, + KV_LEN_PER_SPLIT, + num_q_head_per_kv, + NUM_KV_SPLITS, + ), + ) + + return O.view(batch_size, num_q_heads, 1, head_dim) + + @staticmethod + def backward(ctx, do): + raise NotImplementedError("Fused decode backward is not implemented yet") + + +attention_decode_fused = _attention_decode_fused.apply + + +@register_impl("fmha_decode_fused", backend="cutile") +def fmha_decode_fused(q, k, v, sm_scale, kv_len_per_split=None, **kwargs): + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) + return attention_decode_fused(q, k, v, sm_scale, kv_len_per_split) + diff --git a/src/tilegym/ops/cutile/group_gemm.py b/src/tilegym/ops/cutile/group_gemm.py index 00e86b2..4353f83 100644 --- a/src/tilegym/ops/cutile/group_gemm.py +++ b/src/tilegym/ops/cutile/group_gemm.py @@ -25,6 +25,7 @@ def group_gemm_kernel( TILE_K: ConstInt, num_sm: ConstInt, transpose_b: ConstBool, + GROUP_SIZE_M: ConstInt, ): tile_idx = ct.bid(0) last_problem_end = 0 @@ -49,9 +50,16 @@ def group_gemm_kernel( # Process tiles for this group using persistent scheduling while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: + # Convert linear tile id -> (tile_m_idx, tile_n_idx) with a 2D swizzle + # to improve L2 locality (Triton-style GROUP_M swizzle). tile_idx_in_gemm = tile_idx - last_problem_end - tile_m_idx = tile_idx_in_gemm // num_n_tiles - tile_n_idx = tile_idx_in_gemm % num_n_tiles + num_pid_in_group = GROUP_SIZE_M * num_n_tiles + group_id = tile_idx_in_gemm // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = ct.minimum(num_m_tiles - first_pid_m, GROUP_SIZE_M) + pid_in_group = tile_idx_in_gemm % num_pid_in_group + tile_m_idx = first_pid_m + (pid_in_group % group_size_m) + tile_n_idx = pid_in_group // group_size_m # Initialize accumulator acc = ct.zeros((TILE_M, TILE_N), dtype=ct.float32) @@ -120,6 +128,8 @@ def group_gemm( "TILE_M": 128, "TILE_N": 128, "TILE_K": 64, + # 2D swizzle parameter (Triton-style GROUP_M). Tune for speed. + "GROUP_SIZE_M": 8, "num_ctas": None, # Let compiler auto-pick } user_cfg = kwargs.get("kernel_configs") @@ -130,6 +140,7 @@ def group_gemm( TILE_M = kernel_configs.get("TILE_M") TILE_N = kernel_configs.get("TILE_N") TILE_K = kernel_configs.get("TILE_K") + GROUP_SIZE_M = kernel_configs.get("GROUP_SIZE_M") num_ctas = kernel_configs.get("num_ctas", None) occupancy = kernel_configs.get("occupancy", None) @@ -144,9 +155,11 @@ def group_gemm( group_C.append(C) kernel = group_gemm_kernel - # When num_ctas is specified, adjust grid size to account for multiple CTAs per SM + # When num_ctas is specified, treat it as "CTAs per SM" and increase the total + # number of programs accordingly. This value must match the stride we pass into + # the kernel for persistent scheduling (see `tile_idx += num_sm`). num_ctas_for_grid = num_ctas if num_ctas is not None else 1 - grid_size = NUM_SMS // num_ctas_for_grid + grid_size = NUM_SMS * num_ctas_for_grid grid = (grid_size,) logger.debug(f"[cuTile] group_gemm launching with grid={grid}, num_ctas={num_ctas}, NUM_SMS={NUM_SMS}") @@ -164,6 +177,7 @@ def group_gemm( TILE_K, grid_size, # Use adjusted grid size for persistent scheduling stride transpose_b, + GROUP_SIZE_M, ), ) diff --git a/src/tilegym/ops/ops.py b/src/tilegym/ops/ops.py index afb86ce..b1bc310 100644 --- a/src/tilegym/ops/ops.py +++ b/src/tilegym/ops/ops.py @@ -294,6 +294,34 @@ def fmha_decode( raise NotImplementedError(f"fmha_decode is not implemented for {get_current_backend()}") +@dispatch( + "fmha_decode_fused", +) +def fmha_decode_fused( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sm_scale: Optional[float], + kv_len_per_split: Optional[int] = None, + **kwargs: Any, +): + """ + Experimental fused FMHA decode (split-kv + reduction in one kernel). + + Args: + q: Query tensor of shape (B, H_q, 1, D) + k: Key tensor of shape (B, H_kv, S_kv, D) + v: Value tensor of shape (B, H_kv, S_kv, D) + sm_scale: Scale factor for attention computation + kv_len_per_split: Optional KV length per split for parallelization + **kwargs: Additional arguments for backend-specific configurations + + Returns: + Output tensor of shape (B, H_q, 1, D) + """ + raise NotImplementedError(f"fmha_decode_fused is not implemented for {get_current_backend()}") + + @dispatch( "mla", ) diff --git a/tests/benchmark/README.md b/tests/benchmark/README.md index b2e5cec..830f68d 100644 --- a/tests/benchmark/README.md +++ b/tests/benchmark/README.md @@ -29,6 +29,8 @@ python bench_matrix_multiplication.py Available benchmark scripts: - `bench_dropout.py` - `bench_fused_attention.py` +- `bench_flash_decode_fused.py` +- `bench_group_gemm.py` - `bench_matrix_multiplication.py` - `bench_mix_triton_cutile.py` - `bench_mla.py` diff --git a/tests/benchmark/bench_flash_decode_fused.py b/tests/benchmark/bench_flash_decode_fused.py new file mode 100644 index 0000000..c98ac3a --- /dev/null +++ b/tests/benchmark/bench_flash_decode_fused.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math + +import torch +import triton + +import tilegym +from tilegym.backend import is_backend_available + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +ALL_IMPLS = [ + ("two_kernel", "CuTile (2-kernel)", ("orange", "-")), + ("fused", "CuTile (fused)", ("blue", "-")), +] + + +def _tflops(ms: float, B: int, H_q: int, S_kv: int, D: int) -> float: + # Rough FLOP model for single-token decode: + # - QK: 2 * B * H_q * S_kv * D + # - PV: 2 * B * H_q * S_kv * D + # total ~ 4 * B * H_q * S_kv * D + flops = 4 * B * H_q * S_kv * D + return flops * 1e-12 / (ms * 1e-3) + + +def create_benchmark_config(dtype: torch.dtype, group_size: int): + dtype_name = str(dtype).split(".")[-1] + return triton.testing.Benchmark( + x_names=["S_kv"], + x_vals=[9, 119, 256, 512, 1024, 2048, 4096, 8192], + line_arg="impl", + line_vals=[i for (i, _, _) in ALL_IMPLS], + line_names=[n for (_, n, _) in ALL_IMPLS], + styles=[s for (_, _, s) in ALL_IMPLS], + xlabel="S_kv", + ylabel="TFLOPS", + plot_name=f"flash-decode-fused-vs-2kernel-g{group_size}-{dtype_name}-TFLOPS", + args={"dtype": dtype, "group_size": group_size}, + ) + + +configs = [create_benchmark_config(torch.float16, g) for g in [1, 4, 8]] + + +@triton.testing.perf_report(configs) +def benchmark(S_kv: int, impl: str, dtype: torch.dtype, group_size: int): + if not is_backend_available("cutile"): + raise RuntimeError("cutile backend unavailable") + tilegym.set_backend("cutile") + + # Match test defaults + B = 2 + H_q = 32 + D = 64 + H_kv = H_q // group_size + sm_scale = 1.0 / math.sqrt(D) + + torch.manual_seed(0) + q = torch.randn(B, H_q, 1, D, device=DEVICE, dtype=dtype) + k = torch.randn(B, H_kv, S_kv, D, device=DEVICE, dtype=dtype) + v = torch.randn(B, H_kv, S_kv, D, device=DEVICE, dtype=dtype) + + if impl == "two_kernel": + fn = lambda: tilegym.ops.fmha_decode(q=q, k=k, v=v, sm_scale=sm_scale) + elif impl == "fused": + fn = lambda: tilegym.ops.fmha_decode_fused(q=q, k=k, v=v, sm_scale=sm_scale) + else: + raise ValueError(f"Unknown impl: {impl}") + + # Lightweight correctness check vs baseline for smaller sizes + if S_kv <= 512 and impl == "fused": + ref = tilegym.ops.fmha_decode(q=q, k=k, v=v, sm_scale=sm_scale) + out = fn() + torch.testing.assert_close(out, ref, atol=2e-2, rtol=2e-2) + + # NOTE: we use non-cudagraph timing here because these ops allocate internal + # workspaces; cudagraph capture may fail depending on allocator behavior. + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + perf = lambda t: _tflops(float(t), B, H_q, S_kv, D) + return perf(ms), perf(max_ms), perf(min_ms) + + +def print_speedup_table(): + if not is_backend_available("cutile"): + print("\n[bench_flash_decode_fused] CuTile backend unavailable; skipping speedup table.") + return + tilegym.set_backend("cutile") + + B, H_q, D = 2, 32, 64 + sm_scale = 1.0 / math.sqrt(D) + + print("\n" + "=" * 84) + print("Speedup: fused / two-kernel (computed from mean-ms timings; higher is better)") + print("=" * 84) + + for group_size in [1, 4, 8]: + H_kv = H_q // group_size + print(f"\n[group_size={group_size}]") + header = f"{'S_kv':>6} {'TFLOPS_2k':>12} {'TFLOPS_fused':>13} {'speedup':>9}" + print(header) + print("-" * len(header)) + + for S_kv in [9, 119, 256, 512, 1024, 2048, 4096, 8192]: + torch.manual_seed(0) + q = torch.randn(B, H_q, 1, D, device=DEVICE, dtype=torch.float16) + k = torch.randn(B, H_kv, S_kv, D, device=DEVICE, dtype=torch.float16) + v = torch.randn(B, H_kv, S_kv, D, device=DEVICE, dtype=torch.float16) + + fn_2k = lambda: tilegym.ops.fmha_decode(q=q, k=k, v=v, sm_scale=sm_scale) + fn_fused = lambda: tilegym.ops.fmha_decode_fused(q=q, k=k, v=v, sm_scale=sm_scale) + + ms_2k = float(triton.testing.do_bench(fn_2k, warmup=5, rep=20, return_mode="mean")) + ms_fused = float(triton.testing.do_bench(fn_fused, warmup=5, rep=20, return_mode="mean")) + + t2 = _tflops(ms_2k, B, H_q, S_kv, D) + tf = _tflops(ms_fused, B, H_q, S_kv, D) + sp = tf / max(t2, 1e-9) + print(f"{S_kv:6d} {t2:12.2f} {tf:13.2f} {sp:9.2f}x") + + +if __name__ == "__main__": + benchmark.run(print_data=True) + print_speedup_table() + diff --git a/tests/benchmark/bench_group_gemm.py b/tests/benchmark/bench_group_gemm.py new file mode 100644 index 0000000..5b9a41f --- /dev/null +++ b/tests/benchmark/bench_group_gemm.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import torch +import triton + +import tilegym +from tilegym.backend import is_backend_available +from tilegym.backend import register_impl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +def reference_group_gemm( + group_A, + group_B, + static_persistent: bool = True, # Unused - kept for interface compatibility + use_tma: bool = True, # Unused - kept for interface compatibility + transpose_b: bool = False, + **kwargs, +): + """Reference implementation using PyTorch (loop over groups).""" + if transpose_b: + return [torch.matmul(A, B.transpose(-2, -1)) for A, B in zip(group_A, group_B)] + return [torch.matmul(A, B) for A, B in zip(group_A, group_B)] + + +register_impl("group_gemm", "torch")(reference_group_gemm) + + +# Available backends with their display names and plot styles +ALL_BACKENDS = [ + ("cutile", "CuTile", ("orange", "-")) if is_backend_available("cutile") else None, + ("torch", "PyTorch", ("green", "-")), +] + + +def get_supported_backends(): + return [p for p in ALL_BACKENDS if p is not None] + + +def create_benchmark_config(dtype, group_size: int): + available_backends = get_supported_backends() + if not available_backends: + return None + + backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] + + # Keep ranges modest to avoid OOM in multi-group runs + # (note: each data point allocates G*(A,B,C)). + max_range = 14 if group_size >= 8 else 15 # 2^14..2^15 sized squares + + return triton.testing.Benchmark( + x_names=["M", "N", "K"], + x_vals=[2**i for i in range(10, max_range)], # square GEMMs + line_arg="backend", + line_vals=list(backends), + line_names=list(names), + styles=list(styles), + xlabel="M/N/K", + ylabel="TFLOPS", + plot_name=f"group-gemm-g{group_size}-{dtype_name}-TFLOPS", + args={"dtype": dtype, "group_size": group_size}, + ) + + +configs = [] +for dtype in [torch.float16, torch.bfloat16]: + for group_size in [1, 2, 4, 8, 16]: + cfg = create_benchmark_config(dtype, group_size) + if cfg is not None: + configs.append(cfg) + + +def _tflops_from_ms(ms: float, M: int, N: int, K: int, group_size: int) -> float: + total_flops = group_size * (2 * M * N * K) + return total_flops * 1e-12 / (ms * 1e-3) + + +def _measure_ms(fn, *, warmup: int = 5, rep: int = 20) -> float: + # NOTE: do_bench_cudagraph can't be used with cuTile list args today. + return float(triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=None, return_mode="mean")) + + +def print_speedup_summary( + *, + dtypes=(torch.float16, torch.bfloat16), + group_sizes=(1, 2, 4, 8, 16), + group_size_ms=(1, 2, 4, 8, 16), + m_exponents=range(10, 15), + warmup: int = 5, + rep: int = 20, +): + """ + Print a lightweight speedup table (CuTile vs PyTorch). + + This is printed in addition to the standard triton perf_report output. + """ + if not (is_backend_available("cutile")): + print("\n[bench_group_gemm] CuTile backend unavailable; skipping speedup summary.") + return + + print("\n" + "=" * 92) + print("Speedup summary: CuTile / PyTorch (computed from mean-ms timings; higher is better)") + print("=" * 92) + + for dtype in dtypes: + dtype_name = str(dtype).split(".")[-1] + print(f"\n--- dtype={dtype_name} ---") + for g in group_sizes: + # Match the perf_report sizing logic (smaller upper bound for larger groups). + max_exp = 14 if g >= 8 else 15 + exps = [e for e in m_exponents if e < max_exp] + + print(f"\n[group_size={g}]") + header = ( + f"{'M=N=K':>10} " + f"{'TFLOPS_torch':>12} " + f"{'TFLOPS_cutile(best)':>18} " + f"{'GROUP_SIZE_M':>12} " + f"{'speedup':>9}" + ) + print(header) + print("-" * len(header)) + + for e in exps: + M = N = K = 2**e + group_A = [torch.randn((M, K), device=DEVICE, dtype=dtype) for _ in range(g)] + group_B = [torch.randn((K, N), device=DEVICE, dtype=dtype) for _ in range(g)] + + fn_torch = lambda: reference_group_gemm(group_A, group_B, transpose_b=False) + ms_torch = _measure_ms(fn_torch, warmup=warmup, rep=rep) + tflops_torch = _tflops_from_ms(ms_torch, M, N, K, g) + + # Sweep GROUP_SIZE_M and take the best cuTile throughput. + best_tflops_cutile = -1.0 + best_group_size_m = None + for gsm in group_size_ms: + fn_cutile = lambda gsm=gsm: tilegym.ops.group_gemm( + group_A, + group_B, + static_persistent=True, + use_tma=True, + transpose_b=False, + backend="cutile", + kernel_configs={"GROUP_SIZE_M": int(gsm)}, + ) + ms_cutile = _measure_ms(fn_cutile, warmup=warmup, rep=rep) + tflops_cutile = _tflops_from_ms(ms_cutile, M, N, K, g) + if tflops_cutile > best_tflops_cutile: + best_tflops_cutile = tflops_cutile + best_group_size_m = int(gsm) + + speedup = best_tflops_cutile / max(tflops_torch, 1e-9) + print( + f"{M:10d} " + f"{tflops_torch:12.2f} " + f"{best_tflops_cutile:18.2f} " + f"{best_group_size_m:12d} " + f"{speedup:9.2f}x" + ) + + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, backend, dtype, group_size): + # Build a homogeneous group of GEMMs: (M,K) @ (K,N) + group_A = [torch.randn((M, K), device=DEVICE, dtype=dtype) for _ in range(group_size)] + group_B = [torch.randn((K, N), device=DEVICE, dtype=dtype) for _ in range(group_size)] + + quantiles = [0.5, 0.2, 0.8] + + fn = lambda: tilegym.ops.group_gemm( + group_A, + group_B, + static_persistent=True, + use_tma=True, + transpose_b=False, + backend=backend, + ) + + # Quick correctness check (small-ish only to avoid spending too much time in validation) + if M <= 2048 and backend != "torch": + out = fn() + ref = reference_group_gemm(group_A, group_B) + for o, r in zip(out, ref): + torch.testing.assert_close(o, r, atol=2e-2, rtol=2e-2) + + # cuTile group_gemm currently passes Python lists into ct.launch, which is not + # compatible with CUDA Graph capture. Use non-cudagraph benchmarking. + # + # If/when list arguments become graph-capture-friendly, we can switch back to + # do_bench_cudagraph for slightly lower measurement noise. + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + + perf = lambda t_ms: _tflops_from_ms(float(t_ms), M, N, K, group_size) + return perf(ms), perf(max_ms), perf(min_ms) + + +if __name__ == "__main__": + benchmark.run(print_data=True) + print_speedup_summary() diff --git a/tests/ops/test_flash_decode_fused.py b/tests/ops/test_flash_decode_fused.py new file mode 100644 index 0000000..df57b20 --- /dev/null +++ b/tests/ops/test_flash_decode_fused.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: MIT + +import math + +import pytest +import torch + +import tilegym + +from .. import common + + +class Test_FlashDecodeFused(common.PyTestCase): + @staticmethod + def reference(q, k, v, sm_scale): + torch.backends.cuda.mem_efficient_sdp_enabled() + return torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, enable_gqa=True) + + _backends = ["cutile"] + + @pytest.mark.parametrize("seq_len", [9, 119, 256, 2048]) + @pytest.mark.parametrize("group_size", [1, 4, 8]) + @pytest.mark.parametrize("dtype", [torch.float16]) + @pytest.mark.parametrize("backend", _backends) + def test_op(self, seq_len, group_size, dtype, backend, arch): + if tilegym.is_backend_available(backend): + tilegym.set_backend(backend) + else: + pytest.skip(f"Backend {backend} is not available") + self.setUp() + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available, skipping FMHA fused decode test") + + batch_size = 2 + num_heads = 32 + head_dim = 64 + + torch.manual_seed(42) + q = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda").to(dtype) + k = torch.randn( + batch_size, + num_heads // group_size, + seq_len, + head_dim, + device="cuda", + ).to(dtype) + v = torch.randn( + batch_size, + num_heads // group_size, + seq_len, + head_dim, + device="cuda", + ).to(dtype) + + sm_scale = 1.0 / math.sqrt(head_dim) + + self.assertCorrectness( + tilegym.ops.fmha_decode_fused, + self.reference, + {"q": q, "k": k, "v": v, "sm_scale": sm_scale}, + atol=2e-2, + rtol=2e-2, + check_stride=False, + ) +