diff --git a/examples/commons/ops/GroupedLinear_example.py b/examples/commons/ops/GroupedLinear_example.py new file mode 100644 index 000000000..71fe3633c --- /dev/null +++ b/examples/commons/ops/GroupedLinear_example.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped Linear Layer with Strided BMM Optimization + +================================================================================ +Problem +================================================================================ +Apply num_groups different linear transformations to corresponding slices of input: + + Input: x of shape (B * num_groups, input_dim) + Output: y of shape (B * num_groups, output_dim) + + For each group n: y[b, n, :] = x[b, n, :] @ W[n, :, :] + +================================================================================ +Reference Implementation +================================================================================ +The straightforward approach uses a loop over groups: + + x = x.reshape(B, num_groups, D_in) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(num_groups): + x_i = x_split[i].squeeze(1) # (B, D_in) + out_i = linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, D_out) + +================================================================================ +Optimized Implementation +================================================================================ +Use torch.bmm with strided output to fuse all GEMMs into one kernel: + + x = x.reshape(B, num_groups, D_in) + output = torch.empty(B, num_groups, D_out, ...) # pre-allocate final layout + torch.bmm(x.permute(1,0,2), weight, + out=output.permute(1,0,2)) # cuBLAS writes to strided memory + return output.view(-1, D_out) # O(1) view, no copy. + +Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC +parameters, allowing direct write to the transposed memory layout. + +================================================================================ +Performance Results +================================================================================ +Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16 +Device: NVIDIA H100 +BMM_Opt Forward: 1.46x +BMM_Opt Forward+Backward:1.41x + +================================================================================ +""" + +import argparse +from typing import List, Tuple + +import torch +import torch.nn as nn + + +def warmup_gpu(): + """Warmup GPU to get stable timing""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +class ReferenceImpl(nn.Module): + """ + Reference implementation using reshape + split + loop + stack. + Simple but slow due to multiple kernel launches. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear_layers = nn.ModuleList( + [ + nn.Linear(input_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ] + ) + + for layer in self.linear_layers: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # reshape: (B*ns, D) -> (B, ns, D) + x = x.reshape(-1, self.num_groups, self.input_dim) + + # split and loop + x_split = torch.split(x, 1, dim=1) + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) # (B, D) + out_i = self.linear_layers[i](x_i) # (B, D_out) + out_list.append(out_i) + + # stack: ns * (B, D_out) -> (B, ns, D_out) -> (B*ns, D_out) + return torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim, batch_first): + ctx.save_for_backward(x, weight) + ctx.batch_first = batch_first + + if batch_first: + # x: [B, G, D] -> need permute to [G, B, D] for bmm + output = torch.empty( + batch_size, num_groups, output_dim, device=x.device, dtype=x.dtype + ) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + else: + # x: [G, B, D] -> already in correct layout for bmm + output = torch.bmm(x, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + batch_first = ctx.batch_first + grad_x = grad_weight = None + + if batch_first: + # grad_output: [B, G, D_out] + grad_output_t = grad_output.permute(1, 0, 2) # [G, B, D_out] + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm( + grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2) + ) + + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm( + x.permute(1, 0, 2).transpose(-1, -2), grad_output_t + ) + else: + # grad_output: [G, B, D_out] + if ctx.needs_input_grad[0]: + grad_x = torch.bmm(grad_output, weight.transpose(-1, -2)) + + if ctx.needs_input_grad[1]: + grad_weight = torch.bmm(x.transpose(-1, -2), grad_output) + + return grad_x, grad_weight, None, None, None, None + + +class GroupedLinear(nn.Module): + """ + Grouped linear layer: applies num_groups different linear transforms in parallel. + Optimized using batched GEMM with strided output for single kernel launch. + + Args: + input_dim: Input feature dimension. + output_dim: Output feature dimension. + num_groups: Number of independent linear transforms. + batch_first: If True, input layout is [B, G, D] (batch-first, default). + If False, input layout is [G, B, D] (group-first). + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_groups: int, + device="cuda", + dtype=torch.bfloat16, + batch_first: bool = True, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.output_dim = output_dim + self.batch_first = batch_first + + self.weight = nn.Parameter( + torch.empty(num_groups, input_dim, output_dim, device=device, dtype=dtype) + ) + for i in range(num_groups): + nn.init.xavier_normal_(self.weight[i], gain=1.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if self.batch_first: + # Input flattened from [B, G, D] -> reshape to [B, G, D] + x = x.reshape(batch_size, self.num_groups, self.input_dim) + else: + # Input flattened from [G, B, D] -> reshape to [G, B, D] + x = x.reshape(self.num_groups, batch_size, self.input_dim) + + output = StridedBmmFunction.apply( + x, + self.weight, + batch_size, + self.num_groups, + self.output_dim, + self.batch_first, + ) + return output.view(-1, self.output_dim) + + +def copy_weights(ref_model: ReferenceImpl, opt_model: GroupedLinear): + """Copy weights from reference to optimized model.""" + with torch.no_grad(): + for i in range(ref_model.num_groups): + opt_model.weight[i].copy_(ref_model.linear_layers[i].weight.T) + + +def check_correctness( + ref_model: ReferenceImpl, + opt_model: GroupedLinear, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float, float]: + """Check forward and backward correctness.""" + # Forward check + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + # Backward check + x_ref = torch.randn( + batch_size * num_groups, + input_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + # Input gradient + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + # Weight gradient + ref_weight_grad = torch.stack( + [ref_model.linear_layers[i].weight.grad.T for i in range(num_groups)] + ) + bwd_w_diff = (ref_weight_grad - opt_model.weight.grad).abs().max().item() + + return fwd_diff, bwd_x_diff, bwd_w_diff + + +def benchmark( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + with_backward: bool = False, +) -> float: + """Benchmark forward or forward+backward pass using CUDA events for accurate GPU timing.""" + if with_backward: + x_list = [xi.requires_grad_(True) for xi in x_list] + grad_outputs = [ + torch.randn(xi.shape[0], model.output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + xi = x_list[i % len(x_list)] + out = model(xi) + if with_backward: + torch.autograd.grad(out, [xi] + params, grad_outputs[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations # ms + + +def main(): + parser = argparse.ArgumentParser( + description="Grouped GEMM: Reference vs Strided BMM" + ) + parser.add_argument( + "--batch-size", type=int, default=2560, help="Batch size to test" + ) + parser.add_argument("--num-groups", type=int, default=12, help="Number of groups") + parser.add_argument("--input-dim", type=int, default=1024, help="Input dimension") + parser.add_argument("--output-dim", type=int, default=3072, help="Output dimension") + parser.add_argument( + "--iterations", type=int, default=100, help="Number of iterations for timing" + ) + args = parser.parse_args() + + torch.cuda.init() + + # Configuration from args + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + output_dim = args.output_dim + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 60) + print("Grouped GEMM: Reference vs Strided BMM") + print("=" * 60) + print( + f"\nConfig: B={batch_size}, groups={num_groups}, D_in={input_dim}, D_out={output_dim}" + ) + print(f"Device: {torch.cuda.get_device_name(0)}") + + # Warmup GPU + print("\nWarming up GPU...") + warmup_gpu() + + # Create models + ref_model = ReferenceImpl(input_dim, output_dim, num_groups, dtype=dtype).cuda() + opt_model = GroupedLinear(input_dim, output_dim, num_groups, dtype=dtype).cuda() + copy_weights(ref_model, opt_model) + + # Correctness check + print("\n" + "-" * 40) + print("Correctness Check") + print("-" * 40) + fwd_diff, bwd_x_diff, bwd_w_diff = check_correctness( + ref_model, opt_model, batch_size, num_groups, input_dim, dtype + ) + print(f"Forward max diff: {fwd_diff:.2e} {'✓' if fwd_diff < 1e-3 else '✗'}") + print(f"Backward dL/dx diff: {bwd_x_diff:.2e} {'✓' if bwd_x_diff < 1e-3 else '✗'}") + print(f"Backward dL/dW diff: {bwd_w_diff:.2e} {'✓' if bwd_w_diff < 1e-3 else '✗'}") + + # Benchmark + print("\n" + "-" * 40) + print("Performance Benchmark") + print("-" * 40) + + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Forward only + ref_fwd = benchmark(ref_model, x_list, num_iterations, with_backward=False) + opt_fwd = benchmark(opt_model, x_list, num_iterations, with_backward=False) + + print(f"\nForward pass (ms):") + print(f" Reference (loop): {ref_fwd:.4f}") + print(f" GroupedLinear: {opt_fwd:.4f}") + print(f" Speedup: {ref_fwd/opt_fwd:.2f}x") + + # Forward + Backward + ref_fwdbwd = benchmark(ref_model, x_list, num_iterations, with_backward=True) + opt_fwdbwd = benchmark(opt_model, x_list, num_iterations, with_backward=True) + + print(f"\nForward + Backward (ms):") + print(f" Reference (loop): {ref_fwdbwd:.4f}") + print(f" GroupedLinear: {opt_fwdbwd:.4f}") + print(f" Speedup: {ref_fwdbwd/opt_fwdbwd:.2f}x") + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/commons/ops/GroupedMLP_example.py b/examples/commons/ops/GroupedMLP_example.py new file mode 100644 index 000000000..8edac2dcc --- /dev/null +++ b/examples/commons/ops/GroupedMLP_example.py @@ -0,0 +1,909 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Grouped MLP Benchmark: Reference vs Plan A + +================================================================================ +Problem +================================================================================ +Apply num_groups different MLP transformations with GLU gating (SwiGLU/GeGLU): + + For each group n: y[b, n, :] = down(act(gate(x)) * up(x)) + +================================================================================ +Implementations +================================================================================ +Reference: Loop over groups with separate nn.Linear layers +Plan A: 3 independent strided BMMs (gate, up, down) + +================================================================================ +""" +import sys +sys.path.insert(0, '/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu') + +import argparse +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.nvtx as nvtx +import triton +import triton.language as tl +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + +from ops.triton_ops.common import triton_autotune + +def silu_configs(): + configs = [] + for x_block_size in [256, 512, 1024, 2048]: + for num_warps in [2, 4, 8, 16]: + config = triton.Config({"x_block_size": x_block_size}, num_warps) + configs.append(config) + return configs + + + + + +# ============================================================================= +# Fused SiLU * Up (SwiGLU pattern): output = silu(gate) * up +# ============================================================================= + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_forward( + output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """Fused forward: output = silu(gate) * up""" + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + # silu(gate) = gate * sigmoid(gate) = gate / (1 + exp(-gate)) + silu_gate = fast_dividef(gate, 1.0 + tl.exp(-gate)) + output = (silu_gate * up).to(output_ptr.dtype.element_ty) + + tl.store(output_ptr + x_offset + cols, output, mask=mask) + + +@triton_autotune(silu_configs(), key=["x_size"]) +@triton.jit +def _silu_mul_backward( + grad_gate_ptr: tl.tensor, + grad_up_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + gate_ptr: tl.tensor, + up_ptr: tl.tensor, + x_size: tl.int32, + x_block_size: tl.constexpr, +): + """ + Fused backward for output = silu(gate) * up + + grad_gate = grad_output * up * d(silu)/d(gate) + = grad_output * up * (sigmoid(gate) + gate * sigmoid(gate) * (1 - sigmoid(gate))) + grad_up = grad_output * silu(gate) + """ + x_offset = tl.program_id(0) * x_block_size + mask = x_offset + tl.arange(0, x_block_size) < x_size + cols = tl.arange(0, x_block_size) + + grad_output = tl.load(grad_output_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + gate = tl.load(gate_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + up = tl.load(up_ptr + x_offset + cols, mask=mask, other=0.0).to(tl.float32) + + sigma = tl.sigmoid(gate) + silu_gate = gate * sigma + + # d(silu)/d(gate) = sigma + gate * sigma * (1 - sigma) + dsilu_dgate = sigma + gate * sigma * (1.0 - sigma) + + grad_gate = grad_output * up * dsilu_dgate + grad_up = grad_output * silu_gate + + tl.store(grad_gate_ptr + x_offset + cols, grad_gate.to(grad_gate_ptr.dtype.element_ty), mask=mask) + tl.store(grad_up_ptr + x_offset + cols, grad_up.to(grad_up_ptr.dtype.element_ty), mask=mask) + + +def triton_silu_mul_fwd(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """Forward: output = silu(gate) * up""" + assert gate.shape == up.shape, f"Shape mismatch: gate {gate.shape} vs up {up.shape}" + x_size = gate.numel() + gate_1d = gate.view(-1).contiguous() + up_1d = up.view(-1).contiguous() + output = torch.empty_like(gate_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_forward[grid]( + output, + gate_1d, + up_1d, + x_size, + ) + return output.view(gate.shape) + + +def triton_silu_mul_bwd( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward: returns (grad_gate, grad_up)""" + shape = gate.shape + x_size = gate.numel() + gate_1d = gate.view(-1).contiguous() + up_1d = up.view(-1).contiguous() + grad_output_1d = grad_output.view(-1).contiguous() + grad_gate = torch.empty_like(gate_1d) + grad_up = torch.empty_like(up_1d) + + def grid(meta): + return (triton.cdiv(x_size, meta["x_block_size"]),) + + _silu_mul_backward[grid]( + grad_gate, + grad_up, + grad_output_1d, + gate_1d, + up_1d, + x_size, + ) + return grad_gate.view(shape), grad_up.view(shape) + + +class TritonSiluMul(torch.autograd.Function): + """Autograd function for fused silu(gate) * up""" + + @staticmethod + def forward(ctx, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + output = triton_silu_mul_fwd(gate, up) + ctx.save_for_backward(gate, up) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + gate, up = ctx.saved_tensors + grad_gate, grad_up = triton_silu_mul_bwd(grad_output, gate, up) + return grad_gate, grad_up + + +def triton_silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + Fused SiLU multiplication (SwiGLU pattern). + + Computes: output = silu(gate) * up + + Args: + gate: Input tensor that goes through SiLU activation + up: Input tensor that multiplies with activated gate + + Returns: + output: silu(gate) * up + """ + gate = gate.contiguous() + up = up.contiguous() + return TritonSiluMul.apply(gate, up) + + +def warmup_gpu(): + """Warmup GPU to get stable timing.""" + x = torch.randn(1000, 1000, device="cuda") + for _ in range(10): + _ = x @ x + torch.cuda.synchronize() + + +def get_activation_fn(activation: Optional[str]) -> Optional[Callable]: + """Get activation function by name.""" + if activation is None: + return None + activation_map = { + "silu": F.silu, + "swish": F.silu, + "gelu": F.gelu, + "relu": F.relu, + "tanh": torch.tanh, + "sigmoid": torch.sigmoid, + "swiglu": triton_silu_mul, + } + if activation.lower() not in activation_map: + raise ValueError(f"Unknown activation: {activation}") + return activation_map[activation.lower()] + + +# ============================================================================= +# Reference Implementation +# ============================================================================= + +class ReferenceGroupedMLP(nn.Module): + """Reference implementation using loop over groups.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "swiglu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + self.up_proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + else: + self.proj = nn.ModuleList([ + nn.Linear(input_dim, hidden_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + + self.down_proj = nn.ModuleList([ + nn.Linear(hidden_dim, output_dim, bias=False, device=device, dtype=dtype) + for _ in range(num_groups) + ]) + + self._init_weights() + + def _init_weights(self): + for module_list in [getattr(self, 'gate_proj', []), + getattr(self, 'up_proj', []), + getattr(self, 'proj', []), + self.down_proj]: + for layer in module_list: + nn.init.xavier_normal_(layer.weight, gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + if enable_nvtx: + with nvtx.range("Ref_reshape"): + x = x.reshape(-1, self.num_groups, self.input_dim) + + with nvtx.range("Ref_split"): + x_split = torch.split(x, 1, dim=1) + + with nvtx.range("Ref_loop_gemm"): + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + with nvtx.range("Ref_stack"): + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + else: + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + out_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + out_i = self.down_proj[i](hidden_i) + out_list.append(out_i) + + output = torch.stack(out_list, dim=1).reshape(-1, self.output_dim) + + return output + + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" + x = x.reshape(-1, self.num_groups, self.input_dim) + x_split = torch.split(x, 1, dim=1) + + hidden_list = [] + for i in range(self.num_groups): + x_i = x_split[i].squeeze(1) + if self.use_gating: + gate_i = self.gate_proj[i](x_i) + up_i = self.up_proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(gate_i) * up_i + else: + hidden_i = gate_i * up_i + else: + hidden_i = self.proj[i](x_i) + if self.act_fn is not None: + hidden_i = self.act_fn(hidden_i) + hidden_list.append(hidden_i) + + return torch.stack(hidden_list, dim=1).reshape(-1, self.hidden_dim) + + +# ============================================================================= +# Strided BMM Function +# ============================================================================= + +class StridedBmmFunction(torch.autograd.Function): + """Custom autograd function for BMM with strided output.""" + + @staticmethod + def forward(ctx, x, weight, batch_size, num_groups, output_dim): + ctx.save_for_backward(x, weight) + + output = torch.empty(batch_size, num_groups, output_dim, + device=x.device, dtype=x.dtype) + torch.bmm(x.permute(1, 0, 2), weight, out=output.permute(1, 0, 2)) + + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight = ctx.saved_tensors + grad_x = grad_weight = None + + grad_output_t = grad_output.permute(1, 0, 2) + + if ctx.needs_input_grad[0]: + grad_x = torch.empty_like(x) + torch.bmm(grad_output_t, weight.transpose(-1, -2), out=grad_x.permute(1, 0, 2)) + + if ctx.needs_input_grad[1]: + x_t = x.permute(1, 0, 2) + grad_weight = torch.bmm(x_t.transpose(-1, -2), grad_output_t) + + return grad_x, grad_weight, None, None, None + + +# ============================================================================= +# Plan A: 3 Independent BMMs +# ============================================================================= + +class GroupedMLP_PlanA(nn.Module): + """Plan A: 3 independent strided BMMs (gate, up, down).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_groups: int, + use_gating: bool = True, + activation: Optional[str] = "silu", + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_groups = num_groups + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.use_gating = use_gating + self.act_fn = get_activation_fn(activation) + + if use_gating: + self.gate_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + self.up_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + else: + self.proj_weight = nn.Parameter( + torch.empty(num_groups, input_dim, hidden_dim, device=device, dtype=dtype) + ) + + self.down_weight = nn.Parameter( + torch.empty(num_groups, hidden_dim, output_dim, device=device, dtype=dtype) + ) + + self._init_weights() + + def _init_weights(self): + for i in range(self.num_groups): + if self.use_gating: + nn.init.xavier_normal_(self.gate_weight[i], gain=1.0) + nn.init.xavier_normal_(self.up_weight[i], gain=1.0) + else: + nn.init.xavier_normal_(self.proj_weight[i], gain=1.0) + nn.init.xavier_normal_(self.down_weight[i], gain=1.0) + + def forward(self, x: torch.Tensor, enable_nvtx: bool = False) -> torch.Tensor: + batch_size = x.shape[0] // self.num_groups + + if enable_nvtx: + with nvtx.range("PlanA_reshape"): + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + with nvtx.range("PlanA_gate_bmm"): + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + + with nvtx.range("PlanA_up_bmm"): + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + hidden = self.act_fn(gate) * up + else: + hidden = gate * up + else: + with nvtx.range("PlanA_proj_bmm"): + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + with nvtx.range("PlanA_activation"): + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + with nvtx.range("PlanA_down_bmm"): + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + with nvtx.range("PlanA_view"): + return output.view(-1, self.output_dim) + else: + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + # hidden = self.act_fn(gate) * up + hidden = triton_silu_mul(gate, up) + else: + hidden = gate * up + else: + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + output = StridedBmmFunction.apply( + hidden, self.down_weight, batch_size, self.num_groups, self.output_dim + ) + + return output.view(-1, self.output_dim) + + def forward_to_hidden(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass up to hidden (excluding down projection).""" + batch_size = x.shape[0] // self.num_groups + x = x.reshape(batch_size, self.num_groups, self.input_dim) + + if self.use_gating: + gate = StridedBmmFunction.apply( + x, self.gate_weight, batch_size, self.num_groups, self.hidden_dim + ) + up = StridedBmmFunction.apply( + x, self.up_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = triton_silu_mul(gate, up) + else: + hidden = gate * up + else: + hidden = StridedBmmFunction.apply( + x, self.proj_weight, batch_size, self.num_groups, self.hidden_dim + ) + if self.act_fn is not None: + hidden = self.act_fn(hidden) + + return hidden.view(-1, self.hidden_dim) + + +# ============================================================================= +# Weight Copy Utilities +# ============================================================================= + +def copy_weights_to_plan_a(ref_model: ReferenceGroupedMLP, opt_model: GroupedMLP_PlanA): + with torch.no_grad(): + num_groups = ref_model.num_groups + if ref_model.use_gating: + for i in range(num_groups): + opt_model.gate_weight[i].copy_(ref_model.gate_proj[i].weight.T) + opt_model.up_weight[i].copy_(ref_model.up_proj[i].weight.T) + else: + for i in range(num_groups): + opt_model.proj_weight[i].copy_(ref_model.proj[i].weight.T) + for i in range(num_groups): + opt_model.down_weight[i].copy_(ref_model.down_proj[i].weight.T) + + +# ============================================================================= +# Correctness Check +# ============================================================================= + +def check_correctness( + ref_model: nn.Module, + opt_model: nn.Module, + batch_size: int, + num_groups: int, + input_dim: int, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[float, float]: + """Check forward and backward correctness.""" + x = torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + with torch.no_grad(): + ref_out = ref_model(x) + opt_out = opt_model(x) + fwd_diff = (ref_out - opt_out).abs().max().item() + + x_ref = torch.randn( + batch_size * num_groups, input_dim, + device="cuda", dtype=dtype, requires_grad=True + ) + x_opt = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_model(x_ref) + opt_out = opt_model(x_opt) + + grad_output = torch.randn_like(ref_out) + ref_out.backward(grad_output) + opt_out.backward(grad_output) + + bwd_x_diff = (x_ref.grad - x_opt.grad).abs().max().item() + + return fwd_diff, bwd_x_diff + + +# ============================================================================= +# Benchmark Functions (same as benchmark_batched_gemm.py) +# ============================================================================= + +def benchmark_forward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward pass using CUDA events for accurate GPU timing.""" + model_name = model.__class__.__name__ + + # Warmup + for i in range(num_warmup): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwd_iter{i}"): + _ = model(x_list[i % len(x_list)], enable_nvtx=True) + else: + for i in range(num_iterations): + _ = model(x_list[i % len(x_list)], enable_nvtx=False) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +def benchmark_forward_to_hidden( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, +) -> float: + """Benchmark forward pass up to hidden (excluding down projection).""" + # Warmup + for i in range(num_warmup): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + for i in range(num_iterations): + _ = model.forward_to_hidden(x_list[i % len(x_list)]) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +def benchmark_forward_backward( + model: nn.Module, + x_list: List[torch.Tensor], + num_iterations: int = 100, + num_warmup: int = 10, + enable_nvtx: bool = False, +) -> float: + """Benchmark forward + backward pass using CUDA events.""" + model_name = model.__class__.__name__ + output_dim = model.output_dim + + grad_outputs = [ + torch.randn(xi.shape[0], output_dim, device="cuda", dtype=xi.dtype) + for xi in x_list + ] + + x_with_grad = [xi.requires_grad_(True) for xi in x_list] + params = list(model.parameters()) + + # Warmup + for i in range(num_warmup): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + torch.cuda.synchronize() + + # Create CUDA events + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Benchmark + start_event.record() + if enable_nvtx: + for i in range(num_iterations): + with nvtx.range(f"{model_name}_fwdbwd_iter{i}"): + xi = x_with_grad[i % len(x_list)] + grad_out = grad_outputs[i % len(x_list)] + with nvtx.range("forward"): + out = model(xi, enable_nvtx=True) + # Separate backward into two parts for clearer profiling + with nvtx.range("backward_activation"): + # dL/dx (activation gradient) + grad_x = torch.autograd.grad( + outputs=out, + inputs=xi, + grad_outputs=grad_out, + retain_graph=True, # Keep graph for weight gradient + ) + with nvtx.range("backward_weight"): + # dL/dW (weight gradient) + grad_w = torch.autograd.grad( + outputs=out, + inputs=params, + grad_outputs=grad_out, + retain_graph=False, + ) + else: + for i in range(num_iterations): + xi = x_with_grad[i % len(x_list)] + out = model(xi, enable_nvtx=False) + grads = torch.autograd.grad( + outputs=out, + inputs=[xi] + params, + grad_outputs=grad_outputs[i % len(x_list)], + ) + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_iterations + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="Grouped MLP Benchmark: Reference vs Plan A" + ) + parser.add_argument("--batch-size", type=int, default=2560) + parser.add_argument("--num-groups", type=int, default=12) + parser.add_argument("--input-dim", type=int, default=1024) + parser.add_argument("--hidden-dim", type=int, default=3072) + parser.add_argument("--output-dim", type=int, default=1024) + parser.add_argument("--activation", type=str, default="silu", + choices=["silu", "gelu", "relu", "tanh", "none"]) + parser.add_argument("--no-gating", action="store_true") + parser.add_argument("--iterations", type=int, default=100) + parser.add_argument("--enable-nvtx", action="store_true", + help="Enable NVTX markers (use with nsys profile)") + parser.add_argument("--compile", action="store_true", + help="Use torch.compile() to optimize models") + args = parser.parse_args() + + torch.cuda.init() + + batch_size = args.batch_size + num_groups = args.num_groups + input_dim = args.input_dim + hidden_dim = args.hidden_dim + output_dim = args.output_dim + activation = None if args.activation == "none" else args.activation + use_gating = not args.no_gating + dtype = torch.bfloat16 + num_iterations = args.iterations + + print("=" * 80) + print("Grouped MLP Benchmark: Reference vs Plan A") + print("=" * 80) + + if args.enable_nvtx: + print("\n*** NVTX PROFILING MODE ***") + print("Run with: nsys profile -o --trace=cuda,nvtx python ...") + + print(f""" +Config: + Batch size: {batch_size} + Num groups: {num_groups} + Dimensions: {input_dim} -> {hidden_dim} -> {output_dim} + Mode: {"GLU (SwiGLU)" if use_gating else "Simple MLP"} + Activation: {activation if activation else "None"} + Dtype: {dtype} + Device: {torch.cuda.get_device_name(0)} + Iterations: {num_iterations} +""") + + print("Warming up GPU...") + warmup_gpu() + + # Create models + print("Creating models...") + ref_model = ReferenceGroupedMLP( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + plan_a_model = GroupedMLP_PlanA( + input_dim, hidden_dim, output_dim, num_groups, + use_gating=use_gating, activation=activation, dtype=dtype + ).cuda() + + copy_weights_to_plan_a(ref_model, plan_a_model) + + # Apply torch.compile() if requested + if args.compile: + print("\nApplying torch.compile() to all models...") + ref_model = torch.compile(ref_model) + plan_a_model = torch.compile(plan_a_model) + print("Compilation complete (will JIT compile on first run).") + + # Correctness check + print("-" * 60) + print("Correctness Check") + print("-" * 60) + + fwd_a, bwd_a = check_correctness(ref_model, plan_a_model, batch_size, num_groups, input_dim, dtype) + print(f"Plan A - Forward diff: {fwd_a:.2e}, Backward diff: {bwd_a:.2e}") + + # Prepare test data + x_list = [ + torch.randn(batch_size * num_groups, input_dim, device="cuda", dtype=dtype) + for _ in range(10) + ] + + # Benchmark (NEVER use NVTX for timing - NVTX adds Python overhead) + print("\n" + "-" * 60) + print("Performance Benchmark (NVTX disabled for accurate timing)") + print("-" * 60) + + # Forward - always benchmark without NVTX + print("\n>>> Forward Pass <<<") + ref_fwd = benchmark_forward(ref_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwd = benchmark_forward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwd:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (batched BMM)':<30} {plan_a_fwd:<12.4f} {ref_fwd/plan_a_fwd:<10.2f}x") + + # Forward to Hidden (excluding down projection) + print("\n>>> Forward to Hidden (excluding down_proj) <<<") + ref_hidden = benchmark_forward_to_hidden(ref_model, x_list, num_iterations) + plan_a_hidden = benchmark_forward_to_hidden(plan_a_model, x_list, num_iterations) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_hidden:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (batched BMM)':<30} {plan_a_hidden:<12.4f} {ref_hidden/plan_a_hidden:<10.2f}x") + + # Forward + Backward - always benchmark without NVTX + print("\n>>> Forward + Backward <<<") + ref_fwdbwd = benchmark_forward_backward(ref_model, x_list, num_iterations, enable_nvtx=False) + plan_a_fwdbwd = benchmark_forward_backward(plan_a_model, x_list, num_iterations, enable_nvtx=False) + + print(f"\n{'Model':<30} {'Time (ms)':<12} {'Speedup':<10}") + print("-" * 52) + print(f"{'Reference (loop)':<30} {ref_fwdbwd:<12.4f} {'1.00x':<10}") + print(f"{'Plan A (batched BMM)':<30} {plan_a_fwdbwd:<12.4f} {ref_fwdbwd/plan_a_fwdbwd:<10.2f}x") + + # NVTX profiling run (separate from benchmark) + if args.enable_nvtx: + print("\n" + "-" * 60) + print("NVTX Profiling Run (for nsys analysis only)") + print("-" * 60) + torch.cuda.profiler.start() + + # Run a few iterations with NVTX for profiling + nvtx_iterations = min(10, num_iterations) + _ = benchmark_forward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(ref_model, x_list, nvtx_iterations, enable_nvtx=True) + _ = benchmark_forward_backward(plan_a_model, x_list, nvtx_iterations, enable_nvtx=True) + + torch.cuda.profiler.stop() + print("NVTX profiling complete.") + + # Summary + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + print(f""" +Implementation Details: + Reference: Loop over {num_groups} groups, uses nn.Linear (C++ autograd) + Plan A: Batched BMM with custom StridedBmmFunction + +Forward Speedup (full MLP): + Plan A vs Reference: {ref_fwd/plan_a_fwd:.2f}x + +Forward to Hidden (excluding down_proj): + Plan A vs Reference: {ref_hidden/plan_a_hidden:.2f}x + +Fwd+Bwd Speedup: + Plan A vs Reference: {ref_fwdbwd/plan_a_fwdbwd:.2f}x +""") + print("=" * 80) + print("Done!") + + +if __name__ == "__main__": + main()