diff --git a/CMakeLists.txt b/CMakeLists.txt index e4deecbe..61359a0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.23) -project(cudnn_frontend VERSION 1.22.1) +project(cudnn_frontend VERSION 1.23.0) option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) diff --git a/README.md b/README.md index ba010c00..e93898f1 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,17 @@ -# cuDNN FrontEnd(FE) +# cuDNN Frontend (FE) -**cuDNN FE** is the modern, open-source entry point to the NVIDIA cuDNN library and high performance open-source kernels. It provides a C++ header-only library and a Python interface to access the powerful cuDNN Graph API and open-source kernels. +[![PyPI version](https://img.shields.io/pypi/v/nvidia-cudnn-frontend.svg)](https://pypi.org/project/nvidia-cudnn-frontend/) +[![PyPI downloads](https://img.shields.io/pypi/dm/nvidia-cudnn-frontend.svg)](https://pypi.org/project/nvidia-cudnn-frontend/) +[![Python versions](https://img.shields.io/pypi/pyversions/nvidia-cudnn-frontend.svg)](https://pypi.org/project/nvidia-cudnn-frontend/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) +[![Docs](https://img.shields.io/badge/docs-nvidia.github.io-blue.svg)](https://nvidia.github.io/cudnn-frontend/) + +**cuDNN Frontend** is NVIDIA's modern, open-source entry point to the cuDNN library and a growing collection of high-performance open-source kernels — scaled dot-product attention (**SDPA / Flash Attention**), grouped GEMM fusions for **Mixture-of-Experts (MoE)** training, fused normalization + activation, and more. + +It provides a **header-only C++ API** and a **Python interface** (with native PyTorch integration) to the cuDNN Graph API, targeting NVIDIA **Hopper** (H100/H200) and **Blackwell** (B200/GB200/GB300) GPUs across FP16, BF16, FP8, and **MXFP8** precision. + +**Links:** [Documentation](https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/) · [Blog & Deep Dives](https://nvidia.github.io/cudnn-frontend/) · [PyPI](https://pypi.org/project/nvidia-cudnn-frontend/) · [Release Notes](https://github.com/NVIDIA/cudnn-frontend/releases) · [Samples](samples/) ## 🚀 Latest news: @@ -11,10 +21,15 @@ We are now shipping **OSS kernels**, allowing you to inspect, modify, and contri * **[GEMM + Amax](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_amax):** Optimized FP8 matrix multiplication with absolute maximum calculation. * **[GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_swiglu):** High-performance implementation of the SwiGLU activation fused with GEMM. +* **[GEMM + sReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_srelu):** High-performance implementation of squared-ReLU fused with GEMM. +* **[GEMM + dsReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_dsrelu):** High-performance implementation of dsquared-ReLU fused with GEMM. * **[Grouped GEMM + GLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_glu):** Unified grouped GEMM GLU API supporting dense and discrete MoE weight layouts. +* **[Grouped GEMM + GLU + Hadamard](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard):** Dense grouped GEMM GLU forward fusion with a fused Hadamard transform and per-expert AMAX reduction. * **[Grouped GEMM + dGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dglu):** Unified grouped GEMM dGLU backward API supporting dense and discrete MoE weight layouts. * **[Grouped GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu):** SwiGLU activation fused with Grouped GEMM. * **[Grouped GEMM + dSwiglu](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dswiglu):** dSwiglu activation fused with Grouped GEMM. +* **[Grouped GEMM + sReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_srelu):** Contiguous grouped squared-ReLU GEMM for MoE workloads. +* **[Grouped GEMM + dsReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dsrelu):** Contiguous grouped dsquared-ReLU GEMM for MoE workloads. * **[Discrete Grouped GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu):** Per-expert-pointer SwiGLU grouped GEMM for MoE workloads without weight packing. * **[Discrete Grouped GEMM + dSwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu):** Per-expert-pointer dSwiGLU backward grouped GEMM for MoE workloads without weight packing. * **[Grouped GEMM + Quant](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_quant):** Legacy dense-only grouped GEMM quant API for MoE FC2/dFC1 workloads. @@ -30,13 +45,13 @@ We are now shipping **OSS kernels**, allowing you to inspect, modify, and contri #### Llama 3.1 style Forward and Bprop with causal masking (GB300)

- Llama 3.1 SDPA Benchmark on GB300 (only cuDNN) + Llama 3.1 SDPA Benchmark on GB300 (only cuDNN)

#### Deepseek v3 style Forward and Bprop with causal masking (GB300)

- DSv3 SDPA Benchmark on GB300 (only cuDNN) + DSv3 SDPA Benchmark on GB300 (only cuDNN)

## Key Features @@ -56,8 +71,9 @@ pip install nvidia-cudnn-frontend ``` **Requirements:** -* Python 3.8+ +* Python 3.9+ * NVIDIA driver and CUDA Toolkit +* NVIDIA cuDNN (minimum 8.5.0) ### ⚙️ C++ (Header Only) @@ -93,9 +109,12 @@ cmake --build . -j16 ## Documentation & Examples -* **Developer Guide:** [Official NVIDIA Documentation](https://docs.nvidia.com/deeplearning/cudnn/frontend/v1.9.0/developer/overview.html) -* **C++ Samples:** See `samples/cpp` for comprehensive usage examples. -* **Python Samples:** See `samples/python` for pythonic implementations. +* **Developer Guide:** [Official NVIDIA Documentation (latest)](https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/) +* **Blog & Deep Dives:** [nvidia.github.io/cudnn-frontend](https://nvidia.github.io/cudnn-frontend/) — release notes, installation guides, and technical deep-dives (MXFP8 attention, FP8 scale layouts, etc.) +* **C++ Samples:** See [`samples/cpp`](samples/cpp) for end-to-end examples covering convolution, matmul, SDPA / Flash Attention, normalization, and more. +* **Python Samples:** See [`samples/python`](samples/python) for Jupyter notebooks and PyTorch integration patterns. +* **OSS Kernels:** See [`python/cudnn/`](python/cudnn/) for source of SDPA, grouped GEMM + SwiGLU/GLU, RMSNorm + SiLU, Native Sparse Attention, and other open-sourced kernels. +* **PyTorch Custom Ops:** See [`python/cudnn/experimental/ops`](python/cudnn/experimental/ops) for `torch.compile`-compatible wrappers around cuDNN kernels. ## 🤝 Contributing diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv new file mode 100644 index 00000000..abbcaaec --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,fwd,False,50.456,1743.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,False,212.546,1076.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,True,210.687,1086.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,fwd,False,111.031,1584.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,False,421.194,1086.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,True,411.727,1111.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,fwd,False,35.942,2447.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,False,123.181,1857.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,True,122.033,1874.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,75.166,2340.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,236.705,1932.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,234.246,1953.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,fwd,False,37.773,2329.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,False,149.238,1532.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,True,152.079,1504.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,81.379,2162.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,293.095,1561.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,289.395,1581.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,fwd,False,13.103,1678.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,False,50.424,1134.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,True,50.136,1140.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,fwd,False,24.270,1812.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,False,105.115,1088.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,True,102.798,1112.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,fwd,False,8.845,2486.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,False,29.050,1968.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,True,29.760,1921.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,16.984,2590.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,57.481,1989.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,57.615,1985.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,fwd,False,9.337,2355.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,False,36.835,1552.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,True,36.960,1547.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,17.648,2492.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,69.820,1638.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,71.538,1598.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,fwd,False,3.470,1585.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,False,11.984,1193.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,True,13.023,1098.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,fwd,False,6.249,1759.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,False,25.403,1125.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,True,25.571,1118.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,fwd,False,2.294,2397.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,False,7.296,1959.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,True,7.530,1898.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,4.258,2582.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,13.472,2122.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,13.256,2157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,fwd,False,2.478,2219.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,False,9.567,1494.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,True,9.205,1553.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,4.567,2408.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,17.014,1680.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,16.251,1759.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,fwd,False,0.977,1406.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,False,3.383,1057.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,True,3.301,1083.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.650,1666.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,False,6.085,1175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,True,6.043,1183.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.656,2095.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.112,1692.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.108,1696.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.092,2518.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,3.450,2071.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,3.449,2072.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.714,1925.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.620,1364.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.621,1364.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.187,2316.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,4.267,1675.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,4.272,1673.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,fwd,False,0.321,1072.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,False,1.056,846.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,True,1.015,881.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.461,1492.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.706,1048.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.633,1094.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.207,1659.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.681,1312.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.682,1311.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.316,2175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,0.989,1807.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,0.988,1808.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.224,1535.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.854,1047.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.854,1046.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.341,2018.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.188,1504.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.206,1482.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png new file mode 100644 index 00000000..b9b11572 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png new file mode 100644 index 00000000..c2aa14f0 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png new file mode 100644 index 00000000..f6c3b22c Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png new file mode 100644 index 00000000..04365b83 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv new file mode 100644 index 00000000..c01eb2bb --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,fwd,False,44.189,1991.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,False,177.494,1289.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,True,175.497,1303.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,fwd,False,91.788,1917.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,False,366.969,1246.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,True,353.235,1295.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,fwd,False,26.146,3364.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,False,95.057,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,True,94.596,2418.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,52.327,3362.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,185.459,2466.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,187.470,2440.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,fwd,False,29.731,2959.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,False,122.768,1863.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,True,121.830,1877.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,57.427,3063.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,234.862,1948.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,237.119,1929.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,fwd,False,11.178,1967.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,False,43.911,1302.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,True,41.569,1375.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,fwd,False,21.640,2032.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,False,91.189,1254.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,True,86.044,1329.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,fwd,False,6.653,3305.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,False,23.257,2459.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,True,23.761,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,12.794,3438.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,44.344,2579.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,43.948,2602.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,fwd,False,7.336,2998.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,False,31.002,1844.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,True,30.575,1870.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,14.110,3117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,56.977,2007.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,56.908,2009.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,fwd,False,3.013,1825.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,False,11.074,1291.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,True,10.887,1313.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,fwd,False,5.439,2021.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,False,21.973,1301.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,True,20.598,1388.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,fwd,False,1.758,3128.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,False,6.188,2310.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,True,6.181,2313.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,3.233,3400.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,11.395,2509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,11.339,2521.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,fwd,False,1.944,2829.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,False,8.030,1780.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,True,8.027,1781.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,3.538,3108.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,14.104,2027.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,14.079,2030.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,fwd,False,0.864,1590.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,False,3.127,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,True,3.017,1185.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.419,1937.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,False,5.641,1267.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,True,5.603,1276.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.493,2788.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,False,1.803,1982.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,True,1.800,1985.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,0.832,3304.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,2.971,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,2.969,2407.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.546,2517.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.348,1522.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.351,1520.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,0.912,3015.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,3.720,1921.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,3.721,1921.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,fwd,False,0.268,1285.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,False,0.977,914.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,True,0.935,956.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.388,1770.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.600,1117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.553,1150.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.150,2290.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.591,1511.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.592,1509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.225,3058.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,0.851,2099.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,0.852,2098.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.166,2069.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.772,1158.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.772,1158.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.249,2758.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.073,1665.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.074,1663.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png new file mode 100644 index 00000000..dec37f4e Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png new file mode 100644 index 00000000..fea637e8 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png new file mode 100644 index 00000000..a921070c Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png new file mode 100644 index 00000000..7c38409c Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv new file mode 100644 index 00000000..ec97be3a --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv @@ -0,0 +1,46 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,fwd,False,2.634,104.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,False,3.692,186.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,True,4.691,146.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.529,179.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,False,90.348,8.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,True,90.348,8.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.637,168.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,False,108.370,6.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,True,108.368,6.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,fwd,False,0.822,167.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,False,1.818,188.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,True,2.354,145.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.766,179.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,False,23.231,15.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,True,23.234,15.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.826,166.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,False,27.940,12.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,True,27.910,12.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,fwd,False,0.414,165.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,False,0.911,187.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,True,1.185,144.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.385,177.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,False,6.176,28.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,True,6.177,28.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.417,164.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,False,7.427,23.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,True,7.426,23.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,fwd,False,0.210,161.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,False,0.465,182.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,True,0.602,140.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.195,173.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.731,49.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.732,49.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.210,161.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,False,2.080,41.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,True,2.081,41.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,fwd,False,0.109,153.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,False,0.237,176.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,True,0.311,134.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.101,165.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.536,78.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.536,78.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.109,153.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.642,65.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.642,65.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png new file mode 100644 index 00000000..27eb0ed0 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png new file mode 100644 index 00000000..a1becf02 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv new file mode 100644 index 00000000..60792295 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv @@ -0,0 +1,46 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,fwd,False,2.530,108.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,False,3.450,199.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,True,4.376,157.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.296,212.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,False,74.763,9.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,True,74.760,9.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.345,204.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,False,96.913,7.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,True,96.937,7.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,fwd,False,0.790,173.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,False,1.712,200.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,True,2.194,156.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.650,211.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,False,19.354,18.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,True,19.347,18.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.676,203.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,False,24.996,14.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,True,24.995,14.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,fwd,False,0.397,172.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,False,0.862,198.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,True,1.106,154.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.327,208.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,False,5.189,33.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,True,5.189,33.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.342,200.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,False,6.648,26.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,True,6.649,26.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,fwd,False,0.202,167.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,False,0.442,191.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,True,0.565,150.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.167,203.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.480,57.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.481,57.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.175,194.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.871,45.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.871,45.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,fwd,False,0.104,160.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,False,0.227,184.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,True,0.291,143.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.086,193.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.467,89.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.467,89.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.092,182.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.581,72.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 +gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.580,72.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png new file mode 100644 index 00000000..dcc9f6f0 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png new file mode 100644 index 00000000..7c88aa95 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv new file mode 100644 index 00000000..30b06508 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,fwd,False,25.720,1710.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,False,105.445,1084.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,True,102.196,1119.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,fwd,False,51.776,1699.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,False,213.066,1073.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,True,202.309,1130.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,fwd,False,18.194,2417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,False,58.126,1967.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,True,60.971,1876.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,36.262,2426.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,113.717,2011.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,112.404,2035.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,fwd,False,18.417,2388.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,False,72.533,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,True,70.153,1630.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,37.733,2331.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,141.358,1618.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,142.283,1607.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,fwd,False,6.398,1719.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,False,24.921,1147.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,True,25.278,1131.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,fwd,False,12.457,1765.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,False,50.963,1122.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,True,47.797,1196.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,fwd,False,4.326,2542.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,False,14.195,2014.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,True,13.830,2067.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,8.242,2668.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,27.238,2099.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,27.125,2108.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,fwd,False,4.667,2356.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,False,17.427,1641.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,True,17.224,1660.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,9.089,2419.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,34.587,1653.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,32.626,1752.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,fwd,False,1.733,1586.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,False,6.005,1190.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,True,5.967,1198.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,fwd,False,3.087,1781.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,False,11.719,1220.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,True,12.095,1182.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,fwd,False,1.139,2414.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.689,1937.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.697,1933.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.155,2551.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,6.535,2187.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,6.532,2188.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,fwd,False,1.254,2193.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,False,4.546,1572.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,True,4.540,1575.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.272,2420.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,8.032,1780.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,7.874,1815.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,fwd,False,0.518,1328.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,False,1.698,1053.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,True,1.644,1087.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.841,1635.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,False,3.009,1187.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,True,3.004,1189.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.338,2034.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.082,1652.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.082,1651.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.575,2389.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.723,2074.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.750,2042.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.363,1891.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.331,1343.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.337,1336.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.625,2200.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,2.120,1686.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,2.125,1681.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,fwd,False,0.167,1029.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,False,0.536,834.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,True,0.526,850.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.237,1448.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.879,1016.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.863,1036.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.110,1557.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.355,1260.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.356,1256.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.165,2077.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.512,1743.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.512,1746.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.119,1447.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.443,1009.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.443,1008.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.178,1927.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.631,1417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.631,1416.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png new file mode 100644 index 00000000..db3418af Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png new file mode 100644 index 00000000..e495d05f Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png new file mode 100644 index 00000000..59f87bae Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png new file mode 100644 index 00000000..8a973840 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv new file mode 100644 index 00000000..0966bc87 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,fwd,False,21.498,2046.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,False,88.821,1287.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,True,86.626,1320.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,fwd,False,43.698,2013.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,False,182.491,1253.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,True,175.560,1303.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,fwd,False,12.849,3423.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,False,45.480,2514.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,True,45.333,2523.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,25.416,3461.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,90.921,2515.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,90.997,2513.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,fwd,False,14.100,3119.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,False,59.247,1930.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,True,58.968,1939.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,28.687,3066.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,115.014,1988.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,114.422,1999.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,fwd,False,5.442,2021.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,False,21.253,1345.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,True,20.838,1372.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,fwd,False,10.590,2077.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,False,43.165,1325.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,True,40.794,1402.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,fwd,False,3.266,3366.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,False,11.276,2535.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,True,11.578,2469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,6.268,3509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,21.662,2639.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,21.863,2615.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,fwd,False,3.613,3044.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,False,14.766,1936.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,True,14.824,1929.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,6.924,3176.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,27.533,2077.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,27.085,2111.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,fwd,False,1.466,1875.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,False,5.466,1308.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,True,5.271,1356.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.632,2089.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,False,10.654,1342.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,True,10.128,1411.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,fwd,False,0.865,3178.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.064,2333.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.064,2333.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,1.584,3471.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,5.463,2617.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,5.481,2608.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,fwd,False,0.957,2872.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.991,1791.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.990,1791.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,1.739,3160.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,6.904,2070.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,6.872,2080.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,fwd,False,0.421,1632.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,False,1.548,1154.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,True,1.507,1186.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.701,1960.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,False,2.809,1272.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,True,2.802,1275.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.244,2822.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,False,0.904,1976.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,True,0.906,1973.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.418,3289.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.472,2428.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.474,2424.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.269,2555.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.185,1508.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.184,1509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.462,2973.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.854,1928.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.854,1928.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,fwd,False,0.136,1265.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,False,0.492,909.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,True,0.477,937.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.190,1809.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.819,1090.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.799,1117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.078,2193.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.304,1470.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.304,1470.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.114,3026.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.443,2019.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.441,2027.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.086,1991.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.397,1125.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.396,1127.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.127,2701.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.557,1603.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.558,1601.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png new file mode 100644 index 00000000..1c04da20 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png new file mode 100644 index 00000000..6cb653e3 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png new file mode 100644 index 00000000..0f005ab9 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png new file mode 100644 index 00000000..6de346a4 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv new file mode 100644 index 00000000..987e0f0c --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,fwd,False,21.692,1622.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,False,71.756,1226.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,True,82.343,1068.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,fwd,False,43.056,1634.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,False,139.364,1262.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,True,159.976,1100.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,fwd,False,16.012,2197.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,False,52.155,1687.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,True,51.781,1699.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,31.655,2223.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,96.384,1825.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,97.941,1796.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,fwd,False,17.201,2046.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,False,61.716,1425.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,True,61.437,1432.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,33.996,2070.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,118.844,1480.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,118.852,1480.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,fwd,False,5.150,1708.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,False,16.554,1328.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,True,19.007,1157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,fwd,False,10.369,1697.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,False,32.014,1374.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,True,38.480,1143.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,fwd,False,4.084,2154.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,False,13.476,1632.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,True,13.439,1636.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,7.773,2263.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,23.960,1836.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,23.444,1876.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,fwd,False,4.400,1999.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,False,15.315,1436.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,True,15.408,1427.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,8.365,2103.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,28.309,1554.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,28.551,1540.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,fwd,False,1.397,1574.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,False,4.127,1332.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,True,4.754,1157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.483,1771.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.831,1404.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,True,9.306,1182.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,fwd,False,1.074,2047.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.781,1454.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.777,1456.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.962,2241.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,6.206,1772.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,6.209,1771.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,fwd,False,1.156,1902.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,False,4.272,1287.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,True,4.270,1288.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.115,2080.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.352,1495.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,7.353,1495.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,fwd,False,0.391,1408.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,False,1.198,1148.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,True,1.352,1017.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.687,1599.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,False,2.037,1350.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.339,1175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.305,1805.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.167,1178.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.164,1181.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.525,2093.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.743,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.743,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.326,1685.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.305,1054.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.312,1048.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.565,1944.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,2.046,1344.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.056,1337.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,fwd,False,0.125,1097.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,False,0.389,885.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,True,0.431,797.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.192,1433.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.611,1124.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.686,1001.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.097,1420.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.414,831.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.413,832.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.146,1881.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.566,1214.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.565,1216.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.103,1330.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.461,745.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.461,746.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.157,1754.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.643,1069.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.642,1071.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png new file mode 100644 index 00000000..c1fbe9b6 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png new file mode 100644 index 00000000..f35d16e1 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png new file mode 100644 index 00000000..caf946b0 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png new file mode 100644 index 00000000..3561cc30 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv new file mode 100644 index 00000000..655b74f1 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv @@ -0,0 +1,91 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,fwd,False,18.014,1953.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,False,59.024,1490.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,True,68.488,1284.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,fwd,False,38.317,1836.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,False,123.272,1427.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,True,142.081,1238.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,fwd,False,11.319,3108.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,False,41.328,2128.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,True,42.165,2086.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,23.072,3050.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,80.754,2178.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,81.705,2153.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,fwd,False,12.632,2785.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,False,55.340,1590.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,True,55.204,1593.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,25.005,2814.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,106.358,1654.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,105.826,1662.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,fwd,False,4.593,1915.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,False,14.593,1507.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,True,16.462,1336.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,fwd,False,8.734,2014.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,False,28.638,1536.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,True,32.449,1355.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,fwd,False,2.880,3054.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,False,10.805,2035.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,True,10.823,2032.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,5.529,3182.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,20.203,2177.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,20.296,2167.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,fwd,False,3.225,2727.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,False,14.072,1563.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,True,14.074,1563.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,6.089,2889.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,26.041,1689.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,25.719,1710.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,fwd,False,1.185,1856.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,False,3.832,1435.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,True,4.298,1279.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.181,2017.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.181,1531.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,True,7.932,1386.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,fwd,False,0.759,2897.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.087,1781.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.088,1780.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.392,3160.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,5.209,2111.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,5.198,2115.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,fwd,False,0.856,2570.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.917,1404.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.915,1404.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.532,2870.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,6.723,1636.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,6.716,1637.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,fwd,False,0.323,1703.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,False,1.101,1248.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,True,1.220,1127.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.579,1900.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.924,1429.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.195,1252.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.217,2532.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,False,0.985,1396.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,True,0.986,1395.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.368,2988.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.495,1839.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.494,1840.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.242,2272.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.203,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.203,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.408,2693.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.871,1469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.871,1469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,fwd,False,0.100,1373.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,False,0.356,967.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,True,0.388,885.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.152,1805.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.567,1213.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.634,1083.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.070,1954.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.364,944.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.365,942.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.101,2728.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.494,1390.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.495,1387.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.079,1751.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.424,811.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.424,810.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.113,2443.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.591,1164.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.591,1163.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png new file mode 100644 index 00000000..dcad4c50 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png new file mode 100644 index 00000000..28da7dce Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png new file mode 100644 index 00000000..2746fa0d Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png new file mode 100644 index 00000000..ca1b3057 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv new file mode 100644 index 00000000..c53bcf74 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv @@ -0,0 +1,16 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,fwd,False,0.437,1415.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,False,1.199,1290.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,True,1.338,1156.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,fwd,False,1.844,1589.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,False,5.225,1403.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,True,6.410,1143.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,fwd,False,2.959,1706.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,False,8.952,1410.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,True,11.175,1130.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,fwd,False,8.657,1731.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,False,28.505,1314.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,True,35.620,1052.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,fwd,False,14.403,1611.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,False,43.187,1343.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,True,53.136,1092.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png new file mode 100644 index 00000000..8d386e6d Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png new file mode 100644 index 00000000..e31f8f26 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv new file mode 100644 index 00000000..e4325010 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv @@ -0,0 +1,16 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,fwd,False,0.363,1704.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,False,1.089,1420.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,True,1.213,1275.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,fwd,False,1.551,1890.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,False,4.796,1528.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,True,5.838,1255.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,fwd,False,2.569,1966.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,False,8.245,1531.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,True,10.058,1255.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,fwd,False,7.623,1965.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,False,25.013,1497.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,True,30.447,1230.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,fwd,False,12.044,1926.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,False,38.088,1523.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,True,42.066,1379.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png new file mode 100644 index 00000000..1992f9d0 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png new file mode 100644 index 00000000..c0b0fbd3 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv new file mode 100644 index 00000000..cc7f3cff --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv @@ -0,0 +1,6 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +qwen35,qwen35,cudnn,bfloat16,top_left,1,32768,32768,32,2,256,256,fwd,False,10.148,1734.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,16384,16384,32,2,256,256,fwd,False,2.438,1804.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,8192,8192,32,2,256,256,fwd,False,0.644,1708.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,4096,4096,32,2,256,256,fwd,False,0.191,1441.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,2048,2048,32,2,256,256,fwd,False,0.061,1126.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png new file mode 100644 index 00000000..37e45c80 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv new file mode 100644 index 00000000..63464f51 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv @@ -0,0 +1,6 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +qwen35,qwen35,cudnn,bfloat16,top_left,1,32768,32768,32,2,256,256,fwd,False,8.394,2096.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,16384,16384,32,2,256,256,fwd,False,2.132,2063.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,8192,8192,32,2,256,256,fwd,False,0.560,1965.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,4096,4096,32,2,256,256,fwd,False,0.158,1743.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 +qwen35,qwen35,cudnn,bfloat16,top_left,1,2048,2048,32,2,256,256,fwd,False,0.053,1307.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000 diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png new file mode 100644 index 00000000..bdab2e02 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png differ diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv new file mode 100644 index 00000000..2e064ca9 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv @@ -0,0 +1,16 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,fwd,False,0.803,1552.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,False,2.236,1393.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,True,2.785,1118.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,fwd,False,3.383,1783.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,False,10.638,1417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,True,13.730,1098.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,fwd,False,13.191,1666.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,False,41.653,1319.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,True,52.261,1051.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,fwd,False,28.902,1657.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,False,93.715,1278.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,True,117.248,1021.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,fwd,False,76.372,1533.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,False,233.302,1254.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,True,287.864,1017.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200 diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png new file mode 100644 index 00000000..a77ce366 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png new file mode 100644 index 00000000..ad2a883b Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png differ diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv new file mode 100644 index 00000000..4b37e548 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv @@ -0,0 +1,16 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,fwd,False,0.678,1836.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,False,2.139,1456.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,True,2.631,1184.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,fwd,False,3.002,2009.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,False,9.883,1525.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,True,12.048,1251.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,fwd,False,11.213,1960.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,False,37.545,1464.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,True,45.146,1217.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,fwd,False,25.686,1865.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,False,88.194,1358.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,True,102.918,1163.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,fwd,False,64.212,1823.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,False,214.474,1364.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 +wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,True,258.728,1131.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200 diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png new file mode 100644 index 00000000..80014b77 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png new file mode 100644 index 00000000..44df0932 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png differ diff --git a/cmake/cuDNN.cmake b/cmake/cuDNN.cmake index 61a3cba9..a6998046 100644 --- a/cmake/cuDNN.cmake +++ b/cmake/cuDNN.cmake @@ -110,6 +110,7 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9) find_cudnn_library(cudnn_adv OPTIONAL) find_cudnn_library(cudnn_engines_precompiled OPTIONAL) find_cudnn_library(cudnn_heuristic OPTIONAL) + find_cudnn_library(cudnn_ext OPTIONAL) target_link_libraries( CUDNN::cudnn_all diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index 06121ddb..e20c9599 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -69,6 +69,17 @@ create_cudnn_tensor( tensor_builder.setVectorCountAndDimension(props->get_vector_count(), props->get_vector_dimension()); } + // Set compile-time constant value before build (if present) + if (props->get_has_compile_time_constant()) { + auto const_value = props->get_compile_time_constant(); + if (const_value.has_value()) { + std::visit([&tensor_builder](auto&& val) { tensor_builder.setConstValue(val); }, *const_value); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Compile-time constant value set for tensor '" << props->get_name() + << "'"); + } + } + if (auto ragged_offset_props = props->get_ragged_offset()) { CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids)); tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid())); diff --git a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h index 175986cb..ee3664a8 100644 --- a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h +++ b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h @@ -69,8 +69,6 @@ inline __device__ void fastDivMod(const FastDivisor_t &d, uint32_t val, mod = val - div * d.val; } -__device__ __inline__ void cfence() {} - inline __device__ char *get_smem_loc_epilogue_swizzle_128b( char *smem_addr, int local_block_id, int tid, int local_row, int column, size_t element_size, int block_size, int row_per_tile) { @@ -1765,7 +1763,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn #pragma unroll for (int i = 0; i < BMM1_TILE_N; i += 2) { - cfence(); if (i - kConvertPipeCount == 32) { sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]); } @@ -1780,12 +1777,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn ®_12_0[((i - kConvertPipeCount) / 2) % 32], reinterpret_cast(®_8_0[i - kConvertPipeCount])); } - cfence(); reinterpret_cast(reg_8_0[i + 0]) = exp2f(reinterpret_cast(reg_8_0[i + 0])); - cfence(); if (i + kFmaPipeCount < BMM1_TILE_N) { float2 in = make_float2( @@ -1841,7 +1836,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn named_barrier_arrive(SOFTMAX_BARRIER + 1, 256); named_barrier_wait(SOFTMAX_BARRIER, 256); } - cfence(); named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128); } bmm_mbar_state ^= 1; @@ -1947,7 +1941,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn #pragma unroll for (int i = 0; i < BMM1_TILE_N; i += 2) { - cfence(); if (i - kConvertPipeCount == 32) { sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]); } @@ -1962,12 +1955,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn ®_12_0[((i - kConvertPipeCount) / 2) % 32], reinterpret_cast(®_8_0[i - kConvertPipeCount])); } - cfence(); reinterpret_cast(reg_8_0[i + 0]) = exp2f(reinterpret_cast(reg_8_0[i + 0])); - cfence(); if (i + kFmaPipeCount < BMM1_TILE_N) { float2 in = make_float2( @@ -2023,7 +2014,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn named_barrier_arrive(SOFTMAX_BARRIER + 1, 256); named_barrier_wait(SOFTMAX_BARRIER, 256); } - cfence(); named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128); } bmm_mbar_state ^= 1; @@ -2232,12 +2222,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn sttm_step * 8, &fp32_O[8 * sttm_step]); } - cfence(); } } fence_view_async_tmem_store(); arrive_barrier(cast_smem_ptr_to_uint(bmm_ready_mbar)); - cfence(); } stat_mbar_state ^= 1; bmm_mbar_state ^= 1; @@ -2372,7 +2360,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn sts_128(cast_smem_ptr_to_uint(smem_loc), reinterpret_cast(®_O[i * 4])); } - cfence(); uint64_t *tma_o_full_mbar = sub_tile_id == 0 ? &(shared_storage.tma_o_0_full_mbar[block]) @@ -2380,7 +2367,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn fence_view_async_shared(); arrive_barrier(cast_smem_ptr_to_uint(tma_o_full_mbar)); } - cfence(); } stat_mbar_state ^= 1; bmm_mbar_state ^= 1; diff --git a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h index c1aa5278..33b9ef2d 100644 --- a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h +++ b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h @@ -69,8 +69,6 @@ inline __device__ void fastDivMod(const FastDivisor_t &d, uint32_t val, mod = val - div * d.val; } -__device__ __inline__ void cfence() {} - inline __device__ char *get_smem_loc_epilogue_swizzle_128b( char *smem_addr, int local_block_id, int tid, int local_row, int column, size_t element_size, int block_size, int row_per_tile) { @@ -1766,7 +1764,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn #pragma unroll for (int i = 0; i < BMM1_TILE_N; i += 2) { - cfence(); if (i - kConvertPipeCount == 32) { sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]); } @@ -1781,12 +1778,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn ®_12_0[((i - kConvertPipeCount) / 2) % 32], reinterpret_cast(®_8_0[i - kConvertPipeCount])); } - cfence(); reinterpret_cast(reg_8_0[i + 0]) = exp2f(reinterpret_cast(reg_8_0[i + 0])); - cfence(); if (i + kFmaPipeCount < BMM1_TILE_N) { float2 in = make_float2( @@ -1842,7 +1837,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn named_barrier_arrive(SOFTMAX_BARRIER + 1, 256); named_barrier_wait(SOFTMAX_BARRIER, 256); } - cfence(); named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128); } bmm_mbar_state ^= 1; @@ -1948,7 +1942,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn #pragma unroll for (int i = 0; i < BMM1_TILE_N; i += 2) { - cfence(); if (i - kConvertPipeCount == 32) { sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]); } @@ -1963,12 +1956,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn ®_12_0[((i - kConvertPipeCount) / 2) % 32], reinterpret_cast(®_8_0[i - kConvertPipeCount])); } - cfence(); reinterpret_cast(reg_8_0[i + 0]) = exp2f(reinterpret_cast(reg_8_0[i + 0])); - cfence(); if (i + kFmaPipeCount < BMM1_TILE_N) { float2 in = make_float2( @@ -2024,7 +2015,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn named_barrier_arrive(SOFTMAX_BARRIER + 1, 256); named_barrier_wait(SOFTMAX_BARRIER, 256); } - cfence(); named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128); } bmm_mbar_state ^= 1; @@ -2234,12 +2224,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn sttm_step * 8, &fp32_O[8 * sttm_step]); } - cfence(); } } fence_view_async_tmem_store(); arrive_barrier(cast_smem_ptr_to_uint(bmm_ready_mbar)); - cfence(); } stat_mbar_state ^= 1; bmm_mbar_state ^= 1; @@ -2374,7 +2362,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn sts_128(cast_smem_ptr_to_uint(smem_loc), reinterpret_cast(®_O[i * 4])); } - cfence(); uint64_t *tma_o_full_mbar = sub_tile_id == 0 ? &(shared_storage.tma_o_0_full_mbar[block]) @@ -2382,7 +2369,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn fence_view_async_shared(); arrive_barrier(cast_smem_ptr_to_uint(tma_o_full_mbar)); } - cfence(); } stat_mbar_state ^= 1; bmm_mbar_state ^= 1; diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 1a0e3d7e..8ce19101 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -22,6 +22,8 @@ #include "node/resample.h" #include "node/reshape.h" #include "node/slice.h" +#include "node/transpose.h" +// #include "node/scaled_dot_product_attention.h" #include "node/scaled_dot_product_flash_attention.h" #include "node/sdpa_fp8_bwd.h" #include "node/block_scale_quantize.h" @@ -194,6 +196,8 @@ class Graph : public ICudnn, public INode { tensor_to_pointer_map.emplace(uid, int64_t_value_ptr); } else if (float *float_value_ptr = std::get_if(&value)) { tensor_to_pointer_map.emplace(uid, float_value_ptr); + } else if (double *double_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, double_value_ptr); } else { RETURN_CUDNN_FRONTEND_ERROR_IF( true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor."); @@ -1450,12 +1454,30 @@ class Graph : public ICudnn, public INode { std::unordered_map tensor_to_pass_by_value; CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); - j["pass_by_values"] = tensor_to_pass_by_value; + + // Convert pass_by_values to JSON (unordered_map with numeric keys needs manual conversion) + json pass_by_values_json = json::object(); + for (const auto &[uid, variant_value] : tensor_to_pass_by_value) { + json variant_json; + variant_json = variant_value; + pass_by_values_json[std::to_string(uid)] = variant_json; + } + j["pass_by_values"] = pass_by_values_json; std::unordered_map>> workspace_modifications; int64_t workspace_offset = 0; CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - j["workspace_modifications"] = workspace_modifications; + + // Convert workspace_modifications to JSON (nlohmann::json doesn't support std::tuple directly) + json workspace_modifications_json = json::object(); + for (const auto &[uid, tuple_value] : workspace_modifications) { + json tuple_json = json::array(); + tuple_json.push_back(std::get<0>(tuple_value)); + tuple_json.push_back(std::get<1>(tuple_value)); + tuple_json.push_back(std::get<2>(tuple_value)); + workspace_modifications_json[std::to_string(uid)] = tuple_json; + } + j["workspace_modifications"] = workspace_modifications_json; j["variant_pack_replacements"] = variant_pack_replacements; @@ -1500,9 +1522,28 @@ class Graph : public ICudnn, public INode { variant_pack_uids = j["variant_pack_uids"].get>(); - deserialized_pass_by_value = j["pass_by_values"]; + // Deserialize pass_by_values from JSON + if (j.contains("pass_by_values")) { + auto pass_by_values_json = j["pass_by_values"]; + for (auto it = pass_by_values_json.begin(); it != pass_by_values_json.end(); ++it) { + uid_t uid = std::stoll(it.key()); + pass_by_values_t value = it.value().get(); + deserialized_pass_by_value[uid] = value; + } + } - deserialized_workspace_modifications = j["workspace_modifications"]; + // Deserialize workspace_modifications from JSON + if (j.contains("workspace_modifications")) { + auto workspace_modifications_json = j["workspace_modifications"]; + for (auto it = workspace_modifications_json.begin(); it != workspace_modifications_json.end(); ++it) { + uid_t uid = std::stoll(it.key()); + auto tuple_json = it.value(); + auto tuple_value = std::make_tuple(tuple_json[0].get(), + tuple_json[1].get(), + tuple_json[2].get>()); + deserialized_workspace_modifications[uid] = tuple_value; + } + } variant_pack_replacements = j["variant_pack_replacements"]; @@ -1572,6 +1613,25 @@ class Graph : public ICudnn, public INode { std::shared_ptr tensor(Tensor_attributes const &tensor); + // Overloaded tensor() methods for compile-time constants + std::shared_ptr + tensor(float const &scalar, ScalarType scalar_type); + + std::shared_ptr + tensor(half const &scalar, ScalarType scalar_type); + + std::shared_ptr + tensor(nv_bfloat16 const &scalar, ScalarType scalar_type); + + std::shared_ptr + tensor(int32_t const &scalar, ScalarType scalar_type); + + std::shared_ptr + tensor(int64_t const &scalar, ScalarType scalar_type); + + std::shared_ptr + tensor(double const &scalar, ScalarType scalar_type); + std::shared_ptr tensor_like(std::shared_ptr const &tensor, std::string const &name = std::string{}); @@ -1736,6 +1796,8 @@ class Graph : public ICudnn, public INode { std::shared_ptr slice(std::shared_ptr, Slice_attributes); + std::shared_ptr transpose(std::shared_ptr, Transpose_attributes); + std::array, 2> block_scale_quantize(std::shared_ptr, Block_scale_quantize_attributes); @@ -2353,6 +2415,12 @@ class Graph : public ICudnn, public INode { CHECK_TENSORS(slice_attributes); FILL_GLOBAL_IO_TENSOR_MAP(slice_attributes); sub_nodes.emplace_back(std::make_unique(std::move(slice_attributes), context)); + } else if (tag == "TRANSPOSE") { + auto transpose_attributes = j_sub_node.get(); + CHECK_TENSORS(transpose_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(transpose_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(transpose_attributes), context)); } else if (tag == "RESAMPLE") { auto resample_attributes = j_sub_node.get(); CHECK_TENSORS(resample_attributes); @@ -2686,6 +2754,49 @@ Graph::tensor(Tensor_attributes const &tensor) { return tensor_ptr; } +// Overloaded tensor() methods for compile-time constants +inline std::shared_ptr +Graph::tensor(float const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline std::shared_ptr +Graph::tensor(half const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline std::shared_ptr +Graph::tensor(nv_bfloat16 const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline std::shared_ptr +Graph::tensor(int32_t const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline std::shared_ptr +Graph::tensor(int64_t const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline std::shared_ptr +Graph::tensor(double const &scalar, ScalarType scalar_type) { + auto tensor_ptr = std::make_shared(scalar, scalar_type); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + inline error_t Graph::query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const { for (auto const &o_tensor : full_graph_outputs) { @@ -3363,6 +3474,15 @@ Graph::slice(std::shared_ptr input, Slice_attributes attribut return Y; } +inline std::shared_ptr +Graph::transpose(std::shared_ptr input, Transpose_attributes attributes) { + attributes.inputs[Transpose_attributes::input_names::X] = input; + auto Y = attributes.outputs[Transpose_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} + inline std::array, 2> Graph::block_scale_quantize(std::shared_ptr x, Block_scale_quantize_attributes attributes) { // Set outputs diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index dee506a2..71ae8962 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -1,6 +1,8 @@ #pragma once +#include +#include #include #include #include @@ -17,8 +19,31 @@ namespace cudnn_frontend { namespace graph { +inline std::optional +get_rescale_threshold_from_env() { + auto const* env_value = std::getenv("CUDNN_RESCALE_THRESHOLD"); + if (env_value == nullptr || env_value[0] == '\0') { + return std::nullopt; + } + + errno = 0; + char* end = nullptr; + auto value = std::strtof(env_value, &end); + if (env_value == end || (end != nullptr && *end != '\0') || errno == ERANGE) { + return std::nullopt; + } + + return value; +} + using managed_backend_descriptor_t = std::vector; +// Enum to distinguish between runtime parameters and compile-time constants +enum class ScalarType { + RUNTIME_PARAM, // Value provided at execution time (can change) + COMPILE_TIME_CONST // Value baked into graph (fixed, optimizable) +}; + // simple structure to hold all properties of a tensor. // Each property has a getter setter. class Tensor_attributes { @@ -31,7 +56,7 @@ class Tensor_attributes { // In approach 1, users provide a value to embed into the graph. // In approach 2, users set is_pass_by_value boolean and then pass a pointer to scalar value with execute() API. // A closed set of types that are allowed to be passed by value. - using pass_by_values_t = std::variant; + using pass_by_values_t = std::variant; error_t validate() const { @@ -52,6 +77,22 @@ class Tensor_attributes { error_code_t::ATTRIBUTE_NOT_SET, "Tensor '" + name + "' can't be a fused scalar and not a pass_by_value tensor at the same time."); + // Validate compile-time constant constraints + RETURN_CUDNN_FRONTEND_ERROR_IF( + has_compile_time_constant && !is_pass_by_value, + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' with compile-time constant must have is_pass_by_value=true."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(has_compile_time_constant && is_virtual, + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' can't be both compile-time constant and virtual."); + + // Can't have both compile-time constant and runtime parameter + RETURN_CUDNN_FRONTEND_ERROR_IF( + has_compile_time_constant && pass_by_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' can't have both compile-time constant and runtime pass_by_value."); + return {error_code_t::OK, ""}; } @@ -68,6 +109,10 @@ class Tensor_attributes { std::optional pass_by_value = std::nullopt; bool is_pass_by_value = false; + // Compile-time constant support (distinct from pass_by_value, which is for runtime parameters) + bool has_compile_time_constant = false; + std::optional compile_time_constant_value = std::nullopt; + TensorReordering_t reordering_type = TensorReordering_t::NONE; uid_t uid = 0; bool uid_assigned = false; @@ -135,6 +180,86 @@ class Tensor_attributes { data_type = DataType_t::INT64; } + Tensor_attributes(double const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::DOUBLE; + } + + // Constructors with ScalarType for compile-time constant or runtime parameter control + Tensor_attributes(float const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::FLOAT; + } + + Tensor_attributes(half const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::HALF; + } + + Tensor_attributes(nv_bfloat16 const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::BFLOAT16; + } + + Tensor_attributes(int32_t const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::INT32; + } + + Tensor_attributes(int64_t const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::INT64; + } + + Tensor_attributes(double const& scalar, ScalarType scalar_type) { + if (scalar_type == ScalarType::COMPILE_TIME_CONST) { + compile_time_constant_value = scalar; + has_compile_time_constant = true; + } else { + pass_by_value = scalar; + } + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::DOUBLE; + } + std::string get_name() const { return name; @@ -225,6 +350,35 @@ class Tensor_attributes { return *this; } + // Compile-time constant accessors + bool + get_has_compile_time_constant() const { + return has_compile_time_constant; + } + + std::optional + get_compile_time_constant() const { + return compile_time_constant_value; + } + + auto + set_compile_time_constant(pass_by_values_t const& value) -> Tensor_attributes& { + compile_time_constant_value = value; + has_compile_time_constant = true; + if (!is_pass_by_value) { + is_pass_by_value = true; + } + return *this; + } + + auto + set_as_runtime_parameter() -> Tensor_attributes& { + is_pass_by_value = true; + has_compile_time_constant = false; + compile_time_constant_value = std::nullopt; + return *this; + } + TensorReordering_t get_reordering_type() const { return reordering_type; @@ -315,9 +469,15 @@ class Attributes { get_non_virtual_uids() const { std::vector non_virtual_uids; auto derived = static_cast(this); + + // Compile-time constants are excluded from the variant pack + auto should_be_in_variant_pack = [](std::shared_ptr const& tensor) { + return tensor && tensor->get_is_virtual() == false && !tensor->get_has_compile_time_constant(); + }; + if constexpr (std::is_same_v) { for (auto tensor : derived->inputs) { - if (tensor && tensor->get_is_virtual() == false) { + if (should_be_in_variant_pack(tensor)) { non_virtual_uids.push_back(tensor->get_uid()); if (auto ragged_offset = tensor->get_ragged_offset()) { non_virtual_uids.push_back(ragged_offset->get_uid()); @@ -327,7 +487,7 @@ class Attributes { } else { for (auto& [name, tensor] : derived->inputs) { (void)name; - if (tensor && tensor->get_is_virtual() == false) { + if (should_be_in_variant_pack(tensor)) { non_virtual_uids.push_back(tensor->get_uid()); if (auto ragged_offset = tensor->get_ragged_offset()) { non_virtual_uids.push_back(ragged_offset->get_uid()); @@ -338,7 +498,7 @@ class Attributes { for (auto& [name, tensor] : derived->outputs) { (void)name; - if (tensor && tensor->get_is_virtual() == false) { + if (should_be_in_variant_pack(tensor)) { non_virtual_uids.push_back(tensor->get_uid()); if (auto ragged_offset = tensor->get_ragged_offset()) { non_virtual_uids.push_back(ragged_offset->get_uid()); @@ -350,7 +510,7 @@ class Attributes { if constexpr (std::is_same_v || std::is_same_v) { for (auto& tensor : derived->peer_stats) { - if (tensor && tensor->get_is_virtual() == false) { + if (should_be_in_variant_pack(tensor)) { non_virtual_uids.push_back(tensor->get_uid()); if (auto ragged_offset = tensor->get_ragged_offset()) { non_virtual_uids.push_back(ragged_offset->get_uid()); @@ -423,13 +583,16 @@ class Attributes { set_compute_data_type(context.get_compute_data_type()); } - // Handle shape and stride inferencing for fused scalars. - // Pick number of dimensions from anyone of non-fused-scalar input/output tensors - // In case, all tensors are fused scalars, just keep them 1D. + // Infer shape and stride for fused scalars (runtime params and compile-time constants). + // Fused scalars expand to match the dimensionality of non-scalar tensors; if all are scalars, keep 1D. + auto is_fused_scalar = [](std::shared_ptr const& tensor) { + return tensor && (tensor->get_pass_by_value().has_value() || tensor->get_has_compile_time_constant()); + }; + int64_t number_of_dims = 1; if constexpr (std::is_same_v) { for (auto tensor : derived->inputs) { - if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + if (tensor && !is_fused_scalar(tensor)) { number_of_dims = tensor->get_dim().size(); break; } @@ -437,7 +600,7 @@ class Attributes { } else { for (auto [name, tensor] : derived->inputs) { (void)name; - if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + if (tensor && !is_fused_scalar(tensor)) { number_of_dims = tensor->get_dim().size(); break; } @@ -448,16 +611,17 @@ class Attributes { if (number_of_dims == 1) { for (auto [name, tensor] : derived->outputs) { (void)name; - if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + if (tensor && !is_fused_scalar(tensor)) { number_of_dims = tensor->get_dim().size(); break; } } } + // Expand fused scalar dimensions to match the number of dims of non-scalar tensors if constexpr (std::is_same_v) { for (auto tensor : derived->inputs) { - if (tensor && tensor->get_pass_by_value().has_value()) { + if (is_fused_scalar(tensor)) { tensor->set_dim(std::vector(number_of_dims, 1)); tensor->set_stride(std::vector(number_of_dims, 1)); } @@ -465,7 +629,7 @@ class Attributes { } else { for (auto [name, tensor] : derived->inputs) { (void)name; - if (tensor && tensor->get_pass_by_value().has_value()) { + if (is_fused_scalar(tensor)) { tensor->set_dim(std::vector(number_of_dims, 1)); tensor->set_stride(std::vector(number_of_dims, 1)); } @@ -1419,13 +1583,21 @@ class Reshape_attributes : public Attributes { std::vector dim = {}; std::vector stride = {}; + ReshapeMode_t reshape_mode = ReshapeMode_t::VIEW_ONLY; public: enum class input_names { X }; std::unordered_map> inputs; enum class output_names { Y }; std::unordered_map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, name, compute_data_type, inputs, outputs, dim, stride) + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, + name, + compute_data_type, + inputs, + outputs, + dim, + stride, + reshape_mode) std::vector get_dim() const { @@ -1448,6 +1620,54 @@ class Reshape_attributes : public Attributes { stride = value; return *this; } + + ReshapeMode_t + get_reshape_mode() const { + return reshape_mode; + } + + auto + set_reshape_mode(ReshapeMode_t const& value) -> Reshape_attributes& { + reshape_mode = value; + return *this; + } +}; + +class Transpose_attributes : public Attributes { + friend class Attributes; + friend class TransposeNode; + friend class Graph; + + std::vector permutation; + + public: + std::string + get_name() const { + return name; + } + + auto + set_name(std::string const& value) -> Transpose_attributes& { + name = value; + return *this; + } + + enum class input_names { X }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Transpose_attributes, name, compute_data_type, inputs, outputs, permutation) + + std::vector + get_permutation() const { + return permutation; + } + + auto + set_permutation(std::vector const& value) -> Transpose_attributes& { + permutation = value; + return *this; + } }; class Rmsnorm_attributes : public Attributes { @@ -2562,13 +2782,14 @@ class Slice_attributes : public Attributes { friend class INode; std::vector> slices; + std::vector slice_strides = {1}; public: enum class input_names { X }; std::unordered_map> inputs; enum class output_names { Y }; std::unordered_map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Slice_attributes, name, compute_data_type, inputs, outputs, slices) + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Slice_attributes, name, compute_data_type, inputs, outputs, slices, slice_strides) Slice_attributes& set_slices(std::vector> const value) { @@ -2576,6 +2797,12 @@ class Slice_attributes : public Attributes { return *this; } + Slice_attributes& + set_strides(std::vector const value) { + slice_strides = value; + return *this; + } + int64_t get_offset() const { auto& input = inputs.at(input_names::X); diff --git a/include/cudnn_frontend/node/concatenate.h b/include/cudnn_frontend/node/concatenate.h index 0c051186..5f62589b 100644 --- a/include/cudnn_frontend/node/concatenate.h +++ b/include/cudnn_frontend/node/concatenate.h @@ -29,9 +29,6 @@ class ConcatenateNode : public NodeCRTP { RETURN_CUDNN_FRONTEND_ERROR_IF(!attributes.axis.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "Axis not set\n"); - RETURN_CUDNN_FRONTEND_ERROR_IF( - !attributes.in_place_index.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "In-place index not set\n"); - auto X = attributes.inputs; RETURN_CUDNN_FRONTEND_ERROR_IF( @@ -116,12 +113,14 @@ class ConcatenateNode : public NodeCRTP { 1, &axis)); - int64_t in_place_index = attributes.in_place_index.value(); - _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), - CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX, - CUDNN_TYPE_INT64, - 1, - &in_place_index)); + if (attributes.in_place_index.has_value()) { + int64_t in_place_index = attributes.in_place_index.value(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX, + CUDNN_TYPE_INT64, + 1, + &in_place_index)); + } _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(concatenate_operation->get_backend_descriptor())); diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 39b08c79..e127f65e 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -48,6 +48,19 @@ class ReshapeNode : public NodeCRTP { return {error_code_t::SHAPE_DEDUCTION_FAILED, "Reshape node output shape deduction failed"}; } + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reshape_attributes::input_names::X); + auto const& input_data_type = X->second->get_data_type(); + if (y_tensor->get_data_type() == DataType_t::NOT_SET) { + y_tensor->set_data_type(input_data_type); + } else if (attributes.get_reshape_mode() == ReshapeMode_t::LOGICAL) { + // Lexicographic reshape preserves element type; reject inconsistent metadata. + // VIEW_ONLY paths (e.g. SDPA backward) may set Y dtype after reshape to match a consumer. + RETURN_CUDNN_FRONTEND_ERROR_IF( + y_tensor->get_data_type() != input_data_type, + error_code_t::INVALID_VALUE, + "Output and input tensor data types must match for LOGICAL reshape operation."); + } + return {error_code_t::OK, ""}; } @@ -75,7 +88,16 @@ class ReshapeNode : public NodeCRTP { CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &x_desc)); - +#if (CUDNN_VERSION >= 92200) + // Set reshape mode + cudnnBackendReshapeMode_t cudnn_reshape_mode; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_reshape_mode(), cudnn_reshape_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESHAPE_MODE, + CUDNN_TYPE_RESHAPE_MODE, + 1, + &cudnn_reshape_mode)); +#endif // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reshape_attributes::output_names::Y); auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index 71edf64d..ee1d0b27 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -512,8 +512,16 @@ class SDPANodeBase : public NodeCRTP { #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { - j = attributes; - if (attributes.mma_core_mode == DataType_t::FP8_E4M3 || attributes.mma_core_mode == DataType_t::FP8_E5M2) { + j = attributes; + j["is_mxfp8"] = is_mxfp8_scaling(); + j["unfuse_fma"] = attributes.unfuse_fma; + if (auto const rescale_threshold = get_rescale_threshold_from_env(); rescale_threshold.has_value()) { + j["rescale_threshold"] = rescale_threshold.value(); + } + if (is_mxfp8_scaling()) { + j.update(R"({"tag": "SDPA_MXFP8_FWD"})"_json); + } else if (attributes.mma_core_mode == DataType_t::FP8_E4M3 || + attributes.mma_core_mode == DataType_t::FP8_E5M2) { j.update(R"({"tag": "SDPA_FP8_FWD"})"_json); } else { j.update(R"({"tag": "SDPA"})"_json); @@ -1296,11 +1304,6 @@ class CompositeSDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.6.0, group-query attention with raggged offset is not supported"); - // TODO add version check once fixed - RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_rng, - error_code_t::GRAPH_NOT_SUPPORTED, - "Dropout RNG dump is not supported for SM Major version 10"); - // TODO add version check once fixed RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_ragged && is_dbias, error_code_t::GRAPH_NOT_SUPPORTED, @@ -1952,10 +1955,32 @@ class CompositeSDPABackwardNode : public NodeCRTP { std::pair> override_heuristics_query() const { int32_t const sm_version = context.get_sm_version(); + bool const use_new_knobs = detail::get_backend_version() >= 92300; + // {128,128} bprop: tileM=3, tileN=2, kernelCfg=2(bprop warp), streamK=0, cgaM=0 if (sm_version > 103 && is_deterministic_algorithm_supported_on_blackwell) { - return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + if (use_new_knobs) { + return {17, + {{KnobType_t::TILE_M, 3}, + {KnobType_t::TILE_N, 2}, + {KnobType_t::KERNEL_CFG, 2}, + {KnobType_t::STREAM_K, 0}, + {KnobType_t::TILE_CGA_M, 0}, + {KnobType_t::STAGES, 2}}}; + } else { + return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } } else if (is_deterministic_algorithm_supported_on_blackwell) { - return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + if (use_new_knobs) { + return {5, + {{KnobType_t::TILE_M, 3}, + {KnobType_t::TILE_N, 2}, + {KnobType_t::KERNEL_CFG, 2}, + {KnobType_t::STREAM_K, 0}, + {KnobType_t::TILE_CGA_M, 0}, + {KnobType_t::STAGES, 2}}}; + } else { + return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } } else { return {-1, {}}; } diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h index 12b390c7..29b8f200 100644 --- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -1089,10 +1089,32 @@ class SDPAFP8BackwardNode : public NodeCRTP { std::pair> override_heuristics_query() const { int32_t const sm_version = context.get_sm_version(); + bool const use_new_knobs = detail::get_backend_version() >= 92300; + // {128,128} bprop: tileM=3, tileN=2, kernelCfg=2(bprop warp), streamK=0, cgaM=0 if (sm_version > 103 && is_deterministic_algorithm_supported_on_blackwell) { - return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + if (use_new_knobs) { + return {17, + {{KnobType_t::TILE_M, 3}, + {KnobType_t::TILE_N, 2}, + {KnobType_t::KERNEL_CFG, 2}, + {KnobType_t::STREAM_K, 0}, + {KnobType_t::TILE_CGA_M, 0}, + {KnobType_t::STAGES, 2}}}; + } else { + return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } } else if (is_deterministic_algorithm_supported_on_blackwell) { - return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + if (use_new_knobs) { + return {5, + {{KnobType_t::TILE_M, 3}, + {KnobType_t::TILE_N, 2}, + {KnobType_t::KERNEL_CFG, 2}, + {KnobType_t::STREAM_K, 0}, + {KnobType_t::TILE_CGA_M, 0}, + {KnobType_t::STAGES, 2}}}; + } else { + return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } } else { return {-1, {}}; } @@ -1102,7 +1124,15 @@ class SDPAFP8BackwardNode : public NodeCRTP { virtual void serialize(json& j) const override final { j = attributes; - j.update(R"({"tag": "SDPA_FP8_BWD"})"_json); + j["is_mxfp8"] = is_mxfp8_scaling(); + if (auto const rescale_threshold = get_rescale_threshold_from_env(); rescale_threshold.has_value()) { + j["rescale_threshold"] = rescale_threshold.value(); + } + if (is_mxfp8_scaling()) { + j.update(R"({"tag": "SDPA_MXFP8_BWD"})"_json); + } else { + j.update(R"({"tag": "SDPA_FP8_BWD"})"_json); + } } #endif }; diff --git a/include/cudnn_frontend/node/sdpa_support_surface.h b/include/cudnn_frontend/node/sdpa_support_surface.h index 42e40476..223539b3 100644 --- a/include/cudnn_frontend/node/sdpa_support_surface.h +++ b/include/cudnn_frontend/node/sdpa_support_surface.h @@ -350,10 +350,6 @@ SDPA_attributes::validate_sdpa_support_surface(const detail::Context& context, "THD (ragged offset) is only supported in Hopper and above : " + std::to_string(context.get_sm_version())); } - // TODO add version check once fixed - RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_rng, - error_code_t::GRAPH_NOT_SUPPORTED, - "dropout RNG dump is not supported for Blackwell architecture"); } else { RETURN_CUDNN_FRONTEND_ERROR_IF(true, error_code_t::GRAPH_NOT_SUPPORTED, "Unsupported mma core mode"); } @@ -464,8 +460,13 @@ SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Co "Diagonal alignment for unified SDPA node requires cuDNN 9.21.0 or above"}; } - if (dropout_probability.has_value() && effective_cudnn_ver < 92100) { - return {error_code_t::GRAPH_NOT_SUPPORTED, "Dropout for unified SDPA node requires cuDNN 9.21.0"}; + if (dropout_probability.has_value() && effective_cudnn_ver < 92200) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "Dropout for unified SDPA node requires cuDNN 9.22.0"}; + } + + if (dropout_probability.has_value() && generate_stats.value_or(false)) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Dropout for unified SDPA node with generated stats is not supported"}; } // Unified engine in cuDNN < 9.15 can't meaningfully support max sequence length, @@ -496,4 +497,4 @@ SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Co return {error_code_t::OK, ""}; } -} // namespace cudnn_frontend::graph \ No newline at end of file +} // namespace cudnn_frontend::graph diff --git a/include/cudnn_frontend/node/slice.h b/include/cudnn_frontend/node/slice.h index e40f5c53..ebcad6a0 100644 --- a/include/cudnn_frontend/node/slice.h +++ b/include/cudnn_frontend/node/slice.h @@ -1,5 +1,8 @@ #pragma once +#include "../graph_helpers.h" +#include "../node_interface.h" + namespace cudnn_frontend::graph { class SliceNode : public NodeCRTP { @@ -21,12 +24,28 @@ class SliceNode : public NodeCRTP { attributes.fill_from_context(context); + for (size_t i = 0; i < attributes.slice_strides.size(); ++i) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.slice_strides[i] <= 0, + error_code_t::INVALID_VALUE, + "Slice slice_strides[" + std::to_string(i) + "] must be strictly positive (got " + + std::to_string(attributes.slice_strides[i]) + + "). Non-positive strides break output dimension calculation and can cause division by zero."); + } + auto output = attributes.outputs.at(Slice_attributes::output_names::Y); auto output_dim = output->get_dim(); if (output_dim.empty()) { for (size_t i = 0; i < attributes.slices.size(); ++i) { - output_dim.push_back(attributes.slices[i].second - attributes.slices[i].first); + int64_t start = attributes.slices[i].first; + int64_t limit = attributes.slices[i].second; + int64_t stride = (!attributes.slice_strides.empty() && i < attributes.slice_strides.size()) + ? attributes.slice_strides[i] + : 1; + // Output dimension = ceil((limit - start) / stride) + int64_t dim = (limit - start + stride - 1) / stride; + output_dim.push_back(dim); } output->set_dim(output_dim); } @@ -44,10 +63,15 @@ class SliceNode : public NodeCRTP { auto const input_stride = input->get_stride(); if (output->get_stride().empty()) { - // For simple slicing without changing the step, the stride remains the same - // std::vector stride_order = - // detail::generate_stride_order_preserving_format(input_stride, output_dim.size()); - output->set_stride(input_stride); + // When slice strides > 1, output strides need to be multiplied accordingly + std::vector output_stride; + for (size_t i = 0; i < input_stride.size(); ++i) { + int64_t stride = (!attributes.slice_strides.empty() && i < attributes.slice_strides.size()) + ? attributes.slice_strides[i] + : 1; + output_stride.push_back(input_stride[i] * stride); + } + output->set_stride(output_stride); } return {error_code_t::OK, ""}; @@ -57,16 +81,27 @@ class SliceNode : public NodeCRTP { create_cudnn_tensors_node(std::unordered_map>& tensors, int64_t& potential_uid, std::unordered_set const& used_uids) const override final { - // Do not make input tensor for backend. - // But assign it a uid - auto const input = attributes.inputs.at(Slice_attributes::input_names::X); + getLogger() << "[cudnn_frontend] INFO: Creating cudnn tensors for SliceNode " << attributes.name << std::endl; + + auto const input = attributes.inputs.at(Slice_attributes::input_names::X); + auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); + +#if (CUDNN_VERSION >= 92200) + // For cuDNN >= 9.22.0: Use native slice operation, create both input and output tensors + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(input, tensors, potential_uid, used_uids)); + output->set_is_virtual(false); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); +#else + // For cuDNN < 9.22.0: Fallback to pointer arithmetic approach + // Only assign UID to input, don't create backend tensor if (input->has_uid() == false) { detail::assign_uid(input.get(), potential_uid, used_uids); } - auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); + // Create output tensor output->set_is_virtual(false); CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); +#endif return {error_code_t::OK, ""}; } @@ -74,20 +109,98 @@ class SliceNode : public NodeCRTP { error_t create_cudnn_operations( std::unordered_set& uids_involved_in_operations, - std::vector>&, + std::vector>& operations, managed_backend_descriptor_t& raw_operations, - std::unordered_map>&) const override final { + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building SliceNode operations " << attributes.name << std::endl; + +#if (CUDNN_VERSION >= 92200) + // cuDNN >= 9.22.0: Use native backend slice operation + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Slice requires cuDNN v9.22.0"}; + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + + auto slice_operation = make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR); + + // Set input tensor + auto X = attributes.inputs.at(Slice_attributes::input_names::X); + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SLICE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + // Set output tensor + auto Y = attributes.outputs.at(Slice_attributes::output_names::Y); + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SLICE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + // Extract start and limit indices from slices + std::vector start_indices; + std::vector limit_indices; + + for (const auto& slice : attributes.slices) { + start_indices.push_back(slice.first); + limit_indices.push_back(slice.second); + } + + // Per-dimension strides: use user slice_strides[i] when set, else 1 (preserves partial configuration) + std::vector strides(attributes.slices.size()); + for (size_t i = 0; i < strides.size(); ++i) { + strides[i] = (i < attributes.slice_strides.size()) ? attributes.slice_strides[i] : 1; + } + + // Set start indices + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SLICE_START_INDICES, + CUDNN_TYPE_INT64, + start_indices.size(), + start_indices.data())); + + // Set limit indices + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SLICE_LIMIT_INDICES, + CUDNN_TYPE_INT64, + limit_indices.size(), + limit_indices.data())); + + // Set strides + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SLICE_STRIDES, + CUDNN_TYPE_INT64, + strides.size(), + strides.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(slice_operation->get_backend_descriptor())); + + raw_operations.push_back(slice_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); +#else + // cuDNN < 9.22.0: Fallback to pointer arithmetic (no backend operation needed) + // The collect_variant_pack_replacements_node method handles the pointer offset mapping + CUDNN_FRONTEND_UNUSED(operations); CUDNN_FRONTEND_UNUSED(raw_operations); - // No corresponding backend operation + CUDNN_FRONTEND_UNUSED(tensors); - auto const virutal_output = attributes.outputs.at(Slice_attributes::output_names::Y); - if (virutal_output && virutal_output->get_is_virtual() == false) { - uids_involved_in_operations.insert(virutal_output->get_uid()); - if (auto ragged_offset = virutal_output->get_ragged_offset()) { + getLogger() << "[cudnn_frontend] INFO: " << "Using pointer arithmetic fallback for slice on cuDNN < 9.22.0" + << std::endl; + + // Register the output tensor as involved (input is handled via variant pack replacement) + auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); + if (output && output->get_is_virtual() == false) { + uids_involved_in_operations.insert(output->get_uid()); + if (auto ragged_offset = output->get_ragged_offset()) { uids_involved_in_operations.insert(ragged_offset->get_uid()); } } - +#endif return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/node/transpose.h b/include/cudnn_frontend/node/transpose.h new file mode 100644 index 00000000..7f764be7 --- /dev/null +++ b/include/cudnn_frontend/node/transpose.h @@ -0,0 +1,164 @@ +#pragma once + +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class TransposeNode : public NodeCRTP { + public: + Transpose_attributes attributes; + + TransposeNode(Transpose_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::TRANSPOSE; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for transpose node " << attributes.name + << std::endl; + + attributes.fill_from_context(context); + + auto const& X_tensor = attributes.inputs[Transpose_attributes::input_names::X]; + auto Y_tensor = attributes.outputs[Transpose_attributes::output_names::Y]; + + // Get input properties + auto const& input_dim = X_tensor->get_dim(); + auto const& input_stride = X_tensor->get_stride(); + auto const& input_data_type = X_tensor->get_data_type(); + auto const& permutation = attributes.get_permutation(); + + // Validate permutation + RETURN_CUDNN_FRONTEND_ERROR_IF( + permutation.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Permutation must be set for transpose operation."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(permutation.size() != input_dim.size(), + error_code_t::INVALID_VALUE, + "Permutation size must match input tensor dimensionality."); + + // Check that permutation is a valid permutation (contains each index 0 to n-1 exactly once) + std::vector seen(permutation.size(), false); + for (auto idx : permutation) { + RETURN_CUDNN_FRONTEND_ERROR_IF(idx < 0 || idx >= static_cast(permutation.size()), + error_code_t::INVALID_VALUE, + "Permutation indices must be in range [0, n-1]."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + seen[idx], error_code_t::INVALID_VALUE, "Permutation indices must be unique."); + seen[idx] = true; + } + + // Infer output dimensions by permuting input dimensions + std::vector output_dim(input_dim.size()); + for (size_t i = 0; i < permutation.size(); ++i) { + output_dim[i] = input_dim[permutation[i]]; + } + + // Infer output strides by permuting input strides + std::vector output_stride(input_stride.size()); + for (size_t i = 0; i < permutation.size(); ++i) { + output_stride[i] = input_stride[permutation[i]]; + } + + // Set output tensor properties + if (Y_tensor->get_dim().empty()) { + Y_tensor->set_dim(output_dim); + } + if (Y_tensor->get_stride().empty()) { + Y_tensor->set_stride(output_stride); + } + if (Y_tensor->get_data_type() == DataType_t::NOT_SET) { + Y_tensor->set_data_type(input_data_type); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(Y_tensor->get_data_type() != input_data_type, + error_code_t::INVALID_VALUE, + "Output and input tensor data types must match for transpose operation."); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building TransposeNode operations " << attributes.name + << std::endl; + +#if (CUDNN_VERSION >= 92200) + // cuDNN >= 9.22.0: Use native backend transpose operation + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Transpose requires cuDNN v9.22.0"}; + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + + auto transpose_operation = make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR); + + // Set input tensor + auto X = attributes.inputs.at(Transpose_attributes::input_names::X); + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_TRANSPOSE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + // Set output tensor + auto Y = attributes.outputs.at(Transpose_attributes::output_names::Y); + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_TRANSPOSE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + // Set permutation + auto permutation = attributes.get_permutation(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_TRANSPOSE_PERMUTATION, + CUDNN_TYPE_INT64, + permutation.size(), + permutation.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(transpose_operation->get_backend_descriptor())); + + raw_operations.push_back(transpose_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); +#else + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + return {error_code_t::GRAPH_NOT_SUPPORTED, "Transpose operation requires cuDNN version >= 9.22.0"}; +#endif + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "TRANSPOSE"})"_json); + } +#endif +}; + +inline std::shared_ptr +INode::transpose(std::shared_ptr input, Transpose_attributes attributes) { + attributes.inputs[Transpose_attributes::input_names::X] = input; + auto Y = attributes.outputs[Transpose_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} + +} // namespace cudnn_frontend::graph diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index e9109ce3..7bc922ad 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -38,6 +38,8 @@ class CompositeSoftmaxNode; class UnifiedSoftmaxNode; class MoeGroupedMatmulNode; class UnifiedDiagonalBandMaskNode; +class TransposeNode; +class SliceNode; // Interface for all nodes to follow. class INode { @@ -138,6 +140,7 @@ class INode { RMSNORM, RNG, SLICE, + TRANSPOSE, WGRAD, PAGED_CACHE_LOAD, BLOCK_SCALE_QUANTIZE, @@ -253,6 +256,15 @@ class INode { Moe_grouped_matmul_bwd_attributes attributes, std::shared_ptr dweight); + void + transpose(std::shared_ptr input, + Transpose_attributes attributes, + std::shared_ptr output); + + void + slice(std::shared_ptr input, + Slice_attributes attributes, + std::shared_ptr output); error_t validate_subtree() { // pre validate to catch errors early @@ -392,6 +404,8 @@ class INode { std::shared_ptr reduction(std::shared_ptr, Reduction_attributes); std::array, 2> resample(std::shared_ptr, Resample_attributes); std::shared_ptr reshape(std::shared_ptr, Reshape_attributes); + std::shared_ptr transpose(std::shared_ptr, Transpose_attributes); + std::shared_ptr slice(std::shared_ptr, Slice_attributes); std::shared_ptr rng(std::shared_ptr, std::shared_ptr, diff --git a/include/cudnn_frontend/utils/attn_score_modifiers.h b/include/cudnn_frontend/utils/attn_score_modifiers.h index ff0a3877..e0597907 100644 --- a/include/cudnn_frontend/utils/attn_score_modifiers.h +++ b/include/cudnn_frontend/utils/attn_score_modifiers.h @@ -253,7 +253,7 @@ sliding_window_mask(std::shared_ptr graph, [[maybe_unused]] inline error_t build_operation_subgraph(std::shared_ptr graph) { return graph->build_operation_graph(/*handle=*/nullptr); -}; +} class Softcap { private: diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index 6cc4590a..1277c5b8 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -298,6 +298,13 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Reshape_attributes::input_names, NLOHMANN_JSON_SERIALIZE_ENUM(Reshape_attributes::output_names, {{Reshape_attributes::output_names::Y, "Y"}}) +NLOHMANN_JSON_SERIALIZE_ENUM(Transpose_attributes::input_names, + { + {Transpose_attributes::input_names::X, "X"}, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(Transpose_attributes::output_names, {{Transpose_attributes::output_names::Y, "Y"}}) + NLOHMANN_JSON_SERIALIZE_ENUM(Rmsnorm_attributes::input_names, { {Rmsnorm_attributes::input_names::X, "X"}, @@ -457,6 +464,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::input_names, {SDPA_fp8_backward_attributes::input_names::Scale_dV, "Scale_dV"}, {SDPA_fp8_backward_attributes::input_names::Scale_S, "Scale_S"}, {SDPA_fp8_backward_attributes::input_names::Scale_dP, "Scale_dP"}, + {SDPA_fp8_backward_attributes::input_names::SINK_TOKEN, "SINK_TOKEN"}, }) NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::output_names, @@ -468,6 +476,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::output_names, {SDPA_fp8_backward_attributes::output_names::Amax_dK, "Amax_dK"}, {SDPA_fp8_backward_attributes::output_names::Amax_dV, "Amax_dV"}, {SDPA_fp8_backward_attributes::output_names::Amax_dP, "Amax_d"}, + {SDPA_fp8_backward_attributes::output_names::DSINK_TOKEN, "DSINK_TOKEN"}, }) NLOHMANN_JSON_SERIALIZE_ENUM(Block_scale_quantize_attributes::input_names, diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index 6bf01019..5be62537 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -236,6 +236,7 @@ class Operation_v8 : public BackendDescriptor { NormFwdPhase_t norm_fwd_phase; NormMode_t norm_mode; + ReshapeMode_t reshape_mode = ReshapeMode_t::VIEW_ONLY; float alpha_s = 1.0f, beta_s = .0f, alpha2_s = 1.0f; double alpha_d = 1.0, beta_d = 0.0, alpha2_d = 1.0; @@ -1790,6 +1791,29 @@ class OperationBuilder_v8 { "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_RESHAPE_YDESC Failed"); return std::move(m_operation); } + +#if (CUDNN_VERSION >= 92200) + // Set reshape mode if it's not NOT_SET + if (m_operation.reshape_mode != ReshapeMode_t::NOT_SET) { + cudnnBackendReshapeMode_t cudnn_reshape_mode; + status = detail::convert_to_cudnn_type(m_operation.reshape_mode, cudnn_reshape_mode); + if (status == CUDNN_STATUS_SUCCESS) { + status = detail::set_attribute(m_operation.pointer->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_RESHAPE_MODE, + CUDNN_TYPE_RESHAPE_MODE, + 1, + &cudnn_reshape_mode); + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception( + &m_operation, + status, + "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_RESHAPE_MODE Failed"); + return std::move(m_operation); + } + } + } +#endif + status = detail::finalize(m_operation.pointer->get_backend_descriptor()); if (status != CUDNN_STATUS_SUCCESS) { set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed"); @@ -2561,6 +2585,21 @@ class OperationBuilder_v8 { return *this; } + auto + setReshapeMode(ReshapeMode_t mode) -> OperationBuilder_v8 & { + m_operation.reshape_mode = mode; + return *this; + } + +#if (CUDNN_VERSION >= 92200) + // To be deprecated. Please use setReshapeMode(cudnn_frontend::ReshapeMode_t mode) instead. + auto + setReshapeMode(cudnnBackendReshapeMode_t mode) -> OperationBuilder_v8 & { + detail::convert_from_cudnn_type(mode, m_operation.reshape_mode); + return *this; + } +#endif + auto setNormalizationMode(NormMode_t mode) -> OperationBuilder_v8 & { m_operation.norm_mode = mode; diff --git a/include/cudnn_frontend_Tensor.h b/include/cudnn_frontend_Tensor.h index 91585003..ba538844 100644 --- a/include/cudnn_frontend_Tensor.h +++ b/include/cudnn_frontend_Tensor.h @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -179,6 +180,10 @@ class Tensor_v8 : public BackendDescriptor { int64_t vectorCount = 1; //! What is the vectorization count (4 or 32) bool isVirtual = false; //! Whether it is an intermediate tensor of an op graph bool isByValue = false; //! Whether the tensor is in host memory that needs to be passed to the kernel by value + + bool hasConstValue = false; //! Whether this tensor has a compile-time constant value + std::vector constValueBytes; //! Raw bytes of the compile-time constant value + cudnn_frontend::TensorReordering_t reorder_type = cudnn_frontend::TensorReordering_t::NONE; //! Type of reordering in the tensor std::shared_ptr raggedOffset; //! Ragged offsets for ragged tensors @@ -242,6 +247,18 @@ class TensorBuilder_v8 { m_tensor.isByValue = isByValue_; return *this; } + + //! Set compile-time constant value for this tensor + template + auto + setConstValue(T const &value) -> TensorBuilder_v8 & { + m_tensor.hasConstValue = true; + m_tensor.constValueBytes.resize(sizeof(T)); + std::memcpy(m_tensor.constValueBytes.data(), &value, sizeof(T)); + m_tensor.isByValue = true; + return *this; + } + auto setVectorCountAndDimension(int64_t vectorCount_, int64_t vectorDimension_) -> TensorBuilder_v8 & { m_tensor.vectorCount = vectorCount_; @@ -506,6 +523,38 @@ class TensorBuilder_v8 { return std::move(m_tensor); } } + + // Set compile-time constant value (if present); CUDNN_ATTR_TENSOR_CONSTANT_VALUE is cuDNN 9.22.0+ +#if (CUDNN_VERSION >= 92200) + NV_CUDNN_FE_DYNAMIC_CHECK_BACKEND_DESCRIPTOR(92200, + m_tensor, + "CUDNN_BACKEND_TENSOR_DESCRIPTOR: SetAttribute " + "CUDNN_ATTR_TENSOR_CONSTANT_VALUE requires cudnn version 9.22.0"); + if (m_tensor.hasConstValue) { + void *value_ptr = m_tensor.constValueBytes.data(); + status = detail::set_attribute(m_tensor.pointer->get_backend_descriptor(), + CUDNN_ATTR_TENSOR_CONSTANT_VALUE, + CUDNN_TYPE_VOID_PTR, + 1, + &value_ptr); + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception( + &m_tensor, + status, + "CUDNN_BACKEND_TENSOR_DESCRIPTOR: SetAttribute CUDNN_ATTR_TENSOR_CONSTANT_VALUE Failed"); + return std::move(m_tensor); + } + } +#else + if (m_tensor.hasConstValue) { + set_error_and_throw_exception( + &m_tensor, + CUDNN_STATUS_INVALID_VALUE, + "CUDNN_BACKEND_TENSOR_DESCRIPTOR: CUDNN_ATTR_TENSOR_CONSTANT_VALUE requires cudnn version 9.22.0"); + return std::move(m_tensor); + } +#endif // CUDNN_VERSION >= 92200 + // Finalizing the descriptor status = detail::finalize(m_tensor.pointer->get_backend_descriptor()); if (status != CUDNN_STATUS_SUCCESS) { diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index e07ca484..ea838d33 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -427,6 +427,70 @@ destroy_handle(cudnnHandle_t handle) { NV_FE_CALL_TO_BACKEND(destroy_handle, cudnnDestroy, handle); } +#if CUDNN_VERSION >= 92200 +inline cudnnStatus_t +causal_conv1d_forward(cudaStream_t stream, + const void *x, + const void *weight, + const void *bias, + void *y, + int batch, + int dim, + int seq_len, + int kernel_size, + cudnnDataType_t data_type, + cudnnCausalConv1dActivation_t activation) { + NV_FE_CALL_TO_BACKEND(causal_conv1d_forward, + cudnnCausalConv1dForward, + stream, + x, + weight, + bias, + y, + batch, + dim, + seq_len, + kernel_size, + data_type, + activation); +} + +inline cudnnStatus_t +causal_conv1d_backward(cudaStream_t stream, + const void *x, + const void *weight, + const void *bias, + const void *dy, + void *dx, + void *dweight, + void *dbias, + int batch, + int dim, + int seq_len, + int kernel_size, + cudnnDataType_t data_type, + cudnnDataType_t dw_data_type, + cudnnCausalConv1dActivation_t activation) { + NV_FE_CALL_TO_BACKEND(causal_conv1d_backward, + cudnnCausalConv1dBackward, + stream, + x, + weight, + bias, + dy, + dx, + dweight, + dbias, + batch, + dim, + seq_len, + kernel_size, + data_type, + dw_data_type, + activation); +} +#endif + inline size_t get_backend_version(void) { #if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index 603b9cf1..33e7c3a5 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -96,21 +96,21 @@ struct nlohmann::adl_serializer { }; template <> -struct nlohmann::adl_serializer> { +struct nlohmann::adl_serializer> { static void - to_json(nlohmann::json& j, const std::variant& data) { + to_json(nlohmann::json& j, const std::variant& data) { std::visit([&](const auto& v) { j = {{"index", data.index()}, {"value", v}}; }, data); } static void - from_json(const nlohmann::json& j, std::variant& data) { + from_json(const nlohmann::json& j, std::variant& data) { if (!j.is_object() || !j.contains("index") || !j.contains("value")) { return; } size_t type_index = j.at("index").get(); if (type_index == 0) { - data = j.at("value").get(); + data = j.at("value").get(); } else if (type_index == 1) { data = j.at("value").get(); } else if (type_index == 2) { @@ -118,6 +118,8 @@ struct nlohmann::adl_serializer(); } else if (type_index == 4) { + data = j.at("value").get(); + } else if (type_index == 5) { data = j.at("value").get(); } else { return; @@ -386,6 +388,20 @@ enum class PaddingMode_t { ZERO_PAD }; +enum class ReshapeMode_t { + NOT_SET, + + VIEW_ONLY, + LOGICAL +}; + +NLOHMANN_JSON_SERIALIZE_ENUM(ReshapeMode_t, + { + {ReshapeMode_t::NOT_SET, nullptr}, + {ReshapeMode_t::VIEW_ONLY, "VIEW_ONLY"}, + {ReshapeMode_t::LOGICAL, "LOGICAL"}, + }) + enum class ConvolutionMode_t { NOT_SET, @@ -480,6 +496,8 @@ enum class DescriptorType_t { OPERATION_CONCATENATE_DESCRIPTOR, OPERATION_MOE_GROUPED_MATMUL_DESCRIPTOR, OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR, + OPERATION_TRANSPOSE_DESCRIPTOR, + OPERATION_SLICE_DESCRIPTOR }; enum class NormMode_t { @@ -962,6 +980,12 @@ operator<<(std::ostream& os, const DescriptorType_t& mode) { case DescriptorType_t::OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR: os << "OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR"; break; + case DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR: + os << "OPERATION_TRANSPOSE_DESCRIPTOR"; + break; + case DescriptorType_t::OPERATION_SLICE_DESCRIPTOR: + os << "OPERATION_SLICE_DESCRIPTOR"; + break; case DescriptorType_t::NOT_SET: os << "NOT_SET"; break; @@ -1684,6 +1708,22 @@ convert_to_cudnn_type(cudnn_frontend::DescriptorType_t const mode, cudnnBackendD #else return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; #endif + case DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR: +#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE); + cudnn_mode = CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#else + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif + case DescriptorType_t::OPERATION_SLICE_DESCRIPTOR: +#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE); + cudnn_mode = CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#else + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif #ifndef NO_DEFAULT_IN_SWITCH default: @@ -1799,6 +1839,26 @@ convert_to_cudnn_type(cudnn_frontend::NormFwdPhase_t const mode, cudnnBackendNor return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; } +#if (CUDNN_VERSION >= 92200) +static inline cudnnStatus_t +convert_to_cudnn_type(cudnn_frontend::ReshapeMode_t const mode, cudnnBackendReshapeMode_t& cudnn_mode) { + switch (mode) { + case ReshapeMode_t::VIEW_ONLY: + cudnn_mode = CUDNN_RESHAPE_VIEW_ONLY; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case ReshapeMode_t::LOGICAL: + cudnn_mode = CUDNN_RESHAPE_LOGICAL; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + +#ifndef NO_DEFAULT_IN_SWITCH + default: + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif + } + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +} +#endif + // To be deprecated. Only exists as setResampleMode(cudnnPaddingMode_t) requires it. static inline void convert_from_cudnn_type(cudnnPaddingMode_t const cudnn_mode, cudnn_frontend::PaddingMode_t& mode) { @@ -1930,6 +1990,27 @@ convert_from_cudnn_type(cudnnBackendNormFwdPhase_t const cudnn_mode, cudnn_front } } +#if (CUDNN_VERSION >= 92200) +// To be deprecated. Only exists as setReshapeMode(cudnnBackendReshapeMode_t) requires it. +static inline void +convert_from_cudnn_type(cudnnBackendReshapeMode_t const cudnn_mode, cudnn_frontend::ReshapeMode_t& mode) { + mode = ReshapeMode_t::NOT_SET; + switch (cudnn_mode) { + case CUDNN_RESHAPE_VIEW_ONLY: + mode = ReshapeMode_t::VIEW_ONLY; + break; + case CUDNN_RESHAPE_LOGICAL: + mode = ReshapeMode_t::LOGICAL; + break; + +#ifndef NO_DEFAULT_IN_SWITCH + default: + break; +#endif + } +} +#endif + static inline cudnnStatus_t convert_to_cudnn_type(cudnn_frontend::TensorReordering_t const mode, cudnnBackendTensorReordering_t& cudnn_mode) { switch (mode) { @@ -2094,6 +2175,13 @@ convert_from_cudnn_type(cudnnBackendDescriptorType_t const cudnn_mode) { return DescriptorType_t::OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR; #endif +#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900) + case CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR: + return DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR; + case CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR: + return DescriptorType_t::OPERATION_SLICE_DESCRIPTOR; +#endif + #ifndef NO_DEFAULT_IN_SWITCH default: return DescriptorType_t::NOT_SET; diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h index b3f38800..bfbaa99f 100644 --- a/include/cudnn_frontend_version.h +++ b/include/cudnn_frontend_version.h @@ -23,7 +23,7 @@ #pragma once #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 22 -#define CUDNN_FRONTEND_PATCH_VERSION 1 +#define CUDNN_FRONTEND_MINOR_VERSION 23 +#define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/pyproject.toml b/pyproject.toml index 0b1d55ad..1ec44cd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,48 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cudnn-frontend" dynamic = ["version"] -description = "CUDNN FrontEnd python library" +description = "NVIDIA cuDNN Frontend — Python and C++ Graph API with SOTA attention (SDPA / Flash Attention), MoE grouped GEMM fusions, and FP8/MXFP8 kernels for Hopper and Blackwell GPUs." readme = "README.md" requires-python = ">=3.9" -license = {text = "NVIDIA Proprietary Software"} +license = {text = "MIT"} +keywords = [ + "cudnn", + "cuda", + "gpu", + "nvidia", + "deep-learning", + "attention", + "sdpa", + "flash-attention", + "transformer", + "moe", + "mixture-of-experts", + "grouped-gemm", + "fp8", + "mxfp8", + "blackwell", + "hopper", + "pytorch", + "kernel", + "graph-api", +] classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Programming Language :: C++", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", ] [tool.setuptools] @@ -19,8 +55,12 @@ package-dir = {"" = "python", "include" = "include"} include-package-data = true [project.urls] -"Homepage" = "https://github.com/nvidia/cudnn-frontend" -"Bug Tracker" = "https://github.com/nvidia/cudnn-frontend/issues" +"Homepage" = "https://github.com/NVIDIA/cudnn-frontend" +"Documentation" = "https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/" +"Blog" = "https://nvidia.github.io/cudnn-frontend/" +"Repository" = "https://github.com/NVIDIA/cudnn-frontend" +"Bug Tracker" = "https://github.com/NVIDIA/cudnn-frontend/issues" +"Release Notes" = "https://github.com/NVIDIA/cudnn-frontend/releases" [tool.setuptools.dynamic] version = {attr = "cudnn.__version__"} diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 80836a14..f0de8773 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -37,6 +37,7 @@ find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) include(${PROJECT_SOURCE_DIR}/cmake/cuDNN.cmake) option(CUDNN_FRONTEND_FETCH_PYBINDS_IN_CMAKE "Whether cmake build system should fetch pybinds." ON) +set(PYBIND11_FINDPYTHON ON) if(CUDNN_FRONTEND_FETCH_PYBINDS_IN_CMAKE) diff --git a/python/cudnn/README.md b/python/cudnn/README.md index 1d766455..f574cd66 100644 --- a/python/cudnn/README.md +++ b/python/cudnn/README.md @@ -38,11 +38,17 @@ To add a new frontend-only API, follow these steps: **Currently implemented frontend-only APIs**: - `GEMM + Amax` +- `RMSNorm + RHT + Amax` - `GEMM + SwiGLU` +- `GEMM + sReLU` +- `GEMM + dsReLU` - `Grouped Gemm + GLU (Unified)` +- `Grouped Gemm + GLU + Hadamard` - `Grouped Gemm + dGLU (Unified)` - `Grouped Gemm + SwiGLU (Legacy, Contiguous-only)` - `Grouped Gemm + dSwiglu (Legacy, Contiguous-only)` +- `Grouped Gemm + sReLU (Contiguous-only)` +- `Grouped Gemm + dsReLU (Contiguous-only)` - `Discrete Grouped Gemm + SwiGLU` - `Discrete Grouped Gemm + dSwiglu` - `Grouped Gemm + Quant (Legacy, Dense-only)` diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index a1125b2d..b9b9d39d 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -40,14 +40,20 @@ def is_windows(): "diagonal_alignment", "attention_implementation", "moe_grouped_matmul_mode", + "scalar_type", + "reshape_mode", ] for symbol_name in symbols_to_import: globals()[symbol_name] = getattr(_pybind_module, symbol_name) +for _optional_symbol in ["causal_conv1d_forward", "causal_conv1d_backward"]: + if hasattr(_pybind_module, _optional_symbol): + globals()[_optional_symbol] = getattr(_pybind_module, _optional_symbol) + from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.22.1" +__version__ = "1.23.0" def _tensor( @@ -230,287 +236,83 @@ def _dlopen_cudnn(): from typing import Any +_OPTIONAL_DEPENDENCY_INSTALL_HINT = "Install with 'pip install nvidia-cudnn-frontend[cutedsl]'" + +_LAZY_OPTIONAL_IMPORTS = { + "NSA": (".native_sparse_attention", "NSA"), + "GemmSwigluSm100": (".gemm_swiglu", "GemmSwigluSm100"), + "gemm_swiglu_wrapper_sm100": (".gemm_swiglu", "gemm_swiglu_wrapper_sm100"), + "GemmSreluSm100": (".gemm_srelu", "GemmSreluSm100"), + "gemm_srelu_wrapper_sm100": (".gemm_srelu", "gemm_srelu_wrapper_sm100"), + "GemmDsreluSm100": (".gemm_dsrelu", "GemmDsreluSm100"), + "gemm_dsrelu_wrapper_sm100": (".gemm_dsrelu", "gemm_dsrelu_wrapper_sm100"), + "GemmAmaxSm100": (".gemm_amax", "GemmAmaxSm100"), + "gemm_amax_wrapper_sm100": (".gemm_amax", "gemm_amax_wrapper_sm100"), + "RmsNormRhtAmaxSm100": (".rmsnorm_rht_amax", "RmsNormRhtAmaxSm100"), + "rmsnorm_rht_amax_wrapper_sm100": (".rmsnorm_rht_amax", "rmsnorm_rht_amax_wrapper_sm100"), + "grouped_gemm": (".grouped_gemm", None), + "GroupedGemmSwigluSm100": (".grouped_gemm", "GroupedGemmSwigluSm100"), + "grouped_gemm_swiglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_swiglu_wrapper_sm100"), + "GroupedGemmDswigluSm100": (".grouped_gemm", "GroupedGemmDswigluSm100"), + "grouped_gemm_dswiglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dswiglu_wrapper_sm100"), + "GroupedGemmSreluSm100": (".grouped_gemm", "GroupedGemmSreluSm100"), + "grouped_gemm_srelu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_srelu_wrapper_sm100"), + "GroupedGemmDsreluSm100": (".grouped_gemm", "GroupedGemmDsreluSm100"), + "grouped_gemm_dsrelu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dsrelu_wrapper_sm100"), + "SdpafwdSm100D256": (".sdpa", "SdpafwdSm100D256"), + "sdpa_fwd_wrapper_sm100_d256": (".sdpa", "sdpa_fwd_wrapper_sm100_d256"), + "SdpabwdSm100D256": (".sdpa", "SdpabwdSm100D256"), + "sdpa_bwd_wrapper_sm100_d256": (".sdpa", "sdpa_bwd_wrapper_sm100_d256"), + "GroupedGemmQuantSm100": (".grouped_gemm", "GroupedGemmQuantSm100"), + "grouped_gemm_quant_wrapper_sm100": (".grouped_gemm", "grouped_gemm_quant_wrapper_sm100"), + "GroupedGemmGluSm100": (".grouped_gemm", "GroupedGemmGluSm100"), + "grouped_gemm_glu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_glu_wrapper_sm100"), + "GroupedGemmGluHadamardSm100": (".grouped_gemm", "GroupedGemmGluHadamardSm100"), + "grouped_gemm_glu_hadamard_wrapper_sm100": (".grouped_gemm", "grouped_gemm_glu_hadamard_wrapper_sm100"), + "GroupedGemmDgluSm100": (".grouped_gemm", "GroupedGemmDgluSm100"), + "grouped_gemm_dglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dglu_wrapper_sm100"), + "GroupedGemmWgradSm100": (".grouped_gemm", "GroupedGemmWgradSm100"), + "grouped_gemm_wgrad_wrapper_sm100": (".grouped_gemm", "grouped_gemm_wgrad_wrapper_sm100"), + "discrete_grouped_gemm": (".discrete_grouped_gemm", None), + "DiscreteGroupedGemmSwigluSm100": (".discrete_grouped_gemm", "DiscreteGroupedGemmSwigluSm100"), + "discrete_grouped_gemm_swiglu_wrapper_sm100": (".discrete_grouped_gemm", "discrete_grouped_gemm_swiglu_wrapper_sm100"), + "DiscreteGroupedGemmDswigluSm100": (".discrete_grouped_gemm", "DiscreteGroupedGemmDswigluSm100"), + "discrete_grouped_gemm_dswiglu_wrapper_sm100": (".discrete_grouped_gemm", "discrete_grouped_gemm_dswiglu_wrapper_sm100"), +} + + +def _load_optional_symbol(name: str) -> Any: + module_name, attr_name = _LAZY_OPTIONAL_IMPORTS[name] + try: + module = importlib.import_module(module_name, package=__name__) + value = module if attr_name is None else getattr(module, attr_name) + except Exception as e: + raise ImportError(f"{name} requires optional dependencies. {_OPTIONAL_DEPENDENCY_INSTALL_HINT}: {e}") from e + + globals()[name] = value + return value -def __getattr__(name: str) -> Any: - if name == "NSA": - try: - from .native_sparse_attention import NSA as _NSA - - return _NSA - except Exception as e: - raise ImportError(f"NSA requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "GemmSwigluSm100": - try: - from .gemm_swiglu import GemmSwigluSm100 as _GemmSwigluSm100 - - return _GemmSwigluSm100 - except Exception as e: - raise ImportError(f"GemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "gemm_swiglu_wrapper_sm100": - try: - from .gemm_swiglu import ( - gemm_swiglu_wrapper_sm100 as _gemm_swiglu_wrapper_sm100, - ) - - return _gemm_swiglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "GemmAmaxSm100": - try: - from .gemm_amax import GemmAmaxSm100 as _GemmAmaxSm100 - - return _GemmAmaxSm100 - except Exception as e: - raise ImportError(f"GemmAmaxSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "gemm_amax_wrapper_sm100": - try: - from .gemm_amax import ( - gemm_amax_wrapper_sm100 as _gemm_amax_wrapper_sm100, - ) - - return _gemm_amax_wrapper_sm100 - except Exception as e: - raise ImportError(f"gemm_amax_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - # Grouped GEMM module - elif name == "grouped_gemm": - try: - from . import grouped_gemm as _grouped_gemm - - return _grouped_gemm - except Exception as e: - raise ImportError(f"grouped_gemm requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "GroupedGemmSwigluSm100": - try: - from .grouped_gemm import GroupedGemmSwigluSm100 as _GroupedGemmSwigluSm100 - - return _GroupedGemmSwigluSm100 - except Exception as e: - raise ImportError(f"GroupedGemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_swiglu_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_swiglu_wrapper_sm100 as _grouped_gemm_swiglu_wrapper_sm100, - ) - - return _grouped_gemm_swiglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "GroupedGemmDswigluSm100": - try: - from .grouped_gemm import ( - GroupedGemmDswigluSm100 as _GroupedGemmDswigluSm100, - ) - - return _GroupedGemmDswigluSm100 - except Exception as e: - raise ImportError(f"GroupedGemmDswigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_dswiglu_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_dswiglu_wrapper_sm100 as _grouped_gemm_dswiglu_wrapper_sm100, - ) - - return _grouped_gemm_dswiglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_dswiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - elif name == "SdpafwdSm100D256": - try: - from .sdpa import SdpafwdSm100D256 as _SdpafwdSm100D256 - - return _SdpafwdSm100D256 - except Exception as e: - raise ImportError(f"SdpafwdSm100D256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "sdpa_fwd_wrapper_sm100_d256": - try: - from .sdpa import ( - sdpa_fwd_wrapper_sm100_d256 as _sdpa_fwd_wrapper_sm100_d256, - ) - - return _sdpa_fwd_wrapper_sm100_d256 - except Exception as e: - raise ImportError( - f"sdpa_fwd_wrapper_sm100_d256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "SdpabwdSm100D256": - try: - from .sdpa import SdpabwdSm100D256 as _SdpabwdSm100D256 - - return _SdpabwdSm100D256 - except Exception as e: - raise ImportError(f"SdpabwdSm100D256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "sdpa_bwd_wrapper_sm100_d256": - try: - from .sdpa import ( - sdpa_bwd_wrapper_sm100_d256 as _sdpa_bwd_wrapper_sm100_d256, - ) - - return _sdpa_bwd_wrapper_sm100_d256 - except Exception as e: - raise ImportError( - f"sdpa_bwd_wrapper_sm100_d256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "GroupedGemmQuantSm100": - try: - from .grouped_gemm import GroupedGemmQuantSm100 as _GroupedGemmQuantSm100 - - return _GroupedGemmQuantSm100 - except Exception as e: - raise ImportError(f"GroupedGemmQuantSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_quant_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_quant_wrapper_sm100 as _grouped_gemm_quant_wrapper_sm100, - ) - - return _grouped_gemm_quant_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_quant_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - # Unified Grouped GEMM GLU (forward) - elif name == "GroupedGemmGluSm100": - try: - from .grouped_gemm import GroupedGemmGluSm100 as _GroupedGemmGluSm100 - - return _GroupedGemmGluSm100 - except Exception as e: - raise ImportError(f"GroupedGemmGluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_glu_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_glu_wrapper_sm100 as _grouped_gemm_glu_wrapper_sm100, - ) - - return _grouped_gemm_glu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_glu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - # Unified Grouped GEMM dGLU (backward) - elif name == "GroupedGemmDgluSm100": - try: - from .grouped_gemm import GroupedGemmDgluSm100 as _GroupedGemmDgluSm100 - - return _GroupedGemmDgluSm100 - except Exception as e: - raise ImportError(f"GroupedGemmDgluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_dglu_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_dglu_wrapper_sm100 as _grouped_gemm_dglu_wrapper_sm100, - ) - return _grouped_gemm_dglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_dglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "GroupedGemmWgradSm100": - try: - from .grouped_gemm import GroupedGemmWgradSm100 as _GroupedGemmWgradSm100 - - return _GroupedGemmWgradSm100 - except Exception as e: - raise ImportError(f"GroupedGemmWgradSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "grouped_gemm_wgrad_wrapper_sm100": - try: - from .grouped_gemm import ( - grouped_gemm_wgrad_wrapper_sm100 as _grouped_gemm_wgrad_wrapper_sm100, - ) - - return _grouped_gemm_wgrad_wrapper_sm100 - except Exception as e: - raise ImportError( - f"grouped_gemm_wgrad_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - # Discrete-weight Grouped GEMM GLU module - elif name == "discrete_grouped_gemm": - try: - from . import discrete_grouped_gemm as _discrete_grouped_gemm - - return _discrete_grouped_gemm - except Exception as e: - raise ImportError(f"discrete_grouped_gemm requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e - - elif name == "DiscreteGroupedGemmSwigluSm100": - try: - from .discrete_grouped_gemm import ( - DiscreteGroupedGemmSwigluSm100 as _DiscreteGroupedGemmSwigluSm100, - ) - - return _DiscreteGroupedGemmSwigluSm100 - except Exception as e: - raise ImportError( - f"DiscreteGroupedGemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "discrete_grouped_gemm_swiglu_wrapper_sm100": - try: - from .discrete_grouped_gemm import ( - discrete_grouped_gemm_swiglu_wrapper_sm100 as _discrete_grouped_gemm_swiglu_wrapper_sm100, - ) - - return _discrete_grouped_gemm_swiglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"discrete_grouped_gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "DiscreteGroupedGemmDswigluSm100": - try: - from .discrete_grouped_gemm import ( - DiscreteGroupedGemmDswigluSm100 as _DiscreteGroupedGemmDswigluSm100, - ) - - return _DiscreteGroupedGemmDswigluSm100 - except Exception as e: - raise ImportError( - f"DiscreteGroupedGemmDswigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "discrete_grouped_gemm_dswiglu_wrapper_sm100": - try: - from .discrete_grouped_gemm import ( - discrete_grouped_gemm_dswiglu_wrapper_sm100 as _discrete_grouped_gemm_dswiglu_wrapper_sm100, - ) - - return _discrete_grouped_gemm_dswiglu_wrapper_sm100 - except Exception as e: - raise ImportError( - f"discrete_grouped_gemm_dswiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e - - elif name == "experimental": +def __getattr__(name: str) -> Any: + if name == "ops": + # Use importlib rather than "from . import ops" to avoid infinite + # recursion. The cycle: + # 1. cudnn.ops accessed → __getattr__("ops") fires + # 2. "from . import ops" → _handle_fromlist(cudnn, ["ops"], ...) + # 3. _handle_fromlist calls hasattr(cudnn, "ops") + # 4. "ops" not in __dict__ yet → __getattr__("ops") again → goto 1 + # importlib.import_module bypasses _handle_fromlist entirely. + _ops = importlib.import_module(".ops", __name__) + globals()["ops"] = _ops + return _ops + + if name == "experimental": from . import experimental as _experimental globals()["experimental"] = _experimental return _experimental - else: - raise AttributeError(name) + + if name in _LAZY_OPTIONAL_IMPORTS: + return _load_optional_symbol(name) + + raise AttributeError(name) diff --git a/python/cudnn/api_base.py b/python/cudnn/api_base.py index 82fb161f..e60a52e1 100644 --- a/python/cudnn/api_base.py +++ b/python/cudnn/api_base.py @@ -14,6 +14,7 @@ from dataclasses import dataclass, field from typing import Any, List, Tuple, Optional import logging +import threading import cuda.bindings.driver as cuda import cutlass import torch @@ -31,6 +32,26 @@ def is_power_of_2(n: int) -> bool: return n > 0 and (n & (n - 1)) == 0 +_experimental_api_warnings_emitted = set() +_experimental_api_warnings_lock = threading.Lock() + + +def warn_experimental_api_once(logger: logging.Logger, api_name: str) -> None: + """Emit the experimental API warning once per API class per process.""" + with _experimental_api_warnings_lock: + if api_name in _experimental_api_warnings_emitted: + return + _experimental_api_warnings_emitted.add(api_name) + + logger.warning("%s is an experimental API", api_name) + + +def _reset_experimental_api_warning_registry() -> None: + """Reset experimental API warning state for tests.""" + with _experimental_api_warnings_lock: + _experimental_api_warnings_emitted.clear() + + @dataclass(frozen=True) class TensorDesc: """Metadata needed to validate/compile tensor signatures without storage.""" @@ -40,6 +61,7 @@ class TensorDesc: stride: Tuple[int, ...] stride_order: Tuple[int, ...] device: torch.device + interpret_uint8_as_fp4x2: bool = False ndim: int = field(init=False) name: str = "" @@ -156,6 +178,7 @@ def _with_layout(self, shape: Tuple[int, ...], stride: Tuple[int, ...]) -> "Tens stride=stride, stride_order=self._compute_stride_order(shape, stride), device=self.device, + interpret_uint8_as_fp4x2=self.interpret_uint8_as_fp4x2, name=self.name, ) @@ -363,6 +386,9 @@ def __init__(self): self._interpret_uint8_as_fp4x2 = False self._logger = logging.getLogger(self.__class__.__name__) + def _warn_experimental_api(self) -> None: + warn_experimental_api_once(self._logger, self.__class__.__name__) + @abstractmethod def check_support(self) -> bool: """Check if the current configuration is supported by the kernel. @@ -571,8 +597,16 @@ def _is_fp4x2(self, tensor_or_dtype: torch.Tensor | torch.dtype | TensorDesc) -> """ if tensor_or_dtype is None: return False - dtype = tensor_or_dtype.dtype if isinstance(tensor_or_dtype, (torch.Tensor, TensorDesc)) else tensor_or_dtype - return (dtype == torch.float4_e2m1fn_x2) or (self._interpret_uint8_as_fp4x2 and dtype == torch.uint8) + if isinstance(tensor_or_dtype, TensorDesc): + dtype = tensor_or_dtype.dtype + interpret_uint8_as_fp4x2 = tensor_or_dtype.interpret_uint8_as_fp4x2 + elif isinstance(tensor_or_dtype, torch.Tensor): + dtype = tensor_or_dtype.dtype + interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2 + else: + dtype = tensor_or_dtype + interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2 + return (dtype == torch.float4_e2m1fn_x2) or (interpret_uint8_as_fp4x2 and dtype == torch.uint8) def _is_fp8(self, tensor_or_dtype: torch.Tensor | torch.dtype | TensorDesc) -> bool: """Check if tensor or dtype is an FP8 datatype. @@ -875,12 +909,24 @@ def _make_fake_cute_tensor_like( assumed_align=assumed_align, ) - def _make_tensor_desc(self, tensor: Optional[torch.Tensor], name: str = "") -> Optional[TensorDesc]: + def _make_tensor_desc( + self, + tensor: Optional[torch.Tensor], + name: str = "", + interpret_uint8_as_fp4x2: Optional[bool] = None, + ) -> Optional[TensorDesc]: """Capture logical tensor metadata that is sufficient for validation/compile.""" if tensor is None: return None - tensor_shape = self._tensor_shape(tensor, name=name) - tensor_stride = self._tensor_stride(tensor, name=name) + if interpret_uint8_as_fp4x2 is None: + interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2 + prev_interpret = self._interpret_uint8_as_fp4x2 + self._interpret_uint8_as_fp4x2 = interpret_uint8_as_fp4x2 + try: + tensor_shape = self._tensor_shape(tensor, name=name) + tensor_stride = self._tensor_stride(tensor, name=name) + finally: + self._interpret_uint8_as_fp4x2 = prev_interpret tensor_stride_order = tuple(i for i, s in sorted(enumerate(tensor_stride), key=lambda x: (x[1], tensor_shape[x[0]]))) return TensorDesc( dtype=tensor.dtype, @@ -888,6 +934,7 @@ def _make_tensor_desc(self, tensor: Optional[torch.Tensor], name: str = "") -> O stride=tensor_stride, stride_order=tensor_stride_order, device=tensor.device, + interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2, name=name, ) @@ -904,6 +951,7 @@ def _make_fake_cute_tensor_from_desc( shape=tensor_desc.shape, stride=tensor_desc.stride, assumed_align=assumed_align, + interpret_uint8_as_fp4x2=tensor_desc.interpret_uint8_as_fp4x2, ) def _make_fake_cute_tensor( @@ -912,6 +960,7 @@ def _make_fake_cute_tensor( shape: Tuple[int, ...], stride: Tuple[int, ...], assumed_align: int = 16, + interpret_uint8_as_fp4x2: Optional[bool] = None, ) -> cute.Pointer: """Make a fake tensor. @@ -926,8 +975,10 @@ def _make_fake_cute_tensor( :return: A fake tensor :rtype: cute.Pointer """ + if interpret_uint8_as_fp4x2 is None: + interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2 return cute.runtime.make_fake_tensor( - dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), + dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2), shape=shape, stride=stride, assumed_align=assumed_align, @@ -941,6 +992,7 @@ def _make_fake_cute_compact_tensor( assumed_align: int = 16, dynamic_mode: Optional[int] = None, divisibility: int = 16, + interpret_uint8_as_fp4x2: Optional[bool] = None, ) -> cute.Pointer: """Make a fake compact tensor. :param dtype: The dtype of the tensor @@ -958,8 +1010,10 @@ def _make_fake_cute_compact_tensor( dynamic_dim = cute.sym_int(divisibility=divisibility) shape = shape[:dynamic_mode] + (dynamic_dim,) + shape[dynamic_mode + 1 :] + if interpret_uint8_as_fp4x2 is None: + interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2 return cute.runtime.make_fake_compact_tensor( - dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), + dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2), shape=shape, stride_order=stride_order, assumed_align=assumed_align, diff --git a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py index 4c110e17..dc52c846 100644 --- a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py +++ b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py @@ -133,7 +133,7 @@ def __init__( """ super().__init__() - self._logger.warning("DiscreteGroupedGemmDswigluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self._value_error_if(num_experts == 0, "num_experts must be > 0") @@ -196,7 +196,7 @@ def __init__( self._kernel = BlockScaledDiscreteWeightDgluDbiasGroupedGemmKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._workspace = None self._logger.debug("__init__ completed") diff --git a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py index d3e736ca..a3d3c0fd 100644 --- a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py +++ b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py @@ -141,7 +141,7 @@ def __init__( """ super().__init__() - self._logger.warning("DiscreteGroupedGemmSwigluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self._value_error_if(num_experts == 0, "num_experts must be > 0") @@ -194,7 +194,7 @@ def __init__( self._kernel = BlockScaledDiscreteWeightGroupedGemmBiasKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._workspace = None diff --git a/python/cudnn/gemm_amax/api.py b/python/cudnn/gemm_amax/api.py index eea7d816..62fc475b 100644 --- a/python/cudnn/gemm_amax/api.py +++ b/python/cudnn/gemm_amax/api.py @@ -31,7 +31,7 @@ def __init__( ): super().__init__() - self._logger.warning("GemmAmaxSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self.a_desc = self._make_tensor_desc(sample_a, name="sample_a") @@ -47,7 +47,7 @@ def __init__( self.sf_vec_size = sf_vec_size self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") # used to reshape sfa/sfb tensors to atom layout self.atom_m = (32, 4) diff --git a/python/cudnn/gemm_dsrelu/__init__.py b/python/cudnn/gemm_dsrelu/__init__.py new file mode 100644 index 00000000..ade25458 --- /dev/null +++ b/python/cudnn/gemm_dsrelu/__init__.py @@ -0,0 +1,9 @@ +from .api import ( + GemmDsreluSm100, + gemm_dsrelu_wrapper_sm100, +) + +__all__ = [ + "GemmDsreluSm100", + "gemm_dsrelu_wrapper_sm100", +] diff --git a/python/cudnn/gemm_dsrelu/api.py b/python/cudnn/gemm_dsrelu/api.py new file mode 100644 index 00000000..fb02fa02 --- /dev/null +++ b/python/cudnn/gemm_dsrelu/api.py @@ -0,0 +1,612 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import logging +import os +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cuda.bindings import driver as cuda +from cutlass.cute.runtime import make_fake_stream + +from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2 +from cudnn.datatypes import _convert_to_cutlass_data_type + +from .dense_blockscaled_gemm_persistent_dsrelu_quant import ( + Sm100BlockScaledPersistentDenseGemmKernel, +) + + +def _major_from_stride_order(stride_order: Tuple[int, ...], mode0_label: str, mode1_label: str) -> str: + if stride_order == (0, 1, 2): + return mode0_label + if stride_order == (1, 0, 2): + return mode1_label + raise ValueError(f"Unsupported stride order {stride_order}") + + +class GemmDsreluSm100(APIBase): + def __init__( + self, + sample_a: torch.Tensor, + sample_b: torch.Tensor, + sample_c: torch.Tensor, + sample_d: torch.Tensor, + sample_dprob: torch.Tensor, + sample_sfa: torch.Tensor, + sample_sfb: torch.Tensor, + sample_prob: torch.Tensor, + sample_sfd: Optional[torch.Tensor] = None, + sample_amax: Optional[torch.Tensor] = None, + sample_norm_const: Optional[torch.Tensor] = None, + alpha: float = 1.0, + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + ): + super().__init__() + + self._warn_experimental_api() + + self.a_desc = self._make_tensor_desc(sample_a, name="sample_a") + self.b_desc = self._make_tensor_desc(sample_b, name="sample_b") + self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") + self.d_desc = self._make_tensor_desc(sample_d, name="sample_d") + self.dprob_desc = self._make_tensor_desc(sample_dprob, name="sample_dprob") + self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa") + self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") + self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") + self.sfd_desc = self._make_tensor_desc(sample_sfd, name="sample_sfd") + self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "amax") + self.norm_const_desc = self._unpad_tensor_to_ndim( + self._make_tensor_desc(sample_norm_const, name="sample_norm_const"), + 1, + "norm_const", + ) + + self.alpha = alpha + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if mma_tiler_mn[0] == 256 else (1, 1)) + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) + + self._interpret_uint8_as_fp4x2 = True + self._kernel = Sm100BlockScaledPersistentDenseGemmKernel + + def check_support(self) -> bool: + m, k, l = self._tensor_shape(self.a_desc, name="sample_a") + n, b_k, b_l = self._tensor_shape(self.b_desc, name="sample_b") + + self._value_error_if((b_k, b_l) != (k, l), f"B shape mismatch: expected (*, {k}, {l}), got {(n, b_k, b_l)}") + self._check_tensor_shape(self.c_desc, (m, n, l), "C") + self._check_tensor_shape(self.d_desc, (m, n, l), "D") + self._check_tensor_shape(self.prob_desc, (m, 1, l), "prob") + self._check_tensor_shape(self.dprob_desc, (m, 1, l), "dprob") + + rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA") + self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") + + if self.sfd_desc is not None: + rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfd_desc, (32, 4, ceil_div(m, 128), 4, rest_n, l), "SFD") + + self._check_tensor_shape(self.amax_desc, (1,), "amax") + self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const") + + self.ab_dtype = self._check_dtype( + self.a_desc, + dtype=[torch.float4_e2m1fn_x2, torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2], + name="A", + ) + self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="A and B must have the same dtype") + self.c_dtype = self._check_dtype( + self.c_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32], + name="C", + ) + self.d_dtype = self._check_dtype( + self.d_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2], + name="D", + ) + self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob") + self._check_dtype(self.dprob_desc, dtype=torch.float32, name="dprob") + + self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA") + self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must have the same dtype as SFA") + self._check_dtype(self.sfd_desc, dtype=self.sf_dtype, name="SFD", extra_error_msg="SFD must have the same dtype as SFA") + + self._check_dtype(self.acc_dtype, dtype=torch.float32, name="Accumulator") + + self._value_error_if(self.sf_vec_size not in {16, 32}, f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}") + self._value_error_if( + self._is_fp8(self.d_desc) and (self.sfd_desc is None or self.norm_const_desc is None), "sfd and norm_const are required when D is FP8" + ) + self._value_error_if( + self._is_fp4x2(self.ab_dtype) and self.d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}, "FP4 input with FP8 output is not supported" + ) + + a_major = _major_from_stride_order(self.a_desc.stride_order, "m", "k") + b_major = _major_from_stride_order(self.b_desc.stride_order, "n", "k") + c_major = _major_from_stride_order(self.c_desc.stride_order, "m", "n") + d_major = _major_from_stride_order(self.d_desc.stride_order, "m", "n") + self._value_error_if(c_major != d_major, f"C and D must share the same layout, got {c_major} and {d_major}") + + self._value_error_if( + self.mma_tiler_mn[0] not in {128, 256} or self.mma_tiler_mn[1] not in {64, 128, 192, 256}, + f"Unsupported mma_tiler_mn {self.mma_tiler_mn}", + ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape {self.cluster_shape_mn}", + ) + + self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available") + major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + self._runtime_error_if(major * 10 + minor < 100, f"GemmDsreluSm100 requires SM100+, found SM{major}{minor}") + + self._value_error_if( + not self._kernel.can_implement( + ab_dtype=_convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), + sf_dtype=_convert_to_cutlass_data_type(self.sf_dtype), + sf_vec_size=self.sf_vec_size, + d_dtype=_convert_to_cutlass_data_type(self.d_dtype), + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + m=m, + n=n, + k=k, + l=l, + a_major=a_major, + b_major=b_major, + d_major=d_major, + ), + "Unsupported configuration for dense dsReLU kernel", + ) + + self._is_supported = True + return True + + def compile(self) -> None: + self._ensure_support_checked() + if self._compiled_kernel is not None: + return + + gemm = self._kernel( + sf_vec_size=self.sf_vec_size, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vector_f32=self.vector_f32, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + max_active_clusters -= self.num_cluster_overlap_margin + self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin") + + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + epilogue_op = lambda x, y: cute.where(x > 0, x, cute.full_like(x, 0)) * 2 * y + use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None + use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None + + if use_dynamic_m: + valid_m = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride_order=self.a_desc.stride_order, + ) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride_order=self.c_desc.stride_order, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, *self.d_desc.shape[1:]), + stride_order=self.d_desc.stride_order, + ) + prob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride_order=self.prob_desc.stride_order, + ) + dprob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, *self.dprob_desc.shape[1:]), + stride_order=self.dprob_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], self.sfa_desc.shape[5]), + stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), + ) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + + sfd_cute_fake = None + if self.sfd_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfd_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_desc.shape[4], self.sfd_desc.shape[5]), + stride=(16, 4, self.sfd_desc.stride[2], 1, 512, stride_sfd_m), + ) + elif use_full_dynamic: + valid_m = cute.sym_int() + n_sym = cute.sym_int() + k_sym = cute.sym_int() + l_sym = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, k_sym, l_sym), + stride_order=self.a_desc.stride_order, + dynamic_mode=self.a_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + b_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.b_desc.dtype, + shape=(n_sym, k_sym, l_sym), + stride_order=self.b_desc.stride_order, + dynamic_mode=self.b_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, n_sym, l_sym), + stride_order=self.c_desc.stride_order, + dynamic_mode=self.c_desc.stride_order[0], + divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, n_sym, l_sym), + stride_order=self.d_desc.stride_order, + dynamic_mode=self.d_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_desc.dtype) else 16, + ) + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:-1], l_sym), + stride=(1, 1, valid_m), + ) + dprob_cute_fake = self._make_fake_cute_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, *self.dprob_desc.shape[1:-1], l_sym), + stride=(l_sym, l_sym, 1), + ) + + tensor_m_128 = cute.sym_int() + rest_k = cute.sym_int() + stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_shape[4] = rest_k + sfa_shape[5] = l_sym + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[2] = stride_rest_k + sfa_stride[5] = stride_tensor_m_128 + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + ) + + tensor_n_128 = cute.sym_int() + stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfb_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfb_desc.dtype, + shape=(32, 4, tensor_n_128, 4, rest_k, l_sym), + stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k), + ) + + sfd_cute_fake = None + if self.sfd_desc is not None: + rest_n = cute.sym_int() + stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfd_shape = list(self.sfd_desc.shape) + sfd_shape[2] = tensor_m_128 + sfd_shape[4] = rest_n + sfd_shape[5] = l_sym + sfd_stride = list(self.sfd_desc.stride) + sfd_stride[2] = stride_sfd_rest_n + sfd_stride[5] = stride_sfd_tensor_m_128 + sfd_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfd_desc.dtype, + shape=tuple(sfd_shape), + stride=tuple(sfd_stride), + ) + else: + a_cute_fake = self._make_fake_cute_tensor_from_desc(self.a_desc, assumed_align=16) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_tensor_from_desc(self.c_desc, assumed_align=16) + d_cute_fake = self._make_fake_cute_tensor_from_desc(self.d_desc, assumed_align=16) + prob_cute_fake = self._make_fake_cute_tensor_from_desc(self.prob_desc, assumed_align=16) + dprob_cute_fake = self._make_fake_cute_tensor_from_desc(self.dprob_desc, assumed_align=16) + sfa_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfa_desc, assumed_align=16) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + sfd_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfd_desc, assumed_align=16) + + compiled = cute.compile( + gemm, + a_tensor=a_cute_fake, + b_tensor=b_cute_fake, + sfa_tensor=sfa_cute_fake, + sfb_tensor=sfb_cute_fake, + c_tensor=c_cute_fake, + d_tensor=d_cute_fake, + prob_tensor=prob_cute_fake, + dprob_tensor=dprob_cute_fake, + amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16), + sfd_tensor=sfd_cute_fake, + norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16), + alpha=self.alpha, + max_active_clusters=max_active_clusters, + stream=fake_stream, + epilogue_op=epilogue_op, + options="--enable-tvm-ffi", + ) + + def tensor_api( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + dprob_tensor: torch.Tensor, + amax_tensor: Optional[torch.Tensor], + sfd_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + alpha: float, + stream: cuda.CUstream, + ) -> None: + compiled( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + d_tensor, + prob_tensor, + dprob_tensor, + self._unpad_tensor_to_ndim(amax_tensor, 1, "amax"), + sfd_tensor, + self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const"), + alpha, + stream, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + dprob_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + sfd_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + self._runtime_error_if(self._compiled_kernel is None, "GemmDsreluSm100 kernel not compiled; call compile() first") + current_stream = self._get_default_stream(current_stream) + self._compiled_kernel( + a_tensor=a_tensor, + b_tensor=b_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + amax_tensor=amax_tensor, + sfd_tensor=sfd_tensor, + norm_const_tensor=norm_const_tensor, + alpha=self.alpha if alpha is None else alpha, + stream=current_stream, + ) + + +_logger = logging.getLogger(__name__) +_cache_of_GemmDsreluSm100Objects = {} +_DENSE_GEMM_DYNAMIC_M_ENV = "CUDNN_FE_GEMM_DYNAMIC_M" +_DENSE_GEMM_DYNAMIC_MNKL_ENV = "CUDNN_FE_GEMM_DYNAMIC_MNKL" + + +def _allocate_dense_output(shape: Tuple[int, int, int], major: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + m, n, l = shape + if major == "m": + return torch.empty_strided((m, n, l), (1, m, m * n), dtype=dtype, device=device) + if major == "n": + return torch.empty_strided((m, n, l), (n, 1, m * n), dtype=dtype, device=device) + raise ValueError(f"major must be 'm' or 'n', got {major}") + + +def gemm_dsrelu_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + alpha: float = 1.0, + d_major: str = "n", + d_dtype: torch.dtype = torch.bfloat16, + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + m, k, l = a_tensor.shape + n, _, _ = b_tensor.shape + + d_tensor = _allocate_dense_output((m, n, l), d_major, d_dtype, a_tensor.device) + dprob_tensor = torch.zeros((m, 1, l), dtype=torch.float32, device=a_tensor.device) + + sfd_tensor = None + if d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + sf_k = ceil_div(n, sf_vec_size) + mma_shape = ( + l, + ceil_div(m, 128), + ceil_div(sf_k, 4), + 32, + 4, + 4, + ) + sfd_tensor = torch.empty(mma_shape, dtype=sfa_tensor.dtype, device=a_tensor.device).permute(3, 4, 1, 5, 2, 0) + + amax_tensor = None + if a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} and d_dtype in {torch.bfloat16, torch.float16, torch.float32}: + amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device=a_tensor.device) + + use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None + use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None + + def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: + return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) + + def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype + + def dynamic_compact_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape[1:]), stride_order(tensor), tensor.dtype + + def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return None, stride_order(tensor), tensor.dtype + + def dynamic_m_tensor_signature( + tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = () + ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) + return static_shape_suffix, stride_signature, tensor.dtype + + cache_key = ( + use_full_dynamic, + use_dynamic_m, + *(dynamic_tensor_signature(a_tensor) if use_full_dynamic else dynamic_compact_signature(a_tensor) if use_dynamic_m else tensor_signature(a_tensor)), + *(dynamic_tensor_signature(b_tensor) if use_full_dynamic else tensor_signature(b_tensor)), + *(dynamic_tensor_signature(c_tensor) if use_full_dynamic else dynamic_compact_signature(c_tensor) if use_dynamic_m else tensor_signature(c_tensor)), + *(dynamic_tensor_signature(d_tensor) if use_full_dynamic else dynamic_compact_signature(d_tensor) if use_dynamic_m else tensor_signature(d_tensor)), + *( + dynamic_tensor_signature(dprob_tensor) + if use_full_dynamic + else dynamic_compact_signature(dprob_tensor) if use_dynamic_m else tensor_signature(dprob_tensor) + ), + d_dtype, + *( + dynamic_tensor_signature(sfa_tensor) + if use_full_dynamic + else ( + dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], sfa_tensor.shape[5]), dynamic_stride_dims=(5,)) + if use_dynamic_m + else tensor_signature(sfa_tensor) + ) + ), + *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)), + *( + dynamic_tensor_signature(prob_tensor) + if use_full_dynamic + else dynamic_compact_signature(prob_tensor) if use_dynamic_m else tensor_signature(prob_tensor) + ), + norm_const_tensor.shape if norm_const_tensor is not None else None, + norm_const_tensor.stride() if norm_const_tensor is not None else None, + norm_const_tensor.dtype if norm_const_tensor is not None else None, + alpha, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + d_major, + sf_vec_size, + vector_f32, + ) + + op = _cache_of_GemmDsreluSm100Objects.get(cache_key) + if op is None: + op = GemmDsreluSm100( + sample_a=a_tensor, + sample_b=b_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_dprob=dprob_tensor, + sample_sfa=sfa_tensor, + sample_sfb=sfb_tensor, + sample_prob=prob_tensor, + sample_sfd=sfd_tensor, + sample_amax=amax_tensor, + sample_norm_const=norm_const_tensor, + alpha=alpha, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + ) + assert op.check_support(), "Unsupported testcase" + op.compile() + _cache_of_GemmDsreluSm100Objects[cache_key] = op + + op.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + dprob_tensor=dprob_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + prob_tensor=prob_tensor, + sfd_tensor=sfd_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + alpha=alpha, + current_stream=stream, + ) + + return TupleDict( + d_tensor=d_tensor, + dprob_tensor=dprob_tensor, + amax_tensor=amax_tensor, + sfd_tensor=sfd_tensor, + ) diff --git a/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py b/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py new file mode 100644 index 00000000..5b197014 --- /dev/null +++ b/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py @@ -0,0 +1,2174 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils + +import cutlass.cute.math as math +from cutlass.cute.typing import Float32 +from cutlass._mlir.dialects import llvm, nvvm +from cutlass._mlir.dialects.nvvm import AtomicOpKind +from cutlass.cutlass_dsl import T + + +def atomic_add_float32( + ptr, + value: Float32, + *, + loc=None, + ip=None, +) -> Float32: + old_value = nvvm.atomicrmw( + AtomicOpKind.FADD, + ptr, + value.ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + return Float32(llvm.bitcast(T.f32(), old_value, loc=loc, ip=ip)) + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + # {$nv-internal-release begin} + # Note: We don't have SFD generation support in this example for now, so Float4E2M1FN output is only for internal testing and will not be released. + - Float4E2M1FN + # {$nv-internal-release end} + + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 64/128/192/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, d_tensor, prob_tensor, amax_tensor, sfd_tensor, norm_const_tensor, alpha, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vector_f32: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.epilog_load_tma_id = 6 + self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, self.epilog_load_tma_id, *self.epilog_warp_id)) + # Set barrier id for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + # Generate sfd output by 2xf32 + self.vector_f32 = vector_f32 + + # Amax reduction configuration + self.num_epilog_warps = len(self.epilog_warp_id) + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + self.mma_tiler_c = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_c = ( + self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_c[1], + self.mma_tiler_c[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Always use subtile (128,32) + self.epi_tile = (cute.make_layout(128), cute.make_layout(32)) + self.epi_tile_cnt = ( + self.cta_tile_shape_mnk[0] // cute.size(self.epi_tile[0]), + self.cta_tile_shape_mnk[1] // cute.size(self.epi_tile[1]), + ) + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_d_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.d_dtype, + self.d_layout, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + # Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case + self.overlapping_accum = False # self.num_acc_stage == 1 + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + d_tensor: cute.Tensor, + prob_tensor: cute.Tensor, + dprob_tensor: cute.Tensor, + amax_tensor: Optional[cute.Tensor], + sfd_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + alpha: cutlass.Float32, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x, y: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param d_tensor: Input tensor D + :type d_tensor: cute.Tensor + :param prob_tensor: Probability tensor + :type prob_tensor: cute.Tensor + :param dprob_tensor: Derivative of probability tensor + :type dprob_tensor: cute.Tensor + :param sfd_tensor: Scale factor tensor C + :type sfd_tensor: cute.Tensor + :param norm_const_tensor: Normalization constant tensor for quantization + :type norm_const_tensor: cute.Tensor + :param amax_tensor: Output tensor for absolute maximum value + :type amax_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.d_dtype: Type[cutlass.Numeric] = d_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + self.d_layout = utils.LayoutEnum.from_tensor(d_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release} + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # {$nv-internal-release begin} + # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) + # logical blocks for SFB when cta_tile_shape_n=192. + # {$nv-internal-release end} + if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + # Setup TMA load for C + epi_c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + self.tma_c_load_bytes = cute.size_in_bytes(self.c_dtype, epi_c_smem_layout) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + c_tensor, + epi_c_smem_layout, + self.epi_tile, + ) + epi_d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d_tensor, + epi_d_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk_c, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + self.generate_sfd = sfd_tensor is not None and norm_const_tensor is not None + if cutlass.const_expr(self.generate_sfd): + sfd_layout = blockscaled_utils.tile_atom_to_shape_SF(c_tensor.shape, self.sf_vec_size) + sfd_tensor = cute.make_tensor(sfd_tensor.iterator, sfd_layout) + + self.generate_amax = amax_tensor is not None + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] + c_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + cute.cosize(self.d_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + # Amax reduction shared memory (one FP32 per epilogue warp) + # Use smaller alignment for amax since it's only 16 bytes + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + prob_tensor, + dprob_tensor, + amax_tensor, + sfd_tensor, + norm_const_tensor, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + alpha, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, + mProb_mnl: cute.Tensor, + mDProb_mnl: cute.Tensor, + mAmax_tensor: Optional[cute.Tensor], + mSFD_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + alpha: cutlass.Float32, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + cpasync.prefetch_descriptor(tma_atom_d) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.c_full_mbar_ptr.data_ptr(), + num_stages=self.num_c_stage, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.tma_c_load_bytes, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, + swizzle=c_smem_layout_staged.inner, + dtype=self.c_dtype, + ) + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sD = storage.sD.get_tensor( + d_smem_layout_staged.outer, + swizzle=d_smem_layout_staged.inner, + dtype=self.d_dtype, + ) + + # Shared memory for amax reduction (one FP32 per epilogue warp) + # Simple 1D layout. The allocation always here if no amax is generated, + # as the overhead is minimal and we want to keep the code simple. + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None)) + # (bM, bN, RestM, RestN, RestL) + gD_mnl = cute.local_tile( + mD_mnl, + cute.slice_(self.mma_tiler_c, (None, None, 0)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgD = thr_mma.partition_C(gD_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + + # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release} + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # Make accumulator tmem tensor + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or cta_tile_shape_n=64 # {$nv-internal-release} + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_s2r, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + _, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD) + ( + bSG_sD, + bSG_gD_mnl, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD, epi_tile, sD) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + + # Threads/warps participating in tma store pipeline + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + d_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + # Load C pipeline + c_pipeline_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_c_stage) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gD = bSG_gD_mnl[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] + + # Initialize thread-local amax accumulator for this tile + # Use 0.0 as initial value since we're computing absolute maximum + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8)) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + # + # Get PROB + # Note, it always assumes T2R_M/EPI_M is 1, otherwise it will break the result. + # + mPosition = cur_tile_coord[0] * self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape) + tidx + mProb = mProb_mnl[mPosition, 0, cur_tile_coord[2]] + dProb = cutlass.Float32(0.0) + for subtile_idx in cutlass.range(0, subtile_cnt, 1): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_subtile = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_subtile, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_s2r.retile(tTR_rAcc) + + # + # Apply alpha + # + acc_vec_ = acc_vec.load() + acc_vec.store(acc_vec_ * alpha) + + # + # Load C from shared memory + # + c_pipeline.consumer_wait(c_pipeline_consumer_state) + cute.copy( + tiled_copy_s2r, + tRS_sC[(None, None, None, c_pipeline_consumer_state.index)], + tRS_rC, + ) + cute.arch.fence_proxy("async.shared", space="cta") + c_pipeline.consumer_release(c_pipeline_consumer_state) + c_pipeline_consumer_state.advance() + c_vec = tiled_copy_s2r.retile(tRS_rC) + + # + # Generate dSquared ReLu + # + acc_values = acc_vec.load() + c_values = c_vec.load() + dsquared_relu = epilogue_op(acc_values, c_values * mProb) + + # + # Generate dProb + # + acc_relu = cute.where(acc_values > 0, acc_values, cute.full_like(acc_values, 0)) + dProb += (acc_relu**2 * c_values).reduce( + cute.ReductionOp.ADD, + cutlass.Float32(0.0), + 0, + ) + + # Write dsquared_relu values back for store + acc_vec.store(dsquared_relu) + + # + # Generate amax + # + if cutlass.const_expr(self.generate_amax): + # Apply element-wise absolute value using math.absf (supports vectors) + dsquared_relu_ir = cutlass._mlir.dialects.math.absf(dsquared_relu.ir_value()) # operand (positional) + dsquared_relu_values = type(acc_values)(dsquared_relu_ir, acc_values.shape, acc_values.dtype) + subtile_amax = dsquared_relu_values.reduce( + cute.ReductionOp.MAX, + cutlass.Float32(0.0), + 0, # Use 0.0 as init for abs values + ) + thread_tile_amax = cute.arch.fmax(thread_tile_amax, subtile_amax) + + # + # Generate sfd + # + if cutlass.const_expr(self.generate_sfd): + cute.printf("SFD not implemented\n") + else: + # + # Convert to D type directly + # + acc_vec = tiled_copy_s2r.retile(acc_vec).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + # + # Store D to shared memory + # + d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage + cute.copy( + tiled_copy_s2r, + tRS_rD, + tRS_sD[(None, None, None, d_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store D to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_d, + bSG_sD[(None, d_buffer)], + bSG_gD[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # Perform amax reduction after all subtiles are processed + if cutlass.const_expr(self.generate_amax): + # Warp-level reduction using wrapper function + warp_amax = cute.arch.warp_redux_sync( + value=thread_tile_amax, + kind="fmax", + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + # Each epilogue warp's lane 0 writes warp amax to shared memory + if cute.arch.lane_idx() == 0: + sAmax[warp_idx] = cutlass.Float32(warp_amax) + + # Ensure all epilogue warps complete their writes before block reduction + self.epilog_sync_barrier.arrive_and_wait() + + # Block-level reduction: only first epilogue warp's lane 0 handles this + if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum + for i in cutlass.range(self.num_epilog_warps): + warp_amax_val = sAmax[i] + block_amax = cute.arch.fmax(block_amax, warp_amax_val) + + # Global atomic max (accumulates across all tiles for final tensor amax) + # Since we compute absolute values, all values are non-negative + # Use wrapper function for atomic max operation + _ = cute.arch.atomic_max_float32(ptr=mAmax_tensor.iterator.llvm_ptr, value=block_amax) + + # write dProb result to global memory + _ = atomic_add_float32( + ptr=mDProb_mnl[(mPosition, None, cur_tile_coord[2])].iterator.llvm_ptr, + value=dProb, + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + # + # Wait for C store complete + # + d_pipeline.producer_tail() + + # + # Specialized epilog load warp + # + if warp_idx == self.epilog_load_tma_id: + ## M 1024, N 512 + ## tCgC (((32,128),1),1,8,4,2,1) : (((1@0,1@1),0),0,32@0,256@1,512@0,1@2) + ## bGS_gC_mnl (((32,128),1),1,8,EPI_M,EPI_N,L) : (((1@0,1@1),0),0,32@0,256@1,512@0,1@2) + ## bGS_sC ((4096, 1), (1, 4)) : ((1, 0), (0, 4096)) + ( + bGS_sC, + bGS_gC_mnl, + ) = self.epilog_gmem_copy_and_partition(tidx, tma_atom_c, tCgC, epi_tile, sC) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + c_pipeline_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_c_stage) + is_reverse = True + while work_tile.is_valid_tile: + # if it needs to be reversed + if cutlass.const_expr(self.overlapping_accum): + reverse_subtile = is_reverse + is_reverse = not is_reverse + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + bGS_gC = bGS_gC_mnl[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC)) + subtile_cnt = cute.size(bGS_gC.shape, mode=[1]) + for subtile_idx in cutlass.range(subtile_cnt): + # Check real subtile index + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + # Subtile always iterates on N dimension as we only have 4x1DP tmem load pattern for cta_tile_m = 128 cases. # {$nv-internal-release} + real_subtile_idx = subtile_cnt - 1 - subtile_idx + # Load C from global memory to shared memory using TMALDG + c_pipeline.producer_acquire(c_pipeline_producer_state) + cute.copy( + tma_atom_c, + bGS_gC[(None, real_subtile_idx)], + bGS_sC[(None, c_pipeline_producer_state.index)], + tma_bar_ptr=c_pipeline.producer_get_barrier(c_pipeline_producer_state), + ) + c_pipeline_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait C buffer empty + # + c_pipeline.producer_tail(c_pipeline_producer_state) + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gD_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gD_mnl: The global tensor C + :type gD_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The partitioned accumulator tensor for acc up + - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gD = thr_copy_t2r.partition_D(gD_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor(tTR_gD[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_s2r, tRS_rC, tRS_sC) where: + - tiled_copy_s2r: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_s2r.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_s2r.retile(tTR_rC) + return tiled_copy_s2r, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(gC_epi, 0, 2), + ) + return bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + d_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param d_dtype: Data type of operand D. + :type d_dtype: type[cutlass.Numeric] + :param d_layout: Layout enum of operand D. + :type d_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 # mma_tiler_mnk[1] // cute.cosize(epi_tile[1]) + num_d_stage = num_c_stage + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + d_dtype, + d_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + amax_bytes = Sm100BlockScaledPersistentDenseGemmKernel.get_amax_smem_size() + epi_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage * num_d_stage + amax_bytes + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)) // ab_bytes_per_stage + + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) + + return tile_sched_params, grid + + @staticmethod + def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float: + """ + Calculates the reciprocal of the maximum absolute value for a given data type. + + :param dtype: Data type + :type dtype: Type[cutlass.Numeric] + + :return: An float representing the reciprocal of the maximum absolute value + :rtype: float + """ + if dtype == cutlass.Float4E2M1FN: + return 1 / 6.0 + if dtype == cutlass.Float8E4M3FN: + return 1 / 448.0 + if dtype == cutlass.Float8E5M2: + return 1 / 128.0 + return 1.0 + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if d_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + cutlass.Float4E2M1FN, # {$nv-internal-release} + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # {$nv-internal-release begin} + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + # TODO: Currently we don't support m major output for Float4E2M1FN + if c_dtype is cutlass.Float4E2M1FN and c_major == "m": + is_valid = False + # {$nv-internal-release end} + + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [64, 128, 192, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param d_major: The major axis of the C tensor + :type d_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, d_dtype): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(ab_dtype, d_dtype, a_major, b_major, d_major): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major): + can_implement = False + return can_implement + + @staticmethod + def get_amax_smem_size(): + # Note: 4 is hardcoded for num_epilog_warps + return 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) diff --git a/python/cudnn/gemm_srelu/__init__.py b/python/cudnn/gemm_srelu/__init__.py new file mode 100644 index 00000000..29324f36 --- /dev/null +++ b/python/cudnn/gemm_srelu/__init__.py @@ -0,0 +1,9 @@ +from .api import ( + GemmSreluSm100, + gemm_srelu_wrapper_sm100, +) + +__all__ = [ + "GemmSreluSm100", + "gemm_srelu_wrapper_sm100", +] diff --git a/python/cudnn/gemm_srelu/api.py b/python/cudnn/gemm_srelu/api.py new file mode 100644 index 00000000..3bf21f95 --- /dev/null +++ b/python/cudnn/gemm_srelu/api.py @@ -0,0 +1,587 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import logging +import os +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cuda.bindings import driver as cuda +from cutlass.cute.runtime import make_fake_stream + +from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2 +from cudnn.datatypes import _convert_to_cutlass_data_type + +from .dense_blockscaled_gemm_persistent_srelu_quant import ( + Sm100BlockScaledPersistentDenseGemmKernel, +) + + +def _major_from_stride_order(stride_order: Tuple[int, ...], mode0_label: str, mode1_label: str) -> str: + if stride_order == (0, 1, 2): + return mode0_label + if stride_order == (1, 0, 2): + return mode1_label + raise ValueError(f"Unsupported stride order {stride_order}") + + +class GemmSreluSm100(APIBase): + def __init__( + self, + sample_a: torch.Tensor, + sample_b: torch.Tensor, + sample_c: torch.Tensor, + sample_d: torch.Tensor, + sample_sfa: torch.Tensor, + sample_sfb: torch.Tensor, + sample_prob: torch.Tensor, + sample_sfd: Optional[torch.Tensor] = None, + sample_amax: Optional[torch.Tensor] = None, + sample_norm_const: Optional[torch.Tensor] = None, + alpha: float = 1.0, + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + ): + super().__init__() + + self._warn_experimental_api() + + self.a_desc = self._make_tensor_desc(sample_a, name="sample_a") + self.b_desc = self._make_tensor_desc(sample_b, name="sample_b") + self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") + self.d_desc = self._make_tensor_desc(sample_d, name="sample_d") + self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa") + self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") + self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") + self.sfd_desc = self._make_tensor_desc(sample_sfd, name="sample_sfd") + self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "amax") + self.norm_const_desc = self._unpad_tensor_to_ndim( + self._make_tensor_desc(sample_norm_const, name="sample_norm_const"), + 1, + "norm_const", + ) + + self.alpha = alpha + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if mma_tiler_mn[0] == 256 else (1, 1)) + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) + + self._interpret_uint8_as_fp4x2 = True + self._kernel = Sm100BlockScaledPersistentDenseGemmKernel + + def check_support(self) -> bool: + m, k, l = self._tensor_shape(self.a_desc, name="sample_a") + n, b_k, b_l = self._tensor_shape(self.b_desc, name="sample_b") + c_m, c_n, c_l = self._tensor_shape(self.c_desc, name="sample_c") + d_m, d_n, d_l = self._tensor_shape(self.d_desc, name="sample_d") + + self._value_error_if((b_k, b_l) != (k, l), f"B shape mismatch: expected (*, {k}, {l}), got {(n, b_k, b_l)}") + self._check_tensor_shape(self.c_desc, (m, n, l), "C") + self._check_tensor_shape(self.d_desc, (m, n, l), "D") + self._check_tensor_shape(self.prob_desc, (m, 1, l), "prob") + + rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA") + self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") + + if self.sfd_desc is not None: + rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfd_desc, (32, 4, ceil_div(m, 128), 4, rest_n, l), "SFD") + + self._check_tensor_shape(self.amax_desc, (1,), "amax") + self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const") + + self.ab_dtype = self._check_dtype( + self.a_desc, + dtype=[torch.float4_e2m1fn_x2, torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2], + name="A", + ) + self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="A and B must have the same dtype") + self.c_dtype = self._check_dtype( + self.c_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2], + name="C", + ) + self.d_dtype = self._check_dtype( + self.d_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2], + name="D", + ) + self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob") + + self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA") + self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must have the same dtype as SFA") + self._check_dtype(self.sfd_desc, dtype=self.sf_dtype, name="SFD", extra_error_msg="SFD must have the same dtype as SFA") + + self._check_dtype(self.acc_dtype, dtype=torch.float32, name="Accumulator") + + self._value_error_if(self.sf_vec_size not in {16, 32}, f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}") + self._value_error_if( + self._is_fp8(self.d_desc) and (self.sfd_desc is None or self.norm_const_desc is None), "sfd and norm_const are required when D is FP8" + ) + self._value_error_if( + self._is_fp4x2(self.ab_dtype) and self.d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}, "FP4 input with FP8 output is not supported" + ) + + a_major = _major_from_stride_order(self.a_desc.stride_order, "m", "k") + b_major = _major_from_stride_order(self.b_desc.stride_order, "n", "k") + c_major = _major_from_stride_order(self.c_desc.stride_order, "m", "n") + d_major = _major_from_stride_order(self.d_desc.stride_order, "m", "n") + self._value_error_if(c_major != d_major, f"C and D must share the same layout, got {c_major} and {d_major}") + + self._value_error_if( + self.mma_tiler_mn[0] not in {128, 256} or self.mma_tiler_mn[1] not in {64, 128, 192, 256}, + f"Unsupported mma_tiler_mn {self.mma_tiler_mn}", + ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape {self.cluster_shape_mn}", + ) + + self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available") + major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + self._runtime_error_if(major * 10 + minor < 100, f"GemmSreluSm100 requires SM100+, found SM{major}{minor}") + + self._value_error_if( + not self._kernel.can_implement( + ab_dtype=_convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), + sf_dtype=_convert_to_cutlass_data_type(self.sf_dtype), + sf_vec_size=self.sf_vec_size, + d_dtype=_convert_to_cutlass_data_type(self.d_dtype), + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + m=m, + n=n, + k=k, + l=l, + a_major=a_major, + b_major=b_major, + d_major=d_major, + ), + "Unsupported configuration for dense sReLU kernel", + ) + + self._is_supported = True + return True + + def compile(self) -> None: + self._ensure_support_checked() + if self._compiled_kernel is not None: + return + + gemm = self._kernel( + sf_vec_size=self.sf_vec_size, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vector_f32=self.vector_f32, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + max_active_clusters -= self.num_cluster_overlap_margin + self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin") + + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) ** 2 + use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None + use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None + + if use_dynamic_m: + valid_m = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride_order=self.a_desc.stride_order, + ) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride_order=self.c_desc.stride_order, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, *self.d_desc.shape[1:]), + stride_order=self.d_desc.stride_order, + ) + prob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride_order=self.prob_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], self.sfa_desc.shape[5]), + stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), + ) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + + sfd_cute_fake = None + if self.sfd_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfd_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_desc.shape[4], self.sfd_desc.shape[5]), + stride=(16, 4, self.sfd_desc.stride[2], 1, 512, stride_sfd_m), + ) + elif use_full_dynamic: + valid_m = cute.sym_int() + n_sym = cute.sym_int() + k_sym = cute.sym_int() + l_sym = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, k_sym, l_sym), + stride_order=self.a_desc.stride_order, + dynamic_mode=self.a_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + b_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.b_desc.dtype, + shape=(n_sym, k_sym, l_sym), + stride_order=self.b_desc.stride_order, + dynamic_mode=self.b_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, n_sym, l_sym), + stride_order=self.c_desc.stride_order, + dynamic_mode=self.c_desc.stride_order[0], + divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, n_sym, l_sym), + stride_order=self.d_desc.stride_order, + dynamic_mode=self.d_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_desc.dtype) else 16, + ) + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:-1], l_sym), + stride=(1, 1, valid_m), + ) + + tensor_m_128 = cute.sym_int() + rest_k = cute.sym_int() + stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_shape[4] = rest_k + sfa_shape[5] = l_sym + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[2] = stride_rest_k + sfa_stride[5] = stride_tensor_m_128 + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + ) + + tensor_n_128 = cute.sym_int() + stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfb_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfb_desc.dtype, + shape=(32, 4, tensor_n_128, 4, rest_k, l_sym), + stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k), + ) + + sfd_cute_fake = None + if self.sfd_desc is not None: + rest_n = cute.sym_int() + stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfd_shape = list(self.sfd_desc.shape) + sfd_shape[2] = tensor_m_128 + sfd_shape[4] = rest_n + sfd_shape[5] = l_sym + sfd_stride = list(self.sfd_desc.stride) + sfd_stride[2] = stride_sfd_rest_n + sfd_stride[5] = stride_sfd_tensor_m_128 + sfd_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfd_desc.dtype, + shape=tuple(sfd_shape), + stride=tuple(sfd_stride), + ) + else: + a_cute_fake = self._make_fake_cute_tensor_from_desc(self.a_desc, assumed_align=16) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_tensor_from_desc(self.c_desc, assumed_align=16) + d_cute_fake = self._make_fake_cute_tensor_from_desc(self.d_desc, assumed_align=16) + prob_cute_fake = self._make_fake_cute_tensor_from_desc(self.prob_desc, assumed_align=16) + sfa_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfa_desc, assumed_align=16) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + sfd_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfd_desc, assumed_align=16) + + compiled = cute.compile( + gemm, + a_tensor=a_cute_fake, + b_tensor=b_cute_fake, + sfa_tensor=sfa_cute_fake, + sfb_tensor=sfb_cute_fake, + c_tensor=c_cute_fake, + d_tensor=d_cute_fake, + prob_tensor=prob_cute_fake, + amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16), + sfd_tensor=sfd_cute_fake, + norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16), + alpha=self.alpha, + max_active_clusters=max_active_clusters, + stream=fake_stream, + epilogue_op=epilogue_op, + options="--enable-tvm-ffi", + ) + + def tensor_api( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + amax_tensor: Optional[torch.Tensor], + sfd_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + alpha: float, + stream: cuda.CUstream, + ) -> None: + compiled( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + d_tensor, + prob_tensor, + self._unpad_tensor_to_ndim(amax_tensor, 1, "amax"), + sfd_tensor, + self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const"), + alpha, + stream, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + sfd_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + self._runtime_error_if(self._compiled_kernel is None, "GemmSreluSm100 kernel not compiled; call compile() first") + current_stream = self._get_default_stream(current_stream) + self._compiled_kernel( + a_tensor=a_tensor, + b_tensor=b_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + prob_tensor=prob_tensor, + amax_tensor=amax_tensor, + sfd_tensor=sfd_tensor, + norm_const_tensor=norm_const_tensor, + alpha=self.alpha if alpha is None else alpha, + stream=current_stream, + ) + + +_logger = logging.getLogger(__name__) +_cache_of_GemmSreluSm100Objects = {} +_DENSE_GEMM_DYNAMIC_M_ENV = "CUDNN_FE_GEMM_DYNAMIC_M" +_DENSE_GEMM_DYNAMIC_MNKL_ENV = "CUDNN_FE_GEMM_DYNAMIC_MNKL" + + +def _allocate_dense_output(shape: Tuple[int, int, int], major: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + m, n, l = shape + if major == "m": + return torch.empty_strided((m, n, l), (1, m, m * n), dtype=dtype, device=device) + if major == "n": + return torch.empty_strided((m, n, l), (n, 1, m * n), dtype=dtype, device=device) + raise ValueError(f"major must be 'm' or 'n', got {major}") + + +def gemm_srelu_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + alpha: float = 1.0, + c_major: str = "n", + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.bfloat16, + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + m, k, l = a_tensor.shape + n, _, _ = b_tensor.shape + + c_tensor = _allocate_dense_output((m, n, l), c_major, c_dtype, a_tensor.device) + d_tensor = _allocate_dense_output((m, n, l), c_major, d_dtype, a_tensor.device) + + sfd_tensor = None + if d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + sf_k = ceil_div(n, sf_vec_size) + mma_shape = ( + l, + ceil_div(m, 128), + ceil_div(sf_k, 4), + 32, + 4, + 4, + ) + sfd_tensor = torch.empty(mma_shape, dtype=sfa_tensor.dtype, device=a_tensor.device).permute(3, 4, 1, 5, 2, 0) + + amax_tensor = None + if a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} and d_dtype in {torch.bfloat16, torch.float16, torch.float32}: + amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device=a_tensor.device) + + use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None + use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None + + def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: + return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) + + def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype + + def dynamic_compact_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape[1:]), stride_order(tensor), tensor.dtype + + def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return None, stride_order(tensor), tensor.dtype + + def dynamic_m_tensor_signature( + tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = () + ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) + return static_shape_suffix, stride_signature, tensor.dtype + + cache_key = ( + use_full_dynamic, + use_dynamic_m, + *(dynamic_tensor_signature(a_tensor) if use_full_dynamic else dynamic_compact_signature(a_tensor) if use_dynamic_m else tensor_signature(a_tensor)), + *(dynamic_tensor_signature(b_tensor) if use_full_dynamic else tensor_signature(b_tensor)), + *(dynamic_tensor_signature(c_tensor) if use_full_dynamic else dynamic_compact_signature(c_tensor) if use_dynamic_m else tensor_signature(c_tensor)), + c_dtype, + d_dtype, + *( + dynamic_tensor_signature(sfa_tensor) + if use_full_dynamic + else ( + dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], sfa_tensor.shape[5]), dynamic_stride_dims=(5,)) + if use_dynamic_m + else tensor_signature(sfa_tensor) + ) + ), + *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)), + *( + dynamic_tensor_signature(prob_tensor) + if use_full_dynamic + else dynamic_compact_signature(prob_tensor) if use_dynamic_m else tensor_signature(prob_tensor) + ), + norm_const_tensor.shape if norm_const_tensor is not None else None, + norm_const_tensor.stride() if norm_const_tensor is not None else None, + norm_const_tensor.dtype if norm_const_tensor is not None else None, + alpha, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + c_major, + sf_vec_size, + vector_f32, + ) + + op = _cache_of_GemmSreluSm100Objects.get(cache_key) + if op is None: + op = GemmSreluSm100( + sample_a=a_tensor, + sample_b=b_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_sfa=sfa_tensor, + sample_sfb=sfb_tensor, + sample_prob=prob_tensor, + sample_sfd=sfd_tensor, + sample_amax=amax_tensor, + sample_norm_const=norm_const_tensor, + alpha=alpha, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + ) + assert op.check_support(), "Unsupported testcase" + op.compile() + _cache_of_GemmSreluSm100Objects[cache_key] = op + + op.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + prob_tensor=prob_tensor, + sfd_tensor=sfd_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + alpha=alpha, + current_stream=stream, + ) + + return TupleDict( + c_tensor=c_tensor, + d_tensor=d_tensor, + amax_tensor=amax_tensor, + sfd_tensor=sfd_tensor, + ) diff --git a/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py b/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py new file mode 100644 index 00000000..cac870d6 --- /dev/null +++ b/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py @@ -0,0 +1,2094 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils + +import cutlass.cute.math as math +from cutlass.cute.typing import Float32 + + +def get_divisibility(dtype) -> int: + """ + Get the divisibility value based on the data type. + """ + if dtype in [cutlass.Float4E2M1FN]: + divisibility = 32 + elif dtype in [cutlass.Float8E4M3FN, cutlass.Float8E5M2, cutlass.Float8E8M0FNU]: + divisibility = 16 + elif dtype in [cutlass.Float16, cutlass.BFloat16]: + divisibility = 8 + elif dtype == cutlass.Float32: + divisibility = 4 + else: + raise ValueError(f"Unsupported data type: {dtype}") + return divisibility + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + # {$nv-internal-release begin} + # Note: We don't have SFD generation support in this example for now, so Float4E2M1FN output is only for internal testing and will not be released. + - Float4E2M1FN + # {$nv-internal-release end} + + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 64/128/192/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, d_tensor, prob_tensor, amax_tensor, sfd_tensor, norm_const_tensor, alpha, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vector_f32: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)) + # Set barrier id for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + # Generate sfd output by 2xf32 + self.vector_f32 = vector_f32 + + # Amax reduction configuration + self.num_epilog_warps = len(self.epilog_warp_id) + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + self.mma_tiler_c = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_c = ( + self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_c[1], + self.mma_tiler_c[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Always use subtile (128,32) + self.epi_tile = (cute.make_layout(128), cute.make_layout(32)) + self.epi_tile_cnt = ( + self.cta_tile_shape_mnk[0] // cute.size(self.epi_tile[0]), + self.cta_tile_shape_mnk[1] // cute.size(self.epi_tile[1]), + ) + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_d_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.d_dtype, + self.d_layout, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + d_tensor: cute.Tensor, + prob_tensor: cute.Tensor, + amax_tensor: Optional[cute.Tensor], + sfd_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + alpha: cutlass.Float32, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param d_tensor: Input tensor D + :type d_tensor: cute.Tensor + :param sfd_tensor: Scale factor tensor C + :type sfd_tensor: cute.Tensor + :param norm_const_tensor: Normalization constant tensor for quantization + :type norm_const_tensor: cute.Tensor + :param amax_tensor: Output tensor for absolute maximum value + :type amax_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.d_dtype: Type[cutlass.Numeric] = d_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + self.d_layout = utils.LayoutEnum.from_tensor(d_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release} + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # {$nv-internal-release begin} + # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) + # logical blocks for SFB when cta_tile_shape_n=192. + # {$nv-internal-release end} + if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + # Setup TMA store for C + epi_c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_c_smem_layout, + self.epi_tile, + ) + epi_d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d_tensor, + epi_d_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk_c, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + self.generate_sfd = sfd_tensor is not None and norm_const_tensor is not None + if cutlass.const_expr(self.generate_sfd): + sfd_layout = blockscaled_utils.tile_atom_to_shape_SF(c_tensor.shape, self.sf_vec_size) + sfd_tensor = cute.make_tensor(sfd_tensor.iterator, sfd_layout) + + self.generate_amax = amax_tensor is not None + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + cute.cosize(self.d_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + # Amax reduction shared memory (one FP32 per epilogue warp) + # Use smaller alignment for amax since it's only 16 bytes + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + prob_tensor, + amax_tensor, + sfd_tensor, + norm_const_tensor, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + alpha, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, + mProb_mnl: cute.Tensor, + mAmax_tensor: Optional[cute.Tensor], + mSFD_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + alpha: cutlass.Float32, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + cpasync.prefetch_descriptor(tma_atom_d) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, + swizzle=c_smem_layout_staged.inner, + dtype=self.c_dtype, + ) + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sD = storage.sD.get_tensor( + d_smem_layout_staged.outer, + swizzle=d_smem_layout_staged.inner, + dtype=self.d_dtype, + ) + + # Shared memory for amax reduction (one FP32 per epilogue warp) + # Simple 1D layout. The allocation always here if no amax is generated, + # as the overhead is minimal and we want to keep the code simple. + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None)) + # (bM, bN, RestM, RestN, RestL) + gD_mnl = cute.local_tile( + mD_mnl, + cute.slice_(self.mma_tiler_c, (None, None, 0)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgD = thr_mma.partition_C(gD_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + + # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release} + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # Make accumulator tmem tensor + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or cta_tile_shape_n=64 # {$nv-internal-release} + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) + ( + bSG_sC, + bSG_gC_mnl, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + _, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD) + ( + bSG_sD, + bSG_gD_mnl, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD, epi_tile, sD) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + d_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_mnl[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + bSG_gD = bSG_gD_mnl[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] + + # Initialize thread-local amax accumulator for this tile + # Use 0.0 as initial value since we're computing absolute maximum + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8)) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + # + # Get PROB + # Note, it always assumes T2R_M/EPI_M is 1, otherwise it will break the result. + # + mPosition = cur_tile_coord[0] * self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape) + tidx + mProb = mProb_mnl[mPosition, 0, cur_tile_coord[2]] + for subtile_idx in cutlass.range(0, subtile_cnt, 1): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_subtile = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_subtile, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc) + + # + # Apply alpha + # + acc_vec_ = acc_vec.load() + acc_vec.store(acc_vec_ * alpha) + + # + # Store C to shared memory for bprop + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + # Convert to d_dtype before storing + # Load, convert type, and store back to temporary register tensor + # tTR_rD_up = cute.make_rmem_tensor(tTR_tAcc_mn_up.shape, self.d_dtype) + # tTR_rD_gate = cute.make_rmem_tensor(tTR_tAcc_mn_gate.shape, self.d_dtype) + tRS_rC.store(acc_vec.load().to(self.c_dtype)) + cute.copy( + tiled_copy_r2s, + tRS_rC[(None, None, 0)], + tRS_sC[(None, None, 0, c_buffer)], + ) + + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], # ((8192, 1), (1, 4)), ((1, 0), (0, 8192)) + bSG_gC[(None, subtile_idx)], # (((64, 128), 1), (1, 4)) : (((1@0,1@1),0),(0,64@0)) + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + + self.epilog_sync_barrier.arrive_and_wait() + + # + # Generate amax + # + if cutlass.const_expr(self.generate_amax): + acc_values = acc_vec.load() + acc_values = epilogue_op(acc_values) + acc_values = acc_values * mProb + acc_vec.store(acc_values) + + # Apply element-wise absolute value using math.absf (supports vectors) + abs_acc_values_ir = cutlass._mlir.dialects.math.absf(acc_values.ir_value()) # operand (positional) + abs_acc_values = type(acc_values)(abs_acc_values_ir, acc_values.shape, acc_values.dtype) + subtile_amax = abs_acc_values.reduce( + cute.ReductionOp.MAX, + cutlass.Float32(0.0), + 0, # Use 0.0 as init for abs values + ) + thread_tile_amax = cute.arch.fmax(thread_tile_amax, subtile_amax) + + # + # Generate sfd + # + if cutlass.const_expr(self.generate_sfd): + cute.printf("SFD not implemented\n") + else: + # + # Convert to D type directly + # + acc_vec = tiled_copy_r2s.retile(acc_vec).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + # + # Store D to shared memory + # + d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage + cute.copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[(None, None, None, d_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store D to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_d, + bSG_sD[(None, d_buffer)], + bSG_gD[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # Perform amax reduction after all subtiles are processed + if cutlass.const_expr(self.generate_amax): + # Warp-level reduction using wrapper function + warp_amax = cute.arch.warp_redux_sync( + value=thread_tile_amax, + kind="fmax", + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + # Each epilogue warp's lane 0 writes warp amax to shared memory + if cute.arch.lane_idx() == 0: + sAmax[warp_idx] = cutlass.Float32(warp_amax) + + # Ensure all epilogue warps complete their writes before block reduction + self.epilog_sync_barrier.arrive_and_wait() + + # Block-level reduction: only first epilogue warp's lane 0 handles this + if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum + for i in cutlass.range(self.num_epilog_warps): + warp_amax_val = sAmax[i] + block_amax = cute.arch.fmax(block_amax, warp_amax_val) + + # Global atomic max (accumulates across all tiles for final tensor amax) + # Since we compute absolute values, all values are non-negative + # Use wrapper function for atomic max operation + _ = cute.arch.atomic_max_float32(ptr=mAmax_tensor.iterator.llvm_ptr, value=block_amax) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + d_pipeline.producer_tail() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gD_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gD_mnl: The global tensor C + :type gD_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The partitioned accumulator tensor for acc up + - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gD = thr_copy_t2r.partition_D(gD_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor(tTR_gD[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(gC_epi, 0, 2), + ) + return bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + d_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param d_dtype: Data type of operand D. + :type d_dtype: type[cutlass.Numeric] + :param d_layout: Layout enum of operand D. + :type d_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 # mma_tiler_mnk[1] // cute.cosize(epi_tile[1]) + num_d_stage = num_c_stage + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + d_dtype, + d_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + amax_bytes = Sm100BlockScaledPersistentDenseGemmKernel.get_amax_smem_size() + epi_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage * num_d_stage + amax_bytes + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)) // ab_bytes_per_stage + + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) + + return tile_sched_params, grid + + @staticmethod + def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float: + """ + Calculates the reciprocal of the maximum absolute value for a given data type. + + :param dtype: Data type + :type dtype: Type[cutlass.Numeric] + + :return: An float representing the reciprocal of the maximum absolute value + :rtype: float + """ + if dtype == cutlass.Float4E2M1FN: + return 1 / 6.0 + if dtype == cutlass.Float8E4M3FN: + return 1 / 448.0 + if dtype == cutlass.Float8E5M2: + return 1 / 128.0 + return 1.0 + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if d_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + cutlass.Float4E2M1FN, # {$nv-internal-release} + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # {$nv-internal-release begin} + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + # TODO: Currently we don't support m major output for Float4E2M1FN + if c_dtype is cutlass.Float4E2M1FN and c_major == "m": + is_valid = False + # {$nv-internal-release end} + + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [64, 128, 192, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param d_major: The major axis of the C tensor + :type d_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, d_dtype): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(ab_dtype, d_dtype, a_major, b_major, d_major): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major): + can_implement = False + return can_implement + + @staticmethod + def get_amax_smem_size(): + # Note: 4 is hardcoded for num_epilog_warps + return 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) diff --git a/python/cudnn/gemm_swiglu/api.py b/python/cudnn/gemm_swiglu/api.py index 270c026a..2daeaa74 100644 --- a/python/cudnn/gemm_swiglu/api.py +++ b/python/cudnn/gemm_swiglu/api.py @@ -69,7 +69,7 @@ def __init__( ): super().__init__() - self._logger.warning("GemmSwigluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self.a_desc = self._make_tensor_desc(sample_a, name="sample_a") @@ -94,7 +94,7 @@ def __init__( self.vector_f32 = vector_f32 self.ab12_stages = ab12_stages self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") # Kernel selection if self.sfa_desc is None and self.sfb_desc is None and self.amax_desc is None and self.sfc_desc is None and self.norm_const_desc is None: diff --git a/python/cudnn/grouped_gemm/__init__.py b/python/cudnn/grouped_gemm/__init__.py index a19f6c2f..818a1ab9 100644 --- a/python/cudnn/grouped_gemm/__init__.py +++ b/python/cudnn/grouped_gemm/__init__.py @@ -16,11 +16,26 @@ grouped_gemm_quant_wrapper_sm100, ) +from .grouped_gemm_srelu.api import ( + GroupedGemmSreluSm100, + grouped_gemm_srelu_wrapper_sm100, +) + +from .grouped_gemm_dsrelu.api import ( + GroupedGemmDsreluSm100, + grouped_gemm_dsrelu_wrapper_sm100, +) + from .grouped_gemm_glu.api import ( GroupedGemmGluSm100, grouped_gemm_glu_wrapper_sm100, ) +from .grouped_gemm_glu_hadamard.api import ( + GroupedGemmGluHadamardSm100, + grouped_gemm_glu_hadamard_wrapper_sm100, +) + from .grouped_gemm_dglu.api import ( GroupedGemmDgluSm100, grouped_gemm_dglu_wrapper_sm100, @@ -38,8 +53,14 @@ "grouped_gemm_dswiglu_wrapper_sm100", "GroupedGemmQuantSm100", "grouped_gemm_quant_wrapper_sm100", + "GroupedGemmSreluSm100", + "grouped_gemm_srelu_wrapper_sm100", + "GroupedGemmDsreluSm100", + "grouped_gemm_dsrelu_wrapper_sm100", "GroupedGemmGluSm100", "grouped_gemm_glu_wrapper_sm100", + "GroupedGemmGluHadamardSm100", + "grouped_gemm_glu_hadamard_wrapper_sm100", "GroupedGemmDgluSm100", "grouped_gemm_dglu_wrapper_sm100", "GroupedGemmWgradSm100", diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py index 843b3412..d7a19666 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py @@ -174,7 +174,7 @@ def __init__( """ super().__init__() - self._logger.warning("GroupedGemmDgluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") # ---- Weight mode auto-detection ---- @@ -260,7 +260,7 @@ def __init__( self._kernel = BlockScaledMoEGroupedGemmDgluDbiasKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._workspace = None @@ -676,7 +676,7 @@ def compile(self) -> None: def _compile_dense(self, gemm_dglu, max_active_clusters, fake_stream) -> None: """Compile for dense (contiguous) weight mode.""" - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" fake_workspace_ptr = cute.runtime.nullptr( dtype=cutlass.Uint8, @@ -1446,7 +1446,7 @@ def dynamic_m_tensor_signature( stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) return static_shape_suffix, stride_signature, tensor.dtype - use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" if is_dense: cache_key = ( diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py new file mode 100644 index 00000000..e1513779 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from .api import ( + GroupedGemmDsreluSm100, + grouped_gemm_dsrelu_wrapper_sm100, +) + +__all__ = [ + "GroupedGemmDsreluSm100", + "grouped_gemm_dsrelu_wrapper_sm100", +] diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py new file mode 100644 index 00000000..2226a627 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py @@ -0,0 +1,1581 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Unified API for Grouped GEMM dSReLU Backward Kernel (SM100+) + +This module provides a single API class that supports both contiguous (dense) +and discrete weight modes for block-scaled grouped GEMM with dSReLU activation +gradient in MoE (Mixture of Experts) workloads. + +Dense mode + All expert weights are packed contiguously in a 3-D tensor (N, K, L). + Callers supply ``sample_b`` and ``sample_sfb``. + +Discrete mode + Each expert has its own memory allocation. Callers supply + ``num_experts``, ``b_shape``, ``b_dtype``, and per-expert pointer arrays + at execution time. +""" + +from .moe_blockscaled_grouped_gemm_dsrelu_quant import ( + BlockScaledMoEGroupedGemmQuantBwdKernel, + EpilogueType, +) +from ..moe_utils import MoEWeightMode +from cuda.bindings import driver as cuda +import logging +import os +import torch +from typing import Tuple, Optional + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cute.runtime import from_dlpack, make_fake_stream + +from cudnn.datatypes import _convert_to_cutlass_data_type +from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2 + + +def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.uint8: + cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1) + cute_tensor.element_type = cutlass.Float4E2M1FN + return cute_tensor + return tensor + + +class GroupedGemmDsreluSm100(APIBase): + """Unified API for grouped GEMM dSReLU backward operation on SM100+ GPUs. + + This kernel performs block-scaled grouped GEMM with dSReLU activation + gradient (dSReLU), designed for MoE workloads. It supports + both dense (contiguous) and discrete (per-expert pointer) weight layouts + through the ``BlockScaledMoEGroupedGemmQuantBwdKernel``. + + Weight mode is auto-detected from the constructor arguments: + + - **Dense**: provide ``sample_b`` and ``sample_sfb``. + - **Discrete**: provide ``num_experts``, ``b_shape``, and ``b_dtype``. + + Example:: + + # Dense mode + api = GroupedGemmDsreluSm100( + sample_a=a, sample_c=c, + sample_d_row=d_row, sample_d_col=d_col, + sample_sfa=sfa, sample_padded_offsets=offsets, + sample_alpha=alpha, + sample_prob=prob, sample_dprob=dprob, + sample_b=b, sample_sfb=sfb, + ) + + # Discrete mode + api = GroupedGemmDsreluSm100( + sample_a=a, sample_c=c, + sample_d_row=d_row, sample_d_col=d_col, + sample_sfa=sfa, sample_padded_offsets=offsets, + sample_alpha=alpha, + sample_prob=prob, sample_dprob=dprob, + num_experts=8, b_shape=(n, k), b_dtype=torch.uint8, + ) + + api.check_support() + api.compile() + api.execute(...) + """ + + def __init__( + self, + sample_a: torch.Tensor, + # Dense mode (contiguous) -- provide these. sample_dbias is optional: + sample_b: Optional[torch.Tensor] = None, + sample_c: Optional[torch.Tensor] = None, + sample_d_row: Optional[torch.Tensor] = None, + sample_d_col: Optional[torch.Tensor] = None, + sample_sfa: Optional[torch.Tensor] = None, + sample_sfb: Optional[torch.Tensor] = None, + sample_padded_offsets: Optional[torch.Tensor] = None, + sample_alpha: Optional[torch.Tensor] = None, + sample_prob: Optional[torch.Tensor] = None, + sample_dprob: Optional[torch.Tensor] = None, + sample_dbias: Optional[torch.Tensor] = None, + # Discrete mode -- provide these instead: + num_experts: Optional[int] = None, + b_shape: Optional[Tuple[int, ...]] = None, + b_dtype: Optional[torch.dtype] = None, + # Optional quantization output arguments + sample_sfd_row: Optional[torch.Tensor] = None, + sample_sfd_col: Optional[torch.Tensor] = None, + sample_amax: Optional[torch.Tensor] = None, + sample_norm_const: Optional[torch.Tensor] = None, + # Configuration + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + b_major: str = "k", + use_dynamic_sched: bool = False, + ): + """Initialize the GroupedGemmDsreluSm100 API. + + :param sample_a: Sample A tensor (valid_m, k, 1) + :param sample_c: Sample C tensor -- forward activations (valid_m, n, 1) + :param sample_d_row: Sample D row output tensor (valid_m, n, 1) + :param sample_d_col: Sample D col output tensor (valid_m, n, 1) + :param sample_sfa: Sample scale factor A tensor + :param sample_padded_offsets: End offset for each expert after padding + :param sample_alpha: Per-group alpha scaling factors + :param sample_prob: Per-row probability tensor (valid_m, 1, 1) + :param sample_dprob: Gradient of probability tensor (valid_m, 1, 1), must be zero-initialized + :param sample_b: (Dense) Sample B tensor (n, k, l) + :param sample_sfb: (Dense) Sample scale factor B tensor + :param sample_dbias: Optional dbias output tensor (expert_cnt, n, 1) + :param num_experts: (Discrete) Number of experts + :param b_shape: (Discrete) Shape of a single expert B tensor, e.g. (n, k) + :param b_dtype: (Discrete) Data type of B tensors + :param sample_sfd_row: Optional row scale factor for D + :param sample_sfd_col: Optional column scale factor for D + :param sample_amax: Optional amax tensor for quantization, shape (expert_cnt, 1) + :param sample_norm_const: Optional normalization constant + :param acc_dtype: Accumulator data type + :param mma_tiler_mn: MMA tiler shape (M, N) + :param cluster_shape_mn: Cluster shape (M, N) + :param sf_vec_size: Scale factor vector size + :param vector_f32: Use vectorized f32 operations + :param m_aligned: Alignment for group M dimension + :param discrete_col_sfd: Generate discrete col-major scale factor tensor + :param b_major: Major dimension for B tensor, one of "k" or "n" + :param use_dynamic_sched: Enable dynamic tile scheduling for load balancing + """ + super().__init__() + + self._warn_experimental_api() + self._logger.debug("Entering __init__") + + # ---- Weight mode auto-detection ---- + if sample_b is not None and num_experts is None: + self.weight_mode = MoEWeightMode.DENSE + if sample_sfb is None: + raise ValueError("sample_sfb is required when sample_b is provided (dense mode)") + elif num_experts is not None and sample_b is None: + self.weight_mode = MoEWeightMode.DISCRETE + if b_shape is None or b_dtype is None: + raise ValueError("b_shape and b_dtype are required in discrete mode") + else: + raise ValueError("Provide either (sample_b, sample_sfb) for dense mode " "or (num_experts, b_shape, b_dtype) for discrete mode, but not both.") + + self._sample_a_tensor = sample_a + self._sample_b_tensor = sample_b + + # ---- Common tensor descriptors ---- + self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False) + self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") + self.d_row_desc = self._make_tensor_desc(sample_d_row, name="sample_d_row") + self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col") + self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa") + self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets") + self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha") + self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") + self.dprob_desc = self._make_tensor_desc(sample_dprob, name="sample_dprob") + self.dbias_desc = self._make_tensor_desc(sample_dbias, name="sample_dbias") + + self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row") + self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col") + self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax") + self.norm_const_desc = self._unpad_tensor_to_ndim( + self._make_tensor_desc(sample_norm_const, name="sample_norm_const"), + 1, + "norm_const", + ) + + # ---- Mode-specific state ---- + if self.weight_mode == MoEWeightMode.DENSE: + self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False) + self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") + self.expert_cnt = self.padded_offsets_desc.shape[0] + else: + self._value_error_if(num_experts == 0, "num_experts must be > 0") + self.expert_cnt = num_experts + self.b_shape = b_shape + self.b_dtype = b_dtype + self.b_major = b_major + self._value_error_if( + self.padded_offsets_desc.shape[0] != self.expert_cnt, + f"padded_offsets length ({self.padded_offsets_desc.shape[0]}) " f"must equal num_experts ({self.expert_cnt})", + ) + + # ---- Configuration ---- + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + if cluster_shape_mn is None: + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + else: + self.cluster_shape_mn = cluster_shape_mn + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.m_aligned = m_aligned + self.discrete_col_sfd = discrete_col_sfd + if self.weight_mode == MoEWeightMode.DENSE: + self.b_major = b_major # stored for both modes + + self.use_dynamic_sched = use_dynamic_sched + + self._interpret_uint8_as_fp4x2 = True + self._has_dbias = self.dbias_desc is not None + self._kernel = BlockScaledMoEGroupedGemmQuantBwdKernel + + self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + + self._workspace = None + + self._logger.debug("__init__ completed") + + def check_support(self) -> bool: + """Check if the kernel configuration is supported. + + :return: True if supported, raises exception otherwise + """ + self._logger.debug("Entering check_support") + + # ---- SFD group validation ---- + all_none = all(x is None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc]) + all_provided = all(x is not None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc]) + self._value_error_if( + not (all_none or all_provided), + "sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None", + ) + self._user_requested_sfd = all_provided + + # ---- Shapes and strides ---- + self._logger.debug("Checking tensor shapes and strides") + tensor_m, k, _one = self._tensor_shape(self.a_desc, name="sample_a") + + if self.weight_mode == MoEWeightMode.DENSE: + n, _, l = self._tensor_shape(self.b_desc, name="sample_b") + else: + # Discrete: extract n, k from b_shape + if len(self.b_shape) == 2: + n, b_k = self.b_shape + else: + n, b_k, _ = self.b_shape + self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})") + l = self.expert_cnt # for shape checks that use l + + n_out = n + + self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A") + if self.weight_mode == MoEWeightMode.DENSE: + self._check_tensor_shape(self.b_desc, (n, k, l), "B") + self._check_tensor_shape(self.c_desc, (tensor_m, n_out, 1), "C") + self._check_tensor_shape(self.d_row_desc, (tensor_m, n_out, 1), "D_row") + self._check_tensor_shape(self.d_col_desc, (tensor_m, n_out, 1), "D_col") + + rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_k, 1), "SFA") + if self.weight_mode == MoEWeightMode.DENSE: + self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") + + rest_n_out = ceil_div(ceil_div(n_out, self.sf_vec_size), 4) + self._check_tensor_shape( + self.sfd_row_desc, + (32, 4, ceil_div(tensor_m, 128), 4, rest_n_out, 1), + "SFD_row", + ) + rest_m = ceil_div(ceil_div(tensor_m, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfd_col_desc, (32, 4, ceil_div(n_out, 128), 4, rest_m, 1), "SFD_col") + + self._check_tensor_shape(self.alpha_desc, (self.expert_cnt,), "alpha") + self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob") + self._check_tensor_shape(self.dprob_desc, (tensor_m, 1, 1), "dprob") + self._check_tensor_shape(self.dbias_desc, (self.expert_cnt, n_out, 1), "dbias") + self._check_tensor_shape(self.amax_desc, (self.expert_cnt, 1), "amax") + self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const") + self._check_tensor_shape(self.padded_offsets_desc, (self.expert_cnt,), "padded_offsets") + + # Strides + _ = self._check_tensor_stride( + self.a_desc, + stride=[(k, 1, tensor_m * k)], + extra_error_msg="A must have k-major layout", + ) + if self.weight_mode == MoEWeightMode.DENSE: + if self._is_fp8(self.a_desc): + _ = self._check_tensor_stride( + self.b_desc, + stride=[(k, 1, n * k), (1, n, n * k)], + extra_error_msg="For fp8 ab_dtype, B must have k- or n-major layout", + ) + else: + _ = self._check_tensor_stride( + self.b_desc, + stride=[(k, 1, n * k)], + extra_error_msg="For fp4 ab_dtype, B must have k-major layout", + ) + _ = self._check_tensor_stride( + self.c_desc, + stride=[(n_out, 1, tensor_m * n_out)], + extra_error_msg="C must have n-major layout", + ) + _ = self._check_tensor_stride( + self.d_row_desc, + stride=[(n_out, 1, tensor_m * n_out)], + extra_error_msg="D_row must have n-major layout", + ) + _ = self._check_tensor_stride( + self.d_col_desc, + stride=[(n_out, 1, tensor_m * n_out)], + extra_error_msg="D_col must have n-major layout", + ) + + # ---- Data types ---- + self._logger.debug("Checking data types") + self.ab_dtype = self._check_dtype( + self.a_desc, + dtype=[ + torch.float4_e2m1fn_x2, + torch.uint8, + torch.float8_e5m2, + torch.float8_e4m3fn, + ], + name="A/B", + ) + if self.weight_mode == MoEWeightMode.DENSE: + self._check_dtype( + self.b_desc, + dtype=self.ab_dtype, + name="B", + extra_error_msg="B must have the same dtype as A", + ) + else: + self._value_error_if( + self.b_dtype != self.ab_dtype, + f"b_dtype ({self.b_dtype}) must match A dtype ({self.ab_dtype})", + ) + + self.sf_dtype = self._check_dtype( + self.sfa_desc, + dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], + name="SFA/SFB/SFD", + ) + if self.weight_mode == MoEWeightMode.DENSE: + self._check_dtype( + self.sfb_desc, + dtype=self.sf_dtype, + name="SFB", + extra_error_msg="SFB must have the same dtype as SFA", + ) + self._check_dtype( + self.sfd_row_desc, + dtype=self.sf_dtype, + name="SFD_row", + extra_error_msg="SFD_row must have the same dtype as SFA", + ) + self._check_dtype( + self.sfd_col_desc, + dtype=self.sf_dtype, + name="SFD_col", + extra_error_msg="SFD_col must have the same dtype as SFA", + ) + + self._value_error_if( + self.sf_vec_size not in [16, 32], + f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}", + ) + self._value_error_if( + self.sf_dtype in [torch.float8_e4m3fn] and self.sf_vec_size == 32, + f"sf_dtype {self.sf_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported", + ) + self._value_error_if( + self._is_fp8(self.ab_dtype) and self.sf_vec_size == 16, + f"ab_dtype {self.ab_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported", + ) + + self._check_dtype( + self.acc_dtype, + dtype=torch.float32, + name="Accumulator", + extra_error_msg="Accumulator must be float32", + ) + self._check_dtype( + self.prob_desc, + dtype=torch.float32, + name="Prob", + extra_error_msg="Prob must be float32", + ) + self._check_dtype( + self.dprob_desc, + dtype=torch.float32, + name="Dprob", + extra_error_msg="Dprob must be float32", + ) + self._check_dtype( + self.dbias_desc, + dtype=torch.bfloat16, + name="Dbias", + extra_error_msg="dbias must be bfloat16", + ) + self.c_dtype = self._check_dtype( + self.c_desc, + dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], + name="C", + ) + if self._is_fp8(self.c_dtype) and self.vector_f32: + raise ValueError("Invalid configuration: fp8 c_dtype and vector_f32 is not supported. " "Please use vector_f32=False or c_dtype=bfloat16 instead") + + if self._is_fp4x2(self.ab_dtype): + self.d_dtype = self._check_dtype( + self.d_row_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32], + name="D_row", + extra_error_msg="D_row must be fp16, bf16, or float32 when ab_dtype is fp4", + ) + elif self._is_fp8(self.ab_dtype): + self.d_dtype = self._check_dtype( + self.d_row_desc, + dtype=[ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], + name="D_row", + extra_error_msg="D_row must be fp8 dtype when ab_dtype is fp8", + ) + else: + raise NotImplementedError(f"Invalid ab_dtype: {self.ab_dtype}, expected fp4 or fp8") + self._check_dtype( + self.d_col_desc, + dtype=self.d_dtype, + name="D_col", + extra_error_msg="D_col must have the same dtype as D_row", + ) + + # ---- SFD generation logic ---- + kernel_generate_sfd = self._is_fp8(self.ab_dtype) and self.sf_dtype == torch.float8_e8m0fnu and self._is_fp8(self.d_dtype) + self._value_error_if( + kernel_generate_sfd and not self._user_requested_sfd, + "sfd_row, sfd_col, and norm_const are required for FP8 input/FP8 output with sf_dtype=torch.float8_e8m0fnu", + ) + if not kernel_generate_sfd and self._user_requested_sfd: + self._logger.warning( + "sfd_row/sfd_col/norm_const were provided, but this configuration does not generate SFD outputs; " "the tensors will be ignored by the kernel", + ) + self.generate_sfd = kernel_generate_sfd + if self.discrete_col_sfd and not self.generate_sfd: + self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored") + self.discrete_col_sfd = False + + # ---- Activation function validation ---- + # ---- Discrete-mode-specific validation ---- + if self.weight_mode == MoEWeightMode.DISCRETE: + self._value_error_if( + self.b_major not in ["k", "n"], + f"b_major must be 'k' or 'n', got {self.b_major}", + ) + self._value_error_if( + self._is_fp4x2(self.ab_dtype) and self.b_major != "k", + "b_major must be 'k' when ab_dtype is fp4", + ) + + # ---- MMA tile / cluster shape ---- + self._logger.debug("Checking MMA tile shape and cluster shape") + self._value_error_if( + not self.use_2cta_instrs and self.mma_tiler_mn[0] != 128, + f"MMA tiler M must be 128 when use_2cta_instrs=False, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.use_2cta_instrs and self.mma_tiler_mn[0] != 256, + f"MMA tiler M must be 256 when use_2cta_instrs=True, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.mma_tiler_mn[1] != 256, + f"MMA tiler N must be 256, got {self.mma_tiler_mn[1]}", + ) + self._value_error_if( + self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0, + f"cluster_shape_mn[0] must be divisible by 2 when use_2cta_instrs=True, got {self.cluster_shape_mn[0]}", + ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] <= 4 + and self.cluster_shape_mn[1] <= 4 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape: expected values to be powers of 2 and product <= 16, got {self.cluster_shape_mn}", + ) + cluster_tiler_m = (self.cluster_shape_mn[0] // (2 if self.use_2cta_instrs else 1)) * self.mma_tiler_mn[0] + self._value_error_if( + cluster_tiler_m not in [128, 256], + f"Invalid cluster tiler shape: expected cluster_tiler_m in {{128, 256}}, got {cluster_tiler_m}", + ) + self._value_error_if( + self.m_aligned % self.mma_tiler_mn[0] != 0, + f"m_aligned must be divisible by mma_tiler_mn[0], got {self.m_aligned} % {self.mma_tiler_mn[0]} != 0", + ) + self._value_error_if( + self.m_aligned != BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE, + f"m_aligned must be {BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE} (FIX_PAD_SIZE), got {self.m_aligned}", + ) + + # ---- Tensor alignment ---- + self._logger.debug("Checking tensor alignment") + + def check_contiguous_16B_alignment(dtype, stride_order, tensor_shape): + is_mode0_major = stride_order == (0, 1, 2) + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width) + return num_major_elements % num_contiguous_elements == 0 + + if self.weight_mode == MoEWeightMode.DENSE: + b_stride_order_for_check = self.b_desc.stride_order + b_shape_for_check = (n, k, l) + else: + b_stride_order_for_check = (0, 1, 2) if self.b_major == "n" else (1, 0, 2) + b_shape_for_check = (n, k, 1) + + self._value_error_if( + not ( + check_contiguous_16B_alignment(self.ab_dtype, self.a_desc.stride_order, (tensor_m, k, l)) + and check_contiguous_16B_alignment(self.ab_dtype, b_stride_order_for_check, b_shape_for_check) + and check_contiguous_16B_alignment(self.d_dtype, self.d_row_desc.stride_order, (tensor_m, n_out, 1)) + ), + "Invalid tensor alignment: tensors must be 16B aligned", + ) + + # ---- Expert count limit ---- + self._value_error_if( + self.expert_cnt > 1024, + f"expert_cnt must be <= 1024, got {self.expert_cnt}", + ) + + # ---- Disabled configurations ---- + self._not_implemented_error_if( + self.dbias_desc is None and self._is_fp4x2(self.ab_dtype) and self.sf_vec_size == 16 and self.d_dtype == torch.float32, + "Invalid configuration: fp4 ab_dtype, sf_vec_size 16, d_dtype float32 is not supported. " "Please use sf_vec_size 32 or d_dtype bf16 instead", + ) + + # ---- SM100+ check ---- + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + compute_capability = major * 10 + minor + if compute_capability < 100: + raise RuntimeError(f"GroupedGemmDsrelu requires SM100+ compute capability, " f"but found SM{compute_capability} on device {device}") + + self._is_supported = True + self._logger.debug("check_support completed successfully") + return True + + def compile(self) -> None: + """Compile the kernel.""" + self._logger.debug("Entering compile") + self._ensure_support_checked() + if self._compiled_kernel is not None: + self._logger.debug("Kernel already compiled; skipping recompilation") + return + if self.a_desc.shape[0] == 0: + self._logger.debug("sample valid_m is zero, skipping kernel compilation") + return + + gemm_dsrelu = self._kernel( + sf_vec_size=self.sf_vec_size, + acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), + use_2cta_instrs=self.use_2cta_instrs, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vectorized_f32=self.vector_f32, + generate_sfd=self.generate_sfd, + discrete_col_sfd=self.discrete_col_sfd, + expert_cnt=self.expert_cnt, + weight_mode=self.weight_mode, + use_dynamic_sched=self.use_dynamic_sched, + epilogue_type=EpilogueType.DSRELU.value, + generate_dbias=self._has_dbias, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + max_active_clusters -= self.num_cluster_overlap_margin + self._value_error_if( + max_active_clusters <= 0, + "max_active_clusters must be > 0 after applying overlap margin; reduce CUDNNFE_CLUSTER_OVERLAP_MARGIN", + ) + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + + self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + + workspace_bytes = gemm_dsrelu.get_workspace_bytes() + self._workspace = torch.empty(max(workspace_bytes, 1), dtype=torch.uint8, device="cuda") + + if self.weight_mode == MoEWeightMode.DENSE: + self._compile_dense(gemm_dsrelu, max_active_clusters, fake_stream) + else: + self._compile_discrete(gemm_dsrelu, max_active_clusters, fake_stream) + + self._logger.debug("Kernel compiled successfully") + + def _compile_dense(self, gemm_dsrelu, max_active_clusters, fake_stream) -> None: + """Compile for dense (contiguous) weight mode.""" + self._logger.debug("Compiling grouped_gemm_dsrelu kernel") + use_full_dynamic = self._use_full_dynamic_mnkl + + fake_workspace_ptr = cute.runtime.nullptr( + dtype=cutlass.Uint8, + assumed_align=128, + ) + + if not use_full_dynamic: + valid_m = cute.sym_int(divisibility=256) + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride_order=self.a_desc.stride_order, + ) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride_order=self.c_desc.stride_order, + ) + d_row_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_row_desc.dtype, + shape=(valid_m, *self.d_row_desc.shape[1:]), + stride_order=self.d_row_desc.stride_order, + ) + d_col_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, *self.d_col_desc.shape[1:]), + stride_order=self.d_col_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1), + stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), + ) + + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + + prob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, 1, 1), + stride_order=self.prob_desc.stride_order, + ) + dprob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, 1, 1), + stride_order=self.dprob_desc.stride_order, + ) + + sfd_row_fake = None + sfd_col_fake = None + if self.sfd_row_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_fake = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1), + stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m), + ) + if self.sfd_col_desc is not None: + rest_m = cute.sym_int(divisibility=1) + stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_fake = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1), + stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n), + ) + else: + valid_m = cute.sym_int(divisibility=256) + n_sym = cute.sym_int() + n_out_sym = cute.sym_int() + k_sym = cute.sym_int() + l_sym = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, k_sym, 1), + stride_order=self.a_desc.stride_order, + dynamic_mode=self.a_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + b_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.b_desc.dtype, + shape=(n_sym, k_sym, l_sym), + stride_order=self.b_desc.stride_order, + dynamic_mode=self.b_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, n_out_sym, 1), + stride_order=self.c_desc.stride_order, + dynamic_mode=self.c_desc.stride_order[0], + divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, + ) + + d_row_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_row_desc.dtype, + shape=(valid_m, n_out_sym, 1), + stride_order=self.d_row_desc.stride_order, + dynamic_mode=self.d_row_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_row_desc.dtype) else 16, + ) + + d_col_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, n_out_sym, 1), + stride_order=self.d_col_desc.stride_order, + dynamic_mode=self.d_col_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16, + ) + + tensor_m_128 = cute.sym_int() + rest_k = cute.sym_int() + stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_shape[4] = rest_k + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[2] = stride_rest_k + sfa_stride[5] = stride_tensor_m_128 + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + ) + + tensor_n_128 = cute.sym_int() + stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfb_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfb_desc.dtype, + shape=(32, 4, tensor_n_128, 4, rest_k, l_sym), + stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k), + ) + + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride=self.prob_desc.stride, + ) + dprob_cute_fake = self._make_fake_cute_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, *self.dprob_desc.shape[1:]), + stride=self.dprob_desc.stride, + ) + + sfd_row_fake = None + sfd_col_fake = None + if self.sfd_row_desc is not None: + rest_n_out = cute.sym_int() + stride_sfd_rest_n_out = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_rest_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_fake = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, rest_n_out, 1), + stride=(16, 4, stride_sfd_rest_n_out, 1, 512, stride_sfd_rest_tensor_m_128), + ) + if self.sfd_col_desc is not None: + tensor_n_out_128 = cute.sym_int() + rest_m_dyn = cute.sym_int() + stride_sfd_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_n_out = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_fake = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, tensor_n_out_128, 4, rest_m_dyn, 1), + stride=(16, 4, stride_sfd_rest_m, 1, 512, stride_sfd_n_out), + ) + + dbias_fake = self._make_fake_cute_tensor_from_desc(self.dbias_desc, assumed_align=16) + + _compiled_kernel = cute.compile( + gemm_dsrelu, + a=_reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake, + b=_reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake, + sfb=sfb_cute_fake, + n=cutlass.Int32(0), + k=cutlass.Int32(0), + b_stride_size=cutlass.Int64(0), + b_major_mode=OperandMajorMode.K, + workspace_ptr=fake_workspace_ptr, + c=c_cute_fake, + d=d_row_cute_fake, + d_col=d_col_cute_fake, + sfa=sfa_cute_fake, + sfd_row_tensor=sfd_row_fake, + sfd_col_tensor=sfd_col_fake, + amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16), + norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16), + padded_offsets=self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16), + alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16), + prob=prob_cute_fake, + dprob=dprob_cute_fake, + dbias_tensor=dbias_fake, + max_active_clusters=max_active_clusters, + stream=fake_stream, + options="--enable-tvm-ffi", + ) + + cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator + + def tensor_api( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_row_tensor: torch.Tensor, + d_col_tensor: Optional[torch.Tensor], + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + sfd_row_tensor: Optional[torch.Tensor], + sfd_col_tensor: Optional[torch.Tensor], + amax_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + dprob_tensor: torch.Tensor, + dbias_tensor: Optional[torch.Tensor], + stream: cuda.CUstream, + ) -> None: + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + _compiled_kernel( + _reinterpret_raw_grouped_fp4_tensor(a_tensor), + _reinterpret_raw_grouped_fp4_tensor(b_tensor), + sfb_tensor, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int64(0), + cached_workspace_ptr, + c_tensor, + d_row_tensor, + d_col_tensor, + sfa_tensor, + sfd_row_tensor, + sfd_col_tensor, + amax_tensor, + norm_const_tensor, + padded_offsets, + alpha_tensor, + prob_tensor, + dprob_tensor, + dbias_tensor, + stream, + ) + + self._compiled_kernel = tensor_api + + def _compile_discrete(self, gemm_dsrelu, max_active_clusters, fake_stream) -> None: + """Compile for discrete (per-expert pointer) weight mode.""" + if len(self.b_shape) == 2: + n, k = self.b_shape + else: + n, k, _ = self.b_shape + + b_major_mode = OperandMajorMode.K if self.b_major == "k" else OperandMajorMode.MN + if self.b_major == "k": + b_stride_size = k + else: + b_stride_size = n + + ab_cutlass_dtype = _convert_to_cutlass_data_type(self.a_desc.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2) + align = 32 if ab_cutlass_dtype.width == 4 else 16 + + valid_m = cute.sym_int(divisibility=256) + a_tensor = self._make_fake_cute_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride=(self.a_desc.stride[0], *self.a_desc.stride[1:]), + assumed_align=align, + ) + c_tensor = self._make_fake_cute_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride=(self.c_desc.stride[0], *self.c_desc.stride[1:]), + ) + d_row_tensor = self._make_fake_cute_compact_tensor( + dtype=self.d_row_desc.dtype, + shape=(valid_m, *self.d_row_desc.shape[1:]), + stride_order=self.d_row_desc.stride_order, + ) + d_col_tensor = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, *self.d_col_desc.shape[1:]), + stride_order=self.d_col_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[5] = stride_tensor_m_128 + sfa_tensor = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + assumed_align=16, + ) + sfd_row_tensor = None + if self.sfd_row_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_tensor = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1), + stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m), + assumed_align=16, + ) + sfd_col_tensor = None + if self.sfd_col_desc is not None: + rest_m = cute.sym_int(divisibility=1) + stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_tensor = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1), + stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n), + assumed_align=16, + ) + amax_tensor = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16) + norm_const_tensor_cute = self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16) + padded_offsets_tensor = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16) + alpha_tensor = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16) + prob_tensor = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride=self.prob_desc.stride, + assumed_align=16, + ) + dprob_tensor = self._make_fake_cute_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, *self.dprob_desc.shape[1:]), + stride=self.dprob_desc.stride, + assumed_align=16, + ) + dbias_tensor = self._make_fake_cute_tensor_from_desc(self.dbias_desc, assumed_align=16) + + b_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda") + sfb_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda") + b_ptrs_cute = from_dlpack(b_ptrs_placeholder, assumed_align=8).iterator + sfb_ptrs_cute = from_dlpack(sfb_ptrs_placeholder, assumed_align=8).iterator + + workspace_ptr_cute = from_dlpack(self._workspace, assumed_align=128).iterator + + self._logger.debug("Compiling discrete grouped GEMM dSReLU kernel") + _compiled_kernel = cute.compile( + gemm_dsrelu, + a_tensor, + b_ptrs_cute, + sfb_ptrs_cute, + cutlass.Int32(n), + cutlass.Int32(k), + cutlass.Int64(b_stride_size), + b_major_mode, + workspace_ptr_cute, + c_tensor, + d_row_tensor, + d_col_tensor, + sfa_tensor, + sfd_row_tensor, + sfd_col_tensor, + amax_tensor, + norm_const_tensor_cute, + padded_offsets_tensor, + alpha_tensor, + prob_tensor, + dprob_tensor, + dbias_tensor, + max_active_clusters, + fake_stream, + options="--enable-tvm-ffi", + ) + + self._n = n + self._k = k + self._b_stride_size = b_stride_size + + cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator + cached_n = cutlass.Int32(self._n) + cached_k = cutlass.Int32(self._k) + cached_b_stride = cutlass.Int64(self._b_stride_size) + + def tensor_api( + a_tensor: torch.Tensor, + b_ptrs_device: torch.Tensor, + sfb_ptrs_device: torch.Tensor, + c_tensor: torch.Tensor, + d_row_tensor: torch.Tensor, + d_col_tensor: Optional[torch.Tensor], + sfa_tensor: torch.Tensor, + sfd_row_tensor: Optional[torch.Tensor], + sfd_col_tensor: Optional[torch.Tensor], + amax_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + dprob_tensor: torch.Tensor, + dbias_tensor: Optional[torch.Tensor], + stream: cuda.CUstream, + ) -> None: + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + b_ptrs_addr = int(b_ptrs_device.data_ptr()) + sfb_ptrs_addr = int(sfb_ptrs_device.data_ptr()) + + _compiled_kernel( + a_tensor, + b_ptrs_addr, + sfb_ptrs_addr, + cached_n, + cached_k, + cached_b_stride, + cached_workspace_ptr, + c_tensor, + d_row_tensor, + d_col_tensor, + sfa_tensor, + sfd_row_tensor, + sfd_col_tensor, + amax_tensor, + norm_const_tensor, + padded_offsets, + alpha_tensor, + prob_tensor, + dprob_tensor, + dbias_tensor, + stream, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + a_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_row_tensor: torch.Tensor, + d_col_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + dprob_tensor: torch.Tensor, + # Dense mode: + b_tensor: Optional[torch.Tensor] = None, + sfb_tensor: Optional[torch.Tensor] = None, + dbias_tensor: Optional[torch.Tensor] = None, + # Discrete mode: + b_ptrs: Optional[torch.Tensor] = None, + sfb_ptrs: Optional[torch.Tensor] = None, + # Optional: + sfd_row_tensor: Optional[torch.Tensor] = None, + sfd_col_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + """Execute the compiled kernel. + + For dense mode, supply ``b_tensor`` and ``sfb_tensor``. + For discrete mode, supply ``b_ptrs`` and ``sfb_ptrs``. + + :param a_tensor: Input A tensor (gradient input) + :param c_tensor: Forward activations input + :param d_row_tensor: Output D row tensor + :param d_col_tensor: Output D column tensor + :param sfa_tensor: Scale factor A + :param padded_offsets: End offset per expert after padding + :param alpha_tensor: Per-group alpha scaling factors + :param prob_tensor: Per-row probability (from forward) + :param dprob_tensor: Gradient of probability (output, must be zero-initialized) + :param b_tensor: (Dense) Input B tensor (weights) + :param sfb_tensor: (Dense) Scale factor B + :param dbias_tensor: Optional dbias output tensor. + :param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers + :param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers + :param sfd_row_tensor: Optional row scale factor D + :param sfd_col_tensor: Optional column scale factor D + :param amax_tensor: Optional amax tensor + :param norm_const_tensor: Optional normalization constant + :param current_stream: CUDA stream + """ + self._logger.debug("Entering execute") + current_stream = self._get_default_stream(current_stream) + + if a_tensor.shape[0] == 0: + self._logger.debug("execute: valid_m is zero, skipping kernel execution") + return + self._runtime_error_if( + self._compiled_kernel is None, + "Kernel not compiled; call compile() first", + ) + + self._logger.debug("Executing grouped GEMM dSReLU kernel") + if self._has_dbias: + self._value_error_if( + dbias_tensor is None, + "dbias_tensor is required when GroupedGemmDsreluSm100 is configured with sample_dbias", + ) + + if self.weight_mode == MoEWeightMode.DENSE: + self._compiled_kernel( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + stream=current_stream, + ) + else: + self._compiled_kernel( + a_tensor=a_tensor, + b_ptrs_device=b_ptrs, + sfb_ptrs_device=sfb_ptrs, + c_tensor=c_tensor, + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + stream=current_stream, + ) + + self._logger.debug("Execute completed") + + +_logger = logging.getLogger(__name__) +_cache_of_GroupedGemmDsreluSm100Objects = {} + + +def grouped_gemm_dsrelu_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: Optional[torch.Tensor] = None, + c_tensor: Optional[torch.Tensor] = None, + sfa_tensor: Optional[torch.Tensor] = None, + sfb_tensor: Optional[torch.Tensor] = None, + padded_offsets: Optional[torch.Tensor] = None, + alpha_tensor: Optional[torch.Tensor] = None, + prob_tensor: Optional[torch.Tensor] = None, + dprob_tensor: Optional[torch.Tensor] = None, + # generate_dbias is optional in both modes: + generate_dbias: bool = False, + # Discrete mode: + b_ptrs: Optional[torch.Tensor] = None, + sfb_ptrs: Optional[torch.Tensor] = None, + n: Optional[int] = None, + b_dtype: Optional[torch.dtype] = None, + b_major: str = "k", + # Common: + norm_const_tensor: Optional[torch.Tensor] = None, + acc_dtype: torch.dtype = torch.float32, + d_dtype: torch.dtype = torch.bfloat16, + cd_major: str = "n", + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + use_dynamic_sched: bool = False, + current_stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + """Convenience wrapper for grouped GEMM dSReLU backward operation. + + Auto-detects dense vs. discrete mode based on which weight arguments + are provided. + + Dense mode: provide ``b_tensor`` and ``sfb_tensor``. + Discrete mode: provide ``b_ptrs``, ``sfb_ptrs``, ``n``, and ``b_dtype``. + + Compiled kernels are cached for reuse when called with the same configuration. + + Args: + a_tensor: Input A tensor (valid_m, k, 1) -- gradient input + c_tensor: Forward activations input (valid_m, n_out, 1) + sfa_tensor: Scale factor A + padded_offsets: End offset per expert after padding + alpha_tensor: Per-group alpha scaling + prob_tensor: Per-row probability (from forward) + dprob_tensor: Gradient of probability (output, must be zero-initialized) + b_tensor: (Dense) Weight B tensor (n, k, l) + sfb_tensor: (Dense) Scale factor B + generate_dbias: Optional flag to allocate and return dbias output + b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers + sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers + n: (Discrete) B weight N dimension + b_dtype: (Discrete) B weight data type + b_major: (Discrete) B tensor major dimension ("k" or "n") + norm_const_tensor: Optional normalization constant + acc_dtype: Accumulator data type + d_dtype: Output D tensor data type + cd_major: CD major dimension (only "n" supported) + mma_tiler_mn: MMA tiler shape + cluster_shape_mn: Cluster shape + sf_vec_size: Scale factor vector size + vector_f32: Use vectorized f32 + m_aligned: M alignment (must be 256) + discrete_col_sfd: Generate discrete col-major scale factor tensor + use_dynamic_sched: Enable dynamic tile scheduling for load balancing + current_stream: CUDA stream + + Returns: + TupleDict with keys: d_row_tensor, d_col_tensor, dprob_tensor, + dbias_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor + """ + from cudnn.discrete_grouped_gemm.discrete_kernel_utils import _require_pointer_tensor + + is_dense = b_tensor is not None + is_discrete = b_ptrs is not None + + if is_dense and is_discrete: + raise ValueError("Provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs), not both") + if not is_dense and not is_discrete: + raise ValueError("Must provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs)") + + valid_m, k_physical, _ = a_tensor.shape + + if is_dense: + weight_mode = MoEWeightMode.DENSE + n_weight, _, l = b_tensor.shape + else: + weight_mode = MoEWeightMode.DISCRETE + _require_pointer_tensor(b_ptrs, "b_ptrs") + num_experts = b_ptrs.shape[0] + _require_pointer_tensor(sfb_ptrs, "sfb_ptrs", num_experts) + if n is None or b_dtype is None: + raise ValueError("n and b_dtype are required for discrete mode") + n_weight = n + k_logical = k_physical * 2 if b_dtype in (torch.float4_e2m1fn_x2, torch.uint8) else k_physical + b_shape = (n_weight, k_logical) + l = num_experts + + n_out = n_weight + + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Creating output tensors") + + if cd_major == "n": + d_row_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + d_col_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + else: + raise ValueError(f"cd_major must be 'n', got {cd_major}") + + sfd_row_tensor = None + sfd_col_tensor = None + amax_tensor = None + dbias_tensor = None + + if dprob_tensor is None: + dprob_tensor = torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device) + + if a_tensor.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]: + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Detected fp8 config, constructing sfd tensors") + + sf_dtype = sfa_tensor.dtype + mma_permute_order = (3, 4, 1, 5, 2, 0) + + sf_k_row = ceil_div(n_out, sf_vec_size) + mma_shape_row = (1, ceil_div(valid_m, 128), ceil_div(sf_k_row, 4), 32, 4, 4) + sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + sf_k_col = ceil_div(valid_m, sf_vec_size) + mma_shape_col = (1, ceil_div(n_out, 128), ceil_div(sf_k_col, 4), 32, 4, 4) + sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + if d_dtype in [torch.bfloat16, torch.float16]: + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Constructing amax_tensor") + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + if generate_dbias: + dbias_tensor = torch.zeros((l, n_out, 1), dtype=torch.bfloat16, device=a_tensor.device) + + if valid_m == 0: + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: valid_m is zero, skipping kernel execution") + return TupleDict( + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + amax_tensor=amax_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + ) + + # ---- Build cache key ---- + def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: + return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) + + def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype + + def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return None, stride_order(tensor), tensor.dtype + + def dynamic_m_tensor_signature( + tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = () + ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) + return static_shape_suffix, stride_signature, tensor.dtype + + use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + + if is_dense: + cache_key = ( + weight_mode, + use_full_dynamic, + a_tensor.shape[1:] if not use_full_dynamic else None, + b_tensor.shape[2] if use_full_dynamic else tuple(b_tensor.shape), + c_tensor.shape[1:] if not use_full_dynamic else None, + a_tensor.dtype, + b_tensor.dtype, + c_tensor.dtype, + stride_order(a_tensor), + stride_order(b_tensor), + stride_order(c_tensor), + *( + dynamic_tensor_signature(sfa_tensor) + if use_full_dynamic + else dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)) + ), + *tensor_signature(alpha_tensor), + *(dynamic_m_tensor_signature(prob_tensor, (1, 1)) if not use_full_dynamic else dynamic_tensor_signature(prob_tensor)), + *(dynamic_m_tensor_signature(dprob_tensor, (1, 1)) if not use_full_dynamic else dynamic_tensor_signature(dprob_tensor)), + *(dynamic_tensor_signature(dbias_tensor) if use_full_dynamic else tensor_signature(dbias_tensor)), + *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)), + norm_const_tensor.shape if norm_const_tensor is not None else None, + norm_const_tensor.stride() if norm_const_tensor is not None else None, + norm_const_tensor.dtype if norm_const_tensor is not None else None, + tuple(padded_offsets.shape), + tuple(padded_offsets.stride()), + padded_offsets.dtype, + acc_dtype, + d_dtype, + cd_major, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + discrete_col_sfd, + use_dynamic_sched, + ) + else: + cache_key = ( + weight_mode, + *dynamic_m_tensor_signature(a_tensor, tuple(a_tensor.shape[1:]), dynamic_stride_dims=(2,)), + b_shape, + b_dtype, + *dynamic_m_tensor_signature(c_tensor, tuple(c_tensor.shape[1:]), dynamic_stride_dims=(2,)), + *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)), + *tensor_signature(alpha_tensor), + *dynamic_m_tensor_signature(prob_tensor, (1, 1)), + *dynamic_m_tensor_signature(dprob_tensor, (1, 1)), + *tensor_signature(dbias_tensor), + *tensor_signature(norm_const_tensor), + tuple(b_ptrs.shape), + tuple(b_ptrs.stride()), + b_ptrs.dtype, + tuple(sfb_ptrs.shape), + tuple(sfb_ptrs.stride()), + sfb_ptrs.dtype, + tuple(padded_offsets.shape), + tuple(padded_offsets.stride()), + padded_offsets.dtype, + acc_dtype, + d_dtype, + cd_major, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + discrete_col_sfd, + use_dynamic_sched, + b_major, + num_experts, + ) + + # ---- Cache lookup or create + compile ---- + if cache_key in _cache_of_GroupedGemmDsreluSm100Objects: + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Using cached object") + api = _cache_of_GroupedGemmDsreluSm100Objects[cache_key] + else: + _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Creating new object") + if is_dense: + api = GroupedGemmDsreluSm100( + sample_a=a_tensor, + sample_c=c_tensor, + sample_d_row=d_row_tensor, + sample_d_col=d_col_tensor, + sample_sfa=sfa_tensor, + sample_padded_offsets=padded_offsets, + sample_alpha=alpha_tensor, + sample_prob=prob_tensor, + sample_dprob=dprob_tensor, + sample_dbias=dbias_tensor, + sample_b=b_tensor, + sample_sfb=sfb_tensor, + sample_sfd_row=sfd_row_tensor, + sample_sfd_col=sfd_col_tensor, + sample_amax=amax_tensor, + sample_norm_const=norm_const_tensor, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + discrete_col_sfd=discrete_col_sfd, + use_dynamic_sched=use_dynamic_sched, + ) + else: + api = GroupedGemmDsreluSm100( + sample_a=a_tensor, + sample_c=c_tensor, + sample_d_row=d_row_tensor, + sample_d_col=d_col_tensor, + sample_sfa=sfa_tensor, + sample_padded_offsets=padded_offsets, + sample_alpha=alpha_tensor, + sample_prob=prob_tensor, + sample_dprob=dprob_tensor, + sample_dbias=dbias_tensor, + num_experts=num_experts, + b_shape=b_shape, + b_dtype=b_dtype, + sample_sfd_row=sfd_row_tensor, + sample_sfd_col=sfd_col_tensor, + sample_amax=amax_tensor, + sample_norm_const=norm_const_tensor, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + discrete_col_sfd=discrete_col_sfd, + b_major=b_major, + use_dynamic_sched=use_dynamic_sched, + ) + + if not api.check_support(): + raise RuntimeError("Unsupported configuration") + api.compile() + _cache_of_GroupedGemmDsreluSm100Objects[cache_key] = api + + # ---- Execute ---- + if is_dense: + api.execute( + a_tensor=a_tensor, + c_tensor=c_tensor, + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + b_tensor=b_tensor, + sfb_tensor=sfb_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + current_stream=current_stream, + ) + else: + api.execute( + a_tensor=a_tensor, + c_tensor=c_tensor, + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + b_ptrs=b_ptrs, + sfb_ptrs=sfb_ptrs, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + current_stream=current_stream, + ) + + return TupleDict( + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + dprob_tensor=dprob_tensor, + dbias_tensor=dbias_tensor, + amax_tensor=amax_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py new file mode 100644 index 00000000..d6db56a7 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py @@ -0,0 +1,2249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +""" +MoE Block-Scaled Grouped GEMM Backward Kernel with DSRELU Support. + +Supports: + - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler) + - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout + - FP8/FP4 output quantization with row/column scale factors (SFD) + - Optional routing-probability (prob) fusion + - DSRELU backward epilogue: dA = relu(acc)·C·2·prob; dprob = Σ_n(relu(acc)²·C) + - NONE epilogue for testing (identity, no SReLU gate) + +EpilogueType.NONE: + out[m,n] = alpha · (SFA·A ★ SFB·B)[m,n] · prob[m] + +EpilogueType.DSRELU (backward through srelu): + out[m,n] = relu(alpha · acc[m,n]) · C[m,n] · 2 · prob[m] + dprob[m] += Σ_n( relu(alpha · acc[m,n])² · C[m,n] ) + +C is the upstream gradient (same shape as forward SReLU output D). +""" + +from enum import Enum +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass._mlir.dialects.nvvm import ReduxKind +from cutlass._mlir.dialects import llvm +from cutlass.cute.typing import Float32, Int32, AddressSpace + + +def atomic_add_bf16x2(ptr, val_fp32_lo, val_fp32_hi, *, loc=None, ip=None): + """Packed BF16x2 atomic reduction to global memory (same impl as dglu_dbias ref).""" + lo_ir = val_fp32_lo.ir_value(loc=loc, ip=ip) + hi_ir = val_fp32_hi.ir_value(loc=loc, ip=ip) + llvm.inline_asm( + None, + [ptr, hi_ir, lo_ir], + "{ .reg .b32 packed; cvt.rn.bf16x2.f32 packed, $1, $2; red.global.add.noftz.bf16x2 [$0], packed; }", + "l,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +from ..moe_persistent_scheduler import ( + MoEPersistentTileScheduler, + MoESchedulerParams, + MoEWorkTileInfo, +) +from ..moe_utils import ( + compute_expert_token_range, + MoEWeightMode, + TensormapWorkspace, + store_tma_desc, +) +from ..moe_sched_extension import ( + DiscreteWeightScaledGemmSchedExtension, + ContiguousAndConsistentGroupedGemmSchedExtension, +) +from ..moe_kernel_helpers import ( + fmin, + fmax, + warp_redux_sync, + atomic_add_float32, + atomic_max_float32, + compute_stages, + compute_grid, + can_implement, + amax_reduction_per_thread, + epilog_gmem_copy_and_partition, + get_dtype_rcp_limits, +) + + +class EpilogueType(Enum): + NONE = 0 + DSRELU = 1 + + +class BlockScaledMoEGroupedGemmQuantBwdKernel: + """Block-scaled grouped GEMM backward kernel with MoE tile scheduling and DSRELU. + + Computes the backward pass through the SReLU epilogue: + out[m,n] = relu(alpha * acc[m,n]) * C[m,n] * 2 * prob[m] (DSRELU) + dprob[m] += sum_n( relu(alpha * acc[m,n])^2 * C[m,n] ) (DSRELU) + or identity (NONE epilogue): + out[m,n] = alpha * acc[m,n] * prob[m] + + Supports both dense and discrete weight layouts, static and dynamic + scheduling, and quantized output with row/column scale factors. + + :param sf_vec_size: Scale-factor vector size (16 or 32). + :param acc_dtype: Accumulator data type (Float32). + :param use_2cta_instrs: Use 2-CTA MMA instructions. + :param mma_tiler_mn: MMA tile shape (M, N). + :param cluster_shape_mn: Cluster shape (M, N). + :param vectorized_f32: Use packed FP32 arithmetic. + :param generate_sfd: Generate output scale factors. + :param discrete_col_sfd: Use discrete column SFD layout. + :param expert_cnt: Number of experts. + :param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``. + :param use_dynamic_sched: Enable dynamic tile scheduling. + :param epilogue_type: Epilogue type (``EpilogueType.NONE`` or ``EpilogueType.DSRELU``). + """ + + FIX_PAD_SIZE = 256 + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + cd_major: str, + m_aligned: int, + ) -> bool: + return can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + acc_dtype, + d_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + cd_major, + m_aligned, + fix_pad_size=BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE, + ) + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vectorized_f32: bool, + generate_sfd: bool, + discrete_col_sfd: bool, + expert_cnt: int, + weight_mode: MoEWeightMode = MoEWeightMode.DENSE, + use_dynamic_sched: bool = False, + epilogue_type: int = EpilogueType.NONE.value, + generate_dbias: bool = False, + ): + mma_tile_m = mma_tiler_mn[0] + if self.FIX_PAD_SIZE % mma_tile_m != 0: + raise ValueError( + f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}). " f"Supported mma_tiler_mn[0] values: 128, 256." + ) + if expert_cnt > 1024: + raise ValueError("Expert count > 1024 is not supported.") + if not isinstance(weight_mode, MoEWeightMode): + raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}") + + self.sf_vec_size = sf_vec_size + self.expert_cnt = expert_cnt + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.epilog_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.epilog_load_tma_id = 6 # new warp: TMA-loads C subtiles into sC + self.sched_warp_id = 7 # shifted from 6 in forward kernel + self.threads_per_warp = 32 + all_warps = [ + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.epilog_load_tma_id, + self.sched_warp_id, + ] + warps_wo_sched = [ + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.epilog_load_tma_id, + ] + self.threads_per_cta = self.threads_per_warp * len(all_warps) + self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched) + + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.vectorized_f32 = vectorized_f32 + self.generate_sfd = generate_sfd + self.discrete_col_sfd = discrete_col_sfd + self.generate_dbias = generate_dbias + self.dbias_cross_warp_reduce = generate_dbias + + self.weight_mode = weight_mode + self.use_dynamic_sched = use_dynamic_sched + + self.epilogue_use_functor = False + self.epilogue_type = epilogue_type + + self.num_epilog_warps = len(self.epilog_warp_id) + + # ------------------------------------------------------------------ + # _setup_attributes + # ------------------------------------------------------------------ + + def _setup_attributes(self): + """Configure MMA / tile / stage / SMEM layouts from GEMM inputs.""" + + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + self.mma_tiler_d = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_d = ( + self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_d[1], + self.mma_tiler_d[2], + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.epi_tile = (128, 32) + + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_d_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.d_dtype, + self.d_layout, + self.sf_dtype, + self.sf_vec_size, + self.num_smem_capacity, + self.occupancy, + self.generate_sfd, + self.generate_dbias, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256 + self.epilogue_prefetch_more = False + + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + self.epi_tile_n_required = cute.size(self.epi_tile[1]) + self.iter_acc_early_release_in_epilogue = (self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1 + + # ------------------------------------------------------------------ + # _compute_stages + # ------------------------------------------------------------------ + + @staticmethod + def _compute_stages( + tiled_mma, + mma_tiler_mnk, + a_dtype, + b_dtype, + epi_tile, + c_dtype, + c_layout, + d_dtype, + d_layout, + sf_dtype, + sf_vec_size, + num_smem_capacity, + occupancy, + generate_sfd, + generate_dbias=False, + ): + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + # Always 2 c stages for c_pipeline double-buffering + num_c_stage = 2 + num_d_stage = 2 if generate_sfd else 1 + num_tile_stage = 2 + + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1) + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + sinfo_bytes = 4 * 4 * num_tile_stage + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1) + amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0 + # dBias SMEM transpose buffer: 128 M rows × epi_tile_n N cols × FP32 + # (same formula as SharedStorage.sDbias field). Must be subtracted from + # the AB-stage budget or num_ab_stage gets over-allocated and the module + # fails to serialize. + dbias_bytes = 128 * cute.size(epi_tile[1]) * 4 if generate_dbias else 0 + + epi_bytes = c_bytes + d_bytes + amax_bytes + dbias_bytes + num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage + + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage + + # ------------------------------------------------------------------ + # Workspace helpers + # ------------------------------------------------------------------ + + def get_desc_workspace_bytes(self) -> int: + if self.weight_mode == MoEWeightMode.DISCRETE: + from ..moe_utils import DiscreteWeightTensormapConstructor + + return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt) + return 0 + + def get_workspace_bytes(self) -> int: + desc_workspace_bytes = self.get_desc_workspace_bytes() + dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0 + return desc_workspace_bytes + dynamic_sched_bytes + + @cute.jit + def _get_sched_counter_ptr(self, workspace_ptr): + counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes() + return cute.make_ptr( + cutlass.Int32, + counter_addr, + AddressSpace.gmem, + assumed_align=4, + ) + + # ------------------------------------------------------------------ + # helper_kernel: pre-main-kernel initialization + # ------------------------------------------------------------------ + + @cute.kernel + def helper_kernel( + self, + ptrs_b: cute.Pointer, + ptrs_sfb: cute.Pointer, + n: Int32, + k: Int32, + b_stride_size: cutlass.Int64, + b_major_mode: cutlass.Constexpr, + workspace_ptr, + tiled_mma_arg: cute.TiledMma, + tiled_mma_sfb_arg: cute.TiledMma, + b_smem_layout_arg, + sfb_smem_layout_arg, + cluster_layout_vmnk_shape_arg: cutlass.Constexpr, + cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr, + ): + """Pre-main-kernel initialization (discrete TMA desc + dynamic sched counter).""" + expert_idx = cute.arch.block_idx()[0] + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id) + sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id) + + b_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,)) + ) + sfb_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,)) + ) + + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + stride_n = b_stride_size + stride_k = c1_64 + else: + stride_n = c1_64 + stride_k = b_stride_size + + b_ptr_val = b_ptr_tensor[expert_idx] + b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem) + b_tensor_i = cute.make_tensor( + b_ptr, + cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)), + ) + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + b_tma_op_arg, + b_tensor_i, + b_smem_layout_arg, + self.mma_tiler, + tiled_mma_arg, + cluster_layout_vmnk_shape_arg, + ) + workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx)) + + sfb_ptr_val = sfb_ptr_tensor[expert_idx] + sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb_tensor_i = cute.make_tensor(sfb_ptr, sfb_layout) + tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( + sfb_tma_op_arg, + sfb_tensor_i, + sfb_smem_layout_arg, + self.mma_tiler_sfb, + tiled_mma_sfb_arg, + cluster_layout_sfb_vmnk_shape_arg, + internal_type=cutlass.Uint64, + ) + store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx)) + + if cutlass.const_expr(self.use_dynamic_sched): + if expert_idx == cutlass.Int32(0): + sched_counter = cute.make_tensor( + self._get_sched_counter_ptr(workspace_ptr), + cute.make_layout(1), + ) + sched_counter[0] = cutlass.Int32(0) + + # ------------------------------------------------------------------ + # __call__ + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[] + sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[] + n: Int32, # Ignored for dense mode + k: Int32, # Ignored for dense mode + b_stride_size: cutlass.Int64, # Ignored for dense mode + b_major_mode: cutlass.Constexpr, # Ignored for dense mode + workspace_ptr, + c: cute.Tensor, # INPUT: upstream gradient (forward output), shape (M,N,L) + d: cute.Tensor, # OUTPUT: dA gradient + d_col: Optional[cute.Tensor], + sfa: cute.Tensor, + sfd_row_tensor: Optional[cute.Tensor], + sfd_col_tensor: Optional[cute.Tensor], + amax_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + prob: cute.Tensor, + dprob: Optional[cute.Tensor], # OUTPUT: dL/d(prob), shape (M,1,1), Float32 + dbias_tensor: Optional[cute.Tensor], # OUTPUT: dL/d(bias), shape (L,N), BF16 (accumulated via atomic) + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Execute the backward GEMM. + + Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L). + Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[] + arrays of per-expert base addresses. + + ``c`` is the upstream gradient tensor (same shape as forward output D), + loaded G2S by the specialized epilog_load_tma warp and consumed by epilog warps. + + ``dprob`` (optional) receives per-token routing probability gradients, + accumulated via atomic FP32 add. + """ + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = a.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.d_dtype: Type[cutlass.Numeric] = d.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + self.d_layout = utils.LayoutEnum.from_tensor(d) + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + else: + self.b_major_mode = b_major_mode + + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"A/B dtype must match: {self.a_dtype} != {self.b_dtype}") + + self._setup_attributes() + + # ---- SFA layout ---- + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size) + sfa = cute.make_tensor(sfa.iterator, sfa_layout) + + # ---- B / SFB setup (mode-dependent) ---- + b_from_call_arg = b + sfb_from_call_arg = sfb + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) + sfb = cute.make_tensor(sfb.iterator, sfb_layout) + else: + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + b_template_stride = (b_stride_size, c1_64, c0) + else: + b_template_stride = (c1_64, b_stride_size, c0) + b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride) + b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16) + b = cute.make_tensor(b_ptr_typed, b_template_layout) + + sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout) + + # ---- SFD setup ---- + self.generate_sfd = sfd_row_tensor is not None and norm_const_tensor is not None + if cutlass.const_expr(self.generate_sfd == False): + self.discrete_col_sfd = False + if cutlass.const_expr(self.generate_sfd): + sfd_row_layout = blockscaled_utils.tile_atom_to_shape_SF(d.shape, self.sf_vec_size) + sfd_row_tensor = cute.make_tensor(sfd_row_tensor.iterator, sfd_row_layout) + sfd_col_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout, + d.shape, + (1, 2, 3), + ) + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_layout = sfd_row_layout + sfd_col_tensor = cute.make_tensor(sfd_col_tensor.iterator, sfd_col_layout) + + self.generate_amax = amax_tensor is not None + # self.generate_dbias was set in __init__ (needed at _compute_stages time); + # assert consistency with the dbias_tensor passed here. + assert self.generate_dbias == (dbias_tensor is not None), "dbias_tensor presence must match generate_dbias set at construction" + + # ---- TMA atoms ---- + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + # C is an INPUT (upstream gradient): use G2S TMA atom + c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + self.tma_c_load_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + c, + c_smem_layout, + self.epi_tile, + ) + + # D is the output (dA gradient): S2G TMA atom + d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d, + d_smem_layout, + self.epi_tile, + ) + tma_atom_d_col, tma_tensor_d_col = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d_col, + d_smem_layout, + self.epi_tile, + ) + + # ---- Helper kernel ---- + _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched) + if cutlass.const_expr(_need_helper): + _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1 + _helper_args = ( + b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0), + b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode, + workspace_ptr, + tiled_mma, + tiled_mma_sfb, + b_smem_layout, + sfb_smem_layout, + self.cluster_layout_vmnk.shape, + self.cluster_layout_sfb_vmnk.shape, + ) + self.helper_kernel(*_helper_args).launch( + grid=(_helper_grid_x, 1, 1), + block=(1, 1, 1), + stream=stream, + min_blocks_per_mp=1, + ) + + # ---- Grid computation via MoE scheduler ---- + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + b_n, b_k, b_l = cute.shape(b) + sched_expert_shape = (self.expert_cnt, b_n, b_k) + else: + sched_expert_shape = (self.expert_cnt, n, k) + + sched_params = MoESchedulerParams( + scenario="2Dx3D", + expert_shape=sched_expert_shape, + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + use_dynamic_sched=self.use_dynamic_sched, + ) + self.sched_params, grid = compute_grid(sched_params, max_active_clusters, self.use_2cta_instrs) + + self.buffer_align_bytes = 1024 + + # ---- Shared storage ---- + sD_col_size = cute.cosize(self.d_smem_layout_staged.outer) if self.generate_sfd else 0 + SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched) + + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2] + scheduler: SchedulerStorage + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # sC: staging buffer for loading C (upstream gradient) from global memory + sC: cute.struct.Align[ + cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sD_col: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, sD_col_size], + self.buffer_align_bytes, + ] + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + 4, + ] + # dBias SMEM transpose buffer: 128 × epi_tile_n FP32 (single-vec dA, vs + # reference's 128 × epi_tile_n × 2 that held both d1+d2 for DGLU). + sDbias: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Float32, + 128 * self.epi_tile[1] if self.generate_dbias else 1, + ], + 128 if self.generate_dbias else 4, + ] + + self.shared_storage = SharedStorage + + # ---- Launch ---- + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + tma_atom_d_col, + tma_tensor_d_col, + sfd_row_tensor, + sfd_col_tensor, + norm_const_tensor, + amax_tensor, + padded_offsets, + alpha, + prob, + dprob, + dbias_tensor, + workspace_ptr, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.epi_tile, + self.sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + max_number_threads=[self.threads_per_cta, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # ------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------ + + def mainloop_s2t_copy_and_partition(self, sSF, tSF): + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + @cute.jit + def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem): + warp_amax = warp_redux_sync( + value=amax_fp32, + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + if cute.arch.lane_idx() == 0: + amax_smem[warp_idx] = cutlass.Float32(warp_amax) + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) + for i in cutlass.range(self.num_epilog_warps): + warp_amax_val = amax_smem[i] + block_amax = cute.arch.fmax(block_amax, warp_amax_val) + _ = atomic_max_float32(ptr=amax_gmem, value=block_amax) + + @cute.jit + def dbias_reduction( + self, + dA_vec, + warp_idx, + sDbias, + dbias_gmem_2d, + expert_idx, + n_base, + dbias_n_total, + ) -> None: + """Sum dA across M within this subtile and atomic-add to dbias[expert, n]. + + Adapted from moe_blockscaled_grouped_gemm_dglu_dbias.py's dbias_reduction, + which handled two interleaved vectors (d1, d2). Here we have a single + vector dA, so 16 lanes handle the 32 N columns (2 cols each via bf16x2). + Lanes 16-31 remain idle in the atomic phase but still contribute to the + SMEM write below. + + SMEM layout: sDbias[n, lane_idx, warp_local] with shape (epi_n, 32, num_warps), + holding fp32 values. Each thread writes its dA_vec (epi_n values) at its + (lane, warp) slot — M is spread across (lane_idx × num_warps). + """ + epi_n = self.epi_tile[1] + lane_idx = cute.arch.lane_idx() + warp_local = warp_idx - self.epilog_warp_id[0] + + for n in cutlass.range(epi_n, unroll_full=True): + sDbias[(n, lane_idx, warp_local)] = dA_vec[n] + + self.epilog_sync_barrier.arrive_and_wait() + + # 16 active lanes, each handles 2 consecutive N columns (col_a, col_b) + # → 16 lanes × 2 cols = 32 cols = epi_n. + col_a = 2 * lane_idx if lane_idx < 16 else 0 + col_b = col_a + 1 + + copy_128bit_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128) + # Each warp owns a contiguous slice of sDbias of size epi_n * 32 + warp_base_ptr = sDbias.iterator + warp_local * epi_n * 32 + swizzle_a = ((col_a >> 1) & 0x7) << 2 + swizzle_b = ((col_b >> 1) & 0x7) << 2 + + sum_a = cutlass.Float32(0.0) + sum_b = cutlass.Float32(0.0) + rDst_a = cute.make_rmem_tensor(cute.make_layout((4,)), cutlass.Float32) + rDst_b = cute.make_rmem_tensor(cute.make_layout((4,)), cutlass.Float32) + # 8 groups × 4 M = 32 M rows per warp + for g in cutlass.range(8, unroll_full=True): + m_base = g * 4 + sw_offset_a = col_a * 32 + (m_base ^ swizzle_a) + sSrc_a = cute.make_tensor(warp_base_ptr + sw_offset_a, cute.make_layout((4,))) + cute.copy_atom_call(copy_128bit_atom, sSrc_a, rDst_a) + + sw_offset_b = col_b * 32 + (m_base ^ swizzle_b) + sSrc_b = cute.make_tensor(warp_base_ptr + sw_offset_b, cute.make_layout((4,))) + cute.copy_atom_call(copy_128bit_atom, sSrc_b, rDst_b) + + for i in cutlass.range(4, unroll_full=True): + sum_a = sum_a + rDst_a[i] + sum_b = sum_b + rDst_b[i] + + n_offset = n_base + 2 * lane_idx # lanes 0..15 cover N offsets 0,2,...,30 + + if cutlass.const_expr(self.dbias_cross_warp_reduce): + # Cross-warp reduction: each warp writes its partial (sum_a, sum_b) to + # a per-warp slot in SMEM, then warp 0 sums across all warps. + reduce_base = sDbias.iterator + copy_64bit_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=64) + + self.epilog_sync_barrier.arrive_and_wait() + if lane_idx < 16: + rSrc_partial = cute.make_rmem_tensor(cute.make_layout((2,)), cutlass.Float32) + rSrc_partial[0] = sum_a + rSrc_partial[1] = sum_b + sDst_partial = cute.make_tensor(reduce_base + warp_local * 32 + lane_idx * 2, cute.make_layout((2,))) + cute.copy_atom_call(copy_64bit_atom, rSrc_partial, sDst_partial) + self.epilog_sync_barrier.arrive_and_wait() + + if warp_idx == self.epilog_warp_id[0] and lane_idx < 16: + cta_sum_a = cutlass.Float32(0.0) + cta_sum_b = cutlass.Float32(0.0) + rDst_w = cute.make_rmem_tensor(cute.make_layout((2,)), cutlass.Float32) + for w in cutlass.range(self.num_epilog_warps): + sSrc_w = cute.make_tensor(reduce_base + w * 32 + lane_idx * 2, cute.make_layout((2,))) + cute.copy_atom_call(copy_64bit_atom, sSrc_w, rDst_w) + cta_sum_a = cta_sum_a + rDst_w[0] + cta_sum_b = cta_sum_b + rDst_w[1] + if n_offset < dbias_n_total: + gmem_ptr = dbias_gmem_2d[(expert_idx, n_offset, None)].iterator.llvm_ptr + atomic_add_bf16x2(gmem_ptr, cta_sum_a, cta_sum_b) + else: + if lane_idx < 16 and n_offset < dbias_n_total: + gmem_ptr = dbias_gmem_2d[(expert_idx, n_offset, None)].iterator.llvm_ptr + atomic_add_bf16x2(gmem_ptr, sum_a, sum_b) + + @cute.jit + def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD): + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + pvscale_f32x4 = cute.make_rmem_tensor(4, cutlass.Float32) + sfd_f8x4 = cute.make_rmem_tensor(4, self.sf_dtype) + tmp_f32 = abs_acc_frg[None, 0].reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) * rcp_limit * norm_const + if tile_idx == 0: + pvscale[0] = tmp_f32 + elif tile_idx == 1: + pvscale[1] = tmp_f32 + elif tile_idx == 2: + pvscale[2] = tmp_f32 + elif tile_idx == 3: + pvscale[3] = tmp_f32 + pvscale_f32x4[0] = tmp_f32 + sfd_f8x4.store(pvscale_f32x4.load().to(self.sf_dtype)) + pvscale_f32x4.store(sfd_f8x4.load().to(cutlass.Float32)) + qpvscale_up = pvscale_f32x4[0] + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale = norm_const * cute.arch.rcp_approx(qpvscale_up) + acc_scale = fmin(acc_scale, fp32_max, nan=True) + if cutlass.const_expr(self.vectorized_f32): + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(0, self.sf_vec_size, 2): + vec[ei], vec[ei + 1] = cute.arch.mul_packed_f32x2( + (vec[ei], vec[ei + 1]), + (acc_scale, acc_scale), + rnd="rn", + ftz=False, + ) + else: + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(self.sf_vec_size): + vec[ei] = vec[ei] * acc_scale + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def quant_sfd_col(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD): + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + tmp_f32 = cutlass.Float32(0.0) + for vi in cutlass.range_constexpr(acc_frg.shape[0]): + max_value_original = ( + cutlass.Float32( + warp_redux_sync( + value=acc_frg[vi, 0], + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + ) + * rcp_limit + * norm_const + ) + max_value_vec = cute.full(4, max_value_original, dtype=cutlass.Float32) + max_value_vec_f8 = max_value_vec.to(cutlass.Float8E8M0FNU) + max_value_vec_f32_chunked = max_value_vec_f8.to(cutlass.Float32) + max_value = max_value_vec_f32_chunked[0] + tidx = cute.arch.thread_idx()[0] + if tidx % 32 == vi: + tmp_f32 = max_value + acc_scale_col = cutlass.Float32(0.0) + if max_value_vec_f32_chunked[0] == 0.000000: + acc_scale_col = cutlass.Float32(0.0) + else: + acc_scale_col = norm_const * cute.arch.rcp_approx(max_value_vec_f32_chunked[0]) + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale_col = fmin(acc_scale_col, fp32_max) + tTR_rAcc_frg[vi] = tTR_rAcc_frg[vi] * acc_scale_col + pvscale[None, None, tile_idx][0] = tmp_f32 + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def tile_info_to_mn_idx(self, tile_info: cute.Tensor): + m_idx = tile_info[1] * cute.size(self.cta_tile_shape_mnk[0]) + n_idx = tile_info[2] * cute.size(self.cta_tile_shape_mnk[1]) + return m_idx, n_idx + + @cute.jit + def create_and_partition_new_SFDCol(self, tile_info, mSFDCol_mnl, padded_offsets): + m_idx, n_idx = self.tile_info_to_mn_idx(tile_info) + expert_idx = tile_info[0] + cumsum_tokens, tokens_this_group = compute_expert_token_range(padded_offsets, expert_idx) + n_total = cute.size(mSFDCol_mnl.shape[1]) + + sf_tile_idx_begin = cumsum_tokens // cute.size(mSFDCol_mnl.shape[0][0]) + mSFDCol_mnl_new_ptr = mSFDCol_mnl[(None, sf_tile_idx_begin), None, 0].iterator + + sfd_col_quant_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout, + (tokens_this_group, n_total, mSFDCol_mnl.shape[2]), + (1, 2, 3), + ) + regPerSubtile = 4 + sfd_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile)) + mSFDCol_mnl_new = cute.make_tensor(mSFDCol_mnl_new_ptr, sfd_col_quant_layout) + gSFDCol_mnl_new = cute.local_tile(mSFDCol_mnl_new, sfd_tile, (None, None, None)) + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col_quant = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gSFDCol_mnl_new.element_type, + num_bits_per_copy=8, + ) + tiled_copy_sfd_col_quant = cute.make_tiled_copy_tv( + copy_atom_sfd_col_quant, + thr_layout, + val_layout, + ) + tidx = cute.arch.thread_idx()[0] + thr_copy_sfd_col_quant = tiled_copy_sfd_col_quant.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col_quant.partition_D(cute.filter_zeros(gSFDCol_mnl_new)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + return tCgSFDCol_mnl + + def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs): + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.d_layout, + self.d_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi) + tTR_rAcc = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition_load(self, tiled_copy_t2r, tTR_rC, tidx, sC): + """Partition sC (smem) for S2R copy — used by epilog warps consuming the c_pipeline.""" + copy_atom_s2r = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r) + thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) + tRS_sC = thr_copy_s2r.partition_D(sC) + tRS_rC = tiled_copy_s2r.retile(tTR_rC) + return tiled_copy_s2r, tRS_rC, tRS_sC + + def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rD, tidx, sD): + copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) + tRS_rD = tiled_copy_r2s.retile(tTR_rD) + return tiled_copy_r2s, tRS_rD, tRS_sD + + # ------------------------------------------------------------------ + # GPU device kernel + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, # G2S atom for loading upstream gradient C + mC_mnl: cute.Tensor, # upstream gradient tensor (input) + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, # dA output tensor + tma_atom_d_col: cute.CopyAtom, + mD_col_mnl: cute.Tensor, + mSFDRow_mnl: Optional[cute.Tensor], + mSFDCol_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + mAmax_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + prob: cute.Tensor, + dprob: Optional[cute.Tensor], # dL/d(prob) output, shape (M,1,1), Float32 + mDbias_tensor: Optional[cute.Tensor], # dL/d(bias) output, shape (L,N), BF16 + workspace_ptr, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + sched_params: MoESchedulerParams, + ): + """GPU device kernel for persistent MoE grouped GEMM backward with DSRELU.""" + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_sfa) + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_d) + if cutlass.const_expr(self.generate_sfd): + cpasync.prefetch_descriptor(tma_atom_d_col) + + if warp_idx == self.epilog_load_tma_id: + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + total_token = padded_offsets[self.expert_cnt - 1] + + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + tidx, _, _ = cute.arch.thread_idx() + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sched_storage = storage.scheduler + + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # C pipeline: epilog_load_tma warp produces, epilog warps consume + c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.c_full_mbar_ptr.data_ptr(), + num_stages=self.num_c_stage, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.tma_c_load_bytes, + ) + + tile_info_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp * 1) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_wo_sched) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=sched_storage.tile_info_mbar.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + scheduler = MoEPersistentTileScheduler.create( + sched_params, + padded_offsets, + cute.arch.block_idx(), + cute.arch.grid_dim(), + counter_ptr=self._get_sched_counter_ptr(workspace_ptr), + sched_storage=sched_storage, + ) + scheduler.internal_init() + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + sD_col = sD + if cutlass.const_expr(self.generate_sfd): + sD_col = storage.sD_col.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + sDbias = None + if cutlass.const_expr(self.generate_dbias): + sDbias = storage.sDbias.get_tensor( + cute.make_layout( + (self.epi_tile[1], 32, self.num_epilog_warps), + stride=(32, 1, self.epi_tile[1] * 32), + ) + ) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = sched_storage.sInfo.get_tensor(info_layout) + + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + thr_mma_common = tiled_mma.get_slice(0) + tCsA_common = thr_mma_common.partition_A(sA) + tCsB_common = thr_mma_common.partition_B(sB) + tCsA_common = cute.filter_zeros(tCsA_common) + tCsB_common = cute.filter_zeros(tCsB_common) + + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped)) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + if total_token <= 0: + cute.arch.nvvm.exit() + + # ============================================================== + # Scheduler warp (MoE Persistent Tile Scheduler) + # ============================================================== + if warp_idx == self.sched_warp_id: + work_tile_info = scheduler.initial_work_tile_info() + tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage) + while work_tile_info.is_valid_tile: + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx + sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx + sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx + sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + work_tile_info = scheduler.advance_to_next_work() + + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1) + sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # ============================================================== + # DMA / TMA load warp (A, B, SFA, SFB) + # ============================================================== + if warp_idx == self.tma_warp_id: + ext = self._make_extension(workspace_ptr) + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + k_tile_cnt = work_tile_info.k_tile_cnt + ext.update_expert_info(padded_offsets, work_tile_info.expert_idx) + + real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info) + real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info) + real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info) + real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info) + + gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None)) + + thr_mma_dma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb_dma = tiled_mma_sfb.get_slice(mma_tile_coord_v) + tCgA = thr_mma_dma.partition_A(gA_mkl) + tCgB = thr_mma_dma.partition_B(gB_nkl) + tCgSFA = thr_mma_dma.partition_A(gSFA_mkl) + tCgSFB = thr_mma_sfb_dma.partition_B(gSFB_nkl) + + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape) + mma_tile_coord_n = work_tile_info.tile_n_idx + tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)] + tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)] + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)] + slice_n = mma_tile_coord_n + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_n // 2 + tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)] + + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] + tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask) + cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b) + cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask) + cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb) + + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + ab_pipeline.producer_tail(ab_producer_state) + + # ============================================================== + # MMA warp + # ============================================================== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acd_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + k_tile_cnt = tile_info[3] + + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + acd_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + mma_tile_coord_mnl = ( + tile_info[1] // cute.size(tiled_mma.thr_id.shape), + tile_info[2], + tile_info[0], + ) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acd_producer_state.phase ^ 1 + else: + acc_stage_index = acd_producer_state.index + + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + if is_leader_cta: + acc_pipeline.producer_acquire(acd_producer_state, peek_acc_empty_status) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + if is_leader_cta: + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + s2t_stage_coord = (None, None, None, None, ab_consumer_state.index) + cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t[s2t_stage_coord], tCtSFA_compact_s2t) + cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t[s2t_stage_coord], tCtSFB_compact_s2t) + + num_kblocks = cute.size(tCrA, mode=[2]) + ab_consumer_state_next = ab_consumer_state.clone() + ab_consumer_state_next.advance() + if ab_consumer_state_next.count < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next) + + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator) + tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator) + cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state = ab_consumer_state_next + + if is_leader_cta: + acc_pipeline.producer_commit(acd_producer_state) + + acd_producer_state.advance() + if acd_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + acc_pipeline.producer_tail(acd_producer_state) + + # ============================================================== + # Epilog load TMA warp: streams upstream gradient C into sC + # Must always consume from tile_info_pipeline (it is part of the + # 224-thread consumer group), but only loads C when DSRELU. + # ============================================================== + if warp_idx == self.epilog_load_tma_id: + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + c_ext = self._make_extension(workspace_ptr) + c_pipeline_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_c_stage) + # Toggled per-tile to mirror the epilog consumer's + # reverse_subtile = (acc_consumer_state.phase == 0). Starts True so + # tile 0 (consumer phase=0) gets C in reverse order. + c_is_reverse = cutlass.Boolean(True) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + c_work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + c_ext.update_expert_info(padded_offsets, c_work_tile_info.expert_idx) + + real_c, _ = c_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, c_work_tile_info) + + thr_mma_c = tiled_mma.get_slice(mma_tile_coord_v) + gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) + tCgC_loop = thr_mma_c.partition_C(gC_mnl_loop) + + _, bGS_sC, bGS_gC_partitioned = epilog_gmem_copy_and_partition( + tidx, + tma_atom_c, + tCgC_loop, + epi_tile, + sC, + ) + + epi_mma_tile_coord = ( + c_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + c_work_tile_info.tile_n_idx, + 0, + ) + bGS_gC = bGS_gC_partitioned[(None, None, None, *epi_mma_tile_coord)] + bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC)) + subtile_cnt = cute.size(bGS_gC.shape, mode=[1]) + + # Toggle per tile to mirror consumer's phase-based reverse_subtile. + c_reverse_this_tile = cutlass.Boolean(False) + if cutlass.const_expr(self.overlapping_accum): + c_reverse_this_tile = c_is_reverse + c_is_reverse = not c_is_reverse + for subtile_idx in cutlass.range(subtile_cnt, unroll=1): + real_c_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if c_reverse_this_tile: + real_c_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx + c_pipeline.producer_acquire(c_pipeline_producer_state) + cute.copy( + tma_atom_c, + bGS_gC[(None, real_c_subtile_idx)], + bGS_sC[(None, c_pipeline_producer_state.index)], + tma_bar_ptr=c_pipeline.producer_get_barrier(c_pipeline_producer_state), + ) + c_pipeline_producer_state.advance() + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + c_pipeline.producer_tail(c_pipeline_producer_state) + + # ============================================================== + # Epilogue warps: compute dA (and dprob for DSRELU) + # ============================================================== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v) + + gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape) + + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCgD_shape, + epi_tile, + use_2cta_instrs, + ) + + # Register buffer for reading C from sC (c_pipeline consumer) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_s2r, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition_load( + tiled_copy_t2r, + tTR_rC, + epi_tidx, + sC, + ) + + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, + tTR_rD, + epi_tidx, + sD, + ) + + if cutlass.const_expr(self.generate_sfd): + tTR_rD_col = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD_col, tRS_sD_col = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, + tTR_rD_col, + epi_tidx, + sD_col, + ) + norm_const = norm_const_tensor[0] + regPerSubtile = 4 + sfd_row_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile)) + gSFDRow_mnl = cute.local_tile(mSFDRow_mnl, sfd_row_tile, (None, None, None)) + thr_copy_t2r_local = tiled_copy_t2r.get_slice(tidx) + tCgSFDRow_mnl = thr_copy_t2r_local.partition_D(gSFDRow_mnl) + tCgSFDRow_mnl = cute.filter_zeros(tCgSFDRow_mnl) + tCrSFDRow = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype) + tCrSFDRow_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + d_rcp_limits = get_dtype_rcp_limits(self.d_dtype) + + sfd_col_tile = sfd_row_tile + gSFDCol_mnl = cute.local_tile(mSFDCol_mnl, sfd_col_tile, (None, None, None)) + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gSFDCol_mnl.element_type, num_bits_per_copy=8) + tiled_copy_sfd_col = cute.make_tiled_copy_tv(copy_atom_sfd_col, thr_layout, val_layout) + thr_copy_sfd_col = tiled_copy_sfd_col.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col.partition_D(cute.filter_zeros(gSFDCol_mnl)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + tCrSFDCol = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].shape, self.sf_dtype) + tCrSFDCol_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + + epi_ext = self._make_extension(workspace_ptr) + + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_pipeline_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_c_stage) + d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id)) + d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + epi_work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + expert_idx = epi_work_tile_info.expert_idx + epi_ext.update_expert_info(padded_offsets, expert_idx) + + alpha_val = alpha[expert_idx] + + real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info) + + thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v) + + gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop) + _, bSG_sD, bSG_gD_partitioned = epilog_gmem_copy_and_partition( + epi_tidx, + tma_atom_d, + tCgD_loop, + epi_tile, + sD, + ) + + real_d_col = real_d + if cutlass.const_expr(self.generate_sfd): + real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info) + + gD_col_mnl_loop = gD_mnl_loop + tCgD_col_loop = tCgD_loop + if cutlass.const_expr(self.generate_sfd): + gD_col_mnl_loop = cute.local_tile(real_d_col, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_col_loop = thr_mma_epi_loop.partition_C(gD_col_mnl_loop) + _, bSG_sD_col, bSG_gD_col_partitioned = epilog_gmem_copy_and_partition( + epi_tidx, + tma_atom_d_col, + tCgD_col_loop, + epi_tile, + sD_col, + ) + + epi_mma_tile_coord = ( + epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + epi_work_tile_info.tile_n_idx, + 0, + ) + bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)] + bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)] + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col)) + + if cutlass.const_expr(self.generate_sfd): + tCgSFDRow_mn = tCgSFDRow_mnl[(None, None, None, None, None, 0)] + tCgSFDCol_mnl_new = tCgSFDCol_mnl + if cutlass.const_expr(self.discrete_col_sfd): + tCgSFDCol_mnl_new = self.create_and_partition_new_SFDCol(tile_info, mSFDCol_mnl, padded_offsets) + tCgSFDCol_mn = tCgSFDCol_mnl_new[(None, None, None, None, None, 0)] + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + mPosition = epi_work_tile_info.tile_m_idx * self.cta_tile_shape_mnk[0] + tidx + real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info) + mProb = real_prob[mPosition, 0, 0] + + # Accumulator for dprob: summed over N subtiles, written once per tile + dProbVal = cutlass.Float32(0.0) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = acc_consumer_state.index + + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + acc_pipeline.consumer_wait(acc_consumer_state) + + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + + for subtile_idx in cutlass.range(0, subtile_cnt, 1, unroll=1): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx + + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Apply alpha scaling + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc[i], tTR_rAcc[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val) + + acc_vec = tTR_rAcc.load() + + # Consume one C subtile from c_pipeline + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + c_pipeline.consumer_wait(c_pipeline_consumer_state) + cute.copy( + tiled_copy_s2r, + tRS_sC[(None, None, None, c_pipeline_consumer_state.index)], + tRS_rC, + ) + cute.arch.fence_proxy("async.shared", space="cta") + c_pipeline.consumer_release(c_pipeline_consumer_state) + c_pipeline_consumer_state.advance() + c_vec = tiled_copy_s2r.retile(tRS_rC) + + tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype) + + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + # DSRELU backward: dA = relu(acc) * C * 2 * prob + acc_relu = cute.where(acc_vec > 0, acc_vec, cute.full_like(acc_vec, 0)) + tRelu = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype) + tRelu.store(acc_relu) + probx2 = 2 * mProb + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tRelu[i], tRelu[i + 1]), + (c_vec[i].to(self.acc_dtype), c_vec[i + 1].to(self.acc_dtype)), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (cutlass.Float32(probx2), cutlass.Float32(probx2)), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tCompute[i] = tRelu[i] * c_vec[i].to(self.acc_dtype) * cutlass.Float32(probx2) + + # Accumulate dprob: dprob[m] += sum_n(relu(acc)^2 * C) + if cutlass.const_expr(dprob is not None): + tDprob = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype) + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tDprob[i], tDprob[i + 1] = cute.arch.mul_packed_f32x2( + (tRelu[i], tRelu[i + 1]), + (tRelu[i], tRelu[i + 1]), + rnd="rn", + ftz=False, + ) + tDprob[i], tDprob[i + 1] = cute.arch.mul_packed_f32x2( + (tDprob[i], tDprob[i + 1]), + (c_vec[i].to(self.acc_dtype), c_vec[i + 1].to(self.acc_dtype)), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tDprob[i] = tRelu[i] * tRelu[i] * c_vec[i].to(self.acc_dtype) + dProbVal = dProbVal + tDprob.load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0) + else: + # NONE epilogue: dA = acc * prob (identity, no SReLU gate) + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (acc_vec[i], acc_vec[i + 1]), + (mProb, mProb), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tCompute[i] = acc_vec[i] * mProb + + # dbias reduction: sum dA across M for each N, atomic-add to dbias[expert, n] + if cutlass.const_expr(self.generate_dbias): + dA_vec = tCompute.load() + n_base = epi_work_tile_info.tile_n_idx * self.mma_tiler[1] + real_subtile_idx * self.epi_tile[1] + dbias_n_total = cute.size(mDbias_tensor, mode=[1]) + self.dbias_reduction( + dA_vec, + warp_idx, + sDbias, + mDbias_tensor, + expert_idx, + n_base, + dbias_n_total, + ) + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = amax_reduction_per_thread(tCompute, thread_tile_amax) + + if cutlass.const_expr(self.generate_sfd): + tCompute_col = cute.make_rmem_tensor(tCompute.layout, tCompute.element_type) + tCompute_col.store(tCompute.load()) + self.quant_sfd_row( + real_subtile_idx % 4, + tiled_copy_r2s, + tCompute, + tCrSFDRow_pvscale, + norm_const, + d_rcp_limits, + tRS_rD, + ) + self.quant_sfd_col( + real_subtile_idx % 4, + tiled_copy_r2s, + tCompute_col, + tCrSFDCol_pvscale, + norm_const, + d_rcp_limits, + tRS_rD_col, + ) + global_sfd_m = epi_work_tile_info.tile_m_idx + epi_ext.token_offset // self.cta_tile_shape_mnk[0] + if cutlass.const_expr(self.mma_tiler[1] == 256): + sfd_n = epi_work_tile_info.tile_n_idx * 2 + (real_subtile_idx >> 2) + else: + sfd_n = epi_work_tile_info.tile_n_idx + sfd_row_idx_mn = (global_sfd_m, sfd_n) + sfd_col_idx_mn = sfd_row_idx_mn + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_idx_mn = ( + epi_work_tile_info.tile_m_idx, + sfd_n, + ) + tCgSFDRow = tCgSFDRow_mn[(None, None, None, *sfd_row_idx_mn)] + tCgSFDCol = tCgSFDCol_mn[(None, None, None, *sfd_col_idx_mn)] + if subtile_idx == 3 or subtile_idx == 7: + if sfd_row_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDRow_mnl.layout, mode=[1])): + tCrSFDRow.store(tCrSFDRow_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDRow, tCgSFDRow) + if sfd_col_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDCol_mnl.layout, mode=[1])): + tCrSFDCol.store(tCrSFDCol_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDCol, tCgSFDCol) + else: + acc_vec = tiled_copy_r2s.retile(tCompute).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + d_buffer = num_prev_subtiles % self.num_d_stage + num_prev_subtiles = num_prev_subtiles + 1 + cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)]) + if cutlass.const_expr(self.generate_sfd): + cute.copy(tiled_copy_r2s, tRS_rD_col, tRS_sD_col[(None, None, None, d_buffer)]) + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + cute.copy(tma_atom_d, bSG_sD[(None, d_buffer)], bSG_gD[(None, real_subtile_idx)]) + if cutlass.const_expr(self.generate_sfd): + cute.copy(tma_atom_d_col, bSG_sD_col[(None, d_buffer)], bSG_gD_col[(None, real_subtile_idx)]) + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Write accumulated dprob gradient atomically + if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value): + if cutlass.const_expr(dprob is not None): + real_dprob, _ = epi_ext.get_gmem_tensor("dprob", dprob, padded_offsets, epi_work_tile_info) + _ = atomic_add_float32( + ptr=real_dprob[(mPosition, None, None)].iterator.llvm_ptr, + value=dProbVal, + ) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + if cutlass.const_expr(self.generate_amax): + gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr + self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax) + + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + d_pipeline.producer_tail() + + # ------------------------------------------------------------------ + # Internal: create extension based on weight_mode + # ------------------------------------------------------------------ + + @cute.jit + def _make_extension(self, workspace_ptr): + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + return DiscreteWeightScaledGemmSchedExtension( + tensormap_ctor=desc_workspace, + sf_vec_size=self.sf_vec_size, + ) + else: + return ContiguousAndConsistentGroupedGemmSchedExtension( + sf_vec_size=self.sf_vec_size, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py index 5da7b868..f6b51822 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py @@ -36,9 +36,7 @@ from .grouped_gemm_dswiglu_quant import ( BlockScaledContiguousGroupedGemmKernel, ) -from ..utils import logical_shape_fp4x2_aware from cuda.bindings import driver as cuda -import os import torch from typing import Tuple, Optional @@ -130,7 +128,7 @@ def __init__( """ super().__init__() - self._logger.warning("GroupedGemmDswigluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") # Store sample tensor descriptors @@ -182,7 +180,7 @@ def __init__( self._interpret_uint8_as_fp4x2 = True self._kernel = BlockScaledContiguousGroupedGemmKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._logger.debug(f"__init__ completed") def check_support(self) -> bool: @@ -374,9 +372,6 @@ def compile(self) -> None: if self._compiled_kernel is not None: self._logger.debug("Kernel already compiled; skipping recompilation") return - if self.a_desc.shape[0] == 0: - self._logger.debug("sample valid_m is zero, skipping kernel compilation") - return gemm_dswiglu = self._kernel( sf_vec_size=self.sf_vec_size, @@ -400,8 +395,8 @@ def compile(self) -> None: fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) self._logger.debug("Compiling grouped_gemm_dswiglu kernel") - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None - if not use_full_dynamic: # only mark the m dimension as dynamic + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + if not use_full_dynamic: valid_m = cute.sym_int(divisibility=256) a_cute_fake = self._make_fake_cute_compact_tensor( @@ -425,27 +420,26 @@ def compile(self) -> None: shape=(valid_m, *self.d_col_desc.shape[1:]), stride_order=self.d_col_desc.stride_order, ) - - tensor_m_128 = cute.sym_int() - stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) - stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) - sfa_cute_fake = self._make_fake_cute_tensor( - dtype=self.sfa_desc.dtype, - shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1), - stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), - ) - - sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) prob_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.prob_desc.dtype, shape=(valid_m, 1, 1), stride_order=self.prob_desc.stride_order, ) - d_prob_fake = self._make_fake_cute_compact_tensor( + dprob_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.dprob_desc.dtype, shape=(valid_m, 1, 1), stride_order=self.dprob_desc.stride_order, ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1), + stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), + ) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + sfd_row_fake = None sfd_col_fake = None if self.sfd_row_desc is not None: @@ -466,7 +460,6 @@ def compile(self) -> None: ) else: valid_m = cute.sym_int(divisibility=256) - n = cute.sym_int() n_2 = cute.sym_int() k = cute.sym_int() l = cute.sym_int() @@ -480,12 +473,11 @@ def compile(self) -> None: ) b_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.b_desc.dtype, - shape=(n, k, l), + shape=(cute.sym_int(), k, l), stride_order=self.b_desc.stride_order, dynamic_mode=self.b_desc.stride_order[0], divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, ) - c_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.c_desc.dtype, shape=(valid_m, n_2, 1), @@ -493,7 +485,6 @@ def compile(self) -> None: dynamic_mode=self.c_desc.stride_order[0], divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, ) - d_row_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.d_row_desc.dtype, shape=(valid_m, n_2, 1), @@ -501,7 +492,6 @@ def compile(self) -> None: dynamic_mode=self.d_row_desc.stride_order[0], divisibility=8 if self._is_f16(self.d_row_desc.dtype) else 16, ) - d_col_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.d_col_desc.dtype, shape=(valid_m, n_2, 1), @@ -509,8 +499,17 @@ def compile(self) -> None: dynamic_mode=self.d_col_desc.stride_order[0], divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16, ) + prob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, 1, 1), + stride_order=self.prob_desc.stride_order, + ) + dprob_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.dprob_desc.dtype, + shape=(valid_m, 1, 1), + stride_order=self.dprob_desc.stride_order, + ) - # 32, 4, tensor_m // 128, 4, rest_k, 1) tensor_m_128 = cute.sym_int() rest_k = cute.sym_int() stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) @@ -520,7 +519,6 @@ def compile(self) -> None: shape=(32, 4, tensor_m_128, 4, rest_k, 1), stride=(16, 4, stride_rest_k, 1, 512, stride_tensor_m_128), ) - tensor_n_128 = cute.sym_int() stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4) stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4) @@ -530,18 +528,6 @@ def compile(self) -> None: stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k), ) - prob_cute_fake = self._make_fake_cute_compact_tensor( - dtype=self.prob_desc.dtype, - shape=(valid_m, 1, 1), - stride_order=self.prob_desc.stride_order, - ) - - d_prob_fake = self._make_fake_cute_compact_tensor( - dtype=self.dprob_desc.dtype, - shape=(valid_m, 1, 1), - stride_order=self.dprob_desc.stride_order, - ) - sfd_row_fake = None sfd_col_fake = None if self.sfd_row_desc is not None: @@ -581,7 +567,7 @@ def compile(self) -> None: alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16), beta=self._make_fake_cute_tensor_from_desc(self.beta_desc, assumed_align=16), prob=prob_cute_fake, - dprob=d_prob_fake, + dprob=dprob_cute_fake, max_active_clusters=max_active_clusters, epilogue_op=self.epilogue_op, stream=fake_stream, @@ -675,9 +661,6 @@ def execute( self._logger.debug("Entering execute") current_stream = self._get_default_stream(current_stream) - if a_tensor.shape[0] == 0: - self._logger.debug("execute: valid_m is zero, skipping kernel execution") - return if self._compiled_kernel is None: raise RuntimeError("Kernel not compiled; call compile() first") self._logger.debug("Executing grouped_gemm_dswiglu kernel") @@ -705,6 +688,7 @@ def execute( import logging +import os _logger = logging.getLogger(__name__) _cache_of_GroupedGemmDswigluSm100Objects = {} @@ -718,7 +702,7 @@ def grouped_gemm_dswiglu_wrapper_sm100( sfb_tensor: torch.Tensor, padded_offsets: torch.Tensor, alpha_tensor: torch.Tensor, - beta_tensor: torch.Tensor, + beta_tensor: Optional[torch.Tensor], prob_tensor: torch.Tensor, norm_const_tensor: Optional[torch.Tensor] = None, acc_dtype: torch.dtype = torch.float32, @@ -732,6 +716,8 @@ def grouped_gemm_dswiglu_wrapper_sm100( discrete_col_sfd: bool = False, epilogue_op: Optional[str] = None, current_stream: Optional[cuda.CUstream] = None, + dprob_tensor_buf: Optional[torch.Tensor] = None, + amax_tensor_buf: Optional[torch.Tensor] = None, ) -> TupleDict: """Convenience wrapper for grouped GEMM dSwiGLU backward operation. @@ -771,67 +757,13 @@ def grouped_gemm_dswiglu_wrapper_sm100( - **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D - **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D """ + valid_m = a_tensor.shape[0] + n, _, l = b_tensor.shape - valid_m, _, _ = logical_shape_fp4x2_aware(a_tensor) - n, _, l = logical_shape_fp4x2_aware(b_tensor) - - _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Creating output tensors d_row_tensor, d_col_tensor, dprob_tensor") - - if cd_major == "n": - d_row_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device) - d_col_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device) - dprob_tensor = torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device) - else: + if cd_major != "n": raise ValueError(f"cd_major must be 'n', got {cd_major}") - sfd_row_tensor = None - sfd_col_tensor = None - amax_tensor = None - - if a_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]: - _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor") - - sf_dtype = sfa_tensor.dtype - mma_permute_order = (3, 4, 1, 5, 2, 0) - - sf_k_row = ceil_div(n * 2, sf_vec_size) - mma_shape_row = ( - 1, - ceil_div(valid_m, 128), - ceil_div(sf_k_row, 4), - 32, - 4, - 4, - ) - sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) - - sf_k_col = ceil_div(valid_m, sf_vec_size) - mma_shape_col = ( - 1, - ceil_div(n * 2, 128), - ceil_div(sf_k_col, 4), - 32, - 4, - 4, - ) - sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) - - if d_dtype in [torch.bfloat16, torch.float16]: - _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor") - amax_tensor = torch.full((l, 2, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) - - if valid_m == 0: - _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: valid_m is zero, skipping kernel execution") - return TupleDict( - d_row_tensor=d_row_tensor, - d_col_tensor=d_col_tensor, - dprob_tensor=dprob_tensor, - amax_tensor=amax_tensor, - sfd_row_tensor=sfd_row_tensor, - sfd_col_tensor=sfd_col_tensor, - ) - - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) @@ -847,12 +779,14 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: stride_order(a_tensor), stride_order(b_tensor), stride_order(c_tensor), - norm_const_tensor.shape if norm_const_tensor is not None else None, - norm_const_tensor.stride() if norm_const_tensor is not None else None, - norm_const_tensor.dtype if norm_const_tensor is not None else None, + sfa_tensor.dtype, + sfb_tensor.dtype, padded_offsets.shape if not use_full_dynamic else None, padded_offsets.stride() if not use_full_dynamic else None, padded_offsets.dtype, + norm_const_tensor.shape if norm_const_tensor is not None else None, + norm_const_tensor.stride() if norm_const_tensor is not None else None, + norm_const_tensor.dtype if norm_const_tensor is not None else None, acc_dtype, d_dtype, cd_major, @@ -865,32 +799,69 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: epilogue_op, ) - if cache_key in _cache_of_GroupedGemmDswigluSm100Objects: - _logger.debug("group_gemm_dswiglu_wrapper_sm100: Using previously cached GroupedGemmDswigluSm100 object") - grouped_gemm_dswiglu = _cache_of_GroupedGemmDswigluSm100Objects[cache_key] - grouped_gemm_dswiglu.execute( - a_tensor=a_tensor, - b_tensor=b_tensor, - c_tensor=c_tensor, + # Allocate M-dependent output tensors fresh every call (M varies across MoE steps). + # Only M-independent tensors (amax, beta) are cached to avoid repeated allocation. + _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Allocating M-dependent output tensors") + d_row_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device) + d_col_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device) + dprob_tensor = dprob_tensor_buf.zero_() if dprob_tensor_buf is not None else torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device) + + if valid_m == 0: + amax_tensor = None + if d_dtype in [torch.bfloat16, torch.float16]: + amax_tensor = torch.full((l, 2, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: valid_m is zero, skipping kernel execution") + return TupleDict( d_row_tensor=d_row_tensor, d_col_tensor=d_col_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - padded_offsets=padded_offsets, - alpha_tensor=alpha_tensor, - beta_tensor=beta_tensor, - prob_tensor=prob_tensor, dprob_tensor=dprob_tensor, - sfd_row_tensor=sfd_row_tensor, - sfd_col_tensor=sfd_col_tensor, amax_tensor=amax_tensor, - norm_const_tensor=norm_const_tensor, - current_stream=current_stream, + sfd_row_tensor=None, + sfd_col_tensor=None, ) + + sfd_row_tensor = None + sfd_col_tensor = None + if a_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]: + _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor") + sf_dtype = sfa_tensor.dtype + mma_permute_order = (3, 4, 1, 5, 2, 0) + sf_k_row = ceil_div(n * 2, sf_vec_size) + mma_shape_row = (1, ceil_div(valid_m, 128), ceil_div(sf_k_row, 4), 32, 4, 4) + sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + sf_k_col = ceil_div(valid_m, sf_vec_size) + mma_shape_col = (1, ceil_div(n * 2, 128), ceil_div(sf_k_col, 4), 32, 4, 4) + sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + if cache_key in _cache_of_GroupedGemmDswigluSm100Objects: + _logger.debug("group_gemm_dswiglu_wrapper_sm100: Using previously cached GroupedGemmDswigluSm100 object") + grouped_gemm_dswiglu, cached_amax_tensor, cached_beta_tensor = _cache_of_GroupedGemmDswigluSm100Objects[cache_key] + amax_tensor = amax_tensor_buf if amax_tensor_buf is not None else cached_amax_tensor + if beta_tensor is not None: + effective_beta = beta_tensor + elif cached_beta_tensor is not None: + effective_beta = cached_beta_tensor + else: + # Fallback: cache was populated without beta caching (non-NVFP4 path), + # but caller now passes None (NVFP4 path). Create ones tensor on-the-fly. + effective_beta = torch.ones(l, dtype=torch.float32, device=a_tensor.device) else: _logger.debug( "group_gemm_dswiglu_wrapper_sm100: No previously cached GroupedGemmDswigluSm100 object found, creating new GroupedGemmDswigluSm100 object" ) + + # For NVFP4 (beta_tensor=None): create and cache a ones tensor — avoids FillFunctor on every step. + # For non-NVFP4 (beta_tensor provided): use caller's value directly; don't cache (it changes each step). + cached_beta_tensor = torch.ones(l, dtype=torch.float32, device=a_tensor.device) if beta_tensor is None else None + effective_beta = cached_beta_tensor if beta_tensor is None else beta_tensor + + cached_amax_tensor = None + amax_tensor = None + if d_dtype in [torch.bfloat16, torch.float16]: + _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor") + cached_amax_tensor = torch.empty((l, 2, 1), dtype=torch.float32, device=a_tensor.device) + amax_tensor = amax_tensor_buf if amax_tensor_buf is not None else cached_amax_tensor + grouped_gemm_dswiglu = GroupedGemmDswigluSm100( sample_a=a_tensor, sample_b=b_tensor, @@ -901,10 +872,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: sample_sfb=sfb_tensor, sample_padded_offsets=padded_offsets, sample_alpha=alpha_tensor, - sample_beta=beta_tensor, + sample_beta=effective_beta, sample_prob=prob_tensor, sample_dprob=dprob_tensor, - sample_amax=amax_tensor, + sample_amax=cached_amax_tensor, sample_sfd_row=sfd_row_tensor, sample_sfd_col=sfd_col_tensor, sample_norm_const=norm_const_tensor, @@ -920,26 +891,27 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: assert grouped_gemm_dswiglu.check_support(), "Unsupported configuration" grouped_gemm_dswiglu.compile() - grouped_gemm_dswiglu.execute( - a_tensor=a_tensor, - b_tensor=b_tensor, - c_tensor=c_tensor, - d_row_tensor=d_row_tensor, - d_col_tensor=d_col_tensor, - sfa_tensor=sfa_tensor, - sfb_tensor=sfb_tensor, - padded_offsets=padded_offsets, - alpha_tensor=alpha_tensor, - beta_tensor=beta_tensor, - prob_tensor=prob_tensor, - dprob_tensor=dprob_tensor, - sfd_row_tensor=sfd_row_tensor, - sfd_col_tensor=sfd_col_tensor, - amax_tensor=amax_tensor, - norm_const_tensor=norm_const_tensor, - current_stream=current_stream, - ) - _cache_of_GroupedGemmDswigluSm100Objects[cache_key] = grouped_gemm_dswiglu + _cache_of_GroupedGemmDswigluSm100Objects[cache_key] = (grouped_gemm_dswiglu, cached_amax_tensor, cached_beta_tensor) + + grouped_gemm_dswiglu.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_row_tensor=d_row_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + beta_tensor=effective_beta, + prob_tensor=prob_tensor, + dprob_tensor=dprob_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + current_stream=current_stream, + ) return TupleDict( d_row_tensor=d_row_tensor, diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py index b60eac72..cecaee3a 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py @@ -165,7 +165,7 @@ def __init__( """ super().__init__() - self._logger.warning("GroupedGemmGluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") # ---- Weight mode auto-detection ---- @@ -239,7 +239,7 @@ def __init__( self._kernel = BlockScaledMoEGroupedGemmGluBiasKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._workspace = None @@ -648,7 +648,7 @@ def compile(self) -> None: def _compile_dense(self, gemm_glu, max_active_clusters, fake_stream) -> None: """Compile for dense (contiguous) weight mode.""" - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" fake_workspace_ptr = cute.runtime.nullptr( dtype=cutlass.Uint8, @@ -1385,7 +1385,7 @@ def dynamic_m_tensor_signature( stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) return static_shape_suffix, stride_signature, tensor.dtype - use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" if is_dense: cache_key = ( diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py b/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py index 7c2f2647..72b728f4 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py @@ -2313,6 +2313,19 @@ def kernel( expert_idx = tile_info[0] gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)] + + # For dynamic MNKL, cuteDSL drops the 128bit alignment requirement + # but we know that during runtime the alignment is always 16 bytes. + gBias_tile = cute.make_tensor( + cute.make_ptr( + gBias_tile.element_type, + gBias_tile.iterator.toint(), + AddressSpace.gmem, + assumed_align=16, + ), + gBias_tile.layout, + ) + tBs_gBias = thr_bias_g2s.partition_S(gBias_tile) # Predicate: check if this thread's chunk is within N diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py new file mode 100644 index 00000000..2ef06aa0 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from .api import ( + GroupedGemmGluHadamardSm100, + grouped_gemm_glu_hadamard_wrapper_sm100, +) + +__all__ = [ + "GroupedGemmGluHadamardSm100", + "grouped_gemm_glu_hadamard_wrapper_sm100", +] diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py new file mode 100644 index 00000000..73539492 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py @@ -0,0 +1,540 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""FE API for grouped GEMM GLU + Hadamard forward fusion.""" + +import logging +import os +from typing import Optional, Tuple + +from cuda.bindings import driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cute.runtime import from_dlpack, make_fake_stream + +from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2 +from cudnn.datatypes import _convert_to_cutlass_data_type + +from ..moe_utils import MoEWeightMode +from .hadamard_utils import HADAMARD_SIZE, hadamard_matrix +from .moe_blockscaled_grouped_gemm_glu_hadamard import BlockScaledMoEGroupedGemmGluHadamardKernel + + +def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.uint8: + cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1) + cute_tensor.element_type = cutlass.Float4E2M1FN + return cute_tensor + return tensor + + +class GroupedGemmGluHadamardSm100(APIBase): + """Dense grouped GEMM GLU forward kernel with fused Hadamard transform.""" + + def __init__( + self, + sample_a: torch.Tensor, + sample_b: torch.Tensor, + sample_c: torch.Tensor, + sample_d: torch.Tensor, + sample_sfa: torch.Tensor, + sample_sfb: torch.Tensor, + sample_padded_offsets: torch.Tensor, + sample_alpha: torch.Tensor, + sample_prob: torch.Tensor, + sample_amax: Optional[torch.Tensor] = None, + sample_bias: Optional[torch.Tensor] = None, + sample_hadamard: Optional[torch.Tensor] = None, + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + act_func: str = "swiglu", + use_dynamic_sched: bool = False, + ): + super().__init__() + + self._warn_experimental_api() + self._interpret_uint8_as_fp4x2 = True + self._sample_a_tensor = sample_a + self._sample_b_tensor = sample_b + + self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False) + self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False) + self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") + self.d_desc = self._make_tensor_desc(sample_d, name="sample_d") + self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa") + self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") + self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets") + self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha") + self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") + self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias") + self.expert_cnt = self.padded_offsets_desc.shape[0] + if sample_amax is None: + sample_amax = torch.empty((self.expert_cnt, 1), dtype=torch.float32, device=sample_a.device) + self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax") + if sample_hadamard is None: + self.hadamard_tensor = self._make_hadamard_tensor(sample_a.device) + else: + self.hadamard_tensor = self._normalize_hadamard_tensor( + sample_hadamard, + device=sample_a.device, + name="sample_hadamard", + ) + + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if self.use_2cta_instrs else (1, 1)) + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.m_aligned = m_aligned + self.act_func = act_func + self.use_dynamic_sched = use_dynamic_sched + self.weight_mode = MoEWeightMode.DENSE + self._kernel = BlockScaledMoEGroupedGemmGluHadamardKernel + self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) + self._workspace = None + + @staticmethod + def _make_hadamard_tensor(device: torch.device) -> torch.Tensor: + return hadamard_matrix(HADAMARD_SIZE, dtype=torch.bfloat16, device=device).t().contiguous() + + @classmethod + def _normalize_hadamard_tensor( + cls, + hadamard_tensor: torch.Tensor, + *, + device: torch.device, + name: str, + ) -> torch.Tensor: + expected_shape = (HADAMARD_SIZE, HADAMARD_SIZE) + if tuple(hadamard_tensor.shape) != expected_shape: + raise ValueError(f"{name} tensor shape mismatch: expected {expected_shape}, got {tuple(hadamard_tensor.shape)}") + if hadamard_tensor.dtype != torch.bfloat16 or hadamard_tensor.device != device: + hadamard_tensor = hadamard_tensor.to(device=device, dtype=torch.bfloat16) + if not hadamard_tensor.is_contiguous(): + hadamard_tensor = hadamard_tensor.contiguous() + return hadamard_tensor + + def check_support(self) -> bool: + tensor_m, k, _ = self._tensor_shape(self.a_desc, name="sample_a") + n, _, l = self._tensor_shape(self.b_desc, name="sample_b") + _, n_c, _ = self._tensor_shape(self.c_desc, name="sample_c") + _, n_d, _ = self._tensor_shape(self.d_desc, name="sample_d") + + self._value_error_if(l != self.expert_cnt, f"B L dimension ({l}) must match expert_cnt ({self.expert_cnt})") + self._value_error_if(n % 64 != 0, f"N must be divisible by 64, got {n}") + self._value_error_if((n // 2) % HADAMARD_SIZE != 0, f"N/2 must be divisible by {HADAMARD_SIZE}, got {n // 2}") + + self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A") + self._check_tensor_shape(self.b_desc, (n, k, l), "B") + self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C") + self._check_tensor_shape(self.d_desc, (tensor_m, n // 2, 1), "D") + self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, ceil_div(ceil_div(k, self.sf_vec_size), 4), 1), "SFA") + self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, ceil_div(ceil_div(k, self.sf_vec_size), 4), l), "SFB") + self._check_tensor_shape(self.padded_offsets_desc, (l,), "padded_offsets") + self._check_tensor_shape(self.alpha_desc, (l,), "alpha") + self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob") + self._check_tensor_shape(self.bias_desc, (n, l), "bias") + self._check_tensor_shape(self.amax_desc, (l, 1), "amax") + self._check_tensor_shape(self.hadamard_tensor, (HADAMARD_SIZE, HADAMARD_SIZE), "hadamard") + + self._check_tensor_stride(self.a_desc, stride=[(k, 1, tensor_m * k)], name="A", extra_error_msg="A must have k-major layout") + self._check_tensor_stride(self.b_desc, stride=[(k, 1, n * k)], name="B", extra_error_msg="B must have k-major layout") + self._check_tensor_stride(self.c_desc, stride=[(n_c, 1, tensor_m * n_c)], name="C", extra_error_msg="C must have n-major layout") + self._check_tensor_stride(self.d_desc, stride=[(n_d, 1, tensor_m * n_d)], name="D", extra_error_msg="D must have n-major layout") + self._check_tensor_stride(self.bias_desc, stride=[(1, n)], name="bias") + + self.ab_dtype = self._check_dtype( + self.a_desc, + dtype=[torch.float4_e2m1fn_x2, torch.uint8], + name="A", + ) + self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="B must match A dtype") + self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA") + self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must match SFA dtype") + self.c_dtype = self._check_dtype(self.c_desc, dtype=[torch.float16, torch.bfloat16], name="C") + self.d_dtype = self._check_dtype(self.d_desc, dtype=[torch.float16, torch.bfloat16], name="D") + self._check_dtype(self.alpha_desc, dtype=torch.float32, name="alpha") + self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob") + self._check_dtype(self.bias_desc, dtype=[torch.float16, torch.bfloat16, torch.float32], name="bias") + self._check_dtype(self.amax_desc, dtype=torch.float32, name="amax") + self._check_dtype(self.hadamard_tensor, dtype=torch.bfloat16, name="hadamard") + self._check_dtype(self.acc_dtype, dtype=torch.float32, name="acc_dtype") + + self._value_error_if(self.sf_vec_size not in [16, 32], f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}") + self._value_error_if(self.act_func not in ["swiglu", "geglu"], f"act_func must be 'swiglu' or 'geglu', got {self.act_func}") + self._value_error_if( + not self.use_2cta_instrs or self.mma_tiler_mn != (256, 256), f"Hadamard fusion requires mma_tiler_mn=(256, 256), got {self.mma_tiler_mn}" + ) + self._value_error_if(self.cluster_shape_mn[0] % 2 != 0, f"cluster_shape_mn[0] must be divisible by 2, got {self.cluster_shape_mn[0]}") + self._value_error_if( + not ( + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] <= 4 + and self.cluster_shape_mn[1] <= 4 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape: {self.cluster_shape_mn}", + ) + self._value_error_if( + self.m_aligned != BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE, + f"m_aligned must be {BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE}, got {self.m_aligned}", + ) + self._value_error_if(self.expert_cnt > 1024, f"expert_cnt must be <= 1024, got {self.expert_cnt}") + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) + compute_capability = major * 10 + minor + if compute_capability < 100: + raise RuntimeError(f"GroupedGemmGluHadamardSm100 requires SM100+, found SM{compute_capability}") + + if not self._kernel.can_implement( + _convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), + _convert_to_cutlass_data_type(self.sf_dtype), + self.sf_vec_size, + _convert_to_cutlass_data_type(self.acc_dtype), + _convert_to_cutlass_data_type(self.d_dtype), + self.use_2cta_instrs, + self.mma_tiler_mn, + self.cluster_shape_mn, + self.m_aligned, + n, + k, + l, + "k", + "k", + "n", + self.m_aligned, + ): + raise RuntimeError("Unsupported grouped GEMM GLU hadamard configuration") + + self._is_supported = True + return True + + def compile(self) -> None: + self._ensure_support_checked() + if self._compiled_kernel is not None: + return + if self.a_desc.shape[0] == 0: + return + + kernel = self._kernel( + sf_vec_size=self.sf_vec_size, + acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), + use_2cta_instrs=self.use_2cta_instrs, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vectorized_f32=self.vector_f32, + expert_cnt=self.expert_cnt, + weight_mode=MoEWeightMode.DENSE, + use_dynamic_sched=self.use_dynamic_sched, + act_func=self.act_func, + enable_bias=self.bias_desc is not None, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + max_active_clusters -= self.num_cluster_overlap_margin + self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin") + self._workspace = torch.empty(max(kernel.get_workspace_bytes(), 1), dtype=torch.uint8, device="cuda") + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + fake_workspace_ptr = cute.runtime.nullptr(dtype=cutlass.Uint8, assumed_align=128) + + valid_m = cute.sym_int(divisibility=self.m_aligned) + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, self.a_desc.shape[1], 1), + stride_order=self.a_desc.stride_order, + dynamic_mode=self.a_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, self.c_desc.shape[1], 1), + stride_order=self.c_desc.stride_order, + dynamic_mode=self.c_desc.stride_order[0], + divisibility=8 if self._is_f16(self.c_desc) else 16, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, self.d_desc.shape[1], 1), + stride_order=self.d_desc.stride_order, + dynamic_mode=self.d_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_desc) else 16, + ) + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, 1, 1), + stride=self.prob_desc.stride, + ) + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1), + stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128), + ) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + alpha_cute_fake = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16) + padded_offsets_cute_fake = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16) + amax_cute_fake = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16) + bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16) + hadamard_cute_fake = self._make_fake_cute_tensor_like(self.hadamard_tensor, assumed_align=16, name="sample_hadamard") + cached_linear_offset = cutlass.Float32(1.0 if self.act_func == "geglu" else 0.0) + + compiled_kernel = cute.compile( + kernel, + _reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake, + _reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake, + sfa_cute_fake, + sfb_cute_fake, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int64(0), + OperandMajorMode.K, + fake_workspace_ptr, + c_cute_fake, + d_cute_fake, + amax_cute_fake, + padded_offsets_cute_fake, + alpha_cute_fake, + prob_cute_fake, + hadamard_cute_fake, + bias_cute_fake, + max_active_clusters, + fake_stream, + cached_linear_offset, + options="--enable-tvm-ffi", + ) + + cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator + + def tensor_api( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + hadamard_tensor: torch.Tensor, + amax_tensor: Optional[torch.Tensor], + bias_tensor: Optional[torch.Tensor], + stream: cuda.CUstream, + ) -> None: + compiled_kernel( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int64(0), + cached_workspace_ptr, + c_tensor, + d_tensor, + amax_tensor, + padded_offsets, + alpha_tensor, + prob_tensor, + hadamard_tensor, + bias_tensor, + stream, + cached_linear_offset, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + hadamard_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + bias_tensor: Optional[torch.Tensor] = None, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + self._ensure_support_checked() + if self._compiled_kernel is None: + raise RuntimeError("Kernel has not been compiled") + if a_tensor.shape[0] == 0: + return + if current_stream is None: + current_stream = cuda.CUstream(torch.cuda.current_stream(a_tensor.device).cuda_stream) + if hadamard_tensor is None: + hadamard_tensor = self.hadamard_tensor + else: + hadamard_tensor = self._normalize_hadamard_tensor( + hadamard_tensor, + device=a_tensor.device, + name="hadamard", + ) + + self._compiled_kernel( + _reinterpret_raw_grouped_fp4_tensor(a_tensor), + _reinterpret_raw_grouped_fp4_tensor(b_tensor), + c_tensor, + d_tensor, + sfa_tensor, + sfb_tensor, + padded_offsets, + alpha_tensor, + prob_tensor, + hadamard_tensor, + amax_tensor, + bias_tensor, + current_stream, + ) + + +_logger = logging.getLogger(__name__) +_cache_of_GroupedGemmGluHadamardSm100Objects = {} + + +def grouped_gemm_glu_hadamard_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + bias_tensor: Optional[torch.Tensor] = None, + acc_dtype: torch.dtype = torch.float32, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.bfloat16, + cd_major: str = "n", + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + act_func: str = "swiglu", + use_dynamic_sched: bool = False, + current_stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + """High-level wrapper for grouped GEMM GLU + Hadamard forward fusion.""" + + valid_m = a_tensor.shape[0] + n_full, _, l = b_tensor.shape + n_out = n_full // 2 + + if cd_major != "n": + raise ValueError(f"cd_major must be 'n', got {cd_major}") + + c_tensor = torch.empty_strided((valid_m, n_full, 1), (n_full, 1, valid_m * n_full), dtype=c_dtype, device=a_tensor.device) + d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + amax_tensor = None + if d_dtype in [torch.bfloat16, torch.float16]: + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + + if valid_m == 0: + return TupleDict(c_tensor=c_tensor, d_tensor=d_tensor, amax_tensor=amax_tensor) + + def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: + return tuple(i for i, _ in sorted(enumerate(tensor.stride()), key=lambda item: item[1])) + + def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype + + def dynamic_m_tensor_signature( + tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = () + ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + stride_signature = tuple(None if idx in dynamic_stride_dims else value for idx, value in enumerate(tensor.stride())) + return static_shape_suffix, stride_signature, tensor.dtype + + cache_key = ( + act_func, + a_tensor.shape[1:], + tuple(b_tensor.shape), + c_tensor.shape[1:], + a_tensor.dtype, + b_tensor.dtype, + c_tensor.dtype, + d_tensor.dtype, + stride_order(a_tensor), + stride_order(b_tensor), + stride_order(c_tensor), + *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1), dynamic_stride_dims=(5,)), + *tensor_signature(sfb_tensor), + *tensor_signature(alpha_tensor), + *dynamic_m_tensor_signature(prob_tensor, (1, 1)), + *tensor_signature(bias_tensor), + *tensor_signature(padded_offsets), + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + use_dynamic_sched, + ) + + if cache_key in _cache_of_GroupedGemmGluHadamardSm100Objects: + api = _cache_of_GroupedGemmGluHadamardSm100Objects[cache_key] + else: + api = GroupedGemmGluHadamardSm100( + sample_a=a_tensor, + sample_b=b_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_sfa=sfa_tensor, + sample_sfb=sfb_tensor, + sample_padded_offsets=padded_offsets, + sample_alpha=alpha_tensor, + sample_prob=prob_tensor, + sample_amax=amax_tensor, + sample_bias=bias_tensor, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + act_func=act_func, + use_dynamic_sched=use_dynamic_sched, + ) + api.check_support() + api.compile() + _cache_of_GroupedGemmGluHadamardSm100Objects[cache_key] = api + + api.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + amax_tensor=amax_tensor, + bias_tensor=bias_tensor, + current_stream=current_stream, + ) + return TupleDict(c_tensor=c_tensor, d_tensor=d_tensor, amax_tensor=amax_tensor) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py new file mode 100644 index 00000000..205ef21d --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Local Hadamard helpers for the grouped GEMM GLU hadamard kernel.""" + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode, OperandSource +import torch + +HADAMARD_SIZE = 16 +TMEM_ROW_STRIDE = 1 << 16 +M_PER_CLUSTER = 256 + + +@cute.jit +def hadamard_setup(g_hadamard, s_hadamard, tidx): + tiled_hmma = sm100_utils.make_trivial_tiled_mma( + cutlass.BFloat16, + OperandMajorMode.K, + OperandMajorMode.K, + cutlass.Float32, + tcgen05.CtaGroup.TWO, + (M_PER_CLUSTER, HADAMARD_SIZE), + OperandSource.TMEM, + ) + s_hadamard[tidx] = g_hadamard[tidx if tidx < 64 else (tidx ^ 8)] + return tiled_hmma + + +@cute.jit +def hadamard_compute(tiled_hmma, tmem_a_ptr, tmem_acc_ptr, s_hadamard, epi_tile, tidx, pipeline_producer): + n = epi_tile[1] + mma_tiler = (M_PER_CLUSTER, HADAMARD_SIZE, HADAMARD_SIZE) + cta_rank = cute.arch.block_in_cluster_idx()[0] + thr = tiled_hmma.get_slice(cta_rank) + b_layout = sm100_utils.make_smem_layout_b(tiled_hmma, mma_tiler, cutlass.BFloat16, 1) + s_bh = s_hadamard.get_tensor(b_layout.outer, swizzle=b_layout.inner) + t_bs_b = tiled_hmma.make_fragment_B(s_bh) + + t_a_subtile = cute.make_tensor(0, cute.make_layout((M_PER_CLUSTER, n), stride=(n, 1))) + t_ht_a_frg = thr.partition_A(t_a_subtile) + t_ht_a = cute.make_tensor( + cute.recast_ptr(tmem_a_ptr, dtype=cutlass.BFloat16), + tiled_hmma.make_fragment_A(t_ht_a_frg.layout).layout, + ) + + t_ht_c_frg = thr.partition_C(t_a_subtile) + t_ht_c = cute.make_tensor( + tmem_acc_ptr, + tiled_hmma.make_fragment_C(t_ht_c_frg.layout).layout, + ) + + if cta_rank == 0 and tidx < 32: + hadamard_empty = pipeline_producer.acquire_and_advance() + for i in cutlass.range_constexpr(cute.size(t_ht_c.shape, mode=[2]), unroll_full=True): + cute.gemm( + tiled_hmma, + cute.append_ones(t_ht_c[None, None, i], up_to_rank=3), + cute.append_ones(t_ht_a[(None, None, i)], up_to_rank=3), + cute.append_ones(t_bs_b[(None, None, 0, 0)], up_to_rank=3), + cute.append_ones(t_ht_c[None, None, i], up_to_rank=3), + ) + hadamard_empty.commit() + + +@cute.jit +def hadamard_in(rmem_src: cute.Tensor, cols: cutlass.Constexpr, tmem_ptr, tidx): + tmem_ptr = cute.make_ptr( + cutlass.Float32, + tmem_ptr.toint(), + cute.AddressSpace.tmem, + assumed_align=8, + ) + tmem_tensor = cute.make_tensor( + tmem_ptr, + cute.make_layout((128, cols), stride=(TMEM_ROW_STRIDE, 1)), + ) + + if cutlass.const_expr(cols == 32): + st_atom = cute.make_copy_atom(tcgen05.St32x32bOp(tcgen05.Repetition.x32), cutlass.Float32) + else: + st_atom = cute.make_copy_atom(tcgen05.St32x32bOp(tcgen05.Repetition.x16), cutlass.Float32) + tiled_st = tcgen05.make_tmem_copy(st_atom, tmem_tensor) + thr_st = tiled_st.get_slice(tidx) + t_d_st = thr_st.partition_D(tmem_tensor) + cute.copy(thr_st, cute.recast_tensor(rmem_src, cutlass.Float32), t_d_st) + + +@cute.jit +def hadamard_out(rmem_dst: cute.Tensor, cols: cutlass.Constexpr, tmem_ptr, tidx): + tmem_ptr = cute.make_ptr( + cutlass.Float32, + tmem_ptr.toint(), + cute.AddressSpace.tmem, + assumed_align=8, + ) + tmem_tensor = cute.make_tensor( + tmem_ptr, + cute.make_layout((128, cols), stride=(TMEM_ROW_STRIDE, 1)), + ) + + if cutlass.const_expr(cols == 32): + ld_atom = cute.make_copy_atom(tcgen05.Ld32x32bOp(tcgen05.Repetition.x32), cutlass.Float32) + else: + ld_atom = cute.make_copy_atom(tcgen05.Ld32x32bOp(tcgen05.Repetition.x16), cutlass.Float32) + tiled_ld = tcgen05.make_tmem_copy(ld_atom, tmem_tensor) + thr_ld = tiled_ld.get_slice(tidx) + t_d_ld_ = thr_ld.partition_D(tmem_tensor) + t_d_ld = cute.make_tensor( + cute.make_ptr( + cutlass.Float32, + t_d_ld_.iterator.toint(), + cute.AddressSpace.tmem, + assumed_align=8, + ), + t_d_ld_.layout, + ) + cute.copy(thr_ld, t_d_ld, cute.recast_tensor(rmem_dst, cutlass.Float32)) + + +def hadamard_matrix(n, dtype=None, device=None): + if dtype is None: + dtype = torch.float32 + if n < 1: + raise ValueError("n must be a positive integer") + if n & (n - 1): + raise ValueError("n must be a power of 2") + + kwargs = {"dtype": dtype} + if device is not None: + kwargs["device"] = device + matrix = torch.tensor([[1]], **kwargs) + base = torch.tensor([[1, 1], [1, -1]], **kwargs) + while matrix.shape[0] < n: + matrix = torch.kron(matrix, base) + return matrix diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py new file mode 100644 index 00000000..2cfc81e1 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py @@ -0,0 +1,2513 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +MoE Block-Scaled Grouped GEMM Kernel with GLU (SwiGLU/GeGLU) + Hadamard Transform Fusion. + +Supports: + - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler) + - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout + - BF16/F16 D output with Hadamard transform (pingpong epilogue) + - Optional C output (pre-activation GLU output) + - AMAX reduction for calibration + - GLU activation fusion (SwiGLU / GeGLU) + +Warp assignment (8 epilogue warps, pingpong): + warps 0-3 : ACT warps — TMEM→reg, alpha scale, GLU activation, C store, hadamard_in + warps 4-7 : RHT store warps — hadamard_compute, D store + warp 8 : MMA warp + warp 9 : TMA load warp + warp 10 : Scheduler warp (MoEPersistentTileScheduler) + warp 11 : Bias load warp (optional) + +sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) + Validity: tile_info[0] >= 0 (expert_idx == -1 signals end) +""" + +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass._mlir.dialects.nvvm import ReduxKind +from cutlass.cute.typing import Float32, Int32, AddressSpace +from ..moe_persistent_scheduler import ( + MoEPersistentTileScheduler, + MoESchedulerParams, + MoEWorkTileInfo, +) +from ..moe_utils import ( + compute_expert_token_range, + MoEWeightMode, + TensormapWorkspace, + store_tma_desc, +) +from .hadamard_utils import ( + hadamard_setup, + hadamard_compute, + hadamard_in, + hadamard_out, + HADAMARD_SIZE, +) +from ..moe_sched_extension import ( + DiscreteWeightScaledGemmSchedExtension, + ContiguousAndConsistentGroupedGemmSchedExtension, +) +from ..moe_kernel_helpers import ( + fmin, + fmax, + warp_redux_sync, + atomic_max_float32, + silu_f32, + silu_f32_geglu_scaled, + compute_grid, + get_dtype_rcp_limits, + can_implement, + amax_reduction_per_thread, +) + + +class BlockScaledMoEGroupedGemmGluHadamardKernel: + """Block-scaled MoE grouped GEMM with GLU activation and Hadamard transform fusion. + + Always uses pingpong epilogue (8 epilogue warps: 4 ACT + 4 RHT-store). + D output is BF16 or F16 only (no FP8/FP4, no SFD). + + :param sf_vec_size: Scalefactor vector size. + :param mma_tiler_mn: Shape of MMA tile (M, N). + :param cluster_shape_mn: Cluster dimensions (M, N). + :param expert_cnt: Number of experts (compile-time constant). + :param weight_mode: Dense or Discrete weight layout. + :param use_dynamic_sched: Use dynamic tile scheduling. + :param act_func: Activation function ('swiglu' or 'geglu'). + :param enable_bias: Enable bias addition. + """ + + FIX_PAD_SIZE = 256 + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + cd_major: str, + m_aligned: int, + ) -> bool: + # speical requirements for hadamard fusion + if not use_2cta_instrs or mma_tiler_mn[0] != 256 or mma_tiler_mn[1] != 256: + return False + return can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + acc_dtype, + d_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + cd_major, + m_aligned, + fix_pad_size=BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE, + ) + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vectorized_f32: bool, + expert_cnt: int, + weight_mode: MoEWeightMode = MoEWeightMode.DISCRETE, + use_dynamic_sched: bool = False, + act_func: str = "swiglu", + enable_bias: bool = False, + ): + mma_tile_m = mma_tiler_mn[0] + if self.FIX_PAD_SIZE % mma_tile_m != 0: + raise ValueError(f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}).") + if expert_cnt > 1024: + raise ValueError("Expert count > 1024 is not supported.") + if not isinstance(weight_mode, MoEWeightMode): + raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}") + + self.sf_vec_size = sf_vec_size + self.expert_cnt = expert_cnt + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + self.mma_tiler = (*mma_tiler_mn, 1) + self.weight_mode = weight_mode + self.use_dynamic_sched = use_dynamic_sched + self.enable_bias = enable_bias + + # Always use pingpong epilogue for Hadamard + self.epilogue_pingpong = True + # Always delay TMA store acquire sync for Hadamard + self.delay_tma_store_acquire_sync = True + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.threads_per_warp = 32 + + # Warp assignments: 8 epilogue warps (4 ACT + 4 RHT-store) + self.epilog_warp_id = (0, 1, 2, 3, 4, 5, 6, 7) + self.epilog_act_warp_id = (0, 1, 2, 3) + self.epilog_rht_store_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.sched_warp_id = 10 + self.bias_load_warp_id = 11 if enable_bias else None + + self.epilogue_warp_group_size = len(self.epilog_act_warp_id) # = 4 + + all_warps = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id, self.sched_warp_id] + warps_wo_sched = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id] + if enable_bias: + all_warps.append(self.bias_load_warp_id) + warps_wo_sched.append(self.bias_load_warp_id) + self.threads_per_cta = self.threads_per_warp * len(all_warps) + self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched) + + # Named barriers + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + # Pingpong barriers (group 0 = ACT warps, group 1 = RHT store warps) + self.epilog_sync_barrier_group0 = pipeline.NamedBarrier( + barrier_id=5, + num_threads=32 * self.epilogue_warp_group_size, + ) + self.epilog_sync_barrier_group1 = pipeline.NamedBarrier( + barrier_id=6, + num_threads=32 * self.epilogue_warp_group_size, + ) + + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.vectorized_f32 = vectorized_f32 + + # Amax: only RHT store warps (4) do reduction + self.num_epilog_warps = len(self.epilog_rht_store_warp_id) # = 4 + + self.act_func = act_func + if act_func not in ["swiglu", "geglu"]: + raise ValueError(f"Invalid activation function: {act_func}") + + def _setup_attributes(self): + """Set up configurations dependent on GEMM inputs (called inside __call__).""" + + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + self.mma_tiler_d = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1] // 2, + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_d = ( + self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_d[1], + self.mma_tiler_d[2], + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.epi_tile = (128, 32) + self.epi_tile_cnt = ( + self.cta_tile_shape_mnk_d[0] // self.epi_tile[0], + self.cta_tile_shape_mnk_d[1] // self.epi_tile[1], + ) + self.epi_tile_c = (128, 64) + + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_d_stage, + self.num_tile_stage, + self.num_bias_stage, + self.num_pingpong_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.epi_tile_c, + self.c_dtype, + self.c_layout, + self.d_dtype, + self.d_layout, + self.sf_dtype, + self.sf_vec_size, + self.num_smem_capacity, + self.occupancy, + self.bias_dtype if self.enable_bias else None, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile_c, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + if self.enable_bias: + self.bias_smem_layout_staged = cute.make_layout( + (self.mma_tiler[1], self.num_bias_stage), + stride=(1, self.mma_tiler[1]), + ) + else: + self.bias_smem_layout_staged = cute.make_layout((1, 1)) + + self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256 + + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1]) + self.iter_acc_early_release_in_epilogue = ((self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1) * 2 + + def get_desc_workspace_bytes(self) -> int: + """Return descriptor workspace size in bytes.""" + if self.weight_mode == MoEWeightMode.DISCRETE: + from ..moe_utils import DiscreteWeightTensormapConstructor + + return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt) + return 0 + + def get_workspace_bytes(self) -> int: + """Return total workspace size in bytes.""" + desc_workspace_bytes = self.get_desc_workspace_bytes() + dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0 + return desc_workspace_bytes + dynamic_sched_bytes + + @cute.jit + def _get_sched_counter_ptr(self, workspace_ptr): + counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes() + return cute.make_ptr( + cutlass.Int32, + counter_addr, + AddressSpace.gmem, + assumed_align=4, + ) + + @cute.kernel + def helper_kernel( + self, + ptrs_b: cute.Pointer, + ptrs_sfb: cute.Pointer, + n: Int32, + k: Int32, + b_stride_size: cutlass.Int64, + b_major_mode: cutlass.Constexpr, + workspace_ptr, + tiled_mma_arg: cute.TiledMma, + tiled_mma_sfb_arg: cute.TiledMma, + b_smem_layout_arg, + sfb_smem_layout_arg, + cluster_layout_vmnk_shape_arg: cutlass.Constexpr, + cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr, + ): + """Pre-main-kernel: build per-expert TMA descriptors (discrete mode) and/or reset sched counter.""" + expert_idx = cute.arch.block_idx()[0] + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id) + sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id) + + # Read per-expert base addresses from the pointer arrays + b_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8), + cute.make_layout((self.expert_cnt,)), + ) + sfb_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8), + cute.make_layout((self.expert_cnt,)), + ) + + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + stride_n = b_stride_size + stride_k = c1_64 + else: + stride_n = c1_64 + stride_k = b_stride_size + + b_ptr_val = b_ptr_tensor[expert_idx] + b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem) + b_expert = cute.make_tensor( + b_ptr, + cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)), + ) + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + b_tma_op_arg, + b_expert, + b_smem_layout_arg, + self.mma_tiler, + tiled_mma_arg, + cluster_layout_vmnk_shape_arg, + ) + + workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx)) + + sfb_ptr_val = sfb_ptr_tensor[expert_idx] + sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb_expert = cute.make_tensor(sfb_ptr, sfb_layout) + tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( + sfb_tma_op_arg, + sfb_expert, + sfb_smem_layout_arg, + self.mma_tiler_sfb, + tiled_mma_sfb_arg, + cluster_layout_sfb_vmnk_shape_arg, + internal_type=cutlass.Uint64, + ) + store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx)) + + if cutlass.const_expr(self.use_dynamic_sched): + if expert_idx == cutlass.Int32(0): + sched_counter = cute.make_tensor( + self._get_sched_counter_ptr(workspace_ptr), + cute.make_layout(1), + ) + sched_counter[0] = cutlass.Int32(0) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[] + sfa: cute.Tensor, + sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[] + n: Int32, # Ignored for dense mode + k: Int32, # Ignored for dense mode + b_stride_size: cutlass.Int64, # Ignored for dense mode + b_major_mode: cutlass.Constexpr, # Ignored for dense mode + workspace_ptr, + c: cute.Tensor, + d: cute.Tensor, + amax_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + prob: cute.Tensor, + hadamard_tensor: Optional[cute.Tensor], + bias: Optional[cute.Tensor], + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + linear_offset: cutlass.Float32 = 0.0, + ): + """Execute the MoE GEMM + GLU + Hadamard kernel. + + Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L). + Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[] + arrays of per-expert base addresses. + """ + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = a.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.d_dtype: Type[cutlass.Numeric] = d.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type + self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16 + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + self.d_layout = utils.LayoutEnum.from_tensor(d) + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + else: + self.b_major_mode = b_major_mode + + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + self._setup_attributes() + + # ---- B / SFB setup (mode-dependent) ---- + b_from_call_arg = b + sfb_from_call_arg = sfb + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) + sfb = cute.make_tensor(sfb.iterator, sfb_layout) + else: + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + b_template_stride = (b_stride_size, c1_64, c0) + else: + b_template_stride = (c1_64, b_stride_size, c0) + b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride) + b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16) + b = cute.make_tensor(b_ptr_typed, b_template_layout) + + sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout) + + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size) + sfa = cute.make_tensor(sfa.iterator, sfa_layout) + + self.generate_amax = amax_tensor is not None + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # TMA load A + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load B + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # TMA load SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb = cute.make_tensor( + tma_tensor_sfb.iterator, + cute.make_layout(new_shape, stride=new_stride), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + # TMA store C + c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + c_smem_layout, + self.epi_tile_c, + ) + + # TMA store D + d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d, + d_smem_layout, + self.epi_tile, + ) + + # ---- Helper kernel (discrete TMA desc init + dynamic sched counter reset) ---- + _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched) + if cutlass.const_expr(_need_helper): + _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1 + _helper_args = ( + b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0), + b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode, + workspace_ptr, + tiled_mma, + tiled_mma_sfb, + b_smem_layout, + sfb_smem_layout, + self.cluster_layout_vmnk.shape, + self.cluster_layout_sfb_vmnk.shape, + ) + self.helper_kernel(*_helper_args).launch( + grid=(_helper_grid_x, 1, 1), + block=(1, 1, 1), + stream=stream, + min_blocks_per_mp=1, + ) + + # ---- Grid computation via MoE scheduler ---- + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + b_n, b_k, b_l = cute.shape(b) + sched_expert_shape = (self.expert_cnt, b_n, b_k) + else: + sched_expert_shape = (self.expert_cnt, n, k) + + sched_params = MoESchedulerParams( + scenario="2Dx3D", + expert_shape=sched_expert_shape, + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + use_dynamic_sched=self.use_dynamic_sched, + ) + self.sched_params, grid = compute_grid( + sched_params, + max_active_clusters, + self.use_2cta_instrs, + ) + + self.buffer_align_bytes = 1024 + + # ---- Shared storage ---- + SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched) + + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + hadamard_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + hadamard_prerequisite_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + pingpong_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_pingpong_stage * 2] + if cutlass.const_expr(self.enable_bias): + bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2] + scheduler: SchedulerStorage + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + sC: cute.struct.Align[ + cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + if cutlass.const_expr(self.enable_bias): + sBias: cute.struct.Align[ + cute.struct.MemRange[self.bias_dtype, cute.cosize(self.bias_smem_layout_staged)], + 16, + ] + sHadamard: cute.struct.Align[cute.struct.MemRange[cutlass.BFloat16, HADAMARD_SIZE * HADAMARD_SIZE], 16] + sBH: cute.struct.Align[cute.struct.MemRange[cutlass.BFloat16, HADAMARD_SIZE * HADAMARD_SIZE], 16] + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + 1, + ] + + self.shared_storage = SharedStorage + + # Launch main kernel + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + amax_tensor, + padded_offsets, + alpha, + bias, + prob, + hadamard_tensor, + workspace_ptr, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.bias_smem_layout_staged, + self.epi_tile, + self.sched_params, + epilogue_op, + linear_offset, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + max_number_threads=[self.threads_per_cta, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @cute.jit + def _make_extension(self, workspace_ptr): + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + return DiscreteWeightScaledGemmSchedExtension( + tensormap_ctor=desc_workspace, + sf_vec_size=self.sf_vec_size, + ) + else: + return ContiguousAndConsistentGroupedGemmSchedExtension( + sf_vec_size=self.sf_vec_size, + ) + + def mainloop_s2t_copy_and_partition(self, sSF, tSF): + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + @cute.jit + def amax_reduction_per_thread(self, vec_fp32, amax_fp32): + vec_fp32_ssa = vec_fp32.load() + import cutlass._mlir.dialects.math as _math + + abs_acc_values_ir = _math.absf(vec_fp32_ssa.ir_value()) + abs_acc_values = type(vec_fp32_ssa)(abs_acc_values_ir, vec_fp32_ssa.shape, vec_fp32_ssa.dtype) + subtile_amax = abs_acc_values.reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) + return cute.arch.fmax(amax_fp32, subtile_amax) + + @cute.jit + def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem): + warp_amax = warp_redux_sync(value=amax_fp32, kind=ReduxKind.MAX, mask_and_clamp=0xFFFFFFFF, nan=True) + if cute.arch.lane_idx() == 0: + amax_smem[warp_idx & 0x3] = cutlass.Float32(warp_amax) + self.epilog_sync_barrier_group1.arrive_and_wait() + if warp_idx == self.epilog_rht_store_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) + for i in cutlass.range(self.num_epilog_warps): + block_amax = cute.arch.fmax(block_amax, amax_smem[i]) + _ = atomic_max_float32(ptr=amax_gmem, value=block_amax) + + @cute.jit + def store_c( + self, + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc, + tTR_rAcc_up, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + prev_subtile_idx, + real_subtile_idx, + ): + c_buffer = prev_subtile_idx % self.num_c_stage + tRS_rC.store(tTR_rAcc.load().to(self.c_dtype)) + cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)]) + tRS_rC.store(tTR_rAcc_up.load().to(self.c_dtype)) + cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 1, c_buffer)]) + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier_group0.arrive_and_wait() + if warp_idx == self.epilog_act_warp_id[0]: + cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)]) + c_pipeline.producer_commit() + if not cutlass.const_expr(self.delay_tma_store_acquire_sync): + c_pipeline.producer_acquire() + if not cutlass.const_expr(self.delay_tma_store_acquire_sync): + self.epilog_sync_barrier_group0.arrive_and_wait() + + @cute.jit + def geglu_act(self, tCompute, acc_vec_up, acc_vec_gate, mProb, linear_offset=1.0): + if cutlass.const_expr(self.vectorized_f32): + LOG2_E = cutlass.Float32(1.4426950408889634) + for i in cutlass.range_constexpr(0, cute.size(tCompute), 2): + scaled_gate_0, scaled_gate_1 = cute.arch.mul_packed_f32x2( + (acc_vec_gate[i], acc_vec_gate[i + 1]), + (1.702, 1.702), + rnd="rn", + ftz=False, + ) + tCompute_log2e = cute.arch.mul_packed_f32x2( + (scaled_gate_0, scaled_gate_1), + (-LOG2_E, -LOG2_E), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.add_packed_f32x2( + (cute.math.exp2(tCompute_log2e[0], fastmath=True), cute.math.exp2(tCompute_log2e[1], fastmath=True)), + (1.0, 1.0), + ) + tCompute[i] = cute.arch.rcp_approx(tCompute[i]) + tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_gate[i], acc_vec_gate[i + 1]), + rnd="rn", + ftz=False, + ) + up0, up1 = cute.arch.add_packed_f32x2( + (linear_offset, linear_offset), + (acc_vec_up[i], acc_vec_up[i + 1]), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (up0, up1), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (mProb, mProb), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tCompute)): + tCompute[i] = (acc_vec_up[i] + linear_offset) * silu_f32_geglu_scaled(acc_vec_gate[i], fastmath=True) + tCompute[i] = tCompute[i] * mProb + + @cute.jit + def swiglu_act(self, tCompute, acc_vec_up, acc_vec_gate, mProb): + if cutlass.const_expr(self.vectorized_f32): + LOG2_E = cutlass.Float32(1.4426950408889634) + for i in cutlass.range_constexpr(0, cute.size(tCompute), 2): + tCompute_log2e = cute.arch.mul_packed_f32x2( + (acc_vec_gate[i], acc_vec_gate[i + 1]), + (-LOG2_E, -LOG2_E), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.add_packed_f32x2( + (cute.math.exp2(tCompute_log2e[0], fastmath=True), cute.math.exp2(tCompute_log2e[1], fastmath=True)), + (1.0, 1.0), + ) + tCompute[i] = cute.arch.rcp_approx(tCompute[i]) + tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_gate[i], acc_vec_gate[i + 1]), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_up[i], acc_vec_up[i + 1]), + rnd="rn", + ftz=False, + ) + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (mProb, mProb), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tCompute)): + tCompute[i] = acc_vec_up[i] * silu_f32(acc_vec_gate[i], fastmath=True) + tCompute[i] = tCompute[i] * mProb + + @cute.jit + def query_hadamard_tmem_a_ptr(self, tile_idx, reverse_subtile, tmem_ptr): + hadamard_tmem_offset = 256 + (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) + if reverse_subtile: + hadamard_tmem_offset = 256 - self.num_sf_tmem_cols - HADAMARD_SIZE - (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) + return cute.recast_ptr(tmem_ptr + hadamard_tmem_offset, dtype=self.acc_dtype) + + @cute.jit + def query_hadamard_tmem_acc_ptr(self, tile_idx, reverse_subtile, tmem_ptr): + hadamard_tmem_offset = 256 + (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) + HADAMARD_SIZE + if reverse_subtile: + hadamard_tmem_offset = 256 - self.num_sf_tmem_cols - HADAMARD_SIZE - (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) - self.epi_tile[1] + return cute.recast_ptr(tmem_ptr + hadamard_tmem_offset, dtype=self.acc_dtype) + + def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs): + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.d_layout, + self.d_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi) + tTR_rAcc_gate = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + tTR_rAcc_up = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc_gate, tTR_rAcc_up + + def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rC, tidx, sD): + copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) + tRS_rD = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rD, tRS_sD + + def epilog_gmem_copy_and_partition(self, tidx, atom, gD_mnl, epi_tile, sD): + gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_d = atom + sD_for_tma_partition = cute.group_modes(sD, 0, 2) + gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2) + bSG_sD, bSG_gD = cpasync.tma_partition( + tma_atom_d, + 0, + cute.make_layout(1), + sD_for_tma_partition, + gD_for_tma_partition, + ) + return tma_atom_d, bSG_sD, bSG_gD + + @staticmethod + def _compute_stages( + tiled_mma, + mma_tiler_mnk, + a_dtype, + b_dtype, + epi_tile, + epi_tile_c, + c_dtype, + c_layout, + d_dtype, + d_layout, + sf_dtype, + sf_vec_size, + num_smem_capacity, + occupancy, + bias_dtype, + ): + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + num_c_stage = 1 + num_d_stage = 1 + num_tile_stage = 2 + num_pingpong_stage = mma_tiler_mnk[1] // epi_tile_c[1] + + a_smem_layout_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1) + b_smem_layout_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1) + sfa_smem_layout_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + sfb_smem_layout_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + c_smem_layout_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile_c, 1) + d_smem_layout_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_one) + ) + mbar_helpers_bytes = 1024 + # sInfo is in SchedulerStorage, not here, so use 4-int sInfo + sinfo_bytes = 4 * 4 * num_tile_stage + c_bytes = cute.size_in_bytes(c_dtype, c_smem_layout_one) * num_c_stage + d_bytes = cute.size_in_bytes(d_dtype, d_smem_layout_one) * num_d_stage + hadamard_bytes = (cutlass.BFloat16.width // 8) * HADAMARD_SIZE * HADAMARD_SIZE if d_dtype == cutlass.BFloat16 else 0 + # sAmax for 4 RHT store warps + amax_bytes = 4 * 4 # 4 Float32 values + + if bias_dtype is not None: + num_bias_stage = 2 + bias_bytes = mma_tiler_mnk[1] * num_bias_stage * (bias_dtype.width // 8) + else: + num_bias_stage = 0 + bias_bytes = 0 + + epi_bytes = c_bytes + d_bytes + hadamard_bytes + amax_bytes + bias_bytes + + num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage + + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage, num_pingpong_stage + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, + mAmax_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + mBias_nl: Optional[cute.Tensor], + prob: cute.Tensor, + hadamard_tensor: Optional[cute.Tensor], + workspace_ptr, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + bias_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + sched_params: MoESchedulerParams, + epilogue_op: cutlass.Constexpr, + linear_offset: cutlass.Float32 = 0.0, + ): + """GPU device kernel: MoE persistent GEMM + GLU + Hadamard (pingpong epilogue).""" + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + cta_rank = cute.arch.block_in_cluster_idx()[0] + + total_token = padded_offsets[self.expert_cnt - 1] + + # Prefetch TMA descriptors + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_sfa) + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + cpasync.prefetch_descriptor(tma_atom_d) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # CTA coordinates + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + tidx, _, _ = cute.arch.thread_idx() + + # Shared memory allocation + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sched_storage = storage.scheduler + + # AB pipeline + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # ACC pipeline + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_act_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Hadamard SMEM tensor + hadamard_b_layout_staged = cute.make_layout((16, 8), stride=(1, 16)) + sHadamard = storage.sHadamard.get_tensor(hadamard_b_layout_staged) + hadamard_h_iter = hadamard_tensor.iterator + (cutlass.Int32(128) if cta_rank == 1 else cutlass.Int32(0)) + hadamard_tensor_local = cute.make_tensor( + cute.make_ptr( + cutlass.BFloat16, + hadamard_h_iter.toint(), + hadamard_h_iter.memspace, + assumed_align=16, + ), + hadamard_b_layout_staged, + ) + + # Hadamard UMMA pipeline (1 stage) + hadamard_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_hadamard_consumer_threads = 128 * (2 if use_2cta_instrs else 1) + hadamard_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_consumer_threads) + hadamard_producer, hadamard_consumer = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.hadamard_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=hadamard_pipeline_producer_group, + consumer_group=hadamard_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ).make_participants() + + # Hadamard prerequisite pipeline (cross-CTA sync, 1 stage) + num_hadamard_prerequisite_threads = 128 + hadamard_prerequisite_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_prerequisite_threads) + hadamard_prerequisite_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_prerequisite_threads) + peer_rank = cutlass.Int32(1) - cta_rank + hadamard_prerequisite_producer, hadamard_prerequisite_consumer = pipeline.PipelineAsync.create( + barrier_storage=storage.hadamard_prerequisite_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=hadamard_prerequisite_pipeline_producer_group, + consumer_group=hadamard_prerequisite_pipeline_consumer_group, + producer_mask=peer_rank, + defer_sync=True, + ).make_participants() + + # Pingpong pipeline + pingpong_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilog_act_warp_id) * self.threads_per_warp) + pingpong_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilog_rht_store_warp_id) * self.threads_per_warp) + pingpong_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.pingpong_mbar_ptr.data_ptr(), + num_stages=self.num_pingpong_stage, + producer_group=pingpong_producer_group, + consumer_group=pingpong_consumer_group, + ) + + # Tile info pipeline (uses SchedulerStorage's barrier) + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=sched_storage.tile_info_mbar.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # MoE persistent tile scheduler + scheduler = MoEPersistentTileScheduler.create( + sched_params, + padded_offsets, + cute.arch.block_idx(), + cute.arch.grid_dim(), + counter_ptr=self._get_sched_counter_ptr(workspace_ptr), + sched_storage=sched_storage, + ) + scheduler.internal_init() + + # Bias pipeline + if cutlass.const_expr(self.enable_bias): + bias_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp, + ) + bias_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_act_warp_id), + ) + bias_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.bias_mbar_ptr.data_ptr(), + num_stages=self.num_bias_stage, + producer_group=bias_pipeline_producer_group, + consumer_group=bias_pipeline_consumer_group, + ) + sBias = storage.sBias.get_tensor(bias_smem_layout_staged) + gBias_nl = cute.local_tile(mBias_nl, cute.slice_(self.mma_tiler[:2], (0, None)), (None, None)) + + # TMEM allocator + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_act_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # SMEM tensors + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + + # sInfo from SchedulerStorage + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = sched_storage.sInfo.get_tensor(info_layout) + + # Multicast masks + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + # MMA fragments + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped)) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # Cluster wait / CTA sync + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + if total_token <= 0: + cute.arch.nvvm.exit() + + # --------------------------------------------------------------- + # Specialized Scheduler warp (MoEPersistentTileScheduler) + # --------------------------------------------------------------- + if warp_idx == self.sched_warp_id: + work_tile_info = scheduler.initial_work_tile_info() + tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage) + + while work_tile_info.is_valid_tile: + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx + sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx + sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx + sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + work_tile_info = scheduler.advance_to_next_work() + + # Send invalid signal: expert_idx = -1 + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1) + sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # --------------------------------------------------------------- + # Specialized TMA load warp + # --------------------------------------------------------------- + if warp_idx == self.tma_warp_id: + ext = self._make_extension(workspace_ptr) + + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + k_tile_cnt = work_tile_info.k_tile_cnt + ext.update_expert_info(padded_offsets, work_tile_info.expert_idx) + + real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info) + real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info) + real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info) + real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info) + + gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None)) + + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape) + mma_tile_coord_n = work_tile_info.tile_n_idx + tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)] + tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)] + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)] + slice_n = mma_tile_coord_n + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_n // 2 + tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)] + + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] + tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + ab_producer_state_next = ab_producer_state.clone() + ab_producer_state_next.advance() + if ab_producer_state_next.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state_next) + else: + peek_ab_empty_status = cutlass.Boolean(1) + + cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask) + cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b) + cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask) + cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb) + + ab_producer_state.advance() + + # Advance to next tile + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + ab_pipeline.producer_tail(ab_producer_state) + + # --------------------------------------------------------------- + # Specialized MMA warp + # --------------------------------------------------------------- + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + sfa_tmem_ptr = cute.recast_ptr(acc_tmem_ptr + self.num_accumulator_tmem_cols, dtype=self.sf_dtype) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acd_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + k_tile_cnt = tile_info[3] + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + acd_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + mma_tile_coord_mnl = ( + tile_info[1] // cute.size(tiled_mma.thr_id.shape), + tile_info[2], + tile_info[0], + ) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acd_producer_state.phase ^ 1 + else: + acc_stage_index = acd_producer_state.index + + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + if is_leader_cta: + acc_pipeline.producer_acquire(acd_producer_state, peek_acc_empty_status) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + if is_leader_cta: + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + s2t_stage_coord = (None, None, None, None, ab_consumer_state.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t) + cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t) + + num_kblocks = cute.size(tCrA, mode=[2]) + ab_consumer_state_next = ab_consumer_state.clone() + ab_consumer_state_next.advance() + if ab_consumer_state_next.count < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next) + + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator) + tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator) + cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state = ab_consumer_state_next + + if is_leader_cta: + acc_pipeline.producer_commit(acd_producer_state) + + acd_producer_state.advance() + if acd_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + acc_pipeline.producer_tail(acd_producer_state) + + # --------------------------------------------------------------- + # Specialized bias load warp + # --------------------------------------------------------------- + if cutlass.const_expr(self.enable_bias): + if warp_idx == self.bias_load_warp_id and total_token > 0: + bias_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_bias_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + bias_elems_per_thread = 128 // self.bias_dtype.width + bias_g2s_atom = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + self.bias_dtype, + num_bits_per_copy=128, + ) + bias_g2s_tiled = cute.make_tiled_copy_tv( + bias_g2s_atom, + cute.make_layout((self.threads_per_warp,)), + cute.make_layout((bias_elems_per_thread,)), + ) + thr_bias_g2s = bias_g2s_tiled.get_slice(cute.arch.lane_idx()) + tBs_sBias = thr_bias_g2s.partition_D(sBias) + + bias_n_total = mBias_nl.shape[0] + tBpBias = cute.make_rmem_tensor(cute.make_layout((1,)), cutlass.Boolean) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + bias_producer_state.reset_count() + mma_n_coord = tile_info[2] + expert_idx = tile_info[0] + gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)] + tBs_gBias = thr_bias_g2s.partition_S(gBias_tile) + tBpBias[0] = mma_n_coord * self.mma_tiler[1] + cute.arch.lane_idx() * bias_elems_per_thread < bias_n_total + bias_pipeline.producer_acquire(bias_producer_state) + cute.copy( + bias_g2s_tiled, + tBs_gBias[(None, 0)], + tBs_sBias[(None, 0, bias_producer_state.index)], + pred=tBpBias, + ) + bias_pipeline.producer_commit(bias_producer_state) + bias_producer_state.advance() + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + bias_pipeline.producer_tail(bias_producer_state) + + # --------------------------------------------------------------- + # Specialized ACT epilogue warps (0-3): TMEM→regs, alpha, GLU activation, + # C store, hadamard_in + # --------------------------------------------------------------- + if warp_idx < self.epilog_rht_store_warp_id[0] and total_token > 0: + epi_tidx = tidx + + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue (shape-only via mD_mnl for invariant setup) + # + thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v) + gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape) + + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc_gate, + tTR_rAcc_up, + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD_shape, epi_tile, use_2cta_instrs) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc_gate.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) + + # + # Create per-expert extension (for C/prob tensors inside tile loop) + # + epi_ext = self._make_extension(workspace_ptr) + + # + # Persistent tile scheduling state + # + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + + # + # Pingpong producer state + # + pingpong_act_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_pingpong_stage) + + # Threads/warps participating in TMA store pipeline for C + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_act_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + if cutlass.const_expr(self.enable_bias): + bias_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_bias_stage) + bias_s2r_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.bias_dtype, num_bits_per_copy=128) + tTR_rBias_gate = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype) + tTR_rBias_up = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + # sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) + epi_work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + mma_tile_coord_mnl = ( + epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + epi_work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + + expert_idx = epi_work_tile_info.expert_idx + alpha_val = alpha[expert_idx] + epi_ext.update_expert_info(padded_offsets, epi_work_tile_info.expert_idx) + + if cutlass.const_expr(self.enable_bias): + bias_consumer_state.reset_count() + bias_pipeline.consumer_wait(bias_consumer_state) + sBias_stage = sBias[(None, bias_consumer_state.index)] + sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(2 * self.epi_tile[1])) + + # + # Get per-expert C tensor inside tile loop + # + real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info) + gC_mnl = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) + thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v) + tCgC = thr_mma_epi_loop.partition_C(gC_mnl) + _, bSG_sC, bSG_gC_partitioned = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, self.epi_tile_c, sC) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Get per-expert prob tensor inside tile loop + # + real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info) + mPosition = ( + (epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape)) * self.mma_tiler[0] + + mma_tile_coord_v * (self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape)) + + tidx + ) + mProb = real_prob[mPosition, 0, 0] + + # + # Get accumulator stage index + # + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = acc_consumer_state.index + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, STAGE) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(0, subtile_cnt, 2, unroll=1): + real_subtile_idx = subtile_idx // 2 + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx // 2 + + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2)] + tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)] + + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate) + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up) + + # + # Async arrive accumulator buffer empty earlier when overlapping_accum is enabled + # + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Notify pingpong consumer for subtile > 0 + # + if subtile_idx != 0: + pingpong_pipeline.producer_commit(pingpong_act_producer_state) + pingpong_act_producer_state.advance() + + # + # Apply alpha (+ bias when enabled) + # + if cutlass.const_expr(self.enable_bias): + sBias_pair = sBias_subtiles[(None, real_subtile_idx)] + sBias_sub = cute.flat_divide(sBias_pair, cute.make_layout(self.epi_tile[1])) + cute.copy(bias_s2r_atom, sBias_sub[(None, 0)], tTR_rBias_gate) + bias_vec_gate = tTR_rBias_gate.load() + cute.copy(bias_s2r_atom, sBias_sub[(None, 1)], tTR_rBias_up) + bias_vec_up = tTR_rBias_up.load() + + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_gate), 2): + bias_gate_f32_0 = bias_vec_gate[i].to(cutlass.Float32) + bias_gate_f32_1 = bias_vec_gate[i + 1].to(cutlass.Float32) + bias_up_f32_0 = bias_vec_up[i].to(cutlass.Float32) + bias_up_f32_1 = bias_vec_up[i + 1].to(cutlass.Float32) + tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + (bias_gate_f32_0, bias_gate_f32_1), + rnd="rn", + ftz=False, + ) + tTR_rAcc_up[i], tTR_rAcc_up[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rAcc_up[i], tTR_rAcc_up[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + (bias_up_f32_0, bias_up_f32_1), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_gate)): + tTR_rAcc_gate[i] = tTR_rAcc_gate[i] * cutlass.Float32(alpha_val) + bias_vec_gate[i].to(cutlass.Float32) + tTR_rAcc_up[i] = tTR_rAcc_up[i] * cutlass.Float32(alpha_val) + bias_vec_up[i].to(cutlass.Float32) + + if subtile_idx == subtile_cnt - 2: + bias_pipeline.consumer_release(bias_consumer_state) + bias_consumer_state.advance() + else: + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_gate), 2): + tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + rnd="rn", + ftz=False, + ) + tTR_rAcc_up[i], tTR_rAcc_up[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc_up[i], tTR_rAcc_up[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_gate)): + tTR_rAcc_gate[i] = tTR_rAcc_gate[i] * cutlass.Float32(alpha_val) + tTR_rAcc_up[i] = tTR_rAcc_up[i] * cutlass.Float32(alpha_val) + + # + # Store gate+up to C tensor (pre-activation for residual) + # + self.store_c( + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc_gate, + tTR_rAcc_up, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + num_prev_subtiles, + real_subtile_idx, + ) + num_prev_subtiles = num_prev_subtiles + 1 + + # + # GeGLU clamp before C store + # + if cutlass.const_expr(self.act_func == "geglu"): + geglu_max_val = cutlass.Float32(7.0) + geglu_min_val = cutlass.Float32(-7.0) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + tTR_rAcc_gate[i] = fmin(tTR_rAcc_gate[i], geglu_max_val) + tTR_rAcc_up[i] = fmin(tTR_rAcc_up[i], geglu_max_val) + tTR_rAcc_up[i] = fmax(tTR_rAcc_up[i], geglu_min_val) + + acc_vec_gate = tTR_rAcc_gate.load() + acc_vec_up = tTR_rAcc_up.load() + + # + # Compute GLU activation (SwiGLU or GeGLU) + # + tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype) + if cutlass.const_expr(self.act_func == "geglu"): + self.geglu_act(tCompute, acc_vec_up, acc_vec_gate, mProb, linear_offset) + elif cutlass.const_expr(self.act_func == "swiglu"): + self.swiglu_act(tCompute, acc_vec_up, acc_vec_gate, mProb) + + # + # Convert to BF16 and write to TMEM for Hadamard + # + tCompute_hadamard = cute.make_rmem_tensor(tCompute.layout, cutlass.BFloat16) + tCompute_hadamard.store(tCompute.load().to(tCompute_hadamard.element_type)) + + # + # Pingpong producer acquire for current subtile + # + pingpong_pipeline.producer_acquire(pingpong_act_producer_state) + hadamard_in( + tCompute_hadamard, + HADAMARD_SIZE, + self.query_hadamard_tmem_a_ptr(subtile_idx, reverse_subtile, tmem_ptr), + epi_tidx, + ) + + # + # Delayed TMA store acquire + group sync (always enabled) + # + if cutlass.const_expr(self.delay_tma_store_acquire_sync): + if warp_idx == self.epilog_act_warp_id[0]: + c_pipeline.producer_acquire() + self.epilog_sync_barrier_group0.arrive_and_wait() + + # + # Pingpong producer commit last subtile + # + pingpong_pipeline.producer_commit(pingpong_act_producer_state) + pingpong_act_producer_state.advance() + + # + # Full epilogue barrier (ACT + RHT must both arrive) + # + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier_group0.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store / pingpong complete + # + c_pipeline.producer_tail() + pingpong_pipeline.producer_tail(pingpong_act_producer_state) + + # --------------------------------------------------------------- + # Specialized RHT store warps (4-7): hadamard_compute + D store + # --------------------------------------------------------------- + if warp_idx < self.mma_warp_id and warp_idx >= self.epilog_rht_store_warp_id[0] and total_token > 0: + epi_tidx = tidx % 128 + + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # + # Hadamard setup (loads Hadamard matrix from GMEM to SMEM) + # + tiled_hmma = hadamard_setup(hadamard_tensor_local, sHadamard, epi_tidx) + + # + # Partition for epilogue (shape-only via mD_mnl) + # + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v) + gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape) + + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc_gate, + tTR_rAcc_up, + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD_shape, epi_tile, use_2cta_instrs) + + # + # Pingpong consumer state + # + pingpong_rht_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_pingpong_stage) + + # + # D register and smem partition + # + tTR_rD = cute.make_rmem_tensor(tTR_rAcc_gate.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD) + + # + # Create per-expert extension (for D tensor inside tile loop) + # + epi_ext = self._make_extension(workspace_ptr) + + # Threads/warps participating in TMA store pipeline for D + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_rht_store_warp_id), + ) + d_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Accumulator stage index (for overlapping_accum) + # + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = 0 + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = 0 + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + # sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) + epi_work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + mma_tile_coord_mnl = ( + epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + epi_work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + expert_idx = epi_work_tile_info.expert_idx + epi_ext.update_expert_info(padded_offsets, epi_work_tile_info.expert_idx) + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + # + # Get per-expert D tensor inside tile loop + # + real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info) + gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v) + tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop) + _, bSG_sD, bSG_gD_partitioned = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD_loop, epi_tile, sD) + bSG_gD = bSG_gD_partitioned[(None, None, None, *mma_tile_coord_mnl)] + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + + # + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, STAGE) + # + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(0, subtile_cnt, 2, unroll=1): + real_subtile_idx = subtile_idx // 2 + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx // 2 + + # + # Get Hadamard TMEM pointers for this subtile + # + hadamard_tmem_a_ptr = self.query_hadamard_tmem_a_ptr(subtile_idx, reverse_subtile, tmem_ptr) + hadamard_tmem_acc_ptr = self.query_hadamard_tmem_acc_ptr(subtile_idx, reverse_subtile, tmem_ptr) + tCompute = cute.make_rmem_tensor(tTR_rAcc_gate.layout, cutlass.Float32) + + # + # Cross-CTA sync (prerequisite for Hadamard: both CTAs must have + # written hadamard_in before either starts hadamard_compute) + # + hadamard_prerequisite_empty = hadamard_prerequisite_producer.acquire_and_advance() + hadamard_prerequisite_empty.commit() + hadamard_prerequisite_empty = hadamard_prerequisite_consumer.wait_and_advance() + hadamard_prerequisite_consumer.release(hadamard_prerequisite_empty) + + # + # Wait for pingpong producer (ACT warp) ready + # + pingpong_pipeline.consumer_wait(pingpong_rht_consumer_state) + + # + # Apply Hadamard transform (reads from TMEM a_ptr, writes to TMEM acc_ptr) + # + hadamard_compute( + tiled_hmma, + hadamard_tmem_a_ptr, + hadamard_tmem_acc_ptr, + storage.sHadamard, + epi_tile, + epi_tidx, + hadamard_producer, + ) + hadamard_consumer.wait_and_advance() + + # + # Get Hadamard result from TMEM to registers + # + hadamard_out(tCompute, epi_tile[1], hadamard_tmem_acc_ptr, epi_tidx) + + # + # Release pingpong consumer slot + # + pingpong_pipeline.consumer_release(pingpong_rht_consumer_state) + pingpong_rht_consumer_state.advance() + + # + # Amax accumulation per subtile + # + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = self.amax_reduction_per_thread(tCompute, thread_tile_amax) + + # + # Convert to D dtype and store D to shared memory + # + acc_vec = tiled_copy_r2s.retile(tCompute).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + d_buffer = num_prev_subtiles % self.num_d_stage + num_prev_subtiles = num_prev_subtiles + 1 + cute.copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[(None, None, None, d_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier_group1.arrive_and_wait() + # + # TMA store D to global memory + # + if warp_idx == self.epilog_rht_store_warp_id[0]: + cute.copy( + tma_atom_d, + bSG_sD[(None, d_buffer)], + bSG_gD[(None, real_subtile_idx)], + ) + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + self.epilog_sync_barrier_group1.arrive_and_wait() + + # + # Full epilogue barrier (ACT + RHT must both arrive) + # + self.epilog_sync_barrier.arrive_and_wait() + + # + # Update overlapping_accum stage index for next tile + # + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_stage_index ^ 1 + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = 0 + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Amax reduction per tile (across warps and CTAs) + # + if cutlass.const_expr(self.generate_amax): + gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr # First element + self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax) + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier_group1.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for D store complete + # + d_pipeline.producer_tail() + + # END OF KERNEL + + +class BlockScaledMoEGroupedGemmGluHadamardCompatKernel(BlockScaledMoEGroupedGemmGluHadamardKernel): + """Compatibility adapter for the existing grouped GLU frontend wrapper shape.""" + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vectorized_f32: bool, + generate_sfd: bool, + discrete_col_sfd: bool, + expert_cnt: int, + weight_mode: MoEWeightMode = MoEWeightMode.DISCRETE, + use_dynamic_sched: bool = False, + act_func: str = "swiglu", + enable_bias: bool = False, + ): + del generate_sfd, discrete_col_sfd + super().__init__( + sf_vec_size=sf_vec_size, + acc_dtype=acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + vectorized_f32=vectorized_f32, + expert_cnt=expert_cnt, + weight_mode=weight_mode, + use_dynamic_sched=use_dynamic_sched, + act_func=act_func, + enable_bias=enable_bias, + ) + + def __call__( + self, + a: cute.Tensor, + b, + sfb, + n: Int32, + k: Int32, + b_stride_size: cutlass.Int64, + b_major_mode: cutlass.Constexpr, + workspace_ptr, + c: cute.Tensor, + d: cute.Tensor, + d_col: cute.Tensor, + sfa: cute.Tensor, + sfd_row_tensor: Optional[cute.Tensor], + sfd_col_tensor: Optional[cute.Tensor], + amax_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + prob: cute.Tensor, + bias: Optional[cute.Tensor], + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + linear_offset: cutlass.Float32 = 0.0, + ): + del d_col, sfd_row_tensor, sfd_col_tensor, norm_const_tensor + return super().__call__( + a=a, + b=b, + sfa=sfa, + sfb=sfb, + n=n, + k=k, + b_stride_size=b_stride_size, + b_major_mode=b_major_mode, + workspace_ptr=workspace_ptr, + c=c, + d=d, + amax_tensor=amax_tensor, + padded_offsets=padded_offsets, + alpha=alpha, + prob=prob, + hadamard_tensor=None, + bias=bias, + max_active_clusters=max_active_clusters, + stream=stream, + epilogue_op=epilogue_op, + linear_offset=linear_offset, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py b/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py index be30580f..f2db2cfa 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py @@ -75,7 +75,7 @@ def __init__( sample_padded_offsets: torch.Tensor, sample_alpha: torch.Tensor, sample_d: torch.Tensor, - sample_d_col: torch.Tensor, + sample_d_col: Optional[torch.Tensor] = None, # Dense mode (contiguous) -- provide these: sample_b: Optional[torch.Tensor] = None, sample_sfb: Optional[torch.Tensor] = None, @@ -90,8 +90,6 @@ def __init__( sample_amax: Optional[torch.Tensor] = None, sample_norm_const: Optional[torch.Tensor] = None, sample_prob: Optional[torch.Tensor] = None, - # Internal: C tensor placeholder (kernel compilation requires it) - sample_c: Optional[torch.Tensor] = None, # Configuration acc_dtype: torch.dtype = torch.float32, mma_tiler_mn: Tuple[int, int] = (256, 256), @@ -110,7 +108,7 @@ def __init__( :param sample_padded_offsets: End offset for each expert after padding, shape (expert_cnt,) :param sample_alpha: Per-group alpha scaling factors :param sample_d: Sample D output tensor (valid_m, n, 1) - :param sample_d_col: Column-quantized D tensor (required for quant kernel) + :param sample_d_col: Optional column-quantized D tensor. Required only when SFD outputs are generated. :param sample_b: (Dense) Sample B tensor (n, k, l) :param sample_sfb: (Dense) Sample scale factor B tensor :param sample_bias: Optional bias tensor with shape (n, l) or (n, expert_cnt), stride (1, n). @@ -123,7 +121,6 @@ def __init__( :param sample_amax: Optional amax tensor for quantization :param sample_norm_const: Optional normalization constant :param sample_prob: Optional probability tensor for gating - :param sample_c: Internal C tensor placeholder (kernel requires it for dtype inference) :param acc_dtype: Accumulator data type :param mma_tiler_mn: MMA tiler shape (M, N) :param cluster_shape_mn: Cluster shape (M, N) @@ -136,7 +133,7 @@ def __init__( """ super().__init__() - self._logger.warning("GroupedGemmQuantSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") # ---- Weight mode auto-detection ---- @@ -157,7 +154,17 @@ def __init__( self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets") self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha") + self._has_d_col = sample_d_col is not None self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col") + if self.d_col_desc is None: + self.d_col_desc = TensorDesc( + dtype=self.d_desc.dtype, + shape=self.d_desc.shape, + stride=self.d_desc.stride, + stride_order=self.d_desc.stride_order, + device=self.d_desc.device, + name="sample_d_col", + ) self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row") self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col") self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax") @@ -169,20 +176,6 @@ def __init__( self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias") - # C tensor: required by kernel for dtype inference but never written to (generate_c=False). - # If not provided, derive from D descriptor with bfloat16 dtype. - if sample_c is not None: - self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") - else: - self.c_desc = TensorDesc( - dtype=torch.bfloat16, - shape=self.d_desc.shape, - stride=self.d_desc.stride, - stride_order=self.d_desc.stride_order, - device=self.d_desc.device, - name="sample_c", - ) - if self.weight_mode == MoEWeightMode.DENSE: self.b_desc = self._make_tensor_desc(sample_b, name="sample_b") self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") @@ -218,8 +211,9 @@ def __init__( self._kernel = BlockScaledMoEGroupedGemmQuantKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._workspace = None + self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" self._logger.debug("__init__ completed") def check_support(self) -> bool: @@ -236,6 +230,10 @@ def check_support(self) -> bool: "sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None", ) self.generate_sfd = all_provided + self._value_error_if( + self.generate_sfd and not self._has_d_col, + "sample_d_col is required when SFD outputs are generated", + ) if self.discrete_col_sfd and not self.generate_sfd: self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored") self.discrete_col_sfd = False @@ -253,13 +251,11 @@ def check_support(self) -> bool: self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})") l = self.expert_cnt - _, _, _one = self._tensor_shape(self.c_desc, name="sample_c") _, _, _one = self._tensor_shape(self.d_desc, name="sample_d") self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A") if self.weight_mode == MoEWeightMode.DENSE: self._check_tensor_shape(self.b_desc, (n, k, l), "B") - self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C") self._check_tensor_shape(self.d_desc, (tensor_m, n, 1), "D") self._check_tensor_shape(self.d_col_desc, (tensor_m, n, 1), "D_col") @@ -302,11 +298,6 @@ def check_support(self) -> bool: stride=[(k, 1, n * k)], extra_error_msg="For fp4 ab_dtype, B must have k-major layout", ) - _ = self._check_tensor_stride( - self.c_desc, - stride=[(n, 1, tensor_m * n)], - extra_error_msg="C must have n-major layout", - ) _ = self._check_tensor_stride( self.d_desc, stride=[(n, 1, tensor_m * n)], @@ -402,19 +393,6 @@ def check_support(self) -> bool: name="Accumulator", extra_error_msg="Accumulator must be float32", ) - self.c_dtype = self._check_dtype( - self.c_desc, - dtype=[ - torch.float32, - torch.float16, - torch.bfloat16, - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.float4_e2m1fn_x2, - ], - name="C", - ) - if self._is_fp4x2(self.ab_dtype): self.d_dtype = self._check_dtype( self.d_desc, @@ -536,11 +514,6 @@ def check_contigous_16B_alignment(dtype, stride_order, tensor_shape): "Invalid configuration: fp8 ab_dtype and sf_vec_size 32 with mma_tiler_mn[1] == 128 and fp8 d_dtype is not supported. " "Please use mma_tiler_mn[1] == 256 instead", ) - self._not_implemented_error_if( - self._is_fp4x2(self.ab_dtype) and (self.c_dtype not in [torch.float16, torch.bfloat16]), - f"Invalid configuration: for fp4 ab_dtype, c_dtype must be float16 or bfloat16, got {self.c_dtype}", - ) - if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") device = torch.cuda.current_device() @@ -564,6 +537,8 @@ def compile(self) -> None: self._logger.debug("sample valid_m is zero, skipping kernel compilation") return + self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + gemm_quant = self._kernel( sf_vec_size=self.sf_vec_size, acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), @@ -573,7 +548,6 @@ def compile(self) -> None: vectorized_f32=self.vector_f32, generate_sfd=self.generate_sfd, discrete_col_sfd=self.discrete_col_sfd, - generate_c=False, enable_bias=self._has_bias, expert_cnt=self.expert_cnt, weight_mode=self.weight_mode, @@ -607,7 +581,7 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: ) self._logger.debug("Compiling grouped_gemm_quant kernel") - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = self._use_full_dynamic_mnkl if not use_full_dynamic: valid_m = cute.sym_int(divisibility=256) @@ -618,11 +592,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: stride_order=self.a_desc.stride_order, ) b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) - c_cute_fake = self._make_fake_cute_compact_tensor( - dtype=self.c_desc.dtype, - shape=(valid_m, *self.c_desc.shape[1:]), - stride_order=self.c_desc.stride_order, - ) d_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.d_desc.dtype, shape=(valid_m, *self.d_desc.shape[1:]), @@ -677,7 +646,8 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16) else: valid_m = cute.sym_int(divisibility=256) - n_sym = cute.sym_int() + n_sym_divisibility = 128 // _convert_to_cutlass_data_type(self.bias_desc.dtype).width if self.bias_desc is not None else 1 + n_sym = cute.sym_int(divisibility=n_sym_divisibility) k_sym = cute.sym_int() l_sym = cute.sym_int() @@ -695,13 +665,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: dynamic_mode=self.b_desc.stride_order[0], divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, ) - c_cute_fake = self._make_fake_cute_compact_tensor( - dtype=self.c_desc.dtype, - shape=(valid_m, n_sym, 1), - stride_order=self.c_desc.stride_order, - dynamic_mode=self.c_desc.stride_order[0], - divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, - ) d_cute_fake = self._make_fake_cute_compact_tensor( dtype=self.d_desc.dtype, shape=(valid_m, n_sym, 1), @@ -790,7 +753,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: b_stride_size=cutlass.Int64(0), b_major_mode=OperandMajorMode.K, workspace_ptr=fake_workspace_ptr, - c=c_cute_fake, d=d_cute_fake, d_col=d_col_cute_fake, sfa=sfa_cute_fake, @@ -812,7 +774,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None: def tensor_api( a_tensor: torch.Tensor, b_tensor: torch.Tensor, - c_tensor: torch.Tensor, d_tensor: torch.Tensor, d_col_tensor: Optional[torch.Tensor], sfa_tensor: torch.Tensor, @@ -836,7 +797,6 @@ def tensor_api( cutlass.Int32(0), cutlass.Int64(0), cached_workspace_ptr, - c_tensor, d_tensor, d_col_tensor, sfa_tensor, @@ -873,11 +833,6 @@ def _compile_discrete(self, gemm_quant, max_active_clusters, fake_stream) -> Non stride_order=self.a_desc.stride_order, assumed_align=align, ) - c_tensor = self._make_fake_cute_compact_tensor( - dtype=self.c_desc.dtype, - shape=(valid_m, *self.c_desc.shape[1:]), - stride_order=self.c_desc.stride_order, - ) d_tensor = self._make_fake_cute_compact_tensor( dtype=self.d_desc.dtype, shape=(valid_m, *self.d_desc.shape[1:]), @@ -942,29 +897,28 @@ def _compile_discrete(self, gemm_quant, max_active_clusters, fake_stream) -> Non self._logger.debug("Compiling discrete grouped_gemm_quant kernel") _compiled_kernel = cute.compile( gemm_quant, - a_tensor, - b_ptrs_cute, - sfb_ptrs_cute, - cutlass.Int32(n), - cutlass.Int32(k), - cutlass.Int64(b_stride_size), - b_major_mode, - workspace_ptr_cute, - c_tensor, - d_tensor, - d_col_tensor, - sfa_tensor, - sfd_row_tensor, - sfd_col_tensor, - amax_tensor, - norm_const_tensor_cute, - padded_offsets_tensor, - alpha_tensor, - bias_cute_fake, - prob_tensor, - max_active_clusters, - fake_stream, - lambda x: x, + a=a_tensor, + b=b_ptrs_cute, + sfb=sfb_ptrs_cute, + n=cutlass.Int32(n), + k=cutlass.Int32(k), + b_stride_size=cutlass.Int64(b_stride_size), + b_major_mode=b_major_mode, + workspace_ptr=workspace_ptr_cute, + d=d_tensor, + d_col=d_col_tensor, + sfa=sfa_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor_cute, + padded_offsets=padded_offsets_tensor, + alpha=alpha_tensor, + bias=bias_cute_fake, + prob=prob_tensor, + max_active_clusters=max_active_clusters, + stream=fake_stream, + epilogue_op=lambda x: x, options="--enable-tvm-ffi", ) @@ -977,7 +931,6 @@ def tensor_api( a_tensor: torch.Tensor, b_ptrs_device: torch.Tensor, sfb_ptrs_device: torch.Tensor, - c_tensor: torch.Tensor, d_tensor: torch.Tensor, d_col_tensor: Optional[torch.Tensor], sfa_tensor: torch.Tensor, @@ -1002,7 +955,6 @@ def tensor_api( cached_k, cached_b_stride, cached_workspace_ptr, - c_tensor, d_tensor, d_col_tensor, sfa_tensor, @@ -1033,7 +985,6 @@ def execute( # Discrete mode: b_ptrs: Optional[torch.Tensor] = None, sfb_ptrs: Optional[torch.Tensor] = None, - c_tensor: Optional[torch.Tensor] = None, d_col_tensor: Optional[torch.Tensor] = None, sfd_row_tensor: Optional[torch.Tensor] = None, sfd_col_tensor: Optional[torch.Tensor] = None, @@ -1056,8 +1007,6 @@ def execute( at construction, ``bias_tensor`` must also be omitted at execute time. :param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers :param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers - :param c_tensor: Optional C tensor placeholder (kernel requires it but never writes to it; - a minimal dummy is created automatically if not provided) :param d_col_tensor: Optional column-quantized output :param sfd_row_tensor: Optional row scale factor D :param sfd_col_tensor: Optional column scale factor D @@ -1077,13 +1026,12 @@ def execute( "Kernel not compiled; call compile() first", ) - if c_tensor is None: - c_tensor = torch.empty_strided( - self.c_desc.shape, - self.c_desc.stride, - dtype=self.c_desc.dtype, - device=d_tensor.device, + if d_col_tensor is None: + self._value_error_if( + self.generate_sfd, + "d_col_tensor is required when SFD outputs are generated", ) + d_col_tensor = d_tensor self._value_error_if( prob_tensor is None, "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. " @@ -1105,7 +1053,6 @@ def execute( self._compiled_kernel( a_tensor=a_tensor, b_tensor=b_tensor, - c_tensor=c_tensor, d_tensor=d_tensor, d_col_tensor=d_col_tensor, sfa_tensor=sfa_tensor, @@ -1125,7 +1072,6 @@ def execute( a_tensor=a_tensor, b_ptrs_device=b_ptrs, sfb_ptrs_device=sfb_ptrs, - c_tensor=c_tensor, d_tensor=d_tensor, d_col_tensor=d_col_tensor, sfa_tensor=sfa_tensor, @@ -1165,7 +1111,6 @@ def grouped_gemm_quant_wrapper_sm100( norm_const_tensor: Optional[torch.Tensor] = None, prob_tensor: Optional[torch.Tensor] = None, acc_dtype: torch.dtype = torch.float32, - c_dtype: torch.dtype = torch.bfloat16, d_dtype: torch.dtype = torch.bfloat16, cd_major: str = "n", mma_tiler_mn: Tuple[int, int] = (256, 256), @@ -1202,7 +1147,6 @@ def grouped_gemm_quant_wrapper_sm100( prob_tensor: Probability tensor for per-row gating (shape `(valid_m, 1, 1)`). This argument is required. Pass a tensor of ones when no gating is needed. acc_dtype: Accumulator data type - c_dtype: Internal C tensor data type (not user-visible) d_dtype: Output D tensor data type cd_major: CD major dimension (only "n"-major layout is supported) mma_tiler_mn: MMA tiler shape @@ -1217,7 +1161,7 @@ def grouped_gemm_quant_wrapper_sm100( TupleDict: A dictionary-like object containing output tensors that can also be unpacked as a tuple. Dictionary keys (also the unpacking order): - **d_tensor** (torch.Tensor): Final output tensor - - **d_col_tensor** (torch.Tensor): Column-wise output tensor + - **d_col_tensor** (torch.Tensor or None): Column-wise output tensor for low-precision D output - **amax_tensor** (torch.Tensor or None): Absolute maximum values (for quantization) - **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D (FP8 only) - **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D (FP8 only) @@ -1264,19 +1208,6 @@ def grouped_gemm_quant_wrapper_sm100( if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, num_experts): raise ValueError(f"bias_tensor must have shape {(n_out, num_experts)}, got {tuple(bias_tensor.shape)}") - _logger.debug("grouped_gemm_quant_wrapper_sm100: Creating output tensors d_tensor, d_col_tensor") - - if cd_major == "n": - c_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=c_dtype, device=a_tensor.device) - d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) - d_col_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) - else: - raise ValueError(f"cd_major must be 'n', got {cd_major}") - - sfd_row_tensor = None - sfd_col_tensor = None - amax_tensor = None - is_fp8_input_config = a_tensor.dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, @@ -1284,23 +1215,39 @@ def grouped_gemm_quant_wrapper_sm100( torch.float8_e8m0fnu, torch.float8_e4m3fn, ] - is_fp8_output_config = d_dtype in [ + is_low_precision_output_config = d_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, torch.float4_e2m1fn_x2, ] - if is_fp8_input_config and is_fp8_output_config and norm_const_tensor is None: + _logger.debug("grouped_gemm_quant_wrapper_sm100: Creating output tensors") + + if cd_major == "n": + d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + d_col_tensor = ( + torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + if is_low_precision_output_config + else None + ) + else: + raise ValueError(f"cd_major must be 'n', got {cd_major}") + + sfd_row_tensor = None + sfd_col_tensor = None + amax_tensor = None + + if is_fp8_input_config and is_low_precision_output_config and norm_const_tensor is None: raise ValueError( "norm_const_tensor is required when FP8 inputs are used with FP8 output " "(a_tensor is FP8 and sfa_tensor is FP8 and d_dtype is FP8). " "Pass a tensor with shape (1,), e.g. torch.tensor([0.01], dtype=torch.float32, device=a_tensor.device)." ) - if not is_fp8_output_config: + if not is_low_precision_output_config: norm_const_tensor = None - if is_fp8_input_config and is_fp8_output_config: + if is_fp8_input_config and is_low_precision_output_config: _logger.debug("grouped_gemm_quant_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor") sf_dtype = sfa_tensor.dtype @@ -1369,7 +1316,7 @@ def dynamic_m_tensor_signature( stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) return static_shape_suffix, stride_signature, tensor.dtype - use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" if is_dense: cache_key = ( @@ -1381,9 +1328,8 @@ def dynamic_m_tensor_signature( b_tensor.dtype, stride_order(a_tensor), stride_order(b_tensor), - c_tensor.shape[1:] if not use_full_dynamic else None, - stride_order(c_tensor), - c_tensor.dtype, + d_tensor.shape[1:] if not use_full_dynamic else None, + stride_order(d_tensor), *( dynamic_tensor_signature(sfa_tensor) if use_full_dynamic @@ -1398,7 +1344,6 @@ def dynamic_m_tensor_signature( tuple(padded_offsets.stride()), padded_offsets.dtype, acc_dtype, - c_dtype, d_dtype, cd_major, mma_tiler_mn, @@ -1417,9 +1362,8 @@ def dynamic_m_tensor_signature( a_tensor.dtype, b_shape, b_dtype, - c_tensor.shape[1:], - stride_order(c_tensor), - c_tensor.dtype, + d_tensor.shape[1:], + stride_order(d_tensor), *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)), *tensor_signature(bias_tensor), *tensor_signature(alpha_tensor), @@ -1435,7 +1379,6 @@ def dynamic_m_tensor_signature( tuple(padded_offsets.stride()), padded_offsets.dtype, acc_dtype, - c_dtype, d_dtype, cd_major, mma_tiler_mn, @@ -1470,7 +1413,6 @@ def dynamic_m_tensor_signature( sample_sfd_col=sfd_col_tensor, sample_norm_const=norm_const_tensor, sample_prob=prob_tensor, - sample_c=c_tensor, acc_dtype=acc_dtype, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, @@ -1497,7 +1439,6 @@ def dynamic_m_tensor_signature( sample_sfd_col=sfd_col_tensor, sample_norm_const=norm_const_tensor, sample_prob=prob_tensor, - sample_c=c_tensor, acc_dtype=acc_dtype, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, @@ -1522,7 +1463,6 @@ def dynamic_m_tensor_signature( d_tensor=d_tensor, b_tensor=b_tensor, sfb_tensor=sfb_tensor, - c_tensor=c_tensor, d_col_tensor=d_col_tensor, sfd_row_tensor=sfd_row_tensor, sfd_col_tensor=sfd_col_tensor, @@ -1541,7 +1481,6 @@ def dynamic_m_tensor_signature( d_tensor=d_tensor, b_ptrs=b_ptrs, sfb_ptrs=sfb_ptrs, - c_tensor=c_tensor, d_col_tensor=d_col_tensor, sfd_row_tensor=sfd_row_tensor, sfd_col_tensor=sfd_col_tensor, diff --git a/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py index 6619baa8..8773accf 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py @@ -9,7 +9,6 @@ - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout - FP8/FP4 output quantization with row/column scale factors (SFD) - Optional bias and routing-probability (prob) fusion - - Optional C output (generate_c) - AMAX reduction for FP8 calibration This module contains only the kernel class. @@ -73,7 +72,6 @@ class BlockScaledMoEGroupedGemmQuantKernel: :param vectorized_f32: Use packed FP32 arithmetic. :param generate_sfd: Generate output scale factors. :param discrete_col_sfd: Use discrete column SFD layout. - :param generate_c: Generate C output tensor. :param enable_bias: Fuse bias addition. :param expert_cnt: Number of experts. :param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``. @@ -131,7 +129,6 @@ def __init__( vectorized_f32: bool, generate_sfd: bool, discrete_col_sfd: bool, - generate_c: bool, enable_bias: bool, expert_cnt: int, weight_mode: MoEWeightMode = MoEWeightMode.DENSE, @@ -199,7 +196,6 @@ def __init__( self.vectorized_f32 = vectorized_f32 self.generate_sfd = generate_sfd self.discrete_col_sfd = discrete_col_sfd - self.generate_c = generate_c self.enable_bias = enable_bias self.weight_mode = weight_mode @@ -295,7 +291,6 @@ def _setup_attributes(self): ( self.num_acc_stage, self.num_ab_stage, - self.num_c_stage, self.num_d_stage, self.num_tile_stage, self.num_bias_stage, @@ -305,8 +300,6 @@ def _setup_attributes(self): self.a_dtype, self.b_dtype, self.epi_tile, - self.c_dtype, - self.c_layout, self.d_dtype, self.d_layout, self.sf_dtype, @@ -314,7 +307,6 @@ def _setup_attributes(self): self.num_smem_capacity, self.occupancy, self.generate_sfd, - self.generate_c, self.bias_dtype if self.enable_bias else None, ) @@ -342,12 +334,6 @@ def _setup_attributes(self): self.sf_vec_size, self.num_ab_stage, ) - self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( - self.c_dtype, - self.c_layout, - self.epi_tile, - self.num_c_stage, - ) self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( self.d_dtype, self.d_layout, @@ -388,8 +374,6 @@ def _compute_stages( a_dtype, b_dtype, epi_tile, - c_dtype, - c_layout, d_dtype, d_layout, sf_dtype, @@ -397,11 +381,9 @@ def _compute_stages( num_smem_capacity, occupancy, generate_sfd, - generate_c, bias_dtype, ): num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 - num_c_stage = 2 if generate_sfd else 1 num_d_stage = 2 if generate_sfd else 1 num_tile_stage = 2 @@ -409,7 +391,6 @@ def _compute_stages( b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1) sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) - c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1) d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1) ab_bytes_per_stage = ( @@ -420,8 +401,6 @@ def _compute_stages( ) mbar_helpers_bytes = 1024 sinfo_bytes = 4 * 4 * num_tile_stage - c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) - c_bytes = c_bytes_per_stage * num_c_stage d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1) amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0 @@ -434,10 +413,10 @@ def _compute_stages( num_bias_stage = 0 bias_bytes = 0 - epi_bytes = c_bytes + d_bytes + amax_bytes + bias_bytes + epi_bytes = d_bytes + amax_bytes + bias_bytes num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage - return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage + return num_acc_stage, num_ab_stage, num_d_stage, num_tile_stage, num_bias_stage # ------------------------------------------------------------------ # Workspace helpers @@ -575,7 +554,6 @@ def __call__( b_stride_size: cutlass.Int64, # Ignored for dense mode b_major_mode: cutlass.Constexpr, # Ignored for dense mode workspace_ptr, - c: cute.Tensor, d: cute.Tensor, d_col: Optional[cute.Tensor], sfa: cute.Tensor, @@ -600,11 +578,9 @@ def __call__( """ self.a_dtype: Type[cutlass.Numeric] = a.element_type self.b_dtype: Type[cutlass.Numeric] = a.element_type - self.c_dtype: Type[cutlass.Numeric] = c.element_type self.d_dtype: Type[cutlass.Numeric] = d.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.c_layout = utils.LayoutEnum.from_tensor(c) self.d_layout = utils.LayoutEnum.from_tensor(d) self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16 @@ -754,13 +730,6 @@ def __call__( sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size - c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) - tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), - c, - c_smem_layout, - self.epi_tile, - ) d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), @@ -832,10 +801,6 @@ class SharedStorage: bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2] tmem_dealloc_mbar_ptr: cutlass.Int64 tmem_holding_buf: cutlass.Int32 - sC: cute.struct.Align[ - cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)], - self.buffer_align_bytes, - ] sD: cute.struct.Align[ cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)], self.buffer_align_bytes, @@ -884,8 +849,6 @@ class SharedStorage: tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, - tma_atom_c, - tma_tensor_c, tma_atom_d, tma_tensor_d, tma_atom_d_col, @@ -905,7 +868,6 @@ class SharedStorage: self.b_smem_layout_staged, self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, - self.c_smem_layout_staged, self.d_smem_layout_staged, self.bias_smem_layout_staged, self.epi_tile, @@ -955,32 +917,6 @@ def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_g block_amax = cute.arch.fmax(block_amax, warp_amax_val) _ = atomic_max_float32(ptr=amax_gmem, value=block_amax) - @cute.jit - def store_c( - self, - tiled_copy_r2s, - tma_atom_c, - warp_idx, - tTR_rAcc, - tRS_rC, - tRS_sC, - bSG_gC, - bSG_sC, - c_pipeline, - prev_subtile_idx, - real_subtile_idx, - ): - c_buffer = prev_subtile_idx % self.num_c_stage - tRS_rC.store(tTR_rAcc.load().to(self.c_dtype)) - cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)]) - cute.arch.fence_proxy("async.shared", space="cta") - self.epilog_sync_barrier.arrive_and_wait() - if warp_idx == self.epilog_warp_id[0]: - cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)]) - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - self.epilog_sync_barrier.arrive_and_wait() - @cute.jit def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD): tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) @@ -1147,8 +1083,6 @@ def kernel( mSFA_mkl: cute.Tensor, tma_atom_sfb: cute.CopyAtom, mSFB_nkl: cute.Tensor, - tma_atom_c: cute.CopyAtom, - mC_mnl: cute.Tensor, tma_atom_d: cute.CopyAtom, mD_mnl: cute.Tensor, tma_atom_d_col: cute.CopyAtom, @@ -1168,7 +1102,6 @@ def kernel( b_smem_layout_staged: cute.ComposedLayout, sfa_smem_layout_staged: cute.Layout, sfb_smem_layout_staged: cute.Layout, - c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], bias_smem_layout_staged: Optional[cute.Layout], epi_tile: cute.Tile, @@ -1189,8 +1122,6 @@ def kernel( cpasync.prefetch_descriptor(tma_atom_d) if cutlass.const_expr(self.generate_sfd): cpasync.prefetch_descriptor(tma_atom_d_col) - if cutlass.const_expr(self.generate_c): - cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 total_token = padded_offsets[self.expert_cnt - 1] @@ -1274,7 +1205,6 @@ def kernel( if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_arrive_relaxed() - sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) sD_col = sD if cutlass.const_expr(self.generate_sfd): @@ -1744,13 +1674,6 @@ def kernel( use_2cta_instrs, ) - tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, - tTR_rC, - epi_tidx, - sC, - ) tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition( tiled_copy_t2r, @@ -1793,8 +1716,6 @@ def kernel( epi_ext = self._make_extension(workspace_ptr) acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) - c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id)) - c_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_producer_group) d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id)) d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group) @@ -1833,7 +1754,6 @@ def kernel( sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(self.epi_tile[1])) real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info) - real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info) real_d_col = real_d if cutlass.const_expr(self.generate_sfd): real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info) @@ -1850,16 +1770,6 @@ def kernel( sD, ) - gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) - tCgC_loop = thr_mma_epi_loop.partition_C(gC_mnl_loop) - _, bSG_sC, bSG_gC_partitioned = epilog_gmem_copy_and_partition( - epi_tidx, - tma_atom_c, - tCgC_loop, - epi_tile, - sC, - ) - gD_col_mnl_loop = gD_mnl_loop tCgD_col_loop = tCgD_loop if cutlass.const_expr(self.generate_sfd): @@ -1878,10 +1788,8 @@ def kernel( epi_work_tile_info.tile_n_idx, 0, ) - bSG_gC = bSG_gC_partitioned[(None, None, None, *epi_mma_tile_coord)] bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)] bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)] - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col)) @@ -1968,21 +1876,6 @@ def kernel( for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val) - if cutlass.const_expr(self.generate_c): - self.store_c( - tiled_copy_r2s, - tma_atom_c, - warp_idx, - tTR_rAcc, - tRS_rC, - tRS_sC, - bSG_gC, - bSG_sC, - c_pipeline, - num_prev_subtiles, - real_subtile_idx, - ) - acc_vec = tTR_rAcc.load() if cutlass.const_expr(not self.enable_bias): tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype) @@ -2089,8 +1982,6 @@ def kernel( tmem.relinquish_alloc_permit() self.epilog_sync_barrier.arrive_and_wait() tmem.free(tmem_ptr) - if cutlass.const_expr(self.generate_c): - c_pipeline.producer_tail() d_pipeline.producer_tail() # ------------------------------------------------------------------ diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py new file mode 100644 index 00000000..70b2bfe7 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +""" +Grouped GEMM SReLU Kernel Module + +This module provides the forward grouped GEMM with SReLU activation +for MoE (Mixture of Experts) workloads on SM100+ GPUs. +""" + +from .api import ( + GroupedGemmSreluSm100, + grouped_gemm_srelu_wrapper_sm100, +) + +__all__ = [ + "GroupedGemmSreluSm100", + "grouped_gemm_srelu_wrapper_sm100", +] diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py new file mode 100644 index 00000000..02f49939 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py @@ -0,0 +1,1572 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Unified API for Grouped GEMM SReLU Kernel (SM100+) + +This module provides a single API class that supports both dense (contiguous) +and discrete weight modes for grouped block-scaled GEMM with output +SReLU output quantization in MoE (Mixture of Experts) workloads. +""" + +import os +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cuda.bindings import driver as cuda +from cutlass.cute.runtime import make_fake_stream + +from cudnn.api_base import APIBase, TensorDesc, TupleDict, ceil_div, is_power_of_2 +from cudnn.datatypes import _convert_to_cutlass_data_type + +from .moe_blockscaled_grouped_gemm_srelu_quant import ( + BlockScaledMoEGroupedGemmQuantKernel, + EpilogueType, +) +from ..moe_utils import MoEWeightMode +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cute.runtime import from_dlpack + + +def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.uint8: + cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1) + cute_tensor.element_type = cutlass.Float4E2M1FN + return cute_tensor + return tensor + + +class GroupedGemmSreluSm100(APIBase): + """Unified API for grouped GEMM SReLU operation on SM100+ GPUs. + + This kernel performs block-scaled grouped GEMM with output SReLU output quantization + (D = srelu(alpha * A @ B)), designed for MoE workloads. It supports both + dense (contiguous) and discrete (per-expert pointer) weight layouts + through ``BlockScaledMoEGroupedGemmQuantKernel``. + + Weight mode is auto-detected from the constructor arguments: + + - Dense: provide ``sample_b`` and ``sample_sfb``. + - Discrete: provide ``num_experts``, ``b_shape``, and ``b_dtype``. + """ + + def __init__( + self, + sample_a: torch.Tensor, + # Dense mode (contiguous) -- provide these: + sample_b: Optional[torch.Tensor] = None, + sample_c: Optional[torch.Tensor] = None, + sample_d: Optional[torch.Tensor] = None, + sample_sfa: Optional[torch.Tensor] = None, + sample_sfb: Optional[torch.Tensor] = None, + sample_padded_offsets: Optional[torch.Tensor] = None, + sample_alpha: Optional[torch.Tensor] = None, + sample_d_col: Optional[torch.Tensor] = None, + sample_bias: Optional[torch.Tensor] = None, + # Discrete mode -- provide these instead: + num_experts: Optional[int] = None, + b_shape: Optional[Tuple[int, ...]] = None, + b_dtype: Optional[torch.dtype] = None, + # Optional SReLU output quantization output arguments + sample_sfd_row: Optional[torch.Tensor] = None, + sample_sfd_col: Optional[torch.Tensor] = None, + sample_amax: Optional[torch.Tensor] = None, + sample_norm_const: Optional[torch.Tensor] = None, + sample_prob: Optional[torch.Tensor] = None, + # Configuration + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + b_major: str = "k", + use_dynamic_sched: bool = False, + ): + """Initialize the GroupedGemmSreluSm100 API. + + :param sample_a: Sample A tensor (valid_m, k, 1) + :param sample_sfa: Sample scale factor A tensor + :param sample_padded_offsets: End offset for each expert after padding, shape (expert_cnt,) + :param sample_alpha: Per-group alpha scaling factors + :param sample_c: Sample C output tensor (valid_m, n, 1) before SReLU + :param sample_d: Sample D output tensor (valid_m, n, 1) + :param sample_d_col: Optional column-sreluized D tensor. Required only when SFD outputs are generated. + :param sample_b: (Dense) Sample B tensor (n, k, l) + :param sample_sfb: (Dense) Sample scale factor B tensor + :param sample_bias: Optional bias tensor with shape (n, l) or (n, expert_cnt), stride (1, n). + Dense mode supports fp16/bfloat16/float32 bias; discrete mode supports fp16/bfloat16 bias. + :param num_experts: (Discrete) Number of experts + :param b_shape: (Discrete) Shape of a single expert B tensor, e.g. (n, k) + :param b_dtype: (Discrete) Data type of B tensors + :param sample_sfd_row: Optional row scale factor for D + :param sample_sfd_col: Optional column scale factor for D + :param sample_amax: Optional amax tensor for SReLU output quantization + :param sample_norm_const: Optional normalization constant + :param sample_prob: Optional probability tensor for gating + :param acc_dtype: Accumulator data type + :param mma_tiler_mn: MMA tiler shape (M, N) + :param cluster_shape_mn: Cluster shape (M, N) + :param sf_vec_size: Scale factor vector size + :param vector_f32: Use vectorized f32 operations + :param m_aligned: Alignment for group M dimension + :param discrete_col_sfd: Enable discrete col-major scale factor tensor + :param b_major: Major dimension for B tensor, one of "k" or "n" + :param use_dynamic_sched: Enable dynamic tile scheduling for load balancing + """ + super().__init__() + + self._warn_experimental_api() + self._logger.debug("Entering __init__") + + # ---- Weight mode auto-detection ---- + if sample_b is not None and num_experts is None: + self.weight_mode = MoEWeightMode.DENSE + if sample_sfb is None: + raise ValueError("sample_sfb is required when sample_b is provided (dense mode)") + elif num_experts is not None and sample_b is None: + self.weight_mode = MoEWeightMode.DISCRETE + if b_shape is None or b_dtype is None: + raise ValueError("b_shape and b_dtype are required in discrete mode") + else: + raise ValueError("Provide either (sample_b, sample_sfb) for dense mode " "or (num_experts, b_shape, b_dtype) for discrete mode, but not both.") + + self._sample_a_tensor = sample_a + self._sample_b_tensor = sample_b + + self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False) + self.c_desc = self._make_tensor_desc(sample_c, name="sample_c") + self.d_desc = self._make_tensor_desc(sample_d, name="sample_d") + self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa") + self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets") + self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha") + + self._has_d_col = sample_d_col is not None + self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col") + if self.d_col_desc is None: + self.d_col_desc = TensorDesc( + dtype=self.d_desc.dtype, + shape=self.d_desc.shape, + stride=self.d_desc.stride, + stride_order=self.d_desc.stride_order, + device=self.d_desc.device, + name="sample_d_col", + ) + self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row") + self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col") + self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax") + self.norm_const_desc = self._unpad_tensor_to_ndim( + self._make_tensor_desc(sample_norm_const, name="sample_norm_const"), + 1, + "norm_const", + ) + self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob") + self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias") + + if self.weight_mode == MoEWeightMode.DENSE: + self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False) + self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb") + self.expert_cnt = self.padded_offsets_desc.shape[0] + else: + self._value_error_if(num_experts == 0, "num_experts must be > 0") + self.expert_cnt = num_experts + self.b_shape = b_shape + self.b_dtype = b_dtype + self.b_major = b_major + self._value_error_if( + self.padded_offsets_desc.shape[0] != self.expert_cnt, + f"padded_offsets length ({self.padded_offsets_desc.shape[0]}) " f"must equal num_experts ({self.expert_cnt})", + ) + + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + if cluster_shape_mn is None: + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + else: + self.cluster_shape_mn = cluster_shape_mn + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.m_aligned = m_aligned + self.discrete_col_sfd = discrete_col_sfd + self.use_dynamic_sched = use_dynamic_sched + if self.weight_mode == MoEWeightMode.DENSE: + self.b_major = b_major + + self._interpret_uint8_as_fp4x2 = True + self._has_bias = self.bias_desc is not None + self._kernel = BlockScaledMoEGroupedGemmQuantKernel + + self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._workspace = None + self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + self._logger.debug("__init__ completed") + + def check_support(self) -> bool: + """Check if the kernel configuration is supported. + + :return: True if supported, raises exception otherwise + """ + self._logger.debug("Entering check_support") + + all_none = all(x is None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc]) + all_provided = all(x is not None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc]) + self._value_error_if( + not (all_none or all_provided), + "sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None", + ) + self.generate_sfd = all_provided + self._value_error_if( + self.generate_sfd and not self._has_d_col, + "sample_d_col is required when SFD outputs are generated", + ) + if self.discrete_col_sfd and not self.generate_sfd: + self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored") + self.discrete_col_sfd = False + + self._logger.debug("Checking tensor shapes and strides") + tensor_m, k, _one = self._tensor_shape(self.a_desc, name="sample_a") + + if self.weight_mode == MoEWeightMode.DENSE: + n, _, l = self._tensor_shape(self.b_desc, name="sample_b") + else: + if len(self.b_shape) == 2: + n, b_k = self.b_shape + else: + n, b_k, _ = self.b_shape + self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})") + l = self.expert_cnt + + _, _, _one = self._tensor_shape(self.d_desc, name="sample_d") + + self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A") + if self.weight_mode == MoEWeightMode.DENSE: + self._check_tensor_shape(self.b_desc, (n, k, l), "B") + self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C") + self._check_tensor_shape(self.d_desc, (tensor_m, n, 1), "D") + self._check_tensor_shape(self.d_col_desc, (tensor_m, n, 1), "D_col") + + rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_k, 1), "SFA") + if self.weight_mode == MoEWeightMode.DENSE: + self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") + rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfd_row_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_n, 1), "SFD_row") + rest_m = ceil_div(ceil_div(tensor_m, self.sf_vec_size), 4) + self._check_tensor_shape(self.sfd_col_desc, (32, 4, ceil_div(n, 128), 4, rest_m, 1), "SFD_col") + + self._check_tensor_shape(self.alpha_desc, (self.expert_cnt,), "alpha") + self._value_error_if( + self.prob_desc is None, + "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. " + "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed.", + ) + self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob") + self._check_tensor_shape(self.bias_desc, (n, l), "bias") + self._check_tensor_shape(self.amax_desc, (self.expert_cnt, 1), "amax") + self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const") + self._check_tensor_shape(self.padded_offsets_desc, (self.expert_cnt,), "padded_offsets") + + _ = self._check_tensor_stride( + self.a_desc, + stride=[(k, 1, tensor_m * k)], + extra_error_msg="A must have k-major layout", + ) + if self.weight_mode == MoEWeightMode.DENSE: + if self._is_fp8(self.a_desc): + _ = self._check_tensor_stride( + self.b_desc, + stride=[(k, 1, n * k), (1, n, n * k)], + extra_error_msg="For fp8 ab_dtype, B must have k- or n-major layout", + ) + else: + _ = self._check_tensor_stride( + self.b_desc, + stride=[(k, 1, n * k)], + extra_error_msg="For fp4 ab_dtype, B must have k-major layout", + ) + _ = self._check_tensor_stride( + self.c_desc, + stride=[(n, 1, tensor_m * n)], + extra_error_msg="C must have n-major layout", + ) + _ = self._check_tensor_stride( + self.d_desc, + stride=[(n, 1, tensor_m * n)], + extra_error_msg="D must have n-major layout", + ) + _ = self._check_tensor_stride( + self.d_col_desc, + stride=[(n, 1, tensor_m * n)], + extra_error_msg="D_col must have n-major layout", + ) + _ = self._check_tensor_stride( + self.bias_desc, + stride=[(1, n)], + ) + + self._logger.debug("Checking data types") + self.ab_dtype = self._check_dtype( + self.a_desc, + dtype=[ + torch.float4_e2m1fn_x2, + torch.uint8, + torch.float8_e5m2, + torch.float8_e4m3fn, + ], + name="A/B", + ) + if self.weight_mode == MoEWeightMode.DENSE: + self._check_dtype( + self.b_desc, + dtype=self.ab_dtype, + name="B", + extra_error_msg="B must have the same dtype as A", + ) + self._check_dtype( + self.bias_desc, + dtype=[torch.bfloat16, torch.float16, torch.float32], + name="bias", + extra_error_msg="bias must be fp16, bfloat16, or float32", + ) + else: + self._value_error_if( + self.b_dtype != self.ab_dtype, + f"b_dtype ({self.b_dtype}) must match A dtype ({self.ab_dtype})", + ) + self._check_dtype( + self.bias_desc, + dtype=[torch.bfloat16, torch.float16], + name="bias", + extra_error_msg="bias must be fp16 or bfloat16 in discrete mode", + ) + + self.sf_dtype = self._check_dtype( + self.sfa_desc, + dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], + name="SFA/SFB/SFD_row/SFD_col", + ) + if self.weight_mode == MoEWeightMode.DENSE: + self._check_dtype( + self.sfb_desc, + dtype=self.sf_dtype, + name="SFB", + extra_error_msg="SFB must have the same dtype as SFA", + ) + self._check_dtype( + self.sfd_row_desc, + dtype=self.sf_dtype, + name="SFD_row", + extra_error_msg="SFD_row must have the same dtype as SFA", + ) + self._check_dtype( + self.sfd_col_desc, + dtype=self.sf_dtype, + name="SFD_col", + extra_error_msg="SFD_col must have the same dtype as SFA", + ) + + self._value_error_if( + self.sf_vec_size not in [16, 32], + f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}", + ) + self._value_error_if( + self.sf_dtype in [torch.float8_e4m3fn] and self.sf_vec_size == 32, + f"sf_dtype {self.sf_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported", + ) + self._value_error_if( + self._is_fp8(self.ab_dtype) and self.sf_vec_size == 16, + f"ab_dtype {self.ab_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported", + ) + + self._check_dtype( + self.acc_dtype, + dtype=torch.float32, + name="Accumulator", + extra_error_msg="Accumulator must be float32", + ) + self.c_dtype = self._check_dtype( + self.c_desc, + dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], + name="C", + ) + if self._is_fp4x2(self.ab_dtype): + self.d_dtype = self._check_dtype( + self.d_desc, + dtype=[torch.float16, torch.bfloat16, torch.float32], + name="D", + extra_error_msg="D must be fp16, bf16, or float32 when ab_dtype is fp4", + ) + else: + self.d_dtype = self._check_dtype( + self.d_desc, + dtype=[ + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float4_e2m1fn_x2, + ], + name="D", + ) + self._check_dtype( + self.d_col_desc, + dtype=self.d_dtype, + name="D_col", + extra_error_msg="D_col must have the same dtype as D", + ) + + self._not_implemented_error_if( + self._is_fp4x2(self.ab_dtype) and self.sf_vec_size == 16 and self.d_dtype == torch.float32, + "Invalid configuration: fp4 ab_dtype, sf_vec_size 16, d_dtype float32 is not supported. Please use sf_vec_size 32 or d_dtype bf16 instead", + ) + + if self.weight_mode == MoEWeightMode.DISCRETE: + self._value_error_if( + self.b_major not in ["k", "n"], + f"b_major must be 'k' or 'n', got {self.b_major}", + ) + self._value_error_if( + self._is_fp4x2(self.ab_dtype) and self.b_major != "k", + "b_major must be 'k' when ab_dtype is fp4", + ) + + self._logger.debug("Checking MMA tile shape and cluster shape") + self._value_error_if( + not self.use_2cta_instrs and self.mma_tiler_mn[0] != 128, + f"MMA tiler M must be 128 when use_2cta_instrs=False, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.use_2cta_instrs and self.mma_tiler_mn[0] != 256, + f"MMA tiler M must be 256 when use_2cta_instrs=True, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.mma_tiler_mn[1] != 256, + f"MMA tiler N must be 256, got {self.mma_tiler_mn[1]}", + ) + self._value_error_if( + self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0, + f"cluster_shape_mn[0] must be divisible by 2 when use_2cta_instrs=True, got {self.cluster_shape_mn[0]}", + ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] <= 4 + and self.cluster_shape_mn[1] <= 4 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + "Invalid cluster shape: expected values to be powers of 2 and " f"cluster_shape_mn[0] * cluster_shape_mn[1] <= 16, got {self.cluster_shape_mn}", + ) + cluster_tiler_m = (self.cluster_shape_mn[0] // (2 if self.use_2cta_instrs else 1)) * self.mma_tiler_mn[0] + self._value_error_if( + cluster_tiler_m not in [128, 256], + f"Invalid cluster tiler shape: expected cluster_tiler_m in {{128, 256}}, got {cluster_tiler_m}", + ) + self._value_error_if( + self.m_aligned % self.mma_tiler_mn[0] != 0, + f"Invalid m_aligned: expected m_aligned to be divisible by mma_tiler_mn[0], got {self.m_aligned} % {self.mma_tiler_mn[0]} != 0", + ) + self._value_error_if( + self.m_aligned != BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE, + f"m_aligned must be {BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE} (FIX_PAD_SIZE), got {self.m_aligned}", + ) + + self._logger.debug("Checking tensor alignment") + + def check_contigous_16B_alignment(dtype, stride_order, tensor_shape): + is_mode0_major = stride_order == (0, 1, 2) + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width) + return num_major_elements % num_contiguous_elements == 0 + + if self.weight_mode == MoEWeightMode.DENSE: + b_stride_order_for_check = self.b_desc.stride_order + b_shape_for_check = (n, k, l) + else: + b_stride_order_for_check = (0, 1, 2) if self.b_major == "n" else (1, 0, 2) + b_shape_for_check = (n, k, 1) + + self._value_error_if( + not ( + check_contigous_16B_alignment(self.ab_dtype, self.a_desc.stride_order, (tensor_m, k, l)) + and check_contigous_16B_alignment(self.ab_dtype, b_stride_order_for_check, b_shape_for_check) + and check_contigous_16B_alignment(self.d_dtype, self.d_desc.stride_order, (tensor_m, n, 1)) + ), + "Invalid tensor alignment: tensors must be 16B aligned", + ) + + self._value_error_if( + self.expert_cnt > 1024, + f"expert_cnt must be <= 1024, got {self.expert_cnt}", + ) + + self._not_implemented_error_if(self._has_bias and self.mma_tiler_mn[1] != 256, "Bias fusion currently requires mma_tiler_mn[1] == 256") + + self._not_implemented_error_if( + (self._is_fp8(self.ab_dtype)) and (self.mma_tiler_mn[1] == 128) and (self._is_fp8(self.d_dtype)), + "Invalid configuration: fp8 ab_dtype and sf_vec_size 32 with mma_tiler_mn[1] == 128 and fp8 d_dtype is not supported. " + "Please use mma_tiler_mn[1] == 256 instead", + ) + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + compute_capability = major * 10 + minor + if compute_capability < 100: + raise RuntimeError(f"GroupedGemmSrelu requires SM100+ compute capability, but found SM{compute_capability} on device {device}") + + self._is_supported = True + self._logger.debug("check_support completed successfully") + return True + + def compile(self) -> None: + """Compile the kernel.""" + self._logger.debug("Entering compile") + self._ensure_support_checked() + if self._compiled_kernel is not None: + self._logger.debug("Kernel already compiled; skipping recompilation") + return + if self.a_desc.shape[0] == 0: + self._logger.debug("sample valid_m is zero, skipping kernel compilation") + return + + self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + + gemm_srelu = self._kernel( + sf_vec_size=self.sf_vec_size, + acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), + use_2cta_instrs=self.use_2cta_instrs, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vectorized_f32=self.vector_f32, + generate_sfd=self.generate_sfd, + discrete_col_sfd=self.discrete_col_sfd, + generate_c=True, + enable_bias=self._has_bias, + expert_cnt=self.expert_cnt, + weight_mode=self.weight_mode, + use_dynamic_sched=self.use_dynamic_sched, + epilogue_type=EpilogueType.SRELU.value, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + max_active_clusters -= self.num_cluster_overlap_margin + self._value_error_if( + max_active_clusters <= 0, + "max_active_clusters must be > 0 after applying overlap margin; reduce CUDNNFE_CLUSTER_OVERLAP_MARGIN", + ) + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + + workspace_bytes = gemm_srelu.get_workspace_bytes() + self._workspace = torch.empty(max(workspace_bytes, 1), dtype=torch.uint8, device="cuda") + + if self.weight_mode == MoEWeightMode.DENSE: + self._compile_dense(gemm_srelu, max_active_clusters, fake_stream) + else: + self._compile_discrete(gemm_srelu, max_active_clusters, fake_stream) + + self._logger.debug("Kernel compiled successfully") + + def _compile_dense(self, gemm_srelu, max_active_clusters, fake_stream) -> None: + """Compile for dense (contiguous) weight mode.""" + fake_workspace_ptr = cute.runtime.nullptr( + dtype=cutlass.Uint8, + assumed_align=128, + ) + + self._logger.debug("Compiling grouped_gemm_srelu kernel") + use_full_dynamic = self._use_full_dynamic_mnkl + + if not use_full_dynamic: + valid_m = cute.sym_int(divisibility=256) + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride_order=self.a_desc.stride_order, + ) + b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride_order=self.c_desc.stride_order, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, *self.d_desc.shape[1:]), + stride_order=self.d_desc.stride_order, + ) + d_col_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, *self.d_col_desc.shape[1:]), + stride_order=self.d_col_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[5] = stride_tensor_m_128 + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + ) + + sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16) + + prob_cute_fake = None + if self.prob_desc is not None: + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride=self.prob_desc.stride, + ) + + sfd_row_fake = None + sfd_col_fake = None + if self.sfd_row_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_fake = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1), + stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m), + ) + if self.sfd_col_desc is not None: + rest_m = cute.sym_int(divisibility=1) + stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_fake = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1), + stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n), + ) + bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16) + else: + valid_m = cute.sym_int(divisibility=256) + n_sym = cute.sym_int() + k_sym = cute.sym_int() + l_sym = cute.sym_int() + + a_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, k_sym, 1), + stride_order=self.a_desc.stride_order, + dynamic_mode=self.a_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + b_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.b_desc.dtype, + shape=(n_sym, k_sym, l_sym), + stride_order=self.b_desc.stride_order, + dynamic_mode=self.b_desc.stride_order[0], + divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16, + ) + c_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, n_sym, 1), + stride_order=self.c_desc.stride_order, + dynamic_mode=self.c_desc.stride_order[0], + divisibility=8 if self._is_f16(self.c_desc.dtype) else 16, + ) + d_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, n_sym, 1), + stride_order=self.d_desc.stride_order, + dynamic_mode=self.d_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_desc.dtype) else 16, + ) + d_col_cute_fake = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, n_sym, 1), + stride_order=self.d_col_desc.stride_order, + dynamic_mode=self.d_col_desc.stride_order[0], + divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16, + ) + + tensor_m_128 = cute.sym_int() + rest_k = cute.sym_int() + stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_shape[4] = rest_k + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[2] = stride_rest_k + sfa_stride[5] = stride_tensor_m_128 + sfa_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + ) + + tensor_n_128 = cute.sym_int() + stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfb_cute_fake = self._make_fake_cute_tensor( + dtype=self.sfb_desc.dtype, + shape=(32, 4, tensor_n_128, 4, rest_k, l_sym), + stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k), + ) + + prob_cute_fake = None + if self.prob_desc is not None: + prob_cute_fake = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride=self.prob_desc.stride, + ) + + sfd_row_fake = None + sfd_col_fake = None + if self.sfd_row_desc is not None: + rest_n = cute.sym_int() + stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_rest_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_fake = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, rest_n, 1), + stride=(16, 4, stride_sfd_rest_n, 1, 512, stride_sfd_rest_tensor_m_128), + ) + if self.sfd_col_desc is not None: + tensor_n_128 = cute.sym_int() + rest_m_dyn = cute.sym_int() + stride_sfd_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_fake = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, tensor_n_128, 4, rest_m_dyn, 1), + stride=(16, 4, stride_sfd_rest_m, 1, 512, stride_sfd_n), + ) + + bias_cute_fake = None + if self.bias_desc is not None: + bias_cute_fake = self._make_fake_cute_tensor( + dtype=self.bias_desc.dtype, + shape=(n_sym, l_sym), + stride=(1, n_sym), + ) + + _compiled_kernel = cute.compile( + gemm_srelu, + a=_reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake, + b=_reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake, + sfb=sfb_cute_fake, + n=cutlass.Int32(0), + k=cutlass.Int32(0), + b_stride_size=cutlass.Int64(0), + b_major_mode=OperandMajorMode.K, + workspace_ptr=fake_workspace_ptr, + c=c_cute_fake, + d=d_cute_fake, + d_col=d_col_cute_fake, + sfa=sfa_cute_fake, + sfd_row_tensor=sfd_row_fake, + sfd_col_tensor=sfd_col_fake, + amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16), + norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16), + padded_offsets=self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16), + alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16), + bias=bias_cute_fake, + prob=prob_cute_fake, + max_active_clusters=max_active_clusters, + stream=fake_stream, + options="--enable-tvm-ffi", + ) + + cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator + + def tensor_api( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + d_col_tensor: Optional[torch.Tensor], + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + sfd_row_tensor: Optional[torch.Tensor], + sfd_col_tensor: Optional[torch.Tensor], + amax_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: Optional[torch.Tensor], + bias_tensor: Optional[torch.Tensor], + stream: cuda.CUstream, + ) -> None: + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + _compiled_kernel( + _reinterpret_raw_grouped_fp4_tensor(a_tensor), + _reinterpret_raw_grouped_fp4_tensor(b_tensor), + sfb_tensor, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int64(0), + cached_workspace_ptr, + c_tensor, + d_tensor, + d_col_tensor, + sfa_tensor, + sfd_row_tensor, + sfd_col_tensor, + amax_tensor, + norm_const_tensor, + padded_offsets, + alpha_tensor, + bias_tensor, + prob_tensor, + stream, + ) + + self._compiled_kernel = tensor_api + + def _compile_discrete(self, gemm_srelu, max_active_clusters, fake_stream) -> None: + """Compile for discrete (per-expert pointer) weight mode.""" + if len(self.b_shape) == 2: + n, k = self.b_shape + else: + n, k, _ = self.b_shape + + b_major_mode = OperandMajorMode.K if self.b_major == "k" else OperandMajorMode.MN + b_stride_size = k if self.b_major == "k" else n + + ab_cutlass_dtype = _convert_to_cutlass_data_type(self.a_desc.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2) + align = 32 if ab_cutlass_dtype.width == 4 else 16 + + valid_m = cute.sym_int(divisibility=256) + a_tensor = self._make_fake_cute_compact_tensor( + dtype=self.a_desc.dtype, + shape=(valid_m, *self.a_desc.shape[1:]), + stride_order=self.a_desc.stride_order, + assumed_align=align, + ) + c_tensor = self._make_fake_cute_compact_tensor( + dtype=self.c_desc.dtype, + shape=(valid_m, *self.c_desc.shape[1:]), + stride_order=self.c_desc.stride_order, + ) + d_tensor = self._make_fake_cute_compact_tensor( + dtype=self.d_desc.dtype, + shape=(valid_m, *self.d_desc.shape[1:]), + stride_order=self.d_desc.stride_order, + ) + d_col_tensor = self._make_fake_cute_compact_tensor( + dtype=self.d_col_desc.dtype, + shape=(valid_m, *self.d_col_desc.shape[1:]), + stride_order=self.d_col_desc.stride_order, + ) + + tensor_m_128 = cute.sym_int() + stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4) + sfa_shape = list(self.sfa_desc.shape) + sfa_shape[2] = tensor_m_128 + sfa_stride = list(self.sfa_desc.stride) + sfa_stride[5] = stride_tensor_m_128 + sfa_tensor = self._make_fake_cute_tensor( + dtype=self.sfa_desc.dtype, + shape=tuple(sfa_shape), + stride=tuple(sfa_stride), + assumed_align=16, + ) + sfd_row_tensor = None + if self.sfd_row_desc is not None: + stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_row_tensor = self._make_fake_cute_tensor( + dtype=self.sfd_row_desc.dtype, + shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1), + stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m), + assumed_align=16, + ) + sfd_col_tensor = None + if self.sfd_col_desc is not None: + rest_m = cute.sym_int(divisibility=1) + stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4) + stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4) + sfd_col_tensor = self._make_fake_cute_tensor( + dtype=self.sfd_col_desc.dtype, + shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1), + stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n), + assumed_align=16, + ) + amax_tensor = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16) + norm_const_tensor_cute = self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16) + padded_offsets_tensor = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16) + alpha_tensor = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16) + prob_tensor = self._make_fake_cute_tensor( + dtype=self.prob_desc.dtype, + shape=(valid_m, *self.prob_desc.shape[1:]), + stride=self.prob_desc.stride, + assumed_align=16, + ) + bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16) + + b_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda") + sfb_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda") + b_ptrs_cute = from_dlpack(b_ptrs_placeholder, assumed_align=8).iterator + sfb_ptrs_cute = from_dlpack(sfb_ptrs_placeholder, assumed_align=8).iterator + workspace_ptr_cute = from_dlpack(self._workspace, assumed_align=128).iterator + + self._logger.debug("Compiling discrete grouped_gemm_srelu kernel") + _compiled_kernel = cute.compile( + gemm_srelu, + a=a_tensor, + b=b_ptrs_cute, + sfb=sfb_ptrs_cute, + n=cutlass.Int32(n), + k=cutlass.Int32(k), + b_stride_size=cutlass.Int64(b_stride_size), + b_major_mode=b_major_mode, + workspace_ptr=workspace_ptr_cute, + c=c_tensor, + d=d_tensor, + d_col=d_col_tensor, + sfa=sfa_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor_cute, + padded_offsets=padded_offsets_tensor, + alpha=alpha_tensor, + bias=bias_cute_fake, + prob=prob_tensor, + max_active_clusters=max_active_clusters, + stream=fake_stream, + epilogue_op=lambda x: x, + options="--enable-tvm-ffi", + ) + + cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator + cached_n = cutlass.Int32(n) + cached_k = cutlass.Int32(k) + cached_b_stride = cutlass.Int64(b_stride_size) + + def tensor_api( + a_tensor: torch.Tensor, + b_ptrs_device: torch.Tensor, + sfb_ptrs_device: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + d_col_tensor: Optional[torch.Tensor], + sfa_tensor: torch.Tensor, + sfd_row_tensor: Optional[torch.Tensor], + sfd_col_tensor: Optional[torch.Tensor], + amax_tensor: Optional[torch.Tensor], + norm_const_tensor: Optional[torch.Tensor], + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: Optional[torch.Tensor], + bias_tensor: Optional[torch.Tensor], + stream: cuda.CUstream, + ) -> None: + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + b_ptrs_addr = int(b_ptrs_device.data_ptr()) + sfb_ptrs_addr = int(sfb_ptrs_device.data_ptr()) + _compiled_kernel( + a_tensor, + b_ptrs_addr, + sfb_ptrs_addr, + cached_n, + cached_k, + cached_b_stride, + cached_workspace_ptr, + c_tensor, + d_tensor, + d_col_tensor, + sfa_tensor, + sfd_row_tensor, + sfd_col_tensor, + amax_tensor, + norm_const_tensor, + padded_offsets, + alpha_tensor, + bias_tensor, + prob_tensor, + stream, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + a_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + padded_offsets: torch.Tensor, + alpha_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + # Dense mode: + b_tensor: Optional[torch.Tensor] = None, + sfb_tensor: Optional[torch.Tensor] = None, + bias_tensor: Optional[torch.Tensor] = None, + # Discrete mode: + b_ptrs: Optional[torch.Tensor] = None, + sfb_ptrs: Optional[torch.Tensor] = None, + d_col_tensor: Optional[torch.Tensor] = None, + sfd_row_tensor: Optional[torch.Tensor] = None, + sfd_col_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + prob_tensor: Optional[torch.Tensor] = None, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + """Execute the compiled kernel. + + :param a_tensor: Input A tensor + :param sfa_tensor: Scale factor A + :param padded_offsets: End offset per expert after padding + :param alpha_tensor: Per-group scaling factors + :param c_tensor: Output C tensor before SReLU + :param d_tensor: Output D tensor + :param b_tensor: (Dense) Input B tensor (weights) + :param sfb_tensor: (Dense) Scale factor B + :param bias_tensor: Optional bias tensor with shape (n, l) and stride (1, n). + Bias fusion is specialized at compile time: if ``sample_bias`` was omitted + at construction, ``bias_tensor`` must also be omitted at execute time. + :param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers + :param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers + :param d_col_tensor: Optional column-sreluized output + :param sfd_row_tensor: Optional row scale factor D + :param sfd_col_tensor: Optional column scale factor D + :param amax_tensor: Optional amax tensor + :param norm_const_tensor: Optional normalization constant + :param prob_tensor: Probability tensor for per-row gating. Required. + :param current_stream: CUDA stream + """ + self._logger.debug("Entering execute") + current_stream = self._get_default_stream(current_stream) + + if a_tensor.shape[0] == 0: + self._logger.debug("execute: valid_m is zero, skipping kernel execution") + return + self._runtime_error_if( + self._compiled_kernel is None, + "Kernel not compiled; call compile() first", + ) + + if d_col_tensor is None: + self._value_error_if( + self.generate_sfd, + "d_col_tensor is required when SFD outputs are generated", + ) + d_col_tensor = d_tensor + self._value_error_if( + prob_tensor is None, + "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. " + "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed.", + ) + if self._has_bias: + self._value_error_if( + bias_tensor is None, + "bias_tensor must be provided at execute() when the API was compiled with sample_bias", + ) + else: + self._value_error_if( + bias_tensor is not None, + "bias_tensor must be omitted at execute() when the API was compiled without sample_bias", + ) + + self._logger.debug("Executing grouped_gemm_srelu kernel") + if self.weight_mode == MoEWeightMode.DENSE: + self._compiled_kernel( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + bias_tensor=bias_tensor, + stream=current_stream, + ) + else: + self._compiled_kernel( + a_tensor=a_tensor, + b_ptrs_device=b_ptrs, + sfb_ptrs_device=sfb_ptrs, + c_tensor=c_tensor, + d_tensor=d_tensor, + d_col_tensor=d_col_tensor, + sfa_tensor=sfa_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + prob_tensor=prob_tensor, + bias_tensor=bias_tensor, + stream=current_stream, + ) + + self._logger.debug("Execute completed") + + +import logging + +_logger = logging.getLogger(__name__) +_cache_of_GroupedGemmSreluSm100Objects = {} + + +def grouped_gemm_srelu_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: Optional[torch.Tensor] = None, + sfa_tensor: Optional[torch.Tensor] = None, + sfb_tensor: Optional[torch.Tensor] = None, + padded_offsets: Optional[torch.Tensor] = None, + alpha_tensor: Optional[torch.Tensor] = None, + bias_tensor: Optional[torch.Tensor] = None, + b_ptrs: Optional[torch.Tensor] = None, + sfb_ptrs: Optional[torch.Tensor] = None, + n: Optional[int] = None, + b_dtype: Optional[torch.dtype] = None, + b_major: str = "k", + norm_const_tensor: Optional[torch.Tensor] = None, + prob_tensor: Optional[torch.Tensor] = None, + acc_dtype: torch.dtype = torch.float32, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.bfloat16, + cd_major: str = "n", + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + use_dynamic_sched: bool = False, + current_stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + """Convenience wrapper for grouped GEMM SReLU operation. + + This function creates the API, compiles, and executes in one call. + Compiled kernels are cached for reuse when called with the same configuration. + + Args: + a_tensor: Input A tensor (valid_m, k, 1) + sfa_tensor: Scale factor A + padded_offsets: End offset per expert after padding (l,) + alpha_tensor: Per-group scaling + b_tensor: (Dense) Weight B tensor (n, k, l) + sfb_tensor: (Dense) Scale factor B + bias_tensor: Optional per-expert bias, shape ``(n, l)`` in dense mode or ``(n, num_experts)`` + in discrete mode, stride ``(1, n)``. Bias fusion requires ``mma_tiler_mn[1] == 256``. + b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers + sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers + n: (Discrete) B weight N dimension + b_dtype: (Discrete) B weight data type + b_major: (Discrete) B tensor major dimension ("k" or "n") + norm_const_tensor: Optional normalization constant. Required when using FP8 + input configurations (i.e., when a_tensor.dtype is FP8 and sfa_tensor.dtype is FP8). + Should be None for FP4/BF16 input configurations. + prob_tensor: Probability tensor for per-row gating (shape `(valid_m, 1, 1)`). + This argument is required. Pass a tensor of ones when no gating is needed. + acc_dtype: Accumulator data type + c_dtype: Output C tensor data type + d_dtype: Output D tensor data type + cd_major: CD major dimension (only "n"-major layout is supported) + mma_tiler_mn: MMA tiler shape + cluster_shape_mn: Cluster shape + sf_vec_size: Scale factor vector size + vector_f32: Use vectorized f32 + m_aligned: M alignment (must be 256) + discrete_col_sfd: Enable discrete col-major scale factor tensor + current_stream: CUDA stream + + Returns: + TupleDict: A dictionary-like object containing output tensors that can also be unpacked as a tuple. + Dictionary keys (also the unpacking order): + - **c_tensor** (torch.Tensor): Accumulator output tensor + - **d_tensor** (torch.Tensor): Final output tensor + - **d_col_tensor** (torch.Tensor or None): Column-wise output tensor for low-precision D output + - **amax_tensor** (torch.Tensor or None): Absolute maximum values (for SReLU output quantization) + - **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D (FP8 only) + - **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D (FP8 only) + + Example usage:: + + # Dictionary-style access + result = grouped_gemm_srelu_wrapper_sm100(...) + c = result["c_tensor"] + d = result["d_tensor"] + + # Tuple unpacking + c, d, d_col, amax, sfd_row, sfd_col = grouped_gemm_srelu_wrapper_sm100(...) + + # Integer indexing + c = result[0] # c_tensor + d = result[1] # d_tensor + """ + from cudnn.discrete_grouped_gemm.discrete_kernel_utils import _require_pointer_tensor + + is_dense = b_tensor is not None + is_discrete = b_ptrs is not None + + if is_dense and is_discrete: + raise ValueError("Provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs), not both") + if not is_dense and not is_discrete: + raise ValueError("Must provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs)") + + valid_m, k_physical, _ = a_tensor.shape + if is_dense: + weight_mode = MoEWeightMode.DENSE + n_out, _, l = b_tensor.shape + if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, l): + raise ValueError(f"bias_tensor must have shape {(n_out, l)}, got {tuple(bias_tensor.shape)}") + else: + weight_mode = MoEWeightMode.DISCRETE + _require_pointer_tensor(b_ptrs, "b_ptrs") + num_experts = b_ptrs.shape[0] + _require_pointer_tensor(sfb_ptrs, "sfb_ptrs", num_experts) + if n is None or b_dtype is None: + raise ValueError("n and b_dtype are required for discrete mode") + k_logical = k_physical * 2 if b_dtype in (torch.float4_e2m1fn_x2, torch.uint8) else k_physical + b_shape = (n, k_logical) + n_out = n + l = num_experts + if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, num_experts): + raise ValueError(f"bias_tensor must have shape {(n_out, num_experts)}, got {tuple(bias_tensor.shape)}") + + is_fp8_input_config = a_tensor.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ] and sfa_tensor.dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ] + is_low_precision_output_config = d_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float4_e2m1fn_x2, + ] + + _logger.debug("grouped_gemm_srelu_wrapper_sm100: Creating output tensors") + + if cd_major == "n": + c_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=c_dtype, device=a_tensor.device) + d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + d_col_tensor = ( + torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + if is_low_precision_output_config + else None + ) + else: + raise ValueError(f"cd_major must be 'n', got {cd_major}") + + sfd_row_tensor = None + sfd_col_tensor = None + amax_tensor = None + + if is_fp8_input_config and is_low_precision_output_config and norm_const_tensor is None: + raise ValueError( + "norm_const_tensor is required when FP8 inputs are used with FP8 output " + "(a_tensor is FP8 and sfa_tensor is FP8 and d_dtype is FP8). " + "Pass a tensor with shape (1,), e.g. torch.tensor([0.01], dtype=torch.float32, device=a_tensor.device)." + ) + + if not is_low_precision_output_config: + norm_const_tensor = None + + if is_fp8_input_config and is_low_precision_output_config: + _logger.debug("grouped_gemm_srelu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor") + + sf_dtype = sfa_tensor.dtype + mma_permute_order = (3, 4, 1, 5, 2, 0) + + sf_k_row = ceil_div(n_out, sf_vec_size) + mma_shape_row = ( + 1, + ceil_div(valid_m, 128), + ceil_div(sf_k_row, 4), + 32, + 4, + 4, + ) + sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + sf_k_col = ceil_div(valid_m, sf_vec_size) + mma_shape_col = ( + 1, + ceil_div(n_out, 128), + ceil_div(sf_k_col, 4), + 32, + 4, + 4, + ) + sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + if d_dtype in [torch.bfloat16, torch.float16]: + _logger.debug("grouped_gemm_srelu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor") + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + + if prob_tensor is None: + raise ValueError( + "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. " + "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed." + ) + + if valid_m == 0: + _logger.debug("grouped_gemm_srelu_wrapper_sm100: valid_m is zero, skipping kernel execution") + return TupleDict( + c_tensor=c_tensor, + d_tensor=d_tensor, + d_col_tensor=d_col_tensor, + amax_tensor=amax_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + ) + + def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype + + def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: + return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) + + def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + return None, stride_order(tensor), tensor.dtype + + def dynamic_m_tensor_signature( + tensor: Optional[torch.Tensor], static_shape_suffix: Tuple[int, ...], dynamic_stride_dims: Tuple[int, ...] = () + ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]: + if tensor is None: + return None, None, None + stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride())) + return static_shape_suffix, stride_signature, tensor.dtype + + use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" + + if is_dense: + cache_key = ( + weight_mode, + use_full_dynamic, + a_tensor.shape[1:] if not use_full_dynamic else None, + b_tensor.shape[2] if use_full_dynamic else tuple(b_tensor.shape), + a_tensor.dtype, + b_tensor.dtype, + stride_order(a_tensor), + stride_order(b_tensor), + c_tensor.shape[1:] if not use_full_dynamic else None, + stride_order(c_tensor), + c_tensor.dtype, + d_tensor.shape[1:] if not use_full_dynamic else None, + stride_order(d_tensor), + *( + dynamic_tensor_signature(sfa_tensor) + if use_full_dynamic + else dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)) + ), + *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)), + *(dynamic_tensor_signature(bias_tensor) if use_full_dynamic else tensor_signature(bias_tensor)), + *tensor_signature(alpha_tensor), + *tensor_signature(norm_const_tensor), + *dynamic_m_tensor_signature(prob_tensor, (1, 1)), + tuple(padded_offsets.shape), + tuple(padded_offsets.stride()), + padded_offsets.dtype, + acc_dtype, + d_dtype, + cd_major, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + discrete_col_sfd, + use_dynamic_sched, + ) + else: + cache_key = ( + weight_mode, + a_tensor.shape[1:], + stride_order(a_tensor), + a_tensor.dtype, + b_shape, + b_dtype, + c_tensor.shape[1:], + stride_order(c_tensor), + c_tensor.dtype, + d_tensor.shape[1:], + stride_order(d_tensor), + *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)), + *tensor_signature(bias_tensor), + *tensor_signature(alpha_tensor), + *tensor_signature(norm_const_tensor), + *dynamic_m_tensor_signature(prob_tensor, (1, 1)), + tuple(b_ptrs.shape), + tuple(b_ptrs.stride()), + b_ptrs.dtype, + tuple(sfb_ptrs.shape), + tuple(sfb_ptrs.stride()), + sfb_ptrs.dtype, + tuple(padded_offsets.shape), + tuple(padded_offsets.stride()), + padded_offsets.dtype, + acc_dtype, + d_dtype, + cd_major, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + discrete_col_sfd, + use_dynamic_sched, + b_major, + num_experts, + ) + + if cache_key in _cache_of_GroupedGemmSreluSm100Objects: + _logger.debug("grouped_gemm_srelu_wrapper_sm100: Using previously cached GroupedGemmSreluSm100 object") + grouped_gemm_srelu = _cache_of_GroupedGemmSreluSm100Objects[cache_key] + else: + _logger.debug("grouped_gemm_srelu_wrapper_sm100: No previously cached object found, creating new GroupedGemmSreluSm100 object") + if is_dense: + grouped_gemm_srelu = GroupedGemmSreluSm100( + sample_a=a_tensor, + sample_sfa=sfa_tensor, + sample_padded_offsets=padded_offsets, + sample_alpha=alpha_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_d_col=d_col_tensor, + sample_b=b_tensor, + sample_sfb=sfb_tensor, + sample_bias=bias_tensor, + sample_amax=amax_tensor, + sample_sfd_row=sfd_row_tensor, + sample_sfd_col=sfd_col_tensor, + sample_norm_const=norm_const_tensor, + sample_prob=prob_tensor, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + discrete_col_sfd=discrete_col_sfd, + use_dynamic_sched=use_dynamic_sched, + ) + else: + grouped_gemm_srelu = GroupedGemmSreluSm100( + sample_a=a_tensor, + sample_sfa=sfa_tensor, + sample_padded_offsets=padded_offsets, + sample_alpha=alpha_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_d_col=d_col_tensor, + num_experts=num_experts, + b_shape=b_shape, + b_dtype=b_dtype, + sample_bias=bias_tensor, + sample_amax=amax_tensor, + sample_sfd_row=sfd_row_tensor, + sample_sfd_col=sfd_col_tensor, + sample_norm_const=norm_const_tensor, + sample_prob=prob_tensor, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + discrete_col_sfd=discrete_col_sfd, + use_dynamic_sched=use_dynamic_sched, + b_major=b_major, + ) + + assert grouped_gemm_srelu.check_support(), "Unsupported configuration" + grouped_gemm_srelu.compile() + _cache_of_GroupedGemmSreluSm100Objects[cache_key] = grouped_gemm_srelu + + if is_dense: + grouped_gemm_srelu.execute( + a_tensor=a_tensor, + sfa_tensor=sfa_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + b_tensor=b_tensor, + sfb_tensor=sfb_tensor, + d_col_tensor=d_col_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + prob_tensor=prob_tensor, + bias_tensor=bias_tensor, + current_stream=current_stream, + ) + else: + grouped_gemm_srelu.execute( + a_tensor=a_tensor, + sfa_tensor=sfa_tensor, + padded_offsets=padded_offsets, + alpha_tensor=alpha_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + b_ptrs=b_ptrs, + sfb_ptrs=sfb_ptrs, + d_col_tensor=d_col_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + prob_tensor=prob_tensor, + bias_tensor=bias_tensor, + current_stream=current_stream, + ) + + return TupleDict( + c_tensor=c_tensor, + d_tensor=d_tensor, + d_col_tensor=d_col_tensor, + amax_tensor=amax_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py new file mode 100644 index 00000000..e1c3245b --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py @@ -0,0 +1,2139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +""" +MoE Block-Scaled Grouped GEMM Kernel with Quantization and SReLU Support. + +Supports: + - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler) + - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout + - FP8/FP4 output quantization with row/column scale factors (SFD) + - Optional bias and routing-probability (prob) fusion + - Optional C output (generate_c) + - AMAX reduction for FP8 calibration + - SReLU epilogue activation fusion (max(x,0)^2) + +This module contains only the kernel class. +MoE scheduler components live in moe_persistent_scheduler.py / moe_sched_extension.py / moe_utils.py. +""" + +from enum import Enum +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass._mlir.dialects.nvvm import ReduxKind +from cutlass.cute.typing import Float32, Int32, AddressSpace +from ..moe_persistent_scheduler import ( + MoEPersistentTileScheduler, + MoESchedulerParams, + MoEWorkTileInfo, +) +from ..moe_utils import ( + compute_expert_token_range, + MoEWeightMode, + TensormapWorkspace, + store_tma_desc, +) +from ..moe_sched_extension import ( + DiscreteWeightScaledGemmSchedExtension, + ContiguousAndConsistentGroupedGemmSchedExtension, +) +from ..moe_kernel_helpers import ( + fmin, + fmax, + warp_redux_sync, + atomic_max_float32, + compute_stages, + compute_grid, + can_implement, + amax_reduction_per_thread, + epilog_gmem_copy_and_partition, + get_dtype_rcp_limits, +) + + +class EpilogueType(Enum): + NONE = 0 + SRELU = 1 + + +class BlockScaledMoEGroupedGemmQuantKernel: + """Block-scaled grouped GEMM kernel with MoE tile scheduling and quantization. + + Supports both dense and discrete weight layouts, static and dynamic + scheduling, and quantized output with row/column scale factors. + + :param sf_vec_size: Scale-factor vector size (16 or 32). + :param acc_dtype: Accumulator data type (Float32). + :param use_2cta_instrs: Use 2-CTA MMA instructions. + :param mma_tiler_mn: MMA tile shape (M, N). + :param cluster_shape_mn: Cluster shape (M, N). + :param vectorized_f32: Use packed FP32 arithmetic. + :param generate_sfd: Generate output scale factors. + :param discrete_col_sfd: Use discrete column SFD layout. + :param generate_c: Generate C output tensor. + :param enable_bias: Fuse bias addition. + :param expert_cnt: Number of experts. + :param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``. + :param use_dynamic_sched: Enable dynamic tile scheduling. + :param epilogue_type: Epilogue activation type (``EpilogueType.NONE`` or ``EpilogueType.SRELU``). + """ + + FIX_PAD_SIZE = 256 + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + cd_major: str, + m_aligned: int, + ) -> bool: + return can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + acc_dtype, + d_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + cd_major, + m_aligned, + fix_pad_size=BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE, + ) + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vectorized_f32: bool, + generate_sfd: bool, + discrete_col_sfd: bool, + generate_c: bool, + enable_bias: bool, + expert_cnt: int, + weight_mode: MoEWeightMode = MoEWeightMode.DENSE, + use_dynamic_sched: bool = False, + epilogue_type: int = EpilogueType.NONE.value, + ): + mma_tile_m = mma_tiler_mn[0] + if self.FIX_PAD_SIZE % mma_tile_m != 0: + raise ValueError( + f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}). " f"Supported mma_tiler_mn[0] values: 128, 256." + ) + if expert_cnt > 1024: + raise ValueError("Expert count > 1024 is not supported.") + if not isinstance(weight_mode, MoEWeightMode): + raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}") + + self.sf_vec_size = sf_vec_size + self.expert_cnt = expert_cnt + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.epilog_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.bias_load_warp_id = 7 if enable_bias else None + self.threads_per_warp = 32 + all_warps = [ + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + ] + warps_wo_sched = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id] + if enable_bias: + all_warps.append(self.bias_load_warp_id) + warps_wo_sched.append(self.bias_load_warp_id) + + self.threads_per_cta = self.threads_per_warp * len(all_warps) + self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched) + + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.vectorized_f32 = vectorized_f32 + self.generate_sfd = generate_sfd + self.discrete_col_sfd = discrete_col_sfd + self.generate_c = generate_c + self.enable_bias = enable_bias + + self.weight_mode = weight_mode + self.use_dynamic_sched = use_dynamic_sched + + self.epilogue_use_functor = False + self.epilogue_type = epilogue_type + + self.num_epilog_warps = len(self.epilog_warp_id) + + # ------------------------------------------------------------------ + # _setup_attributes + # ------------------------------------------------------------------ + + def _setup_attributes(self): + """Configure MMA / tile / stage / SMEM layouts from GEMM inputs.""" + + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + self.mma_tiler_d = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_d = ( + self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_d[1], + self.mma_tiler_d[2], + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.epi_tile = (128, 32) + + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_d_stage, + self.num_tile_stage, + self.num_bias_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.d_dtype, + self.d_layout, + self.sf_dtype, + self.sf_vec_size, + self.num_smem_capacity, + self.occupancy, + self.generate_sfd, + self.generate_c, + self.bias_dtype if self.enable_bias else None, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + if self.enable_bias: + self.bias_smem_layout_staged = cute.make_layout( + (self.mma_tiler[1], self.num_bias_stage), + stride=(1, self.mma_tiler[1]), + ) + else: + self.bias_smem_layout_staged = cute.make_layout((1, 1)) + + self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256 + + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + self.epi_tile_n_required = cute.size(self.epi_tile[1]) + self.iter_acc_early_release_in_epilogue = (self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1 + + # ------------------------------------------------------------------ + # _compute_stages (with bias support) + # ------------------------------------------------------------------ + + @staticmethod + def _compute_stages( + tiled_mma, + mma_tiler_mnk, + a_dtype, + b_dtype, + epi_tile, + c_dtype, + c_layout, + d_dtype, + d_layout, + sf_dtype, + sf_vec_size, + num_smem_capacity, + occupancy, + generate_sfd, + generate_c, + bias_dtype, + ): + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + num_c_stage = 2 if generate_sfd else 1 + num_d_stage = 2 if generate_sfd else 1 + num_tile_stage = 2 + + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1) + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + sinfo_bytes = 4 * 4 * num_tile_stage + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1) + amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0 + + if bias_dtype is not None: + num_bias_stage = 2 + bias_epi_tile_n = mma_tiler_mnk[1] + bias_bytes = bias_epi_tile_n * num_bias_stage * (bias_dtype.width // 8) + else: + num_bias_stage = 0 + bias_bytes = 0 + + epi_bytes = c_bytes + d_bytes + amax_bytes + bias_bytes + num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage + + # cute.printf("num_acc_stage: %d, num_ab_stage: %d, num_c_stage: %d, num_d_stage: %d, num_tile_stage: %d, num_bias_stage: %d\n", num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage + + # ------------------------------------------------------------------ + # Workspace helpers + # ------------------------------------------------------------------ + + def get_desc_workspace_bytes(self) -> int: + if self.weight_mode == MoEWeightMode.DISCRETE: + from ..moe_utils import DiscreteWeightTensormapConstructor + + return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt) + return 0 + + def get_workspace_bytes(self) -> int: + desc_workspace_bytes = self.get_desc_workspace_bytes() + dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0 + return desc_workspace_bytes + dynamic_sched_bytes + + @cute.jit + def _get_sched_counter_ptr(self, workspace_ptr): + counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes() + return cute.make_ptr( + cutlass.Int32, + counter_addr, + AddressSpace.gmem, + assumed_align=4, + ) + + # ------------------------------------------------------------------ + # helper_kernel: pre-main-kernel initialization + # - discrete weight: build per-expert B/SFB TMA descriptors + # - dynamic sched: reset the atomic tile counter + # ------------------------------------------------------------------ + + @cute.kernel + def helper_kernel( + self, + # Discrete-only params (unused in dense mode, but must be present for signature) + ptrs_b: cute.Pointer, + ptrs_sfb: cute.Pointer, + n: Int32, + k: Int32, + b_stride_size: cutlass.Int64, + b_major_mode: cutlass.Constexpr, + workspace_ptr, + tiled_mma_arg: cute.TiledMma, + tiled_mma_sfb_arg: cute.TiledMma, + b_smem_layout_arg, + sfb_smem_layout_arg, + cluster_layout_vmnk_shape_arg: cutlass.Constexpr, + cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr, + ): + """Pre-main-kernel initialization. + + Launched with grid=(expert_cnt, 1, 1) for discrete mode, or + grid=(1, 1, 1) for dense+dynamic mode. + + Discrete weight: each block builds B/SFB TMA descriptors for one expert. + Dynamic sched: block 0 resets the atomic tile counter to 0. + """ + expert_idx = cute.arch.block_idx()[0] + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id) + sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id) + + b_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,)) + ) + sfb_ptr_tensor = cute.make_tensor( + cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,)) + ) + + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + stride_n = b_stride_size + stride_k = c1_64 + else: + stride_n = c1_64 + stride_k = b_stride_size + + b_ptr_val = b_ptr_tensor[expert_idx] + b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem) + b_tensor_i = cute.make_tensor( + b_ptr, + cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)), + ) + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + b_tma_op_arg, + b_tensor_i, + b_smem_layout_arg, + self.mma_tiler, + tiled_mma_arg, + cluster_layout_vmnk_shape_arg, + ) + workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx)) + + sfb_ptr_val = sfb_ptr_tensor[expert_idx] + sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb_tensor_i = cute.make_tensor(sfb_ptr, sfb_layout) + tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( + sfb_tma_op_arg, + sfb_tensor_i, + sfb_smem_layout_arg, + self.mma_tiler_sfb, + tiled_mma_sfb_arg, + cluster_layout_sfb_vmnk_shape_arg, + internal_type=cutlass.Uint64, + ) + store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx)) + + if cutlass.const_expr(self.use_dynamic_sched): + if expert_idx == cutlass.Int32(0): + sched_counter = cute.make_tensor( + self._get_sched_counter_ptr(workspace_ptr), + cute.make_layout(1), + ) + sched_counter[0] = cutlass.Int32(0) + + # ------------------------------------------------------------------ + # __call__ + # ------------------------------------------------------------------ + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[] + sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[] + n: Int32, # Ignored for dense mode + k: Int32, # Ignored for dense mode + b_stride_size: cutlass.Int64, # Ignored for dense mode + b_major_mode: cutlass.Constexpr, # Ignored for dense mode + workspace_ptr, + c: cute.Tensor, + d: cute.Tensor, + d_col: Optional[cute.Tensor], + sfa: cute.Tensor, + sfd_row_tensor: Optional[cute.Tensor], + sfd_col_tensor: Optional[cute.Tensor], + amax_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + bias: Optional[cute.Tensor], + prob: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM. + + Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L). + Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[] + arrays of per-expert base addresses; ``n``, ``k``, ``b_stride_size``, + ``b_major_mode`` describe the uniform per-expert layout. + """ + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = a.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.d_dtype: Type[cutlass.Numeric] = d.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + self.d_layout = utils.LayoutEnum.from_tensor(d) + self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16 + + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + else: + self.b_major_mode = b_major_mode + + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"A/B dtype must match: {self.a_dtype} != {self.b_dtype}") + + self._setup_attributes() + + # ---- SFA layout ---- + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size) + sfa = cute.make_tensor(sfa.iterator, sfa_layout) + + # ---- B / SFB setup (mode-dependent) ---- + # Save the call-arg b/sfb before the discrete branch overwrites them + # with template tensors. helper_kernel needs the original Pointers. + b_from_call_arg = b + sfb_from_call_arg = sfb + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) + sfb = cute.make_tensor(sfb.iterator, sfb_layout) + else: + c1 = cutlass.Int32(1) + c0 = cutlass.Int64(0) + c1_64 = 1 + if cutlass.const_expr(b_major_mode == OperandMajorMode.K): + b_template_stride = (b_stride_size, c1_64, c0) + else: + b_template_stride = (c1_64, b_stride_size, c0) + b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride) + b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16) + b = cute.make_tensor(b_ptr_typed, b_template_layout) + + sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size) + sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout) + + # ---- SFD setup ---- + self.generate_sfd = sfd_row_tensor is not None and norm_const_tensor is not None + if cutlass.const_expr(self.generate_sfd == False): + self.discrete_col_sfd = False + if cutlass.const_expr(self.generate_sfd): + sfd_row_layout = blockscaled_utils.tile_atom_to_shape_SF(d.shape, self.sf_vec_size) + sfd_row_tensor = cute.make_tensor(sfd_row_tensor.iterator, sfd_row_layout) + sfd_col_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout, + d.shape, + (1, 2, 3), + ) + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_layout = sfd_row_layout + sfd_col_tensor = cute.make_tensor(sfd_col_tensor.iterator, sfd_col_layout) + + self.generate_amax = amax_tensor is not None + + # ---- TMA atoms ---- + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + c_smem_layout, + self.epi_tile, + ) + d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d, + d_smem_layout, + self.epi_tile, + ) + tma_atom_d_col, tma_tensor_d_col = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d_col, + d_smem_layout, + self.epi_tile, + ) + + # ---- Helper kernel: TMA desc init (discrete) + sched counter reset (dynamic) ---- + _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched) + if cutlass.const_expr(_need_helper): + _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1 + _helper_args = ( + b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem), + n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0), + b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0), + b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode, + workspace_ptr, + tiled_mma, + tiled_mma_sfb, + b_smem_layout, + sfb_smem_layout, + self.cluster_layout_vmnk.shape, + self.cluster_layout_sfb_vmnk.shape, + ) + self.helper_kernel(*_helper_args).launch( + grid=(_helper_grid_x, 1, 1), + block=(1, 1, 1), + stream=stream, + min_blocks_per_mp=1, + ) + + # ---- Grid computation via MoE scheduler ---- + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + b_n, b_k, b_l = cute.shape(b) # B is (N, K, L) + sched_expert_shape = (self.expert_cnt, b_n, b_k) + else: + sched_expert_shape = (self.expert_cnt, n, k) + + sched_params = MoESchedulerParams( + scenario="2Dx3D", + expert_shape=sched_expert_shape, + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + use_dynamic_sched=self.use_dynamic_sched, + ) + self.sched_params, grid = compute_grid(sched_params, max_active_clusters, self.use_2cta_instrs) + + self.buffer_align_bytes = 1024 + + # ---- Shared storage ---- + sD_col_size = cute.cosize(self.d_smem_layout_staged.outer) if self.generate_sfd else 0 + SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched) + + @cute.struct + class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + scheduler: SchedulerStorage + if cutlass.const_expr(self.enable_bias): + bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + sC: cute.struct.Align[ + cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sD_col: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, sD_col_size], + self.buffer_align_bytes, + ] + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + if cutlass.const_expr(self.enable_bias): + sBias: cute.struct.Align[ + cute.struct.MemRange[self.bias_dtype, cute.cosize(self.bias_smem_layout_staged)], + 16, + ] + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + 4, + ] + + self.shared_storage = SharedStorage + + # ---- Launch ---- + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + tma_atom_d_col, + tma_tensor_d_col, + sfd_row_tensor, + sfd_col_tensor, + norm_const_tensor, + amax_tensor, + padded_offsets, + alpha, + bias, + prob, + workspace_ptr, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.bias_smem_layout_staged, + self.epi_tile, + self.sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + max_number_threads=[self.threads_per_cta, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # ------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------ + + def mainloop_s2t_copy_and_partition(self, sSF, tSF): + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + @cute.jit + def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem): + warp_amax = warp_redux_sync( + value=amax_fp32, + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + if cute.arch.lane_idx() == 0: + amax_smem[warp_idx] = cutlass.Float32(warp_amax) + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) + for i in cutlass.range(self.num_epilog_warps): + warp_amax_val = amax_smem[i] + block_amax = cute.arch.fmax(block_amax, warp_amax_val) + _ = atomic_max_float32(ptr=amax_gmem, value=block_amax) + + @cute.jit + def store_c( + self, + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + prev_subtile_idx, + real_subtile_idx, + ): + c_buffer = prev_subtile_idx % self.num_c_stage + tRS_rC.store(tTR_rAcc.load().to(self.c_dtype)) + cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)]) + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)]) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + @cute.jit + def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD): + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + pvscale_f32x4 = cute.make_rmem_tensor(4, cutlass.Float32) + sfd_f8x4 = cute.make_rmem_tensor(4, self.sf_dtype) + tmp_f32 = abs_acc_frg[None, 0].reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) * rcp_limit * norm_const + if tile_idx == 0: + pvscale[0] = tmp_f32 + elif tile_idx == 1: + pvscale[1] = tmp_f32 + elif tile_idx == 2: + pvscale[2] = tmp_f32 + elif tile_idx == 3: + pvscale[3] = tmp_f32 + pvscale_f32x4[0] = tmp_f32 + sfd_f8x4.store(pvscale_f32x4.load().to(self.sf_dtype)) + pvscale_f32x4.store(sfd_f8x4.load().to(cutlass.Float32)) + qpvscale_up = pvscale_f32x4[0] + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale = norm_const * cute.arch.rcp_approx(qpvscale_up) + acc_scale = fmin(acc_scale, fp32_max, nan=True) + if cutlass.const_expr(self.vectorized_f32): + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(0, self.sf_vec_size, 2): + vec[ei], vec[ei + 1] = cute.arch.mul_packed_f32x2( + (vec[ei], vec[ei + 1]), + (acc_scale, acc_scale), + rnd="rn", + ftz=False, + ) + else: + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(self.sf_vec_size): + vec[ei] = vec[ei] * acc_scale + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def quant_sfd_col(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD): + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + tmp_f32 = cutlass.Float32(0.0) + for vi in cutlass.range_constexpr(acc_frg.shape[0]): + max_value_original = ( + cutlass.Float32( + warp_redux_sync( + value=acc_frg[vi, 0], + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + ) + * rcp_limit + * norm_const + ) + max_value_vec = cute.full(4, max_value_original, dtype=cutlass.Float32) + max_value_vec_f8 = max_value_vec.to(cutlass.Float8E8M0FNU) + max_value_vec_f32_chunked = max_value_vec_f8.to(cutlass.Float32) + max_value = max_value_vec_f32_chunked[0] + tidx = cute.arch.thread_idx()[0] + if tidx % 32 == vi: + tmp_f32 = max_value + acc_scale_col = cutlass.Float32(0.0) + if max_value_vec_f32_chunked[0] == 0.000000: + acc_scale_col = cutlass.Float32(0.0) + else: + acc_scale_col = norm_const * cute.arch.rcp_approx(max_value_vec_f32_chunked[0]) + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale_col = fmin(acc_scale_col, fp32_max) + tTR_rAcc_frg[vi] = tTR_rAcc_frg[vi] * acc_scale_col + pvscale[None, None, tile_idx][0] = tmp_f32 + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def tile_info_to_mn_idx(self, tile_info: cute.Tensor): + m_idx = tile_info[1] * cute.size(self.cta_tile_shape_mnk[0]) + n_idx = tile_info[2] * cute.size(self.cta_tile_shape_mnk[1]) + return m_idx, n_idx + + @cute.jit + def create_and_partition_new_SFDCol(self, tile_info, mSFDCol_mnl, padded_offsets): + m_idx, n_idx = self.tile_info_to_mn_idx(tile_info) + expert_idx = tile_info[0] + cumsum_tokens, tokens_this_group = compute_expert_token_range(padded_offsets, expert_idx) + n_total = cute.size(mSFDCol_mnl.shape[1]) + + sf_tile_idx_begin = cumsum_tokens // cute.size(mSFDCol_mnl.shape[0][0]) + mSFDCol_mnl_new_ptr = mSFDCol_mnl[(None, sf_tile_idx_begin), None, 0].iterator + + sfd_col_quant_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout, + (tokens_this_group, n_total, mSFDCol_mnl.shape[2]), + (1, 2, 3), + ) + regPerSubtile = 4 + sfd_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile)) + mSFDCol_mnl_new = cute.make_tensor(mSFDCol_mnl_new_ptr, sfd_col_quant_layout) + gSFDCol_mnl_new = cute.local_tile(mSFDCol_mnl_new, sfd_tile, (None, None, None)) + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col_quant = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gSFDCol_mnl_new.element_type, + num_bits_per_copy=8, + ) + tiled_copy_sfd_col_quant = cute.make_tiled_copy_tv( + copy_atom_sfd_col_quant, + thr_layout, + val_layout, + ) + tidx = cute.arch.thread_idx()[0] + thr_copy_sfd_col_quant = tiled_copy_sfd_col_quant.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col_quant.partition_D(cute.filter_zeros(gSFDCol_mnl_new)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + return tCgSFDCol_mnl + + def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs): + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.d_layout, + self.d_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi) + tTR_rAcc = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rD, tidx, sD): + copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) + tRS_rD = tiled_copy_r2s.retile(tTR_rD) + return tiled_copy_r2s, tRS_rD, tRS_sD + + # ------------------------------------------------------------------ + # GPU device kernel + # ------------------------------------------------------------------ + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, + tma_atom_d_col: cute.CopyAtom, + mD_col_mnl: cute.Tensor, + mSFDRow_mnl: Optional[cute.Tensor], + mSFDCol_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + mAmax_tensor: Optional[cute.Tensor], + padded_offsets: cute.Tensor, + alpha: cute.Tensor, + mBias_nl: Optional[cute.Tensor], + prob: cute.Tensor, + workspace_ptr, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + bias_smem_layout_staged: Optional[cute.Layout], + epi_tile: cute.Tile, + sched_params: MoESchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """GPU device kernel for persistent MoE grouped GEMM with quantization.""" + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_sfa) + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE): + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_d) + if cutlass.const_expr(self.generate_sfd): + cpasync.prefetch_descriptor(tma_atom_d_col) + if cutlass.const_expr(self.generate_c): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + total_token = padded_offsets[self.expert_cnt - 1] + + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + tidx, _, _ = cute.arch.thread_idx() + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sched_storage = storage.scheduler + + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + tile_info_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp * 1) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_wo_sched) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=sched_storage.tile_info_mbar.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + if cutlass.const_expr(self.enable_bias): + bias_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp) + bias_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + bias_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.bias_mbar_ptr.data_ptr(), + num_stages=self.num_bias_stage, + producer_group=bias_pipeline_producer_group, + consumer_group=bias_pipeline_consumer_group, + ) + sBias = storage.sBias.get_tensor(bias_smem_layout_staged) + # (MMA_N, loopN, loopL) — tiled over mma_tile_n in N, full L; indexed + # directly by expert_idx (ref glu_bias line 1771-1773 pattern). + gBias_nl = cute.local_tile(mBias_nl, cute.slice_(self.mma_tiler[:2], (0, None)), (None, None)) + + scheduler = MoEPersistentTileScheduler.create( + sched_params, + padded_offsets, + cute.arch.block_idx(), + cute.arch.grid_dim(), + counter_ptr=self._get_sched_counter_ptr(workspace_ptr), + sched_storage=sched_storage, + ) + scheduler.internal_init() + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + sD_col = sD + if cutlass.const_expr(self.generate_sfd): + sD_col = storage.sD_col.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = sched_storage.sInfo.get_tensor(info_layout) + + # Multicast masks — must create ALL when any mcast or 2CTA is active + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + # MMA partition (for tCtAcc_fake shape computation only) + thr_mma_common = tiled_mma.get_slice(0) + tCsA_common = thr_mma_common.partition_A(sA) + tCsB_common = thr_mma_common.partition_B(sB) + tCsA_common = cute.filter_zeros(tCsA_common) + tCsB_common = cute.filter_zeros(tCsB_common) + + # SMEM fragments for MMA (used by MMA warp) + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + + # TMEM accumulator shape + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped)) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # Cluster sync before warp specialization + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + if total_token <= 0: + cute.arch.nvvm.exit() + + # ============================================================== + # Scheduler warp (MoE Persistent Tile Scheduler) + # ============================================================== + if warp_idx == self.sched_warp_id: + work_tile_info = scheduler.initial_work_tile_info() + tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage) + while work_tile_info.is_valid_tile: + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx + sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx + sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx + sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + work_tile_info = scheduler.advance_to_next_work() + + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1) + sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # ============================================================== + # Bias load warp + # ============================================================== + if cutlass.const_expr(self.enable_bias): + if warp_idx == self.bias_load_warp_id: + # Ported from ref glu_bias bias_load_warp (lines 2422-2483). + # No extension needed: mBias_nl has shape (N, L) and we index L + # directly by expert_idx from sInfo — much simpler than the old + # `bias_ext.get_gmem_tensor("bias", ...)` path (which reshaped the + # tensor and was unnecessary for a global, per-expert-sliced bias). + bias_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_bias_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + # 128-bit cp.async: threads_per_warp × (128 / bias_dtype bits) = tile_N + bias_elems_per_thread = 128 // self.bias_dtype.width + bias_g2s_atom = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + self.bias_dtype, + num_bits_per_copy=128, + ) + bias_g2s_tiled = cute.make_tiled_copy_tv( + bias_g2s_atom, + cute.make_layout((self.threads_per_warp,)), + cute.make_layout((bias_elems_per_thread,)), + ) + thr_bias_g2s = bias_g2s_tiled.get_slice(cute.arch.lane_idx()) + tBs_sBias = thr_bias_g2s.partition_D(sBias) + + # Per-thread N predicate tensor + bias_n_total = mBias_nl.shape[0] + tBpBias = cute.make_rmem_tensor(cute.make_layout((1,)), cutlass.Boolean) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + bias_producer_state.reset_count() + mma_n_coord = tile_info[2] + expert_idx = tile_info[0] + + # Direct L-indexing — no ext, no domain_offset + gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)] + tBs_gBias = thr_bias_g2s.partition_S(gBias_tile) + + # Predicate: this thread's chunk must be within valid N + tBpBias[0] = mma_n_coord * self.mma_tiler[1] + cute.arch.lane_idx() * bias_elems_per_thread < bias_n_total + + bias_pipeline.producer_acquire(bias_producer_state) + cute.copy( + bias_g2s_tiled, + tBs_gBias[(None, 0)], + tBs_sBias[(None, 0, bias_producer_state.index)], + pred=tBpBias, + ) + bias_pipeline.producer_commit(bias_producer_state) + bias_producer_state.advance() + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + bias_pipeline.producer_tail(bias_producer_state) + + # ============================================================== + # DMA / TMA load warp + # ============================================================== + if warp_idx == self.tma_warp_id: + ext = self._make_extension(workspace_ptr) + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + k_tile_cnt = work_tile_info.k_tile_cnt + ext.update_expert_info(padded_offsets, work_tile_info.expert_idx) + + real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info) + real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info) + real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info) + real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info) + + gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None)) + + # MMA partition on gmem tensors + thr_mma_dma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb_dma = tiled_mma_sfb.get_slice(mma_tile_coord_v) + tCgA = thr_mma_dma.partition_A(gA_mkl) + tCgB = thr_mma_dma.partition_B(gB_nkl) + tCgSFA = thr_mma_dma.partition_A(gSFA_mkl) + tCgSFB = thr_mma_sfb_dma.partition_B(gSFB_nkl) + + # TMA partition A + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA partition B + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + # TMA partition SFA + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA partition SFB + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape) + mma_tile_coord_n = work_tile_info.tile_n_idx + tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)] + tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)] + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)] + slice_n = mma_tile_coord_n + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_n // 2 + tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)] + + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] + tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask) + cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b) + cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask) + cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb) + + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + ab_pipeline.producer_tail(ab_producer_state) + + # ============================================================== + # MMA warp + # ============================================================== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # SFA TMEM tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # SFB TMEM tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # S2T copy partition for SFA/SFB + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + k_tile_cnt = tile_info[3] + + # Peek AB buffer full + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # Peek Acc buffer empty + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state) + + mma_tile_coord_mnl = ( + tile_info[1] // cute.size(tiled_mma.thr_id.shape), + tile_info[2], + tile_info[0], + ) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_producer_state.phase ^ 1 + else: + acc_stage_index = acc_producer_state.index + + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + if is_leader_cta: + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + s2t_stage_coord = (None, None, None, None, ab_consumer_state.index) + cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t[s2t_stage_coord], tCtSFA_compact_s2t) + cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t[s2t_stage_coord], tCtSFB_compact_s2t) + + num_kblocks = cute.size(tCrA, mode=[2]) + ab_consumer_state_next = ab_consumer_state.clone() + ab_consumer_state_next.advance() + if ab_consumer_state_next.count < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next) + + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator) + tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator) + cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state = ab_consumer_state_next + + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + acc_pipeline.producer_tail(acc_producer_state) + + # ============================================================== + # Epilogue warps + # ============================================================== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v) + + # Shape-only partition on global tensor (invariant setup for t2r copy atom) + gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape) + + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCgD_shape, + epi_tile, + use_2cta_instrs, + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, + tTR_rC, + epi_tidx, + sC, + ) + tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, + tTR_rD, + epi_tidx, + sD, + ) + tTR_rD_col = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD_col, tRS_sD_col = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, + tTR_rD_col, + epi_tidx, + sD_col, + ) + + if cutlass.const_expr(self.generate_sfd): + norm_const = norm_const_tensor[0] + regPerSubtile = 4 + sfd_row_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile)) + gSFDRow_mnl = cute.local_tile(mSFDRow_mnl, sfd_row_tile, (None, None, None)) + thr_copy_t2r_local = tiled_copy_t2r.get_slice(tidx) + tCgSFDRow_mnl = thr_copy_t2r_local.partition_D(gSFDRow_mnl) + tCgSFDRow_mnl = cute.filter_zeros(tCgSFDRow_mnl) + tCrSFDRow = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype) + tCrSFDRow_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + d_rcp_limits = get_dtype_rcp_limits(self.d_dtype) + + sfd_col_tile = sfd_row_tile + gSFDCol_mnl = cute.local_tile(mSFDCol_mnl, sfd_col_tile, (None, None, None)) + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gSFDCol_mnl.element_type, num_bits_per_copy=8) + tiled_copy_sfd_col = cute.make_tiled_copy_tv(copy_atom_sfd_col, thr_layout, val_layout) + thr_copy_sfd_col = tiled_copy_sfd_col.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col.partition_D(cute.filter_zeros(gSFDCol_mnl)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + tCrSFDCol = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].shape, self.sf_dtype) + tCrSFDCol_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + + epi_ext = self._make_extension(workspace_ptr) + + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id)) + c_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_producer_group) + d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id)) + d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + if cutlass.const_expr(self.enable_bias): + bias_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_bias_stage) + bias_s2r_tom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.bias_dtype, num_bits_per_copy=128) + tTR_rBias = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype) + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + epi_work_tile_info = MoEWorkTileInfo( + expert_idx=tile_info[0], + tile_m_idx=tile_info[1], + tile_n_idx=tile_info[2], + k_tile_cnt=tile_info[3], + ) + expert_idx = epi_work_tile_info.expert_idx + epi_ext.update_expert_info(padded_offsets, expert_idx) + + alpha_val = alpha[expert_idx] + + if cutlass.const_expr(self.enable_bias): + bias_consumer_state.reset_count() + bias_pipeline.consumer_wait(bias_consumer_state) + sBias_stage = sBias[(None, bias_consumer_state.index)] + sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(self.epi_tile[1])) + + real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info) + real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info) + if cutlass.const_expr(self.generate_sfd): + real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info) + + thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v) + + gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop) + _, bSG_sD, bSG_gD_partitioned = epilog_gmem_copy_and_partition( + epi_tidx, + tma_atom_d, + tCgD_loop, + epi_tile, + sD, + ) + + gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) + tCgC_loop = thr_mma_epi_loop.partition_C(gC_mnl_loop) + _, bSG_sC, bSG_gC_partitioned = epilog_gmem_copy_and_partition( + epi_tidx, + tma_atom_c, + tCgC_loop, + epi_tile, + sC, + ) + + if cutlass.const_expr(self.generate_sfd): + gD_col_mnl_loop = cute.local_tile(real_d_col, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + tCgD_col_loop = thr_mma_epi_loop.partition_C(gD_col_mnl_loop) + _, bSG_sD_col, bSG_gD_col_partitioned = epilog_gmem_copy_and_partition( + epi_tidx, + tma_atom_d_col, + tCgD_col_loop, + epi_tile, + sD_col, + ) + + epi_mma_tile_coord = ( + epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + epi_work_tile_info.tile_n_idx, + 0, + ) + bSG_gC = bSG_gC_partitioned[(None, None, None, *epi_mma_tile_coord)] + bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)] + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + if cutlass.const_expr(self.generate_sfd): + bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)] + bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col)) + + if cutlass.const_expr(self.generate_sfd): + tCgSFDRow_mn = tCgSFDRow_mnl[(None, None, None, None, None, 0)] + tCgSFDCol_mnl_new = tCgSFDCol_mnl + if cutlass.const_expr(self.discrete_col_sfd): + tCgSFDCol_mnl_new = self.create_and_partition_new_SFDCol(tile_info, mSFDCol_mnl, padded_offsets) + tCgSFDCol_mn = tCgSFDCol_mnl_new[(None, None, None, None, None, 0)] + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + mPosition = epi_work_tile_info.tile_m_idx * self.cta_tile_shape_mnk[0] + tidx + real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info) + mProb = real_prob[mPosition, 0, 0] + + # C1 fix: phase-based acc stage indexing for overlapping_accum + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = acc_consumer_state.index + + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + acc_pipeline.consumer_wait(acc_consumer_state) + + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(0, subtile_cnt, 1, unroll=1): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx + + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.enable_bias): + # m7 fix: use real_subtile_idx directly (matches contiguous) + sBias_sub = sBias_subtiles[(None, real_subtile_idx)] + cute.copy(bias_s2r_tom, sBias_sub, tTR_rBias) + bias_vec = tTR_rBias.load() + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + bias_f32_0 = bias_vec[i].to(cutlass.Float32) + bias_f32_1 = bias_vec[i + 1].to(cutlass.Float32) + bias_f32_0, bias_f32_1 = cute.arch.mul_packed_f32x2( + (mProb, mProb), + (bias_f32_0, bias_f32_1), + rnd="rn", + ftz=False, + ) + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.fma_packed_f32x2( + (tTR_rAcc[i], tTR_rAcc[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + (bias_f32_0, bias_f32_1), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val) + bias_vec[i].to(cutlass.Float32) * mProb + else: + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc[i], tTR_rAcc[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val) + + if cutlass.const_expr(self.generate_c): + self.store_c( + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + num_prev_subtiles, + real_subtile_idx, + ) + + acc_vec = tTR_rAcc.load() + + if cutlass.const_expr(self.epilogue_type == EpilogueType.SRELU.value): + acc_relu = cute.where(acc_vec > 0, acc_vec, cute.full_like(acc_vec, 0)) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2( + (acc_relu[i], acc_relu[i + 1]), + (acc_relu[i], acc_relu[i + 1]), + rnd="rn", + ftz=False, + ) + acc_vec = tTR_rAcc.load() + + if cutlass.const_expr(not self.enable_bias): + tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype) + if cutlass.const_expr(self.vectorized_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2( + (acc_vec[i], acc_vec[i + 1]), + (mProb, mProb), + rnd="rn", + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tCompute[i] = acc_vec[i] * mProb + else: + tCompute = tTR_rAcc + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = amax_reduction_per_thread(tCompute, thread_tile_amax) + + if cutlass.const_expr(self.generate_sfd): + tCompute_col = cute.make_rmem_tensor(tCompute.layout, tCompute.element_type) + tCompute_col.store(tCompute.load()) + self.quant_sfd_row( + real_subtile_idx % 4, + tiled_copy_r2s, + tCompute, + tCrSFDRow_pvscale, + norm_const, + d_rcp_limits, + tRS_rD, + ) + self.quant_sfd_col( + real_subtile_idx % 4, + tiled_copy_r2s, + tCompute_col, + tCrSFDCol_pvscale, + norm_const, + d_rcp_limits, + tRS_rD_col, + ) + # SFD M tile = cta_tile_m = 128; tile_m_idx is CTA-level per-expert + global_sfd_m = epi_work_tile_info.tile_m_idx + epi_ext.token_offset // self.cta_tile_shape_mnk[0] + if cutlass.const_expr(self.mma_tiler[1] == 256): + sfd_n = epi_work_tile_info.tile_n_idx * 2 + (real_subtile_idx >> 2) + else: + sfd_n = epi_work_tile_info.tile_n_idx + sfd_row_idx_mn = (global_sfd_m, sfd_n) + sfd_col_idx_mn = sfd_row_idx_mn + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_idx_mn = ( + epi_work_tile_info.tile_m_idx, + sfd_n, + ) + tCgSFDRow = tCgSFDRow_mn[(None, None, None, *sfd_row_idx_mn)] + tCgSFDCol = tCgSFDCol_mn[(None, None, None, *sfd_col_idx_mn)] + if subtile_idx == 3 or subtile_idx == 7: + if sfd_row_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDRow_mnl.layout, mode=[1])): + tCrSFDRow.store(tCrSFDRow_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDRow, tCgSFDRow) + if sfd_col_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDCol_mnl.layout, mode=[1])): + tCrSFDCol.store(tCrSFDCol_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDCol, tCgSFDCol) + else: + acc_vec = tiled_copy_r2s.retile(tCompute).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + d_buffer = num_prev_subtiles % self.num_d_stage + num_prev_subtiles = num_prev_subtiles + 1 + cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)]) + if cutlass.const_expr(self.generate_sfd): + cute.copy(tiled_copy_r2s, tRS_rD_col, tRS_sD_col[(None, None, None, d_buffer)]) + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + cute.copy(tma_atom_d, bSG_sD[(None, d_buffer)], bSG_gD[(None, real_subtile_idx)]) + if cutlass.const_expr(self.generate_sfd): + cute.copy(tma_atom_d_col, bSG_sD_col[(None, d_buffer)], bSG_gD_col[(None, real_subtile_idx)]) + d_pipeline.producer_commit() + d_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + if cutlass.const_expr(self.enable_bias): + bias_pipeline.consumer_release(bias_consumer_state) + bias_consumer_state.advance() + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[0] >= cutlass.Int32(0) + cute.arch.fence_proxy("async.shared", space="cta") + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + if cutlass.const_expr(self.generate_amax): + gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr + self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax) + + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.generate_c): + c_pipeline.producer_tail() + d_pipeline.producer_tail() + + # ------------------------------------------------------------------ + # Internal: create extension based on weight_mode + # ------------------------------------------------------------------ + + @cute.jit + def _make_extension(self, workspace_ptr): + if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE): + desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"]) + return DiscreteWeightScaledGemmSchedExtension( + tensormap_ctor=desc_workspace, + sf_vec_size=self.sf_vec_size, + ) + else: + return ContiguousAndConsistentGroupedGemmSchedExtension( + sf_vec_size=self.sf_vec_size, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py index 91f5431a..7f876e1e 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py @@ -123,7 +123,7 @@ def __init__( """ super().__init__() - self._logger.warning("GroupedGemmSwigluSm100 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") # Store sample tensor descriptors @@ -168,7 +168,7 @@ def __init__( self._kernel = BlockScaledContiguousGroupedGemmKernel self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0")) - print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") + self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}") self._logger.debug(f"__init__ completed") def check_support(self) -> bool: @@ -482,7 +482,7 @@ def compile(self) -> None: fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) self._logger.debug("Compiling grouped_gemm_swiglu kernel") - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" if not use_full_dynamic: # only mark the m dimension as dynamic valid_m = cute.sym_int(divisibility=256) @@ -910,11 +910,10 @@ def grouped_gemm_swiglu_wrapper_sm100( ) sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) - if d_dtype in [torch.bfloat16, torch.float16]: - _logger.debug("grouped_gemm_swiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor") - amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) - if valid_m == 0: + if d_dtype in [torch.bfloat16, torch.float16]: + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + _logger.debug("grouped_gemm_swiglu_wrapper_sm100: valid_m is zero, skipping kernel execution") return TupleDict( c_tensor=c_tensor, @@ -925,7 +924,7 @@ def grouped_gemm_swiglu_wrapper_sm100( sfd_col_tensor=sfd_col_tensor, ) - use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None + use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0" def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1])) @@ -962,7 +961,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: if cache_key in _cache_of_GroupedGemmSwigluSm100Objects: _logger.debug("group_gemm_swiglu_wrapper_sm100: Using previously cached GroupedGemmSwigluSm100 object") - grouped_gemm_swiglu = _cache_of_GroupedGemmSwigluSm100Objects[cache_key] + grouped_gemm_swiglu, amax_tensor = _cache_of_GroupedGemmSwigluSm100Objects[cache_key] + # The cuDNN graph API binds data pointers at execute time, not plan-build time. + # During CUDA graph capture, padded_offsets is allocated in the graph pool + # (stable address across replays), so passing it directly is graph-safe. grouped_gemm_swiglu.execute( a_tensor=a_tensor, b_tensor=b_tensor, @@ -982,6 +984,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: ) else: _logger.debug("group_gemm_swiglu_wrapper_sm100: No previously cached GroupedGemmSwigluSm100 object found, creating new GroupedGemmSwigluSm100 object") + # Allocate amax_tensor once here; cache-hit calls reuse this buffer so + # the FillFunctor (torch.full) only fires during warmup, not every step. + if d_dtype in [torch.bfloat16, torch.float16]: + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) grouped_gemm_swiglu = GroupedGemmSwigluSm100( sample_a=a_tensor, sample_b=b_tensor, @@ -1025,7 +1031,7 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]: prob_tensor=prob_tensor, current_stream=current_stream, ) - _cache_of_GroupedGemmSwigluSm100Objects[cache_key] = grouped_gemm_swiglu + _cache_of_GroupedGemmSwigluSm100Objects[cache_key] = (grouped_gemm_swiglu, amax_tensor) return TupleDict( c_tensor=c_tensor, diff --git a/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py b/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py index cb5808f8..c880295e 100644 --- a/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py +++ b/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py @@ -44,13 +44,13 @@ def __init__( sample_global_scale_a: Optional[torch.Tensor] = None, sample_global_scale_b: Optional[torch.Tensor] = None, acc_dtype: torch.dtype = torch.float32, - mma_tiler_mn: Tuple[int, int] = (128, 128), + mma_tiler_mn: Tuple[int, int] = (256, 256), cluster_shape_mn: Optional[Tuple[int, int]] = None, sf_vec_size: int = 16, accumulate_on_output: bool = False, ): super().__init__() - self._logger.warning("GroupedGemmWgradSm100 is an experimental API") + self._warn_experimental_api() if sample_wgrad is not None and num_experts is None: self.weight_mode = MoEWeightMode.DENSE @@ -79,7 +79,7 @@ def __init__( f"sample_a and sample_b token dimensions must match, got {tokens_sum_a} and {tokens_sum_b}", ) self._offset_values = self._validate_offsets(sample_offsets, tokens_sum_a, name="sample_offsets") - self._scale_cols = self._compute_scale_cols(self._offset_values) + self._scale_cols = _round_up(ceil_div(tokens_sum_a, self.sf_vec_size), 4) if self.weight_mode == MoEWeightMode.DENSE: self.wgrad_desc = self._make_tensor_desc(sample_wgrad, name="sample_wgrad") @@ -136,23 +136,14 @@ def _validate_offsets(self, offsets_tensor: torch.Tensor, tokens_sum: int, name: if offset_values: self._value_error_if( - offset_values[-1] != tokens_sum, - f"{name} last value must equal total tokens {tokens_sum}, got {offset_values[-1]}", + offset_values[-1] > tokens_sum, + f"{name} last value must not exceed total tokens {tokens_sum}, got {offset_values[-1]}", ) else: self._value_error_if(tokens_sum != 0, f"{name} cannot be empty when total tokens is {tokens_sum}") return offset_values - def _compute_scale_cols(self, offset_values: Tuple[int, ...]) -> int: - prev_offset = 0 - scale_cols = 0 - for offset in offset_values: - group_k = offset - prev_offset - scale_cols += _round_up(ceil_div(group_k, self.sf_vec_size), 4) - prev_offset = offset - return scale_cols - def check_support(self) -> bool: m, tokens_sum = self._tensor_shape(self.a_desc, name="sample_a") _, n = self._tensor_shape(self.b_desc, name="sample_b") @@ -572,11 +563,13 @@ def grouped_gemm_wgrad_wrapper_sm100( sfb_tensor: torch.Tensor, offsets_tensor: torch.Tensor, output_mode: str = "dense", + wgrad_tensor: Optional[torch.Tensor] = None, + wgrad_ptrs: Optional[torch.Tensor] = None, global_scale_a: Optional[torch.Tensor] = None, global_scale_b: Optional[torch.Tensor] = None, acc_dtype: torch.dtype = torch.float32, wgrad_dtype: torch.dtype = torch.bfloat16, - mma_tiler_mn: Tuple[int, int] = (128, 128), + mma_tiler_mn: Tuple[int, int] = (256, 256), cluster_shape_mn: Optional[Tuple[int, int]] = None, sf_vec_size: int = 16, accumulate_on_output: bool = False, @@ -585,15 +578,18 @@ def grouped_gemm_wgrad_wrapper_sm100( """Compile and execute grouped GEMM wgrad in one call.""" hidden, _ = a_tensor.shape _, intermediate = b_tensor.shape + wgrad_shape = (hidden, intermediate) expert_cnt = offsets_tensor.shape[0] if output_mode not in {"dense", "discrete"}: raise ValueError(f"output_mode must be 'dense' or 'discrete', got {output_mode}") - if accumulate_on_output: - wgrad_tensor = torch.zeros((expert_cnt, hidden, intermediate), dtype=wgrad_dtype, device=a_tensor.device) - else: - wgrad_tensor = torch.empty((expert_cnt, hidden, intermediate), dtype=wgrad_dtype, device=a_tensor.device) + if wgrad_tensor is None and wgrad_ptrs is None: + # Backward compatibility: Dense mode. + if accumulate_on_output: + wgrad_tensor = torch.zeros((expert_cnt, *wgrad_shape), dtype=wgrad_dtype, device=a_tensor.device) + else: + wgrad_tensor = torch.empty((expert_cnt, *wgrad_shape), dtype=wgrad_dtype, device=a_tensor.device) cache_key = ( output_mode, @@ -604,6 +600,7 @@ def grouped_gemm_wgrad_wrapper_sm100( tuple(offsets_tensor.shape), tuple(offsets_tensor.stride()), offsets_tensor.dtype, + *_dynamic_dim_tensor_signature(wgrad_tensor, dynamic_dims=()), tuple(global_scale_a.shape) if global_scale_a is not None else None, global_scale_a.dtype if global_scale_a is not None else None, tuple(global_scale_b.shape) if global_scale_b is not None else None, @@ -636,13 +633,14 @@ def grouped_gemm_wgrad_wrapper_sm100( accumulate_on_output=accumulate_on_output, ) else: + sample_expert = torch.empty(wgrad_shape, dtype=wgrad_dtype, device=a_tensor.device) op = GroupedGemmWgradSm100( sample_a=a_tensor, sample_b=b_tensor, sample_sfa=sfa_tensor, sample_sfb=sfb_tensor, sample_offsets=offsets_tensor, - sample_wgrad_expert=wgrad_tensor[0], + sample_wgrad_expert=sample_expert, num_experts=expert_cnt, wgrad_shape=(hidden, intermediate), wgrad_dtype=wgrad_dtype, @@ -665,6 +663,7 @@ def grouped_gemm_wgrad_wrapper_sm100( sfb_tensor=sfb_tensor, offsets_tensor=offsets_tensor, wgrad_tensor=wgrad_tensor, + wgrad_ptrs=wgrad_ptrs, global_scale_a=global_scale_a, global_scale_b=global_scale_b, current_stream=current_stream, diff --git a/python/cudnn/native_sparse_attention/compression/api.py b/python/cudnn/native_sparse_attention/compression/api.py index 7433eff5..fc89c004 100644 --- a/python/cudnn/native_sparse_attention/compression/api.py +++ b/python/cudnn/native_sparse_attention/compression/api.py @@ -40,7 +40,7 @@ def __init__( super().__init__() self._kernel = BlackwellFusedMultiHeadAttentionForward - self._logger.warning("CompressionAttention is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self.q_desc = self._make_tensor_desc(sample_q, name="sample_q") diff --git a/python/cudnn/native_sparse_attention/selection/api.py b/python/cudnn/native_sparse_attention/selection/api.py index daf0f7a6..03ca47c2 100644 --- a/python/cudnn/native_sparse_attention/selection/api.py +++ b/python/cudnn/native_sparse_attention/selection/api.py @@ -33,7 +33,7 @@ def __init__( super().__init__() self._kernel = HopperSelectAttentionFwd - self._logger.warning("SelectionAttention is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self.q_desc = self._make_tensor_desc(sample_q, name="sample_q") diff --git a/python/cudnn/native_sparse_attention/top_k/api.py b/python/cudnn/native_sparse_attention/top_k/api.py index 392afdc5..3d0a5f63 100644 --- a/python/cudnn/native_sparse_attention/top_k/api.py +++ b/python/cudnn/native_sparse_attention/top_k/api.py @@ -47,7 +47,7 @@ def __init__( super().__init__() self._kernel = FineGrainedReductionQK - self._logger.warning("TopKReduction is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") self.q_desc = self._make_tensor_desc(sample_q, name="sample_q") diff --git a/python/cudnn/ops/__init__.py b/python/cudnn/ops/__init__.py new file mode 100644 index 00000000..54a48c07 --- /dev/null +++ b/python/cudnn/ops/__init__.py @@ -0,0 +1 @@ +from .causal_conv1d import causal_conv1d diff --git a/python/cudnn/ops/causal_conv1d.py b/python/cudnn/ops/causal_conv1d.py new file mode 100644 index 00000000..bab76197 --- /dev/null +++ b/python/cudnn/ops/causal_conv1d.py @@ -0,0 +1,231 @@ +from typing import List, Optional + +import torch +from torch import Tensor + +_TORCH_DTYPE_TO_CUDNN = { + torch.float32: 0, # CUDNN_DATA_FLOAT + torch.float16: 2, # CUDNN_DATA_HALF + torch.bfloat16: 9, # CUDNN_DATA_BFLOAT16 +} + +_ACTIVATION_TO_INT = { + "identity": 0, # CUDNN_CAUSAL_CONV1D_ACTIVATION_IDENTITY + "silu": 1, # CUDNN_CAUSAL_CONV1D_ACTIVATION_SILU +} + + +def _dtype_to_int(dtype: torch.dtype) -> int: + if dtype not in _TORCH_DTYPE_TO_CUDNN: + raise ValueError(f"Unsupported dtype {dtype}. Supported: float32, float16, bfloat16.") + return _TORCH_DTYPE_TO_CUDNN[dtype] + + +def _activation_to_int(activation: str) -> int: + if activation not in _ACTIVATION_TO_INT: + raise ValueError(f"Unsupported activation '{activation}'. Supported: 'identity', 'silu'.") + return _ACTIVATION_TO_INT[activation] + + +# --------------------------------------------------------------------------- +# Forward primitive +# --------------------------------------------------------------------------- + + +@torch.library.custom_op( + "cudnn::causal_conv1d_fwd_primitive", + mutates_args=(), + device_types="cuda", +) +def _fwd_primitive(x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> Tensor: + if x.dim() != 3 or weight.dim() != 2 or bias.dim() != 1: + raise ValueError(f"Expected x(3D), weight(2D), bias(1D); got {x.shape}, {weight.shape}, {bias.shape}") + + if not (x.is_cuda and weight.is_cuda and bias.is_cuda): + raise ValueError(f"All tensors must be on CUDA: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}") + if not (x.device == weight.device == bias.device): + raise ValueError(f"All tensors must be on the same device: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}") + + if not (x.dtype == weight.dtype == bias.dtype): + raise TypeError(f"Dtype mismatch: x.dtype={x.dtype}, weight.dtype={weight.dtype}, " f"bias.dtype={bias.dtype} (all must match)") + + x = x.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + + batch, dim, seq_len = x.shape + kernel_size = weight.shape[1] + + if weight.shape[0] != dim: + raise ValueError(f"Channel mismatch: x has dim={dim} but weight has shape {weight.shape} " f"(expected weight.shape[0]={dim})") + + if bias.shape[0] != dim: + raise ValueError(f"Bias mismatch: x has dim={dim} but bias has shape {bias.shape} " f"(expected bias.shape[0]={dim})") + + y = torch.empty_like(x) + + import cudnn + + cudnn.causal_conv1d_forward( + torch.cuda.current_stream().cuda_stream, + x.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), + y.data_ptr(), + batch, + dim, + seq_len, + kernel_size, + _dtype_to_int(x.dtype), + _activation_to_int(activation), + ) + return y + + +@torch.library.register_fake("cudnn::causal_conv1d_fwd_primitive") +def _fwd_fake(x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> Tensor: + return torch.empty_like(x) + + +# --------------------------------------------------------------------------- +# Backward primitive +# --------------------------------------------------------------------------- + + +@torch.library.custom_op( + "cudnn::causal_conv1d_bwd_primitive", + mutates_args=(), + device_types="cuda", +) +def _bwd_primitive(grad_out: Tensor, x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> List[Tensor]: + if x.dim() != 3 or weight.dim() != 2 or bias.dim() != 1: + raise ValueError(f"Expected x(3D), weight(2D), bias(1D); got {x.shape}, {weight.shape}, {bias.shape}") + if grad_out.shape != x.shape: + raise ValueError(f"Shape mismatch: dy has shape {grad_out.shape} but x has shape {x.shape} " f"(expected dy.shape == x.shape)") + if not grad_out.is_cuda: + raise ValueError(f"grad_out must be on CUDA: grad_out.device={grad_out.device}") + if grad_out.device != x.device: + raise ValueError(f"Device mismatch: grad_out.device={grad_out.device}, x.device={x.device}") + if grad_out.dtype != x.dtype: + raise ValueError(f"Dtype mismatch: grad_out.dtype={grad_out.dtype}, x.dtype={x.dtype}") + + if not (x.is_cuda and weight.is_cuda and bias.is_cuda): + raise ValueError(f"All tensors must be on CUDA: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}") + if not (x.device == weight.device == bias.device): + raise ValueError(f"All tensors must be on the same device: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}") + + if not (x.dtype == weight.dtype == bias.dtype): + raise TypeError(f"Dtype mismatch: x.dtype={x.dtype}, weight.dtype={weight.dtype}, " f"bias.dtype={bias.dtype} (all must match)") + + batch, dim, seq_len = x.shape + + if weight.shape[0] != dim: + raise ValueError(f"Channel mismatch: x has dim={dim} but weight has shape {weight.shape} " f"(expected weight.shape[0]={dim})") + + if bias.shape[0] != dim: + raise ValueError(f"Bias mismatch: x has dim={dim} but bias has shape {bias.shape} " f"(expected bias.shape[0]={dim})") + + x = x.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + grad_out = grad_out.contiguous() + + kernel_size = weight.shape[1] + + dx = torch.empty_like(x) + dweight = torch.zeros(weight.shape, device=x.device, dtype=torch.float32) + dbias = torch.zeros(bias.shape, device=x.device, dtype=torch.float32) + + import cudnn + + cudnn.causal_conv1d_backward( + torch.cuda.current_stream().cuda_stream, + x.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), + grad_out.data_ptr(), + dx.data_ptr(), + dweight.data_ptr(), + dbias.data_ptr(), + batch, + dim, + seq_len, + kernel_size, + _dtype_to_int(x.dtype), + _dtype_to_int(torch.float32), + _activation_to_int(activation), + ) + return [dx, dweight.to(x.dtype), dbias.to(x.dtype)] + + +@torch.library.register_fake("cudnn::causal_conv1d_bwd_primitive") +def _bwd_fake(grad_out: Tensor, x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> List[Tensor]: + return [torch.empty_like(x), torch.empty_like(weight), torch.empty_like(bias)] + + +# --------------------------------------------------------------------------- +# Autograd glue +# --------------------------------------------------------------------------- + + +def _setup_context(ctx, inputs, output): + x, weight, bias, activation = inputs + ctx.save_for_backward(x, weight, bias) + ctx.activation = activation + + +@torch.compiler.allow_in_graph +def _autograd_bwd(ctx, grad_out): + x, weight, bias = ctx.saved_tensors + dx, dw, db = torch.ops.cudnn.causal_conv1d_bwd_primitive(grad_out, x, weight, bias, ctx.activation) + return dx, dw, db, None + + +torch.library.register_autograd( + "cudnn::causal_conv1d_fwd_primitive", + _autograd_bwd, + setup_context=_setup_context, +) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def causal_conv1d( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + activation: str = "identity", +) -> Tensor: + r"""Depthwise causal 1D convolution with optional activation. + + Computes a depthwise 1D convolution with causal (left-only) padding + and optional fused activation:: + + y = activation(conv1d_causal(x, weight) + bias) + + Causal padding: ``(kernel_size - 1)`` on the left, ``0`` on the right. + Each channel is convolved independently with its own 1D filter. + + Supports ``torch.compile`` and ``torch.autograd`` — backward is handled + automatically when inputs require gradients. + + Args: + x (torch.Tensor): Input tensor of shape ``(batch, dim, seq_len)``. + Must be BF16, FP16, or FP32. Must be contiguous and on CUDA. + weight (torch.Tensor): Filter tensor of shape ``(dim, kernel_size)``. + Same dtype as *x*. + bias (torch.Tensor | None): Optional bias of shape ``(dim,)``. + Same dtype as *x*. Defaults to zeros if ``None``. + activation (str): ``"identity"`` (default) or ``"silu"``. + + Returns: + torch.Tensor: Output of shape ``(batch, dim, seq_len)``, same dtype as *x*. + """ + if activation not in _ACTIVATION_TO_INT: + raise ValueError(f"Unsupported activation '{activation}'. Supported: 'identity', 'silu'.") + if bias is None: + bias = torch.zeros(weight.shape[0], device=x.device, dtype=x.dtype) + return torch.ops.cudnn.causal_conv1d_fwd_primitive(x, weight, bias, activation) diff --git a/python/cudnn/rmsnorm_rht_amax/__init__.py b/python/cudnn/rmsnorm_rht_amax/__init__.py new file mode 100644 index 00000000..9da5cae1 --- /dev/null +++ b/python/cudnn/rmsnorm_rht_amax/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +from .api import ( + RmsNormRhtAmaxSm100, + best_num_threads, + pick_rows_per_cta, + rmsnorm_rht_amax_wrapper_sm100, +) + +__all__ = [ + "RmsNormRhtAmaxSm100", + "best_num_threads", + "pick_rows_per_cta", + "rmsnorm_rht_amax_wrapper_sm100", +] diff --git a/python/cudnn/rmsnorm_rht_amax/api.py b/python/cudnn/rmsnorm_rht_amax/api.py new file mode 100644 index 00000000..0294a3a2 --- /dev/null +++ b/python/cudnn/rmsnorm_rht_amax/api.py @@ -0,0 +1,302 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""FE API for fused RMSNorm + RHT + per-CTA amax.""" + +import logging +from typing import Optional + +from cuda.bindings import driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32 +from cutlass.cute.runtime import make_fake_stream + +from cudnn.api_base import APIBase, TupleDict + +from .kernel import RMSNormRHTAmaxKernel + +DEFAULT_NUM_THREADS_BY_N = { + 2048: 128, + 4096: 256, + 7168: 128, + 8192: 512, + 16384: 1024, + 32768: 512, +} +RPC_CANDIDATES = (2, 4, 8) +TARGET_MIN_CTAS = 148 + + +def best_num_threads(n: int) -> Optional[int]: + for num_threads in (1024, 512, 256, 128, 64): + if n % num_threads != 0: + continue + ept = n // num_threads + if ept >= 8 and ept % 8 == 0: + return num_threads + return None + + +def pick_rows_per_cta(m: int) -> int: + for rows_per_cta in reversed(RPC_CANDIDATES): + if m % rows_per_cta != 0: + continue + num_ctas = m // rows_per_cta + if num_ctas >= TARGET_MIN_CTAS: + return rows_per_cta + return RPC_CANDIDATES[0] + + +class RmsNormRhtAmaxSm100(APIBase): + """Class API for the RMSNorm + RHT + amax kernel.""" + + def __init__( + self, + sample_x: torch.Tensor, + sample_w: torch.Tensor, + sample_o: torch.Tensor, + sample_amax: torch.Tensor, + eps: float = 1e-5, + num_threads: Optional[int] = None, + rows_per_cta: Optional[int] = None, + ): + super().__init__() + + self._warn_experimental_api() + + self.x_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_x, name="sample_x"), 2, "sample_x") + self.w_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_w, name="sample_w"), 1, "sample_w") + self.o_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_o, name="sample_o"), 2, "sample_o") + self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "sample_amax") + + self.eps = eps + self.requested_num_threads = num_threads + self.requested_rows_per_cta = rows_per_cta + self.num_threads = None + self.rows_per_cta = None + self.n = None + + def check_support(self) -> bool: + m, n = self._tensor_shape(self.x_desc, name="sample_x") + w_n = self._tensor_shape(self.w_desc, name="sample_w")[0] + o_m, o_n = self._tensor_shape(self.o_desc, name="sample_o") + + self._check_tensor_shape(self.x_desc, (m, n), "X") + self._check_tensor_shape(self.w_desc, (n,), "W") + self._check_tensor_shape(self.o_desc, (m, n), "O") + self._value_error_if(w_n != n, f"W length must match X hidden dimension, got {w_n} and {n}") + self._value_error_if((n % 16) != 0, f"N must be divisible by 16 for the Hadamard block size, got {n}") + self._value_error_if(o_m != m or o_n != n, f"O shape must match X shape, got {(o_m, o_n)} and {(m, n)}") + + self._check_tensor_stride(self.x_desc, stride=(n, 1), name="X", extra_error_msg="X must be row-major contiguous") + self._check_tensor_stride(self.w_desc, stride=(1,), name="W", extra_error_msg="W must be contiguous") + self._check_tensor_stride(self.o_desc, stride=(n, 1), name="O", extra_error_msg="O must be row-major contiguous") + + self._check_dtype(self.x_desc, dtype=torch.bfloat16, name="X") + self._check_dtype(self.w_desc, dtype=torch.bfloat16, name="W") + self._check_dtype(self.o_desc, dtype=torch.bfloat16, name="O") + self._check_dtype(self.amax_desc, dtype=torch.float32, name="Amax") + + resolved_num_threads = self.requested_num_threads + if resolved_num_threads is None: + resolved_num_threads = DEFAULT_NUM_THREADS_BY_N.get(n, best_num_threads(n)) + self._value_error_if(resolved_num_threads is None, f"No valid num_threads found for N={n}") + self._value_error_if(resolved_num_threads <= 0, f"num_threads must be positive, got {resolved_num_threads}") + self._value_error_if( + (resolved_num_threads % 32) != 0, + f"num_threads must be warp-aligned, got {resolved_num_threads}", + ) + self._value_error_if( + resolved_num_threads > 1024, + f"num_threads must not exceed the CUDA block size limit, got {resolved_num_threads}", + ) + + resolved_rows_per_cta = self.requested_rows_per_cta + if resolved_rows_per_cta is None: + resolved_rows_per_cta = pick_rows_per_cta(m) + + self._value_error_if(m % resolved_rows_per_cta != 0, f"M must be divisible by rows_per_cta, got M={m}, rows_per_cta={resolved_rows_per_cta}") + self._value_error_if(n % resolved_num_threads != 0, f"N={n} must be divisible by num_threads={resolved_num_threads}") + + ept = n // resolved_num_threads + self._value_error_if(ept < 8 or ept % 8 != 0, f"EPT={ept} must be >= 8 and divisible by 8") + + expected_num_ctas = m // resolved_rows_per_cta + self._check_tensor_shape(self.amax_desc, (expected_num_ctas,), "Amax") + + self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available") + major, minor = torch.cuda.get_device_capability(self.x_desc.device) + compute_capability = major * 10 + minor + self._runtime_error_if( + compute_capability < 100, + f"RmsNormRhtAmaxSm100 requires SM100+, found SM{compute_capability}", + ) + + self.num_threads = resolved_num_threads + self.rows_per_cta = resolved_rows_per_cta + self.n = n + self._is_supported = True + return True + + def compile(self) -> None: + self._ensure_support_checked() + if self._compiled_kernel is not None: + return + + kernel = RMSNormRHTAmaxKernel( + n=self.n, + num_threads=self.num_threads, + eps=self.eps, + rows_per_cta=self.rows_per_cta, + ) + + valid_m = cute.sym_int(divisibility=self.rows_per_cta) + + fake_x_tensor = self._make_fake_cute_compact_tensor( + dtype=self.x_desc.dtype, + shape=(valid_m, self.n), + stride_order=self.x_desc.stride_order, + dynamic_mode=None, + divisibility=self.rows_per_cta, + ) + fake_w_tensor = self._make_fake_cute_tensor_from_desc(self.w_desc, assumed_align=16) + fake_o_tensor = self._make_fake_cute_compact_tensor( + dtype=self.o_desc.dtype, + shape=(valid_m, self.n), + stride_order=self.o_desc.stride_order, + dynamic_mode=None, + divisibility=self.rows_per_cta, + ) + fake_num_ctas = cute.sym_int() + fake_amax_tensor = self._make_fake_cute_tensor( + dtype=self.amax_desc.dtype, + shape=(fake_num_ctas,), + stride=self.amax_desc.stride, + assumed_align=16, + ) + fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False) + + compiled_kernel = cute.compile( + kernel, + fake_x_tensor, + fake_w_tensor, + fake_o_tensor, + fake_amax_tensor, + Float32(self.eps), + fake_stream, + options="--enable-tvm-ffi", + ) + + def tensor_api( + x_tensor: torch.Tensor, + w_tensor: torch.Tensor, + o_tensor: torch.Tensor, + amax_tensor: torch.Tensor, + stream: cuda.CUstream, + ) -> None: + compiled_kernel( + x_tensor, + w_tensor, + o_tensor, + amax_tensor, + Float32(self.eps), + stream, + ) + + self._compiled_kernel = tensor_api + + def execute( + self, + x_tensor: torch.Tensor, + w_tensor: torch.Tensor, + o_tensor: torch.Tensor, + amax_tensor: torch.Tensor, + current_stream: Optional[cuda.CUstream] = None, + ) -> None: + self._runtime_error_if(self._compiled_kernel is None, "RmsNormRhtAmaxSm100 kernel not compiled; call compile() first") + + x_tensor = self._unpad_tensor_to_ndim(x_tensor, 2, "x_tensor") + w_tensor = self._unpad_tensor_to_ndim(w_tensor, 1, "w_tensor") + o_tensor = self._unpad_tensor_to_ndim(o_tensor, 2, "o_tensor") + amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax_tensor") + + if current_stream is None: + current_stream = cuda.CUstream(torch.cuda.current_stream(x_tensor.device).cuda_stream) + + self._compiled_kernel( + x_tensor=x_tensor, + w_tensor=w_tensor, + o_tensor=o_tensor, + amax_tensor=amax_tensor, + stream=current_stream, + ) + + +_logger = logging.getLogger(__name__) +_cache_of_RmsNormRhtAmaxSm100Objects = {} + + +def rmsnorm_rht_amax_wrapper_sm100( + x_tensor: torch.Tensor, + w_tensor: torch.Tensor, + eps: float = 1e-5, + num_threads: Optional[int] = None, + rows_per_cta: Optional[int] = None, + current_stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + """High-level wrapper for the RMSNorm + RHT + per-CTA amax kernel.""" + + x_tensor = x_tensor.squeeze(-1) if x_tensor.ndim == 3 and x_tensor.shape[-1] == 1 else x_tensor + w_tensor = w_tensor.squeeze(-1) if w_tensor.ndim == 2 and w_tensor.shape[-1] == 1 else w_tensor + + m, n = x_tensor.shape + resolved_num_threads = num_threads if num_threads is not None else DEFAULT_NUM_THREADS_BY_N.get(n, best_num_threads(n)) + if resolved_num_threads is None: + raise ValueError(f"No valid num_threads found for N={n}") + resolved_rows_per_cta = rows_per_cta if rows_per_cta is not None else pick_rows_per_cta(m) + if m % resolved_rows_per_cta != 0: + raise ValueError(f"M must be divisible by rows_per_cta, got M={m}, rows_per_cta={resolved_rows_per_cta}") + + o_tensor = torch.empty_like(x_tensor) + amax_tensor = torch.full((m // resolved_rows_per_cta,), float("-inf"), dtype=torch.float32, device=x_tensor.device) + + cache_key = ( + n, + x_tensor.dtype, + w_tensor.dtype, + o_tensor.dtype, + tuple(x_tensor.stride()), + tuple(w_tensor.stride()), + tuple(o_tensor.stride()), + eps, + resolved_num_threads, + resolved_rows_per_cta, + ) + + if cache_key in _cache_of_RmsNormRhtAmaxSm100Objects: + api = _cache_of_RmsNormRhtAmaxSm100Objects[cache_key] + else: + api = RmsNormRhtAmaxSm100( + sample_x=x_tensor, + sample_w=w_tensor, + sample_o=o_tensor, + sample_amax=amax_tensor, + eps=eps, + num_threads=resolved_num_threads, + rows_per_cta=resolved_rows_per_cta, + ) + assert api.check_support(), "Unsupported configuration" + api.compile() + _cache_of_RmsNormRhtAmaxSm100Objects[cache_key] = api + + api.execute( + x_tensor=x_tensor, + w_tensor=w_tensor, + o_tensor=o_tensor, + amax_tensor=amax_tensor, + current_stream=current_stream, + ) + + return TupleDict(o_tensor=o_tensor, amax_tensor=amax_tensor) diff --git a/python/cudnn/rmsnorm_rht_amax/kernel.py b/python/cudnn/rmsnorm_rht_amax/kernel.py new file mode 100644 index 00000000..2445cae9 --- /dev/null +++ b/python/cudnn/rmsnorm_rht_amax/kernel.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""CUTE DSL kernel for fused RMSNorm + RHT + per-CTA amax.""" + +import math +import operator + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass import Float32, Int32 +from cutlass._mlir.dialects import llvm +from cutlass.cute.arch import shuffle_sync_bfly +from cutlass.cutlass_dsl import T, dsl_user_op + + +@dsl_user_op +def fabs_f32(val, *, loc=None, ip=None): + val_ir = val.ir_value(loc=loc, ip=ip) + result = llvm.inline_asm( + T.f32(), + [val_ir], + "abs.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Float32(result) + + +@dsl_user_op +def fmax_f32(a, b, *, loc=None, ip=None): + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + result = llvm.inline_asm( + T.f32(), + [a_ir, b_ir], + "max.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Float32(result) + + +@dsl_user_op +def redux_sync_max_f32(val, *, loc=None, ip=None): + val_ir = val.ir_value(loc=loc, ip=ip) + result = llvm.inline_asm( + T.f32(), + [val_ir], + "redux.sync.max.f32 $0, $1, 0xffffffff;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return Float32(result) + + +class RMSNormRHTAmaxKernel: + """Fused RMSNorm + block-diagonal Hadamard + running per-CTA amax.""" + + COPY_BITS = 128 + HAD_BLOCK = 16 + + def __init__(self, n, num_threads=256, eps=1e-5, rows_per_cta=8): + self.n = n + self.num_threads = num_threads + self.eps = eps + self.rows_per_cta = rows_per_cta + self.vec_size = self.COPY_BITS // 16 + self.ept = n // num_threads + + assert n % num_threads == 0, f"N={n} must be divisible by num_threads={num_threads}" + assert self.ept % self.vec_size == 0, f"EPT={self.ept} must be a multiple of vec_size={self.vec_size}" + assert self.ept >= self.vec_size, f"EPT={self.ept} must be >= vec_size={self.vec_size}" + + self.num_vec_blocks = self.ept // self.vec_size + self.warps_per_row = num_threads // 32 + self.inv_sqrt_had = 1.0 / math.sqrt(self.HAD_BLOCK) + self.num_intra_stages = int(math.log2(self.vec_size)) + self.num_cross_stages = 1 + + self.tv_shape = ((num_threads, 1), (self.vec_size, self.num_vec_blocks)) + self.tv_stride = ((self.vec_size, 1), (1, self.vec_size * num_threads)) + self.tiler_mn = (1, n) + + tile_bytes = n * 2 + reduce_bytes = self.warps_per_row * 4 + amax_bytes = self.warps_per_row * 4 + self.smem_bytes = tile_bytes + reduce_bytes + amax_bytes + 128 + + self.intra_butterfly_pairs = [] + for stage in range(self.num_intra_stages): + delta = 1 << stage + pairs = [] + for pair_idx in range(self.vec_size // 2): + i_idx = (pair_idx // delta) * 2 * delta + (pair_idx % delta) + j_idx = i_idx + delta + pairs.append((i_idx, j_idx)) + self.intra_butterfly_pairs.append(pairs) + + @cute.kernel + def kernel(self, m_x: cute.Tensor, m_w: cute.Tensor, m_o: cute.Tensor, m_amax: cute.Tensor, eps: Float32, tv_layout: cute.Layout, tiler_mn: cute.Shape): + cfg = self + tid = cute.arch.thread_idx()[0] + bid = cute.arch.block_idx()[0] + inv_sqrt_had = cutlass.Float32(cfg.inv_sqrt_had) + + smem = utils.SmemAllocator() + s_x = smem.allocate_tensor( + cutlass.BFloat16, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer = smem.allocate_tensor(Float32, cute.make_layout((1, cfg.warps_per_row)), byte_alignment=4) + amax_buffer = smem.allocate_tensor(Float32, cute.make_layout((1, cfg.warps_per_row)), byte_alignment=4) + + copy_atom_g2s = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS) + copy_atom_load_w = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS) + copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS) + + tiled_copy_load = cute.make_tiled_copy(copy_atom_g2s, tv_layout, tiler_mn) + tiled_copy_w = cute.make_tiled_copy(copy_atom_load_w, tv_layout, tiler_mn) + tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn) + + thr_load = tiled_copy_load.get_slice(tid) + thr_w = tiled_copy_w.get_slice(tid) + thr_store = tiled_copy_store.get_slice(tid) + + t_xs_x = thr_load.partition_D(s_x) + + m_w_layout = cute.prepend(m_w.layout, cute.make_layout((1,), stride=(0,))) + m_w_2d = cute.make_tensor(m_w.iterator, m_w_layout) + g_w = cute.local_tile(m_w_2d, tiler_mn, (0, 0)) + t_wg_w = thr_w.partition_S(g_w) + t_wr_w = cute.make_fragment_like(t_wg_w) + cute.copy(copy_atom_load_w, t_wg_w, t_wr_w) + t_xr_w = thr_load.retile(t_wr_w) + + row_base = bid * cfg.rows_per_cta + g_x_first = cute.local_tile(m_x, tiler_mn, (row_base, 0)) + t_xg_x_first = thr_load.partition_S(g_x_first) + t_xr_x = cute.make_fragment_like(t_xg_x_first) + + cute.copy(copy_atom_g2s, t_xg_x_first, t_xs_x) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + reg = cute.make_rmem_tensor(cute.make_layout((cfg.ept,)), cutlass.Float32) + lane_id = cute.arch.lane_idx() + warp_id = cute.arch.warp_idx() + running_max = cutlass.Float32(0.0) + + for row_idx in cutlass.range_constexpr(cfg.rows_per_cta): + cute.autovec_copy(t_xs_x, t_xr_x) + + if row_idx < cfg.rows_per_cta - 1: + g_x_next = cute.local_tile(m_x, tiler_mn, (row_base + (row_idx + 1), 0)) + t_xg_x_next = thr_load.partition_S(g_x_next) + cute.copy(copy_atom_g2s, t_xg_x_next, t_xs_x) + cute.arch.cp_async_commit_group() + + x = t_xr_x.load().to(Float32) + x_sq = x * x + local_sum = x_sq.reduce(cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0) + warp_sum = cute.arch.warp_reduction(local_sum, operator.add) + if lane_id == 0: + reduction_buffer[0, warp_id] = warp_sum + cute.arch.barrier() + + block_val = Float32(0.0) + if lane_id < cfg.warps_per_row: + block_val = reduction_buffer[0, lane_id] + sum_sq = cute.arch.warp_reduction(block_val, operator.add) + + mean_sq = sum_sq / cfg.n + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + w = t_xr_w.load().to(Float32) + y = x * rstd * w + + for elem_idx in cutlass.range_constexpr(cfg.ept): + reg[elem_idx] = y[elem_idx] + + for block_idx in cutlass.range_constexpr(cfg.num_vec_blocks): + block_offset = block_idx * cfg.vec_size + for stage_idx in cutlass.range_constexpr(cfg.num_intra_stages): + for pair_idx in cutlass.range_constexpr(cfg.vec_size // 2): + i_idx = block_offset + cfg.intra_butterfly_pairs[stage_idx][pair_idx][0] + j_idx = block_offset + cfg.intra_butterfly_pairs[stage_idx][pair_idx][1] + a_val = reg[i_idx] + b_val = reg[j_idx] + reg[i_idx] = a_val + b_val + reg[j_idx] = a_val - b_val + + for cross_stage in cutlass.range_constexpr(cfg.num_cross_stages): + xor_mask = cutlass.Int32(1 << cross_stage) + is_lower = (tid & xor_mask) == cutlass.Int32(0) + for elem_idx in cutlass.range_constexpr(cfg.ept): + partner = shuffle_sync_bfly(reg[elem_idx], offset=xor_mask) + if is_lower: + reg[elem_idx] = reg[elem_idx] + partner + else: + reg[elem_idx] = partner - reg[elem_idx] + + for elem_idx in cutlass.range_constexpr(cfg.ept): + scaled = reg[elem_idx] * inv_sqrt_had + abs_val = fabs_f32(scaled) + running_max = fmax_f32(running_max, abs_val) + t_xr_x[elem_idx] = scaled.to(cutlass.BFloat16) + + g_o_r = cute.local_tile(m_o, tiler_mn, (row_base + row_idx, 0)) + t_xg_o_r = thr_store.partition_D(g_o_r) + cute.copy(copy_atom_store, t_xr_x, t_xg_o_r) + + if row_idx < cfg.rows_per_cta - 1: + cute.arch.cp_async_wait_group(0) + + warp_max = redux_sync_max_f32(running_max) + if lane_id == 0: + amax_buffer[0, warp_id] = warp_max + cute.arch.barrier() + + amax_val = cutlass.Float32(0.0) + if lane_id < cfg.warps_per_row: + amax_val = amax_buffer[0, lane_id] + cta_max = redux_sync_max_f32(amax_val) + if tid == cutlass.Int32(0): + m_amax[bid] = cta_max + + @cute.jit + def __call__(self, x_tensor: cute.Tensor, w_tensor: cute.Tensor, o_tensor: cute.Tensor, amax_tensor: cute.Tensor, eps: Float32, stream: cuda.CUstream): + m = x_tensor.shape[0] + num_ctas = m // self.rows_per_cta + tv_layout = cute.make_layout(self.tv_shape, stride=self.tv_stride) + self.kernel(x_tensor, w_tensor, o_tensor, amax_tensor, eps, tv_layout, self.tiler_mn).launch( + grid=(num_ctas, 1, 1), + block=(self.num_threads, 1, 1), + smem=self.smem_bytes, + stream=stream, + ) diff --git a/python/cudnn/sdpa/bwd/api.py b/python/cudnn/sdpa/bwd/api.py index 125f79d7..bd99fe3e 100644 --- a/python/cudnn/sdpa/bwd/api.py +++ b/python/cudnn/sdpa/bwd/api.py @@ -76,7 +76,7 @@ def __init__( super().__init__() self._kernel = BlackwellFusedMultiHeadAttentionBackward - self._logger.warning("SdpabwdSm100D256 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") if sample_cum_seqlen_q is not None: diff --git a/python/cudnn/sdpa/fwd/api.py b/python/cudnn/sdpa/fwd/api.py index 421d8745..90f4c76d 100644 --- a/python/cudnn/sdpa/fwd/api.py +++ b/python/cudnn/sdpa/fwd/api.py @@ -44,7 +44,7 @@ def __init__( super().__init__() self._kernel = BlackwellFusedMultiHeadAttentionForward - self._logger.warning("SdpafwdSm100D256 is an experimental API") + self._warn_experimental_api() self._logger.debug("Entering __init__") if sample_cum_seqlen_q is not None: diff --git a/python/properties.cpp b/python/properties.cpp index e81dd480..c3f6c84f 100644 --- a/python/properties.cpp +++ b/python/properties.cpp @@ -148,6 +148,14 @@ init_properties(py::module_& m) { .value("F16x16", cudnn_frontend::TensorReordering_t::F16x16) .value("F8_128x4", cudnn_frontend::TensorReordering_t::F8_128x4); + py::enum_(m, "scalar_type") + .value("RUNTIME_PARAM", cudnn_frontend::graph::ScalarType::RUNTIME_PARAM) + .value("COMPILE_TIME_CONST", cudnn_frontend::graph::ScalarType::COMPILE_TIME_CONST); + + py::enum_(m, "reshape_mode") + .value("VIEW_ONLY", cudnn_frontend::ReshapeMode_t::VIEW_ONLY) + .value("LOGICAL", cudnn_frontend::ReshapeMode_t::LOGICAL); + py::class_>( m, "tensor") .def(py::init<>()) @@ -171,6 +179,7 @@ init_properties(py::module_& m) { py::return_value_policy::reference) // NOTICE THATS ITS JUST ANOTHER NAME FOR SET_IS_VIRTUAL .def("get_is_pass_by_value", &cudnn_frontend::graph::Tensor_attributes::get_is_pass_by_value) .def("set_is_pass_by_value", &cudnn_frontend::graph::Tensor_attributes::set_is_pass_by_value) + .def("get_has_compile_time_constant", &cudnn_frontend::graph::Tensor_attributes::get_has_compile_time_constant) .def("get_uid", &cudnn_frontend::graph::Tensor_attributes::get_uid) .def("set_uid", &cudnn_frontend::graph::Tensor_attributes::set_uid) .def("get_reordering_type", &cudnn_frontend::graph::Tensor_attributes::get_reordering_type) diff --git a/python/pycudnn.cpp b/python/pycudnn.cpp index a1895607..9655f691 100644 --- a/python/pycudnn.cpp +++ b/python/pycudnn.cpp @@ -92,6 +92,70 @@ PYBIND11_MODULE(_compiled_module, m) { m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn); py::register_exception(m, "cudnnGraphNotSupportedError"); + +#if CUDNN_VERSION >= 92200 + m.def("causal_conv1d_forward", + [](std::intptr_t stream, + std::intptr_t x_ptr, + std::intptr_t weight_ptr, + std::intptr_t bias_ptr, + std::intptr_t out_ptr, + int batch, + int dim, + int seq_len, + int kernel_size, + int data_type, + int activation) { + auto status = detail::causal_conv1d_forward(reinterpret_cast(stream), + reinterpret_cast(x_ptr), + reinterpret_cast(weight_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(out_ptr), + batch, + dim, + seq_len, + kernel_size, + static_cast(data_type), + static_cast(activation)); + if (status != 0) + throw std::runtime_error("cudnnCausalConv1dForward failed with status " + std::to_string(status)); + }); + + m.def("causal_conv1d_backward", + [](std::intptr_t stream, + std::intptr_t x_ptr, + std::intptr_t weight_ptr, + std::intptr_t bias_ptr, + std::intptr_t dy_ptr, + std::intptr_t dx_ptr, + std::intptr_t dweight_ptr, + std::intptr_t dbias_ptr, + int batch, + int dim, + int seq_len, + int kernel_size, + int data_type, + int dw_data_type, + int activation) { + auto status = detail::causal_conv1d_backward(reinterpret_cast(stream), + reinterpret_cast(x_ptr), + reinterpret_cast(weight_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(dy_ptr), + reinterpret_cast(dx_ptr), + reinterpret_cast(dweight_ptr), + reinterpret_cast(dbias_ptr), + batch, + dim, + seq_len, + kernel_size, + static_cast(data_type), + static_cast(dw_data_type), + static_cast(activation)); + if (status != 0) + throw std::runtime_error("cudnnCausalConv1dBackward failed with status " + std::to_string(status)); + }); +#endif } } // namespace python_bindings diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index 772f64af..721c9448 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -185,16 +186,21 @@ PyGraph::slice(std::shared_ptr& input, auto input_dim = input->get_dim(); std::vector> start_end_indices; + std::vector steps; + steps.reserve(slices.size()); for (size_t i = 0; i < slices.size(); ++i) { int64_t start, stop, step, length; if (!slices[i].compute(input_dim[i], &start, &stop, &step, &length)) { throw std::runtime_error("Invalid slice"); } + CUDNN_FRONTEND_UNUSED(length); start_end_indices.push_back({start, stop}); + steps.push_back(step); } auto attributes = cudnn_frontend::graph::Slice_attributes() .set_slices(start_end_indices) + .set_strides(steps) .set_compute_data_type(compute_data_type) .set_name(name); @@ -350,13 +356,60 @@ PyGraph::reduction(std::shared_ptr& in } std::shared_ptr -PyGraph::reshape(std::shared_ptr& input, std::string const& name) { - auto attributes = cudnn_frontend::graph::Reshape_attributes().set_name(name); +PyGraph::reshape(std::shared_ptr& input, + std::string const& name, + cudnn_frontend::ReshapeMode_t reshape_mode) { + auto attributes = cudnn_frontend::graph::Reshape_attributes().set_name(name).set_reshape_mode(reshape_mode); auto OUT_0 = graph->reshape(input, attributes); return OUT_0; } +std::shared_ptr +PyGraph::transpose(std::shared_ptr& input, + std::vector const& permutation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Transpose_attributes() + .set_name(name) + .set_permutation(permutation) + .set_compute_data_type(compute_data_type); + + return graph->transpose(input, attributes); +} + +std::shared_ptr +PyGraph::concatenate(std::vector> inputs, + int64_t axis, + std::optional in_place_index, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Concatenate_attributes().set_axis(axis).set_name(name); + if (in_place_index.has_value()) { + attributes.set_in_place_index(in_place_index.value()); + } + return graph->concatenate(std::move(inputs), attributes); +} + +std::shared_ptr +PyGraph::tensor_scalar(float const& value, cudnn_frontend::graph::ScalarType scalar_type) { + return graph->tensor(value, scalar_type); +} + +std::shared_ptr +PyGraph::tensor_scalar(double const& value, cudnn_frontend::graph::ScalarType scalar_type) { + return graph->tensor(value, scalar_type); +} + +std::shared_ptr +PyGraph::tensor_scalar(int32_t const& value, cudnn_frontend::graph::ScalarType scalar_type) { + return graph->tensor(value, scalar_type); +} + +std::shared_ptr +PyGraph::tensor_scalar(int64_t const& value, cudnn_frontend::graph::ScalarType scalar_type) { + return graph->tensor(value, scalar_type); +} + std::shared_ptr PyGraph::moe_grouped_matmul(std::shared_ptr& token, std::shared_ptr& weight, @@ -772,6 +825,8 @@ init_pygraph_submodule(py::module_& m) { Args: input (cudnn_tensor): The input tensor to be sliced. slices (List[slice]): A list of Python slice objects, one for each dimension. + Per-axis step comes from each slice's ``step`` (default 1), after normalization + for the tensor shape (same semantics as indexing a sequence of that length). compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): A name for the slice operation. @@ -781,7 +836,7 @@ init_pygraph_submodule(py::module_& m) { Example: >>> input_tensor = graph.tensor([4, 8, 16]) - >>> sliced_tensor = graph.slice(input_tensor, [slice(0, 2), slice(1, 5), slice(0, 16)]) + >>> sliced_tensor = graph.slice(input_tensor, [slice(0, 2), slice(1, 8, 2), slice(0, 16)]) )pbdoc") .def( "conv_fprop", @@ -1014,6 +1069,7 @@ init_pygraph_submodule(py::module_& m) { &PyGraph::reshape, py::arg("input"), py::arg_v("name", ""), + py::arg_v("reshape_mode", cudnn_frontend::ReshapeMode_t::VIEW_ONLY), R"pbdoc( Reshape an input tensor to other dimensions without changing the actual memory layout. These dimensions to reshape to are inferred from output tensor shape. @@ -1021,10 +1077,73 @@ init_pygraph_submodule(py::module_& m) { Args: input (cudnn_tensor): The input tensor. name (Optional[str]): A name for the operation to be performed. + reshape_mode (cudnn.reshape_mode): VIEW_ONLY (default) or LOGICAL for lexicographic logical reshapes. Returns: cudnn_tensor: The result of reshape operation. Please set the dims for the output tensor. )pbdoc") + .def("transpose", + &PyGraph::transpose, + py::arg("input"), + py::arg("permutation"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Permute tensor dimensions using a permutation vector (output axis i reads input axis permutation[i]). + + Args: + input (cudnn_tensor): The input tensor. + permutation (List[int]): Permutation of axis indices. + compute_data_type (Optional[cudnn.data_type]): Optional compute type; default NOT_SET. + name (Optional[str]): Operation name. + + Returns: + cudnn_tensor: Transposed tensor (dims/strides inferred when not set). + )pbdoc") + .def("concatenate", + &PyGraph::concatenate, + py::arg("inputs"), + py::arg("axis"), + py::arg_v("in_place_index", std::optional{}), + py::arg_v("name", ""), + R"pbdoc( + Concatenate tensors along an axis. + + Args: + inputs (List[cudnn_tensor]): Tensors to concatenate. + axis (int): Concatenation axis. + in_place_index (Optional[int]): When set, optional in-place concat per backend semantics. + name (Optional[str]): Operation name. + + Returns: + cudnn_tensor: Concatenated output tensor. + )pbdoc") + .def("tensor_scalar", + py::overload_cast(&PyGraph::tensor_scalar), + py::arg("value"), + py::arg("scalar_type"), + R"pbdoc( + Create a rank-1 scalar tensor from a Python float, marked runtime or compile-time. + + Args: + value (float): Scalar value. + scalar_type (cudnn.scalar_type): RUNTIME_PARAM or COMPILE_TIME_CONST. + + Returns: + cudnn_tensor: Scalar tensor (set dim/stride/name as needed for your graph). + )pbdoc") + .def("tensor_scalar", + py::overload_cast(&PyGraph::tensor_scalar), + py::arg("value"), + py::arg("scalar_type")) + .def("tensor_scalar", + py::overload_cast(&PyGraph::tensor_scalar), + py::arg("value"), + py::arg("scalar_type")) + .def("tensor_scalar", + py::overload_cast(&PyGraph::tensor_scalar), + py::arg("value"), + py::arg("scalar_type")) .def("moe_grouped_matmul", &PyGraph::moe_grouped_matmul, py::arg("token"), diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index fd617b62..ae25927f 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -1,3 +1,4 @@ +#include #include #include #include @@ -335,7 +336,33 @@ class PyGraph { std::string const& name); std::shared_ptr - reshape(std::shared_ptr& input, std::string const& name); + reshape(std::shared_ptr& input, + std::string const& name, + cudnn_frontend::ReshapeMode_t reshape_mode = cudnn_frontend::ReshapeMode_t::VIEW_ONLY); + + std::shared_ptr + transpose(std::shared_ptr& input, + std::vector const& permutation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + concatenate(std::vector> inputs, + int64_t axis, + std::optional in_place_index, + std::string const& name); + + std::shared_ptr + tensor_scalar(float const& value, cudnn_frontend::graph::ScalarType scalar_type); + + std::shared_ptr + tensor_scalar(double const& value, cudnn_frontend::graph::ScalarType scalar_type); + + std::shared_ptr + tensor_scalar(int32_t const& value, cudnn_frontend::graph::ScalarType scalar_type); + + std::shared_ptr + tensor_scalar(int64_t const& value, cudnn_frontend::graph::ScalarType scalar_type); std::vector> rmsnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, @@ -495,7 +522,8 @@ class PyGraph { cudnn_frontend::DataType_t const& compute_data_type, std::string const& name, py::object const& generate_stats, - std::shared_ptr sink_token); + std::shared_ptr sink_token, + bool const unfuse_fma); // return [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] // dSink_token is an optional output set via set_dsink_token() attribute @@ -781,4 +809,4 @@ class PyGraph { bool const unfuse_fma = false); }; -} // namespace cudnn_frontend::python_bindings \ No newline at end of file +} // namespace cudnn_frontend::python_bindings diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index ae765e17..13ddaa3d 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -639,7 +639,8 @@ PyGraph::sdpa_mxfp8(std::shared_ptr& q cudnn_frontend::DataType_t const& compute_data_type, std::string const& name, py::object const& generate_stats, - std::shared_ptr sink_token) { + std::shared_ptr sink_token, + bool const unfuse_fma) { auto attributes = cudnn_frontend::graph::SDPA_fp8_attributes().set_name(name).set_compute_data_type(compute_data_type); @@ -718,6 +719,7 @@ PyGraph::sdpa_mxfp8(std::shared_ptr& q if (sink_token) { attributes.set_sink_token(sink_token); } + attributes.set_unfuse_fma(unfuse_fma); // Call the MXFP8 6-parameter overload of sdpa_fp8 // This uses block scale factors (E8M0 + F8_128x4) instead of regular scalar descales @@ -1310,6 +1312,7 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg_v("name", ""), py::arg("generate_stats"), py::arg_v("sink_token", nullptr), + py::arg_v("unfuse_fma", false), R"pbdoc( Perform MXFP8 (Microscaling FP8) scaled dot product attention. @@ -1348,6 +1351,7 @@ init_pygraph_sdpa_submodule(py::class_& m) { name (Optional[str]): The name of the operation. generate_stats (bool): If true, compute and output softmax stats (required for training). sink_token (Optional[cudnn_tensor]): Sink token bias for streaming attention. Shape is (1, h_q, 1, 1), type is float32. Default is None. + unfuse_fma (Optional[bool]): For SM100: use unfused __fmul_rn + __fadd_rn instead of ffma2 in softmax. Default is False. Returns: o (cudnn_tensor): The output data. diff --git a/samples/cpp/CMakeLists.txt b/samples/cpp/CMakeLists.txt index 3c569b08..c64c013e 100644 --- a/samples/cpp/CMakeLists.txt +++ b/samples/cpp/CMakeLists.txt @@ -63,6 +63,16 @@ add_executable( misc/sm_carveout.cpp misc/cudagraphs.cpp misc/deviceless_aot_compilation.cpp + misc/compile_time_constant_example.cpp + + membound/transpose.cpp + membound/reshape.cpp + membound/slice.cpp + membound/concat.cpp + membound/membound_fusion.cpp + membound/boolean_fusion.cpp + + causal_conv1d/causal_conv1d.cpp ) # target flags @@ -100,6 +110,10 @@ target_link_libraries( CUDA::nvrtc ) +if(TARGET CUDNN::cudnn_ext) + target_link_libraries(samples PRIVATE CUDNN::cudnn_ext) +endif() + # target cmake properties set_target_properties( samples PROPERTIES diff --git a/samples/cpp/causal_conv1d/causal_conv1d.cpp b/samples/cpp/causal_conv1d/causal_conv1d.cpp new file mode 100644 index 00000000..d836032e --- /dev/null +++ b/samples/cpp/causal_conv1d/causal_conv1d.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +#if defined(CUDNN_SUBQUADRATIC_OPS_H_) || __has_include() +#if !defined(CUDNN_SUBQUADRATIC_OPS_H_) +#include +#endif +#define HAS_SUBQUADRATIC_OPS 1 +#else +#define HAS_SUBQUADRATIC_OPS 0 +#endif + +TEST_CASE("Causal conv1d forward", "[causal_conv1d][forward]") { +#if !HAS_SUBQUADRATIC_OPS + SKIP("cudnn_subquadratic_ops.h not available"); +#else + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cudnn version"); + } + + int batch = 2; + int dim = 64; + int seq_len = 512; + int kernel_size = 4; + + cudaStream_t stream = nullptr; + + Surface x_tensor(batch * dim * seq_len, false); + Surface w_tensor(dim * kernel_size, false); + Surface bias_tensor(dim, false); + Surface y_tensor(batch * dim * seq_len, false); + + CUDNN_CHECK(cudnnCausalConv1dForward(stream, + x_tensor.devPtr, + w_tensor.devPtr, + bias_tensor.devPtr, + y_tensor.devPtr, + batch, + dim, + seq_len, + kernel_size, + CUDNN_DATA_HALF, + CUDNN_CAUSAL_CONV1D_ACTIVATION_SILU)); + + CUDA_CHECK(cudaDeviceSynchronize()); +#endif +} + +TEST_CASE("Causal conv1d backward", "[causal_conv1d][backward]") { +#if !HAS_SUBQUADRATIC_OPS + SKIP("cudnn_subquadratic_ops.h not available"); +#else + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cudnn version"); + } + + int batch = 2; + int dim = 64; + int seq_len = 512; + int kernel_size = 4; + + cudaStream_t stream = nullptr; + + Surface x_tensor(batch * dim * seq_len, false); + Surface w_tensor(dim * kernel_size, false); + Surface bias_tensor(dim, false); + Surface dy_tensor(batch * dim * seq_len, false); + Surface dx_tensor(batch * dim * seq_len, false); + Surface dw_tensor(dim * kernel_size, 0.0f); + Surface dbias_tensor(dim, 0.0f); + + CUDNN_CHECK(cudnnCausalConv1dBackward(stream, + x_tensor.devPtr, + w_tensor.devPtr, + bias_tensor.devPtr, + dy_tensor.devPtr, + dx_tensor.devPtr, + dw_tensor.devPtr, + dbias_tensor.devPtr, + batch, + dim, + seq_len, + kernel_size, + CUDNN_DATA_HALF, + CUDNN_DATA_FLOAT, + CUDNN_CAUSAL_CONV1D_ACTIVATION_SILU)); + + CUDA_CHECK(cudaDeviceSynchronize()); +#endif +} diff --git a/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp b/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp index 5d4ab4d1..7908839a 100644 --- a/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp +++ b/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp @@ -71,6 +71,8 @@ struct TestParams { } }; +// Note: For nvfp4/mxfp8 block scale matmul, scale factors need to be organized in 128x4 layout for best performance. +// Details of the layout can be found https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout TEST_CASE("Blackwell Block Scale Matmul", "[matmul][graph][FP4]") { #if (CUDNN_VERSION < 90700) SKIP("Matmul with block scaling is not supported in cudnn versions prior to 9.7.0"); @@ -418,8 +420,8 @@ TEST_CASE("Blackwell Block Scale Matmul Quantize", "[matmul][graph][FP4]") { } auto test_params = GENERATE(TestParams(1, - 256, - 256, + 137, + 272, 256, 16, cudnn_frontend::DataType_t::FP4_E2M1, @@ -521,6 +523,8 @@ TEST_CASE("Blackwell Block Scale Matmul Quantize", "[matmul][graph][FP4]") { tensor_d->set_output(true).set_data_type(datatype_d); block_scale->set_output(true) + .set_dim({b, block_scale_dim_out_m, block_scale_dim_out_n}) + .set_stride({block_scale_dim_out_m * block_scale_dim_out_n, block_scale_dim_out_n, 1}) .set_data_type(datatype_block_scale) .set_reordering_type(cudnn_frontend::TensorReordering_t::F8_128x4); diff --git a/samples/cpp/membound/boolean_fusion.cpp b/samples/cpp/membound/boolean_fusion.cpp new file mode 100644 index 00000000..936addf9 --- /dev/null +++ b/samples/cpp/membound/boolean_fusion.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Boolean CMP_GT and LOGICAL_AND fusion", "[membound][boolean][pointwise][graph]") { + namespace fe = cudnn_frontend; + +#if (CUDNN_VERSION < 92200) + SKIP("Boolean fusion sample requires cuDNN 9.22.0 or newer."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Boolean fusion sample requires cuDNN backend 9.22.0 or newer at runtime."); + } + if (!is_blackwell_arch()) { + SKIP("Boolean fusion requires Blackwell (SM100+) architecture."); + } + + constexpr int64_t d0 = 4, d1 = 8, d2 = 16; + constexpr int64_t numel = d0 * d1 * d2; + + // Row-major strides for a 3D tensor + constexpr int64_t s0 = d1 * d2; + constexpr int64_t s1 = d2; + constexpr int64_t s2 = 1; + + fe::graph::Graph graph{}; + graph.set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({d0, d1, d2}) + .set_stride({s0, s1, s2}) + .set_data_type(fe::DataType_t::HALF)); + + auto threshold = graph.tensor(fe::graph::Tensor_attributes() + .set_name("threshold") + .set_dim({d0, d1, d2}) + .set_stride({s0, s1, s2}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto B = graph.tensor(fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({d0, d1, d2}) + .set_stride({s0, s1, s2}) + .set_data_type(fe::DataType_t::BOOLEAN)); + + auto after_cmp = graph.pointwise(X, + threshold, + fe::graph::Pointwise_attributes() + .set_name("cmp_gt") + .set_mode(fe::PointwiseMode_t::CMP_GT) + .set_compute_data_type(fe::DataType_t::FLOAT)); + after_cmp->set_data_type(fe::DataType_t::BOOLEAN); + + auto Y = graph.pointwise(after_cmp, + B, + fe::graph::Pointwise_attributes() + .set_name("logical_and") + .set_mode(fe::PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(fe::DataType_t::BOOLEAN)); + Y->set_output(true).set_data_type(fe::DataType_t::BOOLEAN); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(numel); + Surface threshold_gpu(numel); + Surface B_gpu(numel); + Surface Y_gpu(numel, 0); + + std::unordered_map, void*> variant_pack = { + {X, X_gpu.devPtr}, {threshold, threshold_gpu.devPtr}, {B, B_gpu.devPtr}, {Y, Y_gpu.devPtr}}; + + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + // Verify against CPU reference + std::vector x_host(numel); + std::vector thresh_host(numel); + std::vector b_host(numel); + std::vector y_host(numel); + + CUDA_CHECK(cudaMemcpy(x_host.data(), X_gpu.devPtr, numel * sizeof(half), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(thresh_host.data(), threshold_gpu.devPtr, numel * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(b_host.data(), B_gpu.devPtr, numel * sizeof(uint8_t), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(y_host.data(), Y_gpu.devPtr, numel * sizeof(uint8_t), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaDeviceSynchronize()); + + int mismatches = 0; + for (int64_t i = 0; i < numel; i++) { + uint8_t expected = (__half2float(x_host[i]) > thresh_host[i]) && b_host[i] ? 1 : 0; + if (y_host[i] != expected) { + mismatches++; + } + } + REQUIRE(mismatches == 0); +} diff --git a/samples/cpp/membound/concat.cpp b/samples/cpp/membound/concat.cpp new file mode 100644 index 00000000..b6b84769 --- /dev/null +++ b/samples/cpp/membound/concat.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +#include + +// Out-of-place concat only: do not set in_place_index on Concatenate_attributes (some fusion +// graphs used an in-place index for conv+concat; this sample concatenates into a new Y tensor). +TEST_CASE("Membound concatenate on channel axis (no in-place index)", "[membound][concat][graph]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cuDNN version"); + } + + int64_t const n = 2, c = 4, h = 8, w = 8; + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X0 = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X0") + .set_dim({n, c, h, w}) + .set_stride({c * h * w, 1, c * w, c}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto X1 = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X1") + .set_dim({n, c, h, w}) + .set_stride({c * h * w, 1, c * w, c}) + .set_data_type(fe::DataType_t::FLOAT)); + + std::vector> inputs = {X0, X1}; + auto concat_opts = fe::graph::Concatenate_attributes().set_name("concat").set_axis(1); // no set_in_place_index + + auto Y = graph.concatenate(inputs, concat_opts); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.check_support().is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X0_gpu(n * c * h * w); + Surface X1_gpu(n * c * h * w); + Surface Y_gpu(n * (2 * c) * h * w); + std::unordered_map, void*> variant_pack = { + {X0, X0_gpu.devPtr}, {X1, X1_gpu.devPtr}, {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/membound/membound_fusion.cpp b/samples/cpp/membound/membound_fusion.cpp new file mode 100644 index 00000000..12f98e9a --- /dev/null +++ b/samples/cpp/membound/membound_fusion.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Fusion reshape then ReLU", "[membound][fusion][reshape][graph]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes().set_name("X").set_dim({2, 8}).set_stride({8, 1}).set_data_type( + fe::DataType_t::FLOAT)); + + auto R = + graph.reshape(X, fe::graph::Reshape_attributes().set_name("rs").set_reshape_mode(fe::ReshapeMode_t::LOGICAL)); + R->set_dim({4, 4}).set_stride({4, 1}); + + auto Y = + graph.pointwise(R, fe::graph::Pointwise_attributes().set_name("relu").set_mode(fe::PointwiseMode_t::RELU_FWD)); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(2 * 8); + Surface Y_gpu(4 * 4); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} + +TEST_CASE("Fusion transpose then add bias tensor", "[membound][fusion][transpose][graph]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor( + fe::graph::Tensor_attributes().set_name("X").set_dim({2, 2, 4}).set_stride({8, 4, 1}).set_data_type( + fe::DataType_t::FLOAT)); + + auto T = graph.transpose( + X, + fe::graph::Transpose_attributes().set_name("perm").set_permutation({2, 0, 1}).set_compute_data_type( + fe::DataType_t::FLOAT)); + // T logical shape [4, 2, 2] matches permuted dims + + auto B = graph.tensor( + fe::graph::Tensor_attributes().set_name("B").set_dim({4, 2, 2}).set_stride({4, 2, 1}).set_data_type( + fe::DataType_t::FLOAT)); + + auto Y = graph.pointwise(T, + B, + fe::graph::Pointwise_attributes() + .set_name("add") + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(fe::DataType_t::FLOAT)); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(2 * 2 * 4); + Surface B_gpu(4 * 2 * 2); + Surface Y_gpu(4 * 2 * 2); + std::unordered_map, void*> variant_pack = { + {X, X_gpu.devPtr}, {B, B_gpu.devPtr}, {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/membound/reshape.cpp b/samples/cpp/membound/reshape.cpp new file mode 100644 index 00000000..763a0801 --- /dev/null +++ b/samples/cpp/membound/reshape.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +// Lexicographic (logical) reshape: same linear memory order as row-major traversal of the +// input shape is reinterpreted as row-major traversal of the output shape. Here (3,4,5) -> (6,10), +// 60 elements each. Mode ReshapeMode_t::LOGICAL selects the backend lexicographic reshape path +// (cuDNN 9.22+); attributes still describe the target layout on Y. +TEST_CASE("Membound reshape (3,4,5) to (6,10) lexicographic / LOGICAL mode", + "[membound][reshape][graph][lexicographic]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({3, 4, 5}) + .set_stride({20, 1, 4}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto reshape_attrs = + fe::graph::Reshape_attributes().set_name("lex_reshape").set_reshape_mode(fe::ReshapeMode_t::LOGICAL); + + auto Y = graph.reshape(X, reshape_attrs); + Y->set_dim({6, 10}).set_stride({10, 1}).set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(3 * 4 * 5); + Surface Y_gpu(6 * 10); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/membound/slice.cpp b/samples/cpp/membound/slice.cpp new file mode 100644 index 00000000..93c66caa --- /dev/null +++ b/samples/cpp/membound/slice.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Membound slice window with step", "[membound][slice][graph]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({4, 16, 8}) + .set_stride({16 * 8, 8, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto slice_attrs = + fe::graph::Slice_attributes().set_name("window").set_slices({{1, 3}, {4, 12}, {0, 8}}).set_strides({1, 2, 1}); + + auto Y = graph.slice(X, slice_attrs); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(4 * 16 * 8); + // Output: B 2, M (12-4)/2 = 4, K 8 -> 64 elements + Surface Y_gpu(2 * 4 * 8); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/membound/transpose.cpp b/samples/cpp/membound/transpose.cpp new file mode 100644 index 00000000..2f61f85c --- /dev/null +++ b/samples/cpp/membound/transpose.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Membound transpose permutes dims", "[membound][transpose][graph]") { + namespace fe = cudnn_frontend; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Membound graph samples require cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Membound graph samples require cuDNN backend 9.22.0 or newer at runtime."); + } + + fe::graph::Graph graph{}; + graph.set_io_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({2, 3, 4}) + .set_stride({12, 4, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto Y = graph.transpose( + X, + fe::graph::Transpose_attributes().set_name("permute").set_permutation({2, 0, 1}).set_compute_data_type( + fe::DataType_t::FLOAT)); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(2 * 3 * 4); + Surface Y_gpu(2 * 3 * 4); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/misc/compile_time_constant_example.cpp b/samples/cpp/misc/compile_time_constant_example.cpp new file mode 100644 index 00000000..e92cdc7b --- /dev/null +++ b/samples/cpp/misc/compile_time_constant_example.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Compile-time constant scalar multiply and add", "[compile_time_const][graph]") { + namespace fe = cudnn_frontend; + using fe::graph::ScalarType; + + if (!check_device_arch_newer_than("blackwell")) { + SKIP("TEST requires device blackwell or newer"); + } + +#if (CUDNN_VERSION < 92200) + SKIP("Compile-time constant graph sample requires cuDNN 9.22.0 or newer (compiled CUDNN_VERSION >= 92200)."); +#endif + if (cudnn_frontend::detail::get_backend_version() < 92200) { + SKIP("Compile-time constant graph sample requires cuDNN backend 9.22.0 or newer at runtime."); + } + + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cuDNN version"); + } + + constexpr int64_t N = 1024; + + fe::graph::Graph graph{}; + + auto X = graph.tensor(fe::graph::Tensor_attributes().set_name("X").set_dim({1, N}).set_stride({N, 1}).set_data_type( + fe::DataType_t::FLOAT)); + + auto scale = graph.tensor(2.5f, ScalarType::COMPILE_TIME_CONST); + scale->set_name("scale").set_dim({1, 1}).set_stride({1, 1}); + + auto scaled = graph.pointwise(X, + scale, + fe::graph::Pointwise_attributes() + .set_name("mul") + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(fe::DataType_t::FLOAT)); + scaled->set_name("scaled").set_data_type(fe::DataType_t::FLOAT); + + auto bias = graph.tensor(1.0f, ScalarType::COMPILE_TIME_CONST); + bias->set_name("bias").set_dim({1, 1}).set_stride({1, 1}); + + auto Y = graph.pointwise(scaled, + bias, + fe::graph::Pointwise_attributes() + .set_name("add") + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(fe::DataType_t::FLOAT)); + Y->set_name("Y").set_output(true).set_data_type(fe::DataType_t::FLOAT); + + REQUIRE(graph.validate().is_good()); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(N, false); + Surface Y_gpu(N, false); + + std::unordered_map, void *> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/misc/serialization.cpp b/samples/cpp/misc/serialization.cpp index 0f986036..2ddf46bb 100644 --- a/samples/cpp/misc/serialization.cpp +++ b/samples/cpp/misc/serialization.cpp @@ -193,8 +193,6 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { int64_t s_kv = 1024; // k and v tensor is padded to this seq length int64_t d = 128; // hidden dim - SKIP("BAN due to seg fault"); - // Mode of sdpa operation bool generate_stats = false; diff --git a/samples/python/60_causal_conv1d_forward.ipynb b/samples/python/60_causal_conv1d_forward.ipynb new file mode 100644 index 00000000..75573221 --- /dev/null +++ b/samples/python/60_causal_conv1d_forward.ipynb @@ -0,0 +1,499 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Causal Conv1D Forward Operation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook shows how to compute a causal depthwise 1D convolution forward operation using cuDNN.\n", + "\n", + "$$y = \\text{Activation}(\\text{CausalConv1D}(x, w) + b)$$\n", + "\n", + "where Activation is Identity or SiLU.\n", + "\n", + "Where the convolution uses left-only (causal) padding of size `kernel_size - 1`, and each channel is convolved independently (depthwise)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites and Setup\n", + "This notebook requires an NVIDIA GPU (Hopper or later recommended).\n", + "\n", + "**Environment setup** — make sure the cuDNN runtime library and the `cudnn` Python package are discoverable before launching the notebook:\n", + "\n", + "- **Option A – pip install:**\n", + " ```bash\n", + " pip install nvidia-cudnn-frontend\n", + " ```\n", + "- **Option B – set paths manually:**\n", + " ```bash\n", + " export LD_LIBRARY_PATH=/path/to/cudnn/lib:${LD_LIBRARY_PATH}\n", + " export PYTHONPATH=/path/to/cudnn_frontend/build_python:${PYTHONPATH}\n", + " ```\n", + "\n", + "Adjust the paths above to match your local build or installation directory." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:16.218874Z", + "iopub.status.busy": "2026-04-03T17:50:16.218579Z", + "iopub.status.idle": "2026-04-03T17:50:16.222536Z", + "shell.execute_reply": "2026-04-03T17:50:16.221675Z" + } + }, + "outputs": [], + "source": [ + "# !nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:16.225488Z", + "iopub.status.busy": "2026-04-03T17:50:16.225217Z", + "iopub.status.idle": "2026-04-03T17:50:16.228276Z", + "shell.execute_reply": "2026-04-03T17:50:16.227445Z" + } + }, + "outputs": [], + "source": [ + "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", + "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "We will perform a causal depthwise conv1d forward pass with:\n", + "\n", + "- batch size: 2\n", + "- dim (channels): 768\n", + "- sequence length: 4096\n", + "- kernel size: 4\n", + "- activation: SiLU\n", + "- data type: BFloat16\n", + "\n", + "We compare the cuDNN result against a PyTorch reference implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:16.231190Z", + "iopub.status.busy": "2026-04-03T17:50:16.230925Z", + "iopub.status.idle": "2026-04-03T17:50:20.115117Z", + "shell.execute_reply": "2026-04-03T17:50:20.114082Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDNN backend version: 92200\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import cudnn\n", + "\n", + "print(\"cuDNN backend version:\", cudnn.backend_version())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reference Implementation\n", + "\n", + "A simple PyTorch reference for causal depthwise conv1d + SiLU." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:20.138732Z", + "iopub.status.busy": "2026-04-03T17:50:20.138458Z", + "iopub.status.idle": "2026-04-03T17:50:20.142364Z", + "shell.execute_reply": "2026-04-03T17:50:20.141618Z" + } + }, + "outputs": [], + "source": [ + "def causal_conv1d_ref(x, weight, bias, activation=\"silu\"):\n", + " \"\"\"Reference: causal depthwise conv1d using PyTorch ops.\n", + "\n", + " Args:\n", + " x: (batch, dim, seq_len)\n", + " weight: (dim, kernel_size)\n", + " bias: (dim,)\n", + " activation: 'silu' or 'identity'\n", + "\n", + " Returns:\n", + " y: (batch, dim, seq_len)\n", + " \"\"\"\n", + " batch, dim, seq_len = x.shape\n", + " kernel_size = weight.shape[1]\n", + "\n", + " # Causal padding: (kernel_size - 1) on the left, 0 on the right\n", + " x_padded = F.pad(x, (kernel_size - 1, 0))\n", + "\n", + " # Depthwise conv1d: groups=dim\n", + " # weight needs shape (dim, 1, kernel_size) for groups=dim\n", + " w = weight.unsqueeze(1) # (dim, 1, kernel_size)\n", + " y = F.conv1d(x_padded, w, bias=bias, groups=dim)\n", + "\n", + " if activation == \"silu\":\n", + " y = F.silu(y)\n", + "\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Tensors" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:20.144248Z", + "iopub.status.busy": "2026-04-03T17:50:20.144099Z", + "iopub.status.idle": "2026-04-03T17:50:20.331438Z", + "shell.execute_reply": "2026-04-03T17:50:20.330473Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA available: True\n", + "x: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n", + "weight: torch.Size([768, 4]), dtype=torch.bfloat16\n", + "bias: torch.Size([768]), dtype=torch.bfloat16\n" + ] + } + ], + "source": [ + "batch = 2\n", + "dim = 768\n", + "seq_len = 4096\n", + "kernel_size = 4\n", + "dtype = torch.bfloat16\n", + "\n", + "has_cuda = torch.cuda.is_available()\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "x = torch.randn(batch, dim, seq_len, dtype=dtype)\n", + "weight = torch.randn(dim, kernel_size, dtype=dtype)\n", + "bias = torch.randn(dim, dtype=dtype)\n", + "\n", + "print(f\"CUDA available: {has_cuda}\")\n", + "print(f\"x: {x.shape}, dtype={x.dtype}\")\n", + "print(f\"weight: {weight.shape}, dtype={weight.dtype}\")\n", + "print(f\"bias: {bias.shape}, dtype={bias.dtype}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute Reference Output" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:20.333511Z", + "iopub.status.busy": "2026-04-03T17:50:20.333360Z", + "iopub.status.idle": "2026-04-03T17:50:20.363243Z", + "shell.execute_reply": "2026-04-03T17:50:20.362726Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_ref: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n" + ] + } + ], + "source": [ + "y_ref = causal_conv1d_ref(x, weight, bias, activation=\"silu\")\n", + "print(f\"y_ref: {y_ref.shape}, dtype={y_ref.dtype}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute cuDNN Output" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:20.364697Z", + "iopub.status.busy": "2026-04-03T17:50:20.364555Z", + "iopub.status.idle": "2026-04-03T17:50:23.912909Z", + "shell.execute_reply": "2026-04-03T17:50:23.911867Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y_cudnn: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " x_gpu = x.cuda()\n", + " weight_gpu = weight.cuda()\n", + " bias_gpu = bias.cuda()\n", + " y_cudnn = cudnn.ops.causal_conv1d(x_gpu, weight_gpu, bias_gpu, activation=\"silu\")\n", + " print(f\"y_cudnn: {y_cudnn.shape}, dtype={y_cudnn.dtype}\")\n", + "else:\n", + " print(\"Skipping cuDNN forward (no CUDA device). Run on a GPU machine to test cuDNN.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify Correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:23.915428Z", + "iopub.status.busy": "2026-04-03T17:50:23.915236Z", + "iopub.status.idle": "2026-04-03T17:50:23.964477Z", + "shell.execute_reply": "2026-04-03T17:50:23.963766Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Max absolute difference: 6.250000e-02\n", + "Max relative difference: 3.278166e-02\n", + "PASSED: cuDNN causal_conv1d forward matches reference.\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " y_cudnn_cpu = y_cudnn.cpu()\n", + " max_abs_diff = (y_cudnn_cpu.float() - y_ref.float()).abs().max().item()\n", + " max_rel_diff = ((y_cudnn_cpu.float() - y_ref.float()).abs() / (y_ref.float().abs() + 1e-6)).max().item()\n", + "\n", + " print(f\"Max absolute difference: {max_abs_diff:.6e}\")\n", + " print(f\"Max relative difference: {max_rel_diff:.6e}\")\n", + "\n", + " atol = 1e-2\n", + " rtol = 1e-2\n", + "\n", + " assert torch.allclose(y_cudnn_cpu.float(), y_ref.float(), atol=atol, rtol=rtol), \\\n", + " f\"FAILED: max_abs={max_abs_diff}, max_rel={max_rel_diff}\"\n", + "\n", + " print(\"PASSED: cuDNN causal_conv1d forward matches reference.\")\n", + "else:\n", + " print(\"Skipping verification (no CUDA device).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test with Different Data Types and Kernel Sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:23.966213Z", + "iopub.status.busy": "2026-04-03T17:50:23.966006Z", + "iopub.status.idle": "2026-04-03T17:50:24.090002Z", + "shell.execute_reply": "2026-04-03T17:50:24.089403Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dtype=torch.float32 K= 4 act=silu max_abs_diff=1.9073e-06\n", + "[PASS] dtype=torch.float16 K= 4 act=silu max_abs_diff=7.8125e-03\n", + "[PASS] dtype=torch.bfloat16 K= 4 act=silu max_abs_diff=6.2500e-02\n", + "[PASS] dtype=torch.bfloat16 K= 3 act=silu max_abs_diff=6.2500e-02\n", + "[PASS] dtype=torch.float16 K= 7 act=silu max_abs_diff=7.8125e-03\n", + "[PASS] dtype=torch.bfloat16 K= 4 act=identity max_abs_diff=0.0000e+00\n", + "\n", + "All tests passed!\n" + ] + } + ], + "source": [ + "tolerances = {\n", + " torch.float32: (5e-6, 5e-6),\n", + " torch.float16: (5e-3, 5e-3),\n", + " torch.bfloat16: (1e-2, 1e-2),\n", + "}\n", + "\n", + "test_configs = [\n", + " {\"dtype\": torch.float32, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.float16, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 3, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.float16, \"kernel_size\": 7, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 4, \"activation\": \"identity\"},\n", + "]\n", + "\n", + "batch, dim, seq_len = 2, 256, 1024\n", + "\n", + "for cfg in test_configs:\n", + " dt = cfg[\"dtype\"]\n", + " ks = cfg[\"kernel_size\"]\n", + " act = cfg[\"activation\"]\n", + " atol, rtol = tolerances[dt]\n", + "\n", + " x = torch.randn(batch, dim, seq_len, dtype=dt)\n", + " w = torch.randn(dim, ks, dtype=dt)\n", + " b = torch.randn(dim, dtype=dt)\n", + "\n", + " y_ref = causal_conv1d_ref(x, w, b, activation=act)\n", + "\n", + " if has_cuda:\n", + " y_cudnn = cudnn.ops.causal_conv1d(x.cuda(), w.cuda(), b.cuda(), activation=act)\n", + " max_abs = (y_cudnn.cpu().float() - y_ref.float()).abs().max().item()\n", + " passed = torch.allclose(y_cudnn.cpu().float(), y_ref.float(), atol=atol, rtol=rtol)\n", + " status = \"PASS\" if passed else \"FAIL\"\n", + " print(f\"[{status}] dtype={str(dt):15s} K={ks:3d} act={act:10s} max_abs_diff={max_abs:.4e}\")\n", + " assert passed, f\"Test failed for {cfg}\"\n", + " else:\n", + " print(f\"[REF ] dtype={str(dt):15s} K={ks:3d} act={act:10s} y_ref shape={y_ref.shape}\")\n", + "\n", + "print(\"\\nAll tests passed!\" if has_cuda else \"\\nAll reference tests completed (CPU only).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## torch.compile Support\n", + "\n", + "`cudnn.ops.causal_conv1d` is registered as a PyTorch custom operator, so it works seamlessly with `torch.compile`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:50:24.092083Z", + "iopub.status.busy": "2026-04-03T17:50:24.091922Z", + "iopub.status.idle": "2026-04-03T17:50:35.078779Z", + "shell.execute_reply": "2026-04-03T17:50:35.077808Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.compile matches eager: max_abs_diff=0.0000e+00\n", + "y_compiled: torch.Size([2, 256, 1024]), dtype=torch.bfloat16\n", + "PASSED: torch.compile produces identical output to eager mode.\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " B, D, L, K = 2, 256, 1024, 4\n", + " x_c = torch.randn(B, D, L, dtype=torch.bfloat16, device=\"cuda\")\n", + " w_c = torch.randn(D, K, dtype=torch.bfloat16, device=\"cuda\")\n", + " b_c = torch.randn(D, dtype=torch.bfloat16, device=\"cuda\")\n", + "\n", + " @torch.compile\n", + " def compiled_causal_conv1d(x, w, b):\n", + " return cudnn.ops.causal_conv1d(x, w, b, activation=\"silu\")\n", + "\n", + " y_eager = cudnn.ops.causal_conv1d(x_c, w_c, b_c, activation=\"silu\")\n", + " y_compiled = compiled_causal_conv1d(x_c, w_c, b_c)\n", + "\n", + " max_abs = (y_compiled.float() - y_eager.float()).abs().max().item()\n", + " print(f\"torch.compile matches eager: max_abs_diff={max_abs:.4e}\")\n", + " assert max_abs == 0.0, f\"torch.compile output differs from eager: {max_abs}\"\n", + " print(f\"y_compiled: {y_compiled.shape}, dtype={y_compiled.dtype}\")\n", + " print(\"PASSED: torch.compile produces identical output to eager mode.\")\n", + "else:\n", + " print(\"Skipping torch.compile test (no CUDA device).\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/samples/python/61_causal_conv1d_backward.ipynb b/samples/python/61_causal_conv1d_backward.ipynb new file mode 100644 index 00000000..a663f4f2 --- /dev/null +++ b/samples/python/61_causal_conv1d_backward.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "86918253", + "metadata": {}, + "source": [ + "# Causal Conv1D Backward Operation" + ] + }, + { + "cell_type": "markdown", + "id": "ce999033", + "metadata": {}, + "source": [ + "This notebook shows how to compute gradients for the causal depthwise 1D convolution using cuDNN.\n", + "\n", + "Given the forward operation:\n", + "\n", + "$$y = \\text{Activation}(\\text{CausalConv1D}(x, w) + b)$$\n", + "\n", + "where Activation is Identity or SiLU, this notebook computes:\n", + "\n", + "- $dx$ — gradient w.r.t. input $x$\n", + "- $dw$ — gradient w.r.t. weight $w$\n", + "- $db$ — gradient w.r.t. bias $b$\n", + "\n", + "The `cudnn.ops.causal_conv1d` API integrates with `torch.autograd`, so backward is handled automatically via `.backward()`. We compare the cuDNN autograd result against a PyTorch reference." + ] + }, + { + "cell_type": "markdown", + "id": "b14229b1", + "metadata": {}, + "source": [ + "## Prerequisites and Setup\n", + "This notebook requires an NVIDIA GPU (Hopper or later recommended).\n", + "\n", + "**Environment setup** — make sure the cuDNN runtime library and the `cudnn` Python package are discoverable before launching the notebook:\n", + "\n", + "- **Option A – pip install:**\n", + " ```bash\n", + " pip install nvidia-cudnn-frontend\n", + " ```\n", + "- **Option B – set paths manually:**\n", + " ```bash\n", + " export LD_LIBRARY_PATH=/path/to/cudnn/lib:${LD_LIBRARY_PATH}\n", + " export PYTHONPATH=/path/to/cudnn_frontend/build_python:${PYTHONPATH}\n", + " ```\n", + "\n", + "Adjust the paths above to match your local build or installation directory." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3da2cd41", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:29.029407Z", + "iopub.status.busy": "2026-04-03T17:52:29.029100Z", + "iopub.status.idle": "2026-04-03T17:52:29.033790Z", + "shell.execute_reply": "2026-04-03T17:52:29.032514Z" + } + }, + "outputs": [], + "source": [ + "# !nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9a439b6e", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:29.036616Z", + "iopub.status.busy": "2026-04-03T17:52:29.036341Z", + "iopub.status.idle": "2026-04-03T17:52:31.788865Z", + "shell.execute_reply": "2026-04-03T17:52:31.787996Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDNN backend version: 92200\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "import cudnn\n", + "\n", + "print(\"cuDNN backend version:\", cudnn.backend_version())" + ] + }, + { + "cell_type": "markdown", + "id": "13788665", + "metadata": {}, + "source": [ + "## Reference Implementation\n", + "\n", + "A PyTorch reference for causal depthwise conv1d forward, used with autograd to compute reference gradients." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "793404f1", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:31.791763Z", + "iopub.status.busy": "2026-04-03T17:52:31.791566Z", + "iopub.status.idle": "2026-04-03T17:52:31.795642Z", + "shell.execute_reply": "2026-04-03T17:52:31.794935Z" + } + }, + "outputs": [], + "source": [ + "def causal_conv1d_ref(x, weight, bias, activation=\"silu\"):\n", + " \"\"\"Reference: causal depthwise conv1d using PyTorch ops.\n", + "\n", + " Args:\n", + " x: (batch, dim, seq_len)\n", + " weight: (dim, kernel_size)\n", + " bias: (dim,)\n", + " activation: 'silu' or 'identity'\n", + "\n", + " Returns:\n", + " y: (batch, dim, seq_len)\n", + " \"\"\"\n", + " batch, dim, seq_len = x.shape\n", + " kernel_size = weight.shape[1]\n", + "\n", + " x_padded = F.pad(x, (kernel_size - 1, 0))\n", + "\n", + " w = weight.unsqueeze(1) # (dim, 1, kernel_size)\n", + " y = F.conv1d(x_padded, w, bias=bias, groups=dim)\n", + "\n", + " if activation == \"silu\":\n", + " y = F.silu(y)\n", + "\n", + " return y\n", + "\n", + "\n", + "def causal_conv1d_bwd_ref(x, weight, bias, dy, activation=\"silu\"):\n", + " \"\"\"Reference backward via PyTorch autograd.\n", + "\n", + " Returns:\n", + " dx, dweight, dbias (all float32 for comparison)\n", + " \"\"\"\n", + " x = x.float().detach().requires_grad_(True)\n", + " weight = weight.float().detach().requires_grad_(True)\n", + " bias = bias.float().detach().requires_grad_(True)\n", + " dy = dy.float().detach()\n", + "\n", + " y = causal_conv1d_ref(x, weight, bias, activation=activation)\n", + " y.backward(dy)\n", + "\n", + " return x.grad, weight.grad, bias.grad" + ] + }, + { + "cell_type": "markdown", + "id": "f6fe9c44", + "metadata": {}, + "source": [ + "## Setup Tensors" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02c03769", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:31.797778Z", + "iopub.status.busy": "2026-04-03T17:52:31.797650Z", + "iopub.status.idle": "2026-04-03T17:52:32.093898Z", + "shell.execute_reply": "2026-04-03T17:52:32.092906Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA available: True\n", + "x: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n", + "weight: torch.Size([768, 4]), dtype=torch.bfloat16\n", + "bias: torch.Size([768]), dtype=torch.bfloat16\n", + "dy: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n" + ] + } + ], + "source": [ + "batch = 2\n", + "dim = 768\n", + "seq_len = 4096\n", + "kernel_size = 4\n", + "dtype = torch.bfloat16\n", + "\n", + "has_cuda = torch.cuda.is_available()\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "x = torch.randn(batch, dim, seq_len, dtype=dtype)\n", + "weight = torch.randn(dim, kernel_size, dtype=dtype)\n", + "bias = torch.randn(dim, dtype=dtype)\n", + "dy = torch.randn(batch, dim, seq_len, dtype=dtype)\n", + "\n", + "print(f\"CUDA available: {has_cuda}\")\n", + "print(f\"x: {x.shape}, dtype={x.dtype}\")\n", + "print(f\"weight: {weight.shape}, dtype={weight.dtype}\")\n", + "print(f\"bias: {bias.shape}, dtype={bias.dtype}\")\n", + "print(f\"dy: {dy.shape}, dtype={dy.dtype}\")" + ] + }, + { + "cell_type": "markdown", + "id": "98c68adc", + "metadata": {}, + "source": [ + "## Compute Reference Gradients (CPU, via autograd)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a7b7e1fe", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:32.096767Z", + "iopub.status.busy": "2026-04-03T17:52:32.096626Z", + "iopub.status.idle": "2026-04-03T17:52:33.422754Z", + "shell.execute_reply": "2026-04-03T17:52:33.422005Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dx_ref: torch.Size([2, 768, 4096]), dtype=torch.float32\n", + "dw_ref: torch.Size([768, 4]), dtype=torch.float32\n", + "db_ref: torch.Size([768]), dtype=torch.float32\n" + ] + } + ], + "source": [ + "dx_ref, dw_ref, db_ref = causal_conv1d_bwd_ref(x, weight, bias, dy, activation=\"silu\")\n", + "print(f\"dx_ref: {dx_ref.shape}, dtype={dx_ref.dtype}\")\n", + "print(f\"dw_ref: {dw_ref.shape}, dtype={dw_ref.dtype}\")\n", + "print(f\"db_ref: {db_ref.shape}, dtype={db_ref.dtype}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ea77f584", + "metadata": {}, + "source": [ + "## Compute cuDNN Gradients (via autograd)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7f9e824f", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:33.424733Z", + "iopub.status.busy": "2026-04-03T17:52:33.424415Z", + "iopub.status.idle": "2026-04-03T17:52:34.972397Z", + "shell.execute_reply": "2026-04-03T17:52:34.971285Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dx_cudnn: torch.Size([2, 768, 4096]), dtype=torch.bfloat16\n", + "dw_cudnn: torch.Size([768, 4]), dtype=torch.bfloat16\n", + "db_cudnn: torch.Size([768]), dtype=torch.bfloat16\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " x_gpu = x.cuda().detach().requires_grad_(True)\n", + " w_gpu = weight.cuda().detach().requires_grad_(True)\n", + " b_gpu = bias.cuda().detach().requires_grad_(True)\n", + " dy_gpu = dy.cuda()\n", + "\n", + " y_cudnn = cudnn.ops.causal_conv1d(x_gpu, w_gpu, b_gpu, activation=\"silu\")\n", + " y_cudnn.backward(dy_gpu)\n", + "\n", + " dx_cudnn = x_gpu.grad\n", + " dw_cudnn = w_gpu.grad\n", + " db_cudnn = b_gpu.grad\n", + " print(f\"dx_cudnn: {dx_cudnn.shape}, dtype={dx_cudnn.dtype}\")\n", + " print(f\"dw_cudnn: {dw_cudnn.shape}, dtype={dw_cudnn.dtype}\")\n", + " print(f\"db_cudnn: {db_cudnn.shape}, dtype={db_cudnn.dtype}\")\n", + "else:\n", + " print(\"Skipping cuDNN backward (no CUDA device).\")" + ] + }, + { + "cell_type": "markdown", + "id": "f5290bb0", + "metadata": {}, + "source": [ + "## Verify Correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "42562bfc", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:34.975588Z", + "iopub.status.busy": "2026-04-03T17:52:34.975287Z", + "iopub.status.idle": "2026-04-03T17:52:35.012053Z", + "shell.execute_reply": "2026-04-03T17:52:35.011095Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dx max_abs_diff=3.1241e-02\n", + "[PASS] dweight max_abs_diff=4.9881e-01\n", + "[PASS] dbias max_abs_diff=4.9422e-01\n", + "\n", + "PASSED: cuDNN causal_conv1d autograd backward matches reference.\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " atol = 1e-2\n", + " rtol = 1e-2\n", + "\n", + " for name, cudnn_grad, ref_grad in [\n", + " (\"dx\", dx_cudnn.cpu().float(), dx_ref),\n", + " (\"dweight\", dw_cudnn.cpu().float(), dw_ref),\n", + " (\"dbias\", db_cudnn.cpu().float(), db_ref),\n", + " ]:\n", + " max_abs = (cudnn_grad - ref_grad).abs().max().item()\n", + " passed = torch.allclose(cudnn_grad, ref_grad, atol=atol, rtol=rtol)\n", + " status = \"PASS\" if passed else \"FAIL\"\n", + " print(f\"[{status}] {name:8s} max_abs_diff={max_abs:.4e}\")\n", + " assert passed, f\"{name} verification failed: max_abs={max_abs}\"\n", + "\n", + " print(\"\\nPASSED: cuDNN causal_conv1d autograd backward matches reference.\")\n", + "else:\n", + " print(\"Skipping verification (no CUDA device).\")" + ] + }, + { + "cell_type": "markdown", + "id": "6959bb89", + "metadata": {}, + "source": [ + "## Test with Different Data Types and Kernel Sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3588ba20", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:35.014220Z", + "iopub.status.busy": "2026-04-03T17:52:35.014061Z", + "iopub.status.idle": "2026-04-03T17:52:35.188192Z", + "shell.execute_reply": "2026-04-03T17:52:35.187375Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dtype=torch.float32 K= 4 act=silu max_abs: dx=4.2915e-06 dw=3.0518e-05 db=5.7220e-05\n", + "[PASS] dtype=torch.float16 K= 4 act=silu max_abs: dx=3.8843e-03 dw=3.1090e-02 db=2.6917e-02\n", + "[PASS] dtype=torch.bfloat16 K= 4 act=silu max_abs: dx=3.0970e-02 dw=2.4898e-01 db=2.3223e-01\n", + "[PASS] dtype=torch.bfloat16 K= 3 act=silu max_abs: dx=3.1178e-02 dw=2.8888e-01 db=2.3131e-01\n", + "[PASS] dtype=torch.float16 K= 7 act=silu max_abs: dx=5.2376e-03 dw=3.4286e-02 db=2.4826e-02\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dtype=torch.bfloat16 K= 4 act=identity max_abs: dx=3.1250e-02 dw=4.8404e-01 db=2.4219e-01\n", + "\n", + "All tests passed!\n" + ] + } + ], + "source": [ + "tolerances = {\n", + " torch.float32: (5e-6, 5e-6),\n", + " torch.float16: (5e-3, 5e-3),\n", + " torch.bfloat16: (1e-2, 1e-2),\n", + "}\n", + "# dweight/dbias are always FP32; error is from atomicAdd ordering\n", + "dw_db_atol, dw_db_rtol = 5e-3, 5e-3\n", + "\n", + "test_configs = [\n", + " {\"dtype\": torch.float32, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.float16, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 4, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 3, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.float16, \"kernel_size\": 7, \"activation\": \"silu\"},\n", + " {\"dtype\": torch.bfloat16, \"kernel_size\": 4, \"activation\": \"identity\"},\n", + "]\n", + "\n", + "batch, dim, seq_len = 2, 256, 1024\n", + "\n", + "for cfg in test_configs:\n", + " dt = cfg[\"dtype\"]\n", + " ks = cfg[\"kernel_size\"]\n", + " act = cfg[\"activation\"]\n", + " atol, rtol = tolerances[dt]\n", + "\n", + " x = torch.randn(batch, dim, seq_len, dtype=dt)\n", + " w = torch.randn(dim, ks, dtype=dt)\n", + " b = torch.randn(dim, dtype=dt)\n", + " dy = torch.randn(batch, dim, seq_len, dtype=dt)\n", + "\n", + " dx_ref, dw_ref, db_ref = causal_conv1d_bwd_ref(x, w, b, dy, activation=act)\n", + "\n", + " if has_cuda:\n", + " x_g = x.cuda().detach().requires_grad_(True)\n", + " w_g = w.cuda().detach().requires_grad_(True)\n", + " b_g = b.cuda().detach().requires_grad_(True)\n", + "\n", + " y_g = cudnn.ops.causal_conv1d(x_g, w_g, b_g, activation=act)\n", + " y_g.backward(dy.cuda())\n", + "\n", + " dx_ok = torch.allclose(x_g.grad.cpu().float(), dx_ref, atol=atol, rtol=rtol)\n", + " dw_ok = torch.allclose(w_g.grad.cpu().float(), dw_ref, atol=max(atol, dw_db_atol), rtol=max(rtol, dw_db_rtol))\n", + " db_ok = torch.allclose(b_g.grad.cpu().float(), db_ref, atol=max(atol, dw_db_atol), rtol=max(rtol, dw_db_rtol))\n", + " passed = dx_ok and dw_ok and db_ok\n", + "\n", + " dx_abs = (x_g.grad.cpu().float() - dx_ref).abs().max().item()\n", + " dw_abs = (w_g.grad.cpu().float() - dw_ref).abs().max().item()\n", + " db_abs = (b_g.grad.cpu().float() - db_ref).abs().max().item()\n", + "\n", + " status = \"PASS\" if passed else \"FAIL\"\n", + " print(f\"[{status}] dtype={str(dt):15s} K={ks:3d} act={act:10s} \"\n", + " f\"max_abs: dx={dx_abs:.4e} dw={dw_abs:.4e} db={db_abs:.4e}\")\n", + " assert passed, f\"Test failed for {cfg}\"\n", + " else:\n", + " print(f\"[REF ] dtype={str(dt):15s} K={ks:3d} act={act:10s} \"\n", + " f\"dx={dx_ref.shape} dw={dw_ref.shape} db={db_ref.shape}\")\n", + "\n", + "print(\"\\nAll tests passed!\" if has_cuda else \"\\nAll reference tests completed (CPU only).\")" + ] + }, + { + "cell_type": "markdown", + "id": "d28a112e", + "metadata": {}, + "source": [ + "## torch.compile Support\n", + "\n", + "`cudnn.ops.causal_conv1d` works with `torch.compile` for both forward and backward passes. The compiled graph captures the custom op nodes and the autograd backward." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "390c20ff", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-03T17:52:35.190609Z", + "iopub.status.busy": "2026-04-03T17:52:35.190460Z", + "iopub.status.idle": "2026-04-03T17:52:41.692848Z", + "shell.execute_reply": "2026-04-03T17:52:41.691995Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dx eager vs compiled max_abs_diff=0.0000e+00\n", + "[PASS] dweight eager vs compiled max_abs_diff=0.0000e+00\n", + "[PASS] dbias eager vs compiled max_abs_diff=0.0000e+00\n", + "\n", + "PASSED: torch.compile backward produces identical output to eager mode.\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " B, D, L, K = 2, 256, 1024, 4\n", + " act = \"silu\"\n", + "\n", + " def train_step(x, w, b, dy):\n", + " x = x.detach().requires_grad_(True)\n", + " w = w.detach().requires_grad_(True)\n", + " b = b.detach().requires_grad_(True)\n", + " y = cudnn.ops.causal_conv1d(x, w, b, activation=act)\n", + " y.backward(dy)\n", + " return x.grad, w.grad, b.grad\n", + "\n", + " compiled_train_step = torch.compile(train_step)\n", + "\n", + " x_c = torch.randn(B, D, L, dtype=torch.bfloat16, device=\"cuda\")\n", + " w_c = torch.randn(D, K, dtype=torch.bfloat16, device=\"cuda\")\n", + " b_c = torch.randn(D, dtype=torch.bfloat16, device=\"cuda\")\n", + " dy_c = torch.randn(B, D, L, dtype=torch.bfloat16, device=\"cuda\")\n", + "\n", + " dx_eager, dw_eager, db_eager = train_step(x_c, w_c, b_c, dy_c)\n", + " dx_compiled, dw_compiled, db_compiled = compiled_train_step(x_c, w_c, b_c, dy_c)\n", + "\n", + " for name, eager, compiled in [\n", + " (\"dx\", dx_eager, dx_compiled),\n", + " (\"dweight\", dw_eager, dw_compiled),\n", + " (\"dbias\", db_eager, db_compiled),\n", + " ]:\n", + " max_abs = (eager.float() - compiled.float()).abs().max().item()\n", + " print(f\"[{'PASS' if max_abs == 0 else 'FAIL'}] {name:8s} eager vs compiled max_abs_diff={max_abs:.4e}\")\n", + " assert max_abs == 0.0, f\"{name}: torch.compile differs from eager\"\n", + "\n", + " print(\"\\nPASSED: torch.compile backward produces identical output to eager mode.\")\n", + "else:\n", + " print(\"Skipping torch.compile test (no CUDA device).\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/samples/python/70_boolean_cmp_logic.ipynb b/samples/python/70_boolean_cmp_logic.ipynb new file mode 100644 index 00000000..8cb5cc59 --- /dev/null +++ b/samples/python/70_boolean_cmp_logic.ipynb @@ -0,0 +1,436 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Boolean Comparison and Logic Fusion\n", + "\n", + "This notebook demonstrates a boolean fusion graph using the cuDNN Python API:\n", + "\n", + "$$\\text{output} = (\\mathbf{x} > \\text{threshold}) \\;\\wedge\\; \\mathbf{b}$$\n", + "\n", + "The graph fuses two pointwise operations:\n", + "1. **CMP_GT** — element-wise greater-than comparison producing a boolean mask\n", + "2. **LOGICAL_AND** — element-wise AND of the comparison result with another boolean tensor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites and Setup\n", + "\n", + "This notebook requires an NVIDIA Blackwell GPU (SM100+) or later.\n", + "\n", + "**Environment setup:**\n", + "\n", + "- **Option A — pip install:**\n", + " ```bash\n", + " pip install nvidia-cudnn-frontend\n", + " ```\n", + "- **Option B — set paths manually:**\n", + " ```bash\n", + " export LD_LIBRARY_PATH=/path/to/cudnn/lib:${LD_LIBRARY_PATH}\n", + " export PYTHONPATH=/path/to/cudnn_frontend/build_python:${PYTHONPATH}\n", + " ```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:04.117372Z", + "iopub.status.busy": "2026-04-28T19:01:04.117110Z", + "iopub.status.idle": "2026-04-28T19:01:04.121104Z", + "shell.execute_reply": "2026-04-28T19:01:04.120209Z" + } + }, + "outputs": [], + "source": [ + "# get_ipython().system('pip install nvidia-cudnn-cu13')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", + "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "We build a two-node fusion graph:\n", + "\n", + "| Node | Operation | Inputs | Output | Compute Precision |\n", + "|------|-----------|--------|--------|-------------------|\n", + "| `cmp0` | `CMP_GT` | `x` (half), `threshold` (float) | `after_cmp` (bool, virtual) | float |\n", + "| `logical0` | `LOGICAL_AND` | `after_cmp` (bool), `b` (bool) | `output` (bool) | bool |\n", + "\n", + "Tensors use row-major layout with shape `[d0, d1, d2]`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:04.124069Z", + "iopub.status.busy": "2026-04-28T19:01:04.123806Z", + "iopub.status.idle": "2026-04-28T19:01:06.580620Z", + "shell.execute_reply": "2026-04-28T19:01:06.579851Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuDNN backend version: 92200\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\", message=\"Failed to initialize NumPy\")\n", + "\n", + "import torch\n", + "import cudnn\n", + "\n", + "# Boolean fusion (CMP_GT + LOGICAL_AND) requires cuDNN backend version >= 92200\n", + "print(\"cuDNN backend version:\", cudnn.backend_version())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Tensors\n", + "\n", + "We use a small 3D shape `[4, 8, 16]` in row-major layout for quick verification." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:06.605175Z", + "iopub.status.busy": "2026-04-28T19:01:06.605022Z", + "iopub.status.idle": "2026-04-28T19:01:06.891857Z", + "shell.execute_reply": "2026-04-28T19:01:06.891160Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x: torch.Size([4, 8, 16]), dtype=torch.float16, stride=(128, 16, 1)\n", + "threshold: torch.Size([4, 8, 16]), dtype=torch.float32, stride=(128, 16, 1)\n", + "b: torch.Size([4, 8, 16]), dtype=torch.bool, stride=(128, 16, 1)\n" + ] + } + ], + "source": [ + "shape = (4, 8, 16)\n", + "has_cuda = torch.cuda.is_available()\n", + "device = torch.device(\"cuda\" if has_cuda else \"cpu\")\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "x_gpu = torch.randn(*shape, dtype=torch.float16, device=device).contiguous()\n", + "threshold_gpu = torch.randn(*shape, dtype=torch.float32, device=device).contiguous()\n", + "b_gpu = torch.randint(0, 2, shape, dtype=torch.bool, device=device).contiguous()\n", + "\n", + "print(f\"x: {x_gpu.shape}, dtype={x_gpu.dtype}, stride={x_gpu.stride()}\")\n", + "print(f\"threshold: {threshold_gpu.shape}, dtype={threshold_gpu.dtype}, stride={threshold_gpu.stride()}\")\n", + "print(f\"b: {b_gpu.shape}, dtype={b_gpu.dtype}, stride={b_gpu.stride()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build and Execute cuDNN Graph\n", + "\n", + "We use the low-level `pygraph` API to construct the fusion graph with `cmp_gt` and `logical_and` nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:06.893234Z", + "iopub.status.busy": "2026-04-28T19:01:06.893121Z", + "iopub.status.idle": "2026-04-28T19:01:07.205991Z", + "shell.execute_reply": "2026-04-28T19:01:07.205120Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "after_and_gpu: torch.Size([4, 8, 16]), dtype=torch.bool" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "True count: 114 / 512\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " handle = cudnn.create_handle()\n", + "\n", + " graph = cudnn.pygraph(\n", + " handle=handle,\n", + " intermediate_data_type=cudnn.data_type.FLOAT,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + " )\n", + "\n", + " x_cudnn = graph.tensor_like(x_gpu)\n", + " threshold_cudnn = graph.tensor_like(threshold_gpu)\n", + " b_cudnn = graph.tensor_like(b_gpu)\n", + "\n", + " after_cmp = graph.cmp_gt(\n", + " name=\"cmp0\",\n", + " input=x_cudnn,\n", + " comparison=threshold_cudnn,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + " )\n", + " after_cmp.set_data_type(cudnn.data_type.BOOLEAN)\n", + "\n", + " after_and = graph.logical_and(\n", + " name=\"logical0\",\n", + " a=after_cmp,\n", + " b=b_cudnn,\n", + " compute_data_type=cudnn.data_type.BOOLEAN,\n", + " )\n", + " after_and.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)\n", + "\n", + " graph.validate()\n", + " graph.build_operation_graph()\n", + " graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", + " graph.check_support()\n", + " graph.build_plans()\n", + "\n", + " after_and_gpu = torch.empty(*shape, dtype=torch.bool, device=device).contiguous()\n", + " workspace = torch.empty(graph.get_workspace_size(), device=device, dtype=torch.uint8)\n", + "\n", + " graph.execute(\n", + " {x_cudnn: x_gpu, threshold_cudnn: threshold_gpu, b_cudnn: b_gpu, after_and: after_and_gpu},\n", + " workspace,\n", + " handle=handle,\n", + " )\n", + " torch.cuda.synchronize()\n", + "\n", + " print(f\"after_and_gpu: {after_and_gpu.shape}, dtype={after_and_gpu.dtype}\")\n", + " print(f\"True count: {after_and_gpu.sum().item()} / {after_and_gpu.numel()}\")\n", + "else:\n", + " print(\"Skipping cuDNN graph (no CUDA device).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Verify Correctness" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:07.208645Z", + "iopub.status.busy": "2026-04-28T19:01:07.208529Z", + "iopub.status.idle": "2026-04-28T19:01:07.290338Z", + "shell.execute_reply": "2026-04-28T19:01:07.289782Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ref true count: 114 / 512\n", + "Mismatches: 0 / 512\n", + "PASSED: cuDNN boolean fusion matches reference.\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " after_and_ref = (x_gpu.float() > threshold_gpu) & b_gpu\n", + " mismatches = (after_and_gpu != after_and_ref).sum().item()\n", + " total = after_and_ref.numel()\n", + "\n", + " print(f\"Ref true count: {after_and_ref.sum().item()} / {total}\")\n", + " print(f\"Mismatches: {mismatches} / {total}\")\n", + " assert mismatches == 0, f\"FAILED: {mismatches} mismatches out of {total} elements\"\n", + " print(\"PASSED: cuDNN boolean fusion matches reference.\")\n", + "else:\n", + " print(\"Skipping verification (no CUDA device).\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test with Multiple Shapes and Data Types\n", + "\n", + "Sweep across different tensor shapes (2D–4D) and input data types for the comparison operand." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:07.291657Z", + "iopub.status.busy": "2026-04-28T19:01:07.291532Z", + "iopub.status.idle": "2026-04-28T19:01:07.985594Z", + "shell.execute_reply": "2026-04-28T19:01:07.984724Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dim=[8, 32] x_dtype=torch.float16 mismatches=0\n", + "[PASS] dim=[4, 8, 16] x_dtype=torch.float16 mismatches=0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dim=[4, 8, 16] x_dtype=torch.bfloat16 mismatches=0\n", + "[PASS] dim=[4, 8, 16] x_dtype=torch.float32 mismatches=0\n", + "[PASS] dim=[1, 32] x_dtype=torch.float16 mismatches=0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] dim=[16, 64, 196] x_dtype=torch.float16 mismatches=0\n", + "[PASS] dim=[2, 64, 28, 28] x_dtype=torch.float16 mismatches=0\n", + "\n", + "All tests passed!\n" + ] + } + ], + "source": [ + "if has_cuda:\n", + " test_configs = [\n", + " {\"dim\": [8, 32], \"x_dtype\": torch.float16},\n", + " {\"dim\": [4, 8, 16], \"x_dtype\": torch.float16},\n", + " {\"dim\": [4, 8, 16], \"x_dtype\": torch.bfloat16},\n", + " {\"dim\": [4, 8, 16], \"x_dtype\": torch.float32},\n", + " {\"dim\": [1, 32], \"x_dtype\": torch.float16},\n", + " {\"dim\": [16, 64, 196], \"x_dtype\": torch.float16},\n", + " {\"dim\": [2, 64, 28, 28], \"x_dtype\": torch.float16},\n", + " ]\n", + "\n", + " all_pass = True\n", + " for cfg in test_configs:\n", + " dims = cfg[\"dim\"]\n", + " x_dt = cfg[\"x_dtype\"]\n", + "\n", + " x_t = torch.randn(*dims, dtype=x_dt, device=device).contiguous()\n", + " thresh_t = torch.randn(*dims, dtype=torch.float32, device=device).contiguous()\n", + " b_t = torch.randint(0, 2, dims, dtype=torch.bool, device=device).contiguous()\n", + "\n", + " ref_t = (x_t.float() > thresh_t) & b_t\n", + "\n", + " g = cudnn.pygraph(\n", + " handle=handle,\n", + " intermediate_data_type=cudnn.data_type.FLOAT,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + " )\n", + " x_c = g.tensor_like(x_t)\n", + " th_c = g.tensor_like(thresh_t)\n", + " b_c = g.tensor_like(b_t)\n", + "\n", + " cmp_out = g.cmp_gt(input=x_c, comparison=th_c, compute_data_type=cudnn.data_type.FLOAT)\n", + " cmp_out.set_data_type(cudnn.data_type.BOOLEAN)\n", + "\n", + " and_out = g.logical_and(a=cmp_out, b=b_c, compute_data_type=cudnn.data_type.BOOLEAN)\n", + " and_out.set_output(True).set_data_type(cudnn.data_type.BOOLEAN)\n", + "\n", + " try:\n", + " g.validate()\n", + " g.build_operation_graph()\n", + " g.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", + " g.check_support()\n", + " g.build_plans()\n", + " except cudnn.cudnnGraphNotSupportedError as e:\n", + " print(f\"[SKIP] dim={dims} x_dtype={x_dt}: {e}\")\n", + " continue\n", + "\n", + " out_t = torch.empty(*dims, dtype=torch.bool, device=device).contiguous()\n", + " ws = torch.empty(g.get_workspace_size(), device=device, dtype=torch.uint8)\n", + " g.execute({x_c: x_t, th_c: thresh_t, b_c: b_t, and_out: out_t}, ws, handle=handle)\n", + " torch.cuda.synchronize()\n", + "\n", + " mm = (out_t != ref_t).sum().item()\n", + " status = \"PASS\" if mm == 0 else \"FAIL\"\n", + " if mm != 0:\n", + " all_pass = False\n", + " print(f\"[{status}] dim={str(dims):20s} x_dtype={str(x_dt):18s} mismatches={mm}\")\n", + " assert mm == 0, f\"Test failed for {cfg}\"\n", + "\n", + " print(f\"\\nAll tests passed!\" if all_pass else \"\\nSome tests failed.\")\n", + "else:\n", + " print(\"Skipping multi-config tests (no CUDA device).\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-28T19:01:07.987304Z", + "iopub.status.busy": "2026-04-28T19:01:07.987178Z", + "iopub.status.idle": "2026-04-28T19:01:07.990049Z", + "shell.execute_reply": "2026-04-28T19:01:07.989373Z" + } + }, + "outputs": [], + "source": [ + "if has_cuda:\n", + " cudnn.destroy_handle(handle)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/setup.py b/setup.py index ac9ba60d..222d430f 100644 --- a/setup.py +++ b/setup.py @@ -29,15 +29,9 @@ def build_extension(self, ext: CMakeExtension) -> None: cfg = "Debug" if debug else "Release" is_windows = os.name == "nt" - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - cmake_args = [] - - if is_windows == False: - cmake_args += [ - f"-DPython_EXECUTABLE={sys.executable}", - ] - cmake_args = [ + f"-DPython_EXECUTABLE={sys.executable}", + f"-DPYBIND11_FINDPYTHON=ON", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm f"-DCUDNN_FRONTEND_BUILD_PYTHON_BINDINGS=ON", # There's no need to build cpp samples and tests with python diff --git a/test/python/fe_api/test_api_base_logging.py b/test/python/fe_api/test_api_base_logging.py new file mode 100644 index 00000000..61dcb335 --- /dev/null +++ b/test/python/fe_api/test_api_base_logging.py @@ -0,0 +1,25 @@ +import logging + +import pytest + +from cudnn.api_base import _reset_experimental_api_warning_registry, warn_experimental_api_once + + +@pytest.mark.L0 +def test_experimental_api_warning_emits_once_per_api(caplog): + logger = logging.getLogger("cudnn.test.experimental") + _reset_experimental_api_warning_registry() + + try: + with caplog.at_level(logging.WARNING, logger=logger.name): + warn_experimental_api_once(logger, "FirstExperimentalApi") + warn_experimental_api_once(logger, "FirstExperimentalApi") + warn_experimental_api_once(logger, "SecondExperimentalApi") + + messages = [record.getMessage() for record in caplog.records] + assert messages == [ + "FirstExperimentalApi is an experimental API", + "SecondExperimentalApi is an experimental API", + ] + finally: + _reset_experimental_api_warning_registry() diff --git a/test/python/fe_api/test_gemm_dsrelu.py b/test/python/fe_api/test_gemm_dsrelu.py new file mode 100644 index 00000000..10f99fa0 --- /dev/null +++ b/test/python/fe_api/test_gemm_dsrelu.py @@ -0,0 +1,395 @@ +import pytest +import torch + +import cudnn +from test_utils import torch_fork_set_rng + +from fe_api.test_gemm_dsrelu_utils import ( + allocate_gemm_dsrelu_outputs, + allocate_gemm_dsrelu_tensors, + check_ref_gemm_dsrelu, + gemm_dsrelu_init, + with_gemm_dsrelu_params_fp4, +) + + +def _run_class_api(cfg, inputs, outputs): + op = cudnn.GemmDsreluSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=inputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_dprob=outputs["dprob_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_prob=inputs["prob_tensor"], + sample_sfd=outputs["sfd_tensor"], + sample_amax=outputs["amax_tensor"], + sample_norm_const=outputs["norm_const_tensor"], + alpha=cfg["alpha"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + try: + assert op.check_support() + except (ValueError, NotImplementedError, RuntimeError) as e: + pytest.skip(f"Unsupported testcase: {e}") + op.compile() + op.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + d_tensor=outputs["d_tensor"], + dprob_tensor=outputs["dprob_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + sfd_tensor=outputs["sfd_tensor"], + amax_tensor=outputs["amax_tensor"], + norm_const_tensor=outputs["norm_const_tensor"], + alpha=cfg["alpha"], + ) + torch.cuda.synchronize() + + +def _run_wrapper_api(cfg, inputs): + try: + result = None + for _ in range(2): + result = cudnn.gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + d_major=cfg["c_major"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=( + None if cfg["d_dtype"] not in {torch.float8_e4m3fn, torch.float8_e5m2} else torch.tensor([1.0], dtype=torch.float32, device="cuda") + ), + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + except (ValueError, NotImplementedError, RuntimeError) as e: + pytest.skip(f"Unsupported testcase: {e}") + torch.cuda.synchronize() + return { + "d_tensor": result["d_tensor"], + "dprob_tensor": result["dprob_tensor"], + "amax_tensor": result["amax_tensor"], + "sfd_tensor": result["sfd_tensor"], + "norm_const_tensor": None, + } + + +def _make_dense_dsrelu_cfg(request, m: int, n: int = 256, k: int = 512, l: int = 2): + cfg = gemm_dsrelu_init( + request, + "k", + "k", + "n", + torch.float4_e2m1fn_x2, + torch.bfloat16, + torch.bfloat16, + torch.float32, + (256, 256), + (2, 1), + 16, + torch.float8_e8m0fnu, + False, + ) + cfg["m"] = m + cfg["n"] = n + cfg["k"] = k + cfg["l"] = l + return cfg + + +def _test_gemm_dsrelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m): + try: + from cudnn import gemm_dsrelu_wrapper_sm100 + from cudnn.gemm_dsrelu import api as gemm_dsrelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + if use_dynamic_m: + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_M", "1") + else: + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + compile_count = {"value": 0} + + def counted_compile(self): + compile_count["value"] += 1 + + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "check_support", lambda self: True) + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "compile", counted_compile) + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "execute", lambda self, **kwargs: None) + + try: + for m in (256, 384): + cfg = _make_dense_dsrelu_cfg(request, m) + inputs = allocate_gemm_dsrelu_tensors(cfg) + gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + d_major=cfg["c_major"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + finally: + cache_entries = len(gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects) + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + + return compile_count["value"], cache_entries + + +def _test_gemm_dsrelu_wrapper_full_dynamic_cache_behavior(request, monkeypatch): + try: + from cudnn import gemm_dsrelu_wrapper_sm100 + from cudnn.gemm_dsrelu import api as gemm_dsrelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_MNKL", "1") + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + compile_count = {"value": 0} + + def counted_compile(self): + compile_count["value"] += 1 + + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "check_support", lambda self: True) + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "compile", counted_compile) + monkeypatch.setattr(gemm_dsrelu_api.GemmDsreluSm100, "execute", lambda self, **kwargs: None) + + try: + for mnkl in ((256, 256, 512, 2), (384, 384, 640, 3)): + cfg = _make_dense_dsrelu_cfg(request, *mnkl) + inputs = allocate_gemm_dsrelu_tensors(cfg) + gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + d_major=cfg["c_major"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + finally: + cache_entries = len(gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects) + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + + return compile_count["value"], cache_entries + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=10) +@with_gemm_dsrelu_params_fp4 +def test_gemm_dsrelu_compile_execute_fp4( + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + cfg = gemm_dsrelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + ) + inputs = allocate_gemm_dsrelu_tensors(cfg) + outputs = allocate_gemm_dsrelu_outputs(cfg) + _run_class_api(cfg, inputs, outputs) + check_ref_gemm_dsrelu(inputs, outputs, cfg, check_d=True) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=11) +@with_gemm_dsrelu_params_fp4 +def test_gemm_dsrelu_wrapper_fp4( + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + cfg = gemm_dsrelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + ) + inputs = allocate_gemm_dsrelu_tensors(cfg) + outputs = _run_wrapper_api(cfg, inputs) + check_ref_gemm_dsrelu(inputs, outputs, cfg, check_d=True) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=12) +def test_gemm_dsrelu_wrapper_cache_static_m_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_dsrelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m=False) + + assert compile_count == 2 + assert cache_entries == 2 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=13) +def test_gemm_dsrelu_wrapper_cache_dynamic_m_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_dsrelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m=True) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=14) +def test_gemm_dsrelu_wrapper_dynamic_m_fp4(request, monkeypatch): + try: + import cudnn + from cudnn.gemm_dsrelu import api as gemm_dsrelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_M", "1") + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + + try: + for m in (256, 384): + cfg = _make_dense_dsrelu_cfg(request, m) + inputs = allocate_gemm_dsrelu_tensors(cfg) + outputs = cudnn.gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + d_major=cfg["c_major"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + check_ref_gemm_dsrelu(inputs, outputs, cfg, check_d=True) + + assert len(gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects) == 1 + finally: + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=15) +def test_gemm_dsrelu_wrapper_cache_full_dynamic_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_dsrelu_wrapper_full_dynamic_cache_behavior(request, monkeypatch) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=16) +def test_gemm_dsrelu_wrapper_full_dynamic_fp4(request, monkeypatch): + try: + import cudnn + from cudnn.gemm_dsrelu import api as gemm_dsrelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_MNKL", "1") + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() + + try: + for mnkl in ((256, 256, 512, 2), (384, 384, 640, 3)): + cfg = _make_dense_dsrelu_cfg(request, *mnkl) + inputs = allocate_gemm_dsrelu_tensors(cfg) + outputs = cudnn.gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + d_major=cfg["c_major"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + check_ref_gemm_dsrelu(inputs, outputs, cfg, check_d=True) + + assert len(gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects) == 1 + finally: + gemm_dsrelu_api._cache_of_GemmDsreluSm100Objects.clear() diff --git a/test/python/fe_api/test_gemm_dsrelu_utils.py b/test/python/fe_api/test_gemm_dsrelu_utils.py new file mode 100644 index 00000000..764257b1 --- /dev/null +++ b/test/python/fe_api/test_gemm_dsrelu_utils.py @@ -0,0 +1,152 @@ +import pytest +import torch + +from test_fe_api_utils import ( + compute_reference_amax, + create_and_permute_tensor, + create_scale_factor_tensor, +) + +GEMM_DSRELU_PARAM_MARKS_FP4 = [ + pytest.mark.parametrize("a_major", ["k"]), + pytest.mark.parametrize("b_major", ["k"]), + pytest.mark.parametrize("c_major", ["n"]), + pytest.mark.parametrize("ab_dtype", [torch.float4_e2m1fn_x2]), + pytest.mark.parametrize("c_dtype", [torch.bfloat16]), + pytest.mark.parametrize("d_dtype", [torch.bfloat16]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("mma_tiler_mn", [(256, 256), (128, 256)]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1), (1, 1)]), + pytest.mark.parametrize("sf_vec_size", [16]), + pytest.mark.parametrize("sf_dtype", [torch.float8_e8m0fnu]), + pytest.mark.parametrize("vector_f32", [True, False]), +] + + +def with_gemm_dsrelu_params_fp4(func): + for mark in reversed(GEMM_DSRELU_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def gemm_dsrelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, +): + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 100: + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") + + mnkl_str = request.config.getoption("--gemm-dsrelu-mnkl", default=None) + if mnkl_str is not None: + m, n, k, l = [int(x.strip()) for x in mnkl_str.split(",")] + else: + m, n, k, l = 256, 256, 512, 2 + + return { + "m": m, + "n": n, + "k": k, + "l": l, + "a_major": a_major, + "b_major": b_major, + "c_major": c_major, + "ab_dtype": ab_dtype, + "c_dtype": c_dtype, + "d_dtype": d_dtype, + "acc_dtype": acc_dtype, + "mma_tiler_mn": mma_tiler_mn, + "cluster_shape_mn": cluster_shape_mn, + "sf_vec_size": sf_vec_size, + "sf_dtype": sf_dtype, + "vector_f32": vector_f32, + "alpha": 1.0, + "skip_ref": request.config.getoption("--skip-ref", default=False), + } + + +def allocate_gemm_dsrelu_tensors(cfg): + a_ref, a_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["k"], cfg["a_major"] == "m", cfg["ab_dtype"]) + b_ref, b_tensor = create_and_permute_tensor(cfg["l"], cfg["n"], cfg["k"], cfg["b_major"] == "n", cfg["ab_dtype"]) + c_ref, c_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["c_major"] == "m", cfg["c_dtype"]) + sfa_ref, sfa_tensor = create_scale_factor_tensor(cfg["l"], cfg["m"], cfg["k"], cfg["sf_vec_size"], cfg["sf_dtype"]) + sfb_ref, sfb_tensor = create_scale_factor_tensor(cfg["l"], cfg["n"], cfg["k"], cfg["sf_vec_size"], cfg["sf_dtype"]) + prob_ref, prob_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], 1, cfg["a_major"] == "m", torch.float32) + + return { + "a_ref": a_ref, + "a_tensor": a_tensor, + "b_ref": b_ref, + "b_tensor": b_tensor, + "c_ref": c_ref, + "c_tensor": c_tensor, + "sfa_ref": sfa_ref, + "sfa_tensor": sfa_tensor, + "sfb_ref": sfb_ref, + "sfb_tensor": sfb_tensor, + "prob_ref": prob_ref, + "prob_tensor": prob_tensor, + } + + +def allocate_gemm_dsrelu_outputs(cfg): + _, d_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["c_major"] == "m", cfg["d_dtype"]) + dprob_tensor = torch.zeros((cfg["m"], 1, cfg["l"]), dtype=torch.float32, device="cuda") + + sfd_tensor = None + if cfg["d_dtype"] in {torch.float8_e4m3fn, torch.float8_e5m2}: + _, sfd_tensor = create_scale_factor_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["sf_vec_size"], cfg["sf_dtype"]) + + amax_tensor = None + if cfg["ab_dtype"] in {torch.float4_e2m1fn_x2, torch.uint8} and cfg["d_dtype"] in {torch.bfloat16, torch.float16, torch.float32}: + amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device="cuda") + + norm_const_tensor = None + if sfd_tensor is not None: + norm_const_tensor = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + return { + "d_tensor": d_tensor, + "dprob_tensor": dprob_tensor, + "sfd_tensor": sfd_tensor, + "amax_tensor": amax_tensor, + "norm_const_tensor": norm_const_tensor, + } + + +def gemm_dsrelu_reference(inputs, cfg): + res_a = torch.einsum("mkl,mkl->mkl", inputs["a_ref"], inputs["sfa_ref"]) + res_b = torch.einsum("nkl,nkl->nkl", inputs["b_ref"], inputs["sfb_ref"]) + x_ref = cfg["alpha"] * torch.einsum("mkl,nkl->mnl", res_a, res_b) + d_ref = inputs["c_ref"].float() * inputs["prob_ref"].expand(-1, cfg["n"], -1).float() * 2 * torch.relu(x_ref) + dprob_ref = torch.sum(torch.relu(x_ref) ** 2 * inputs["c_ref"].float(), dim=1, keepdim=True) + return d_ref, dprob_ref + + +def check_ref_gemm_dsrelu(inputs, outputs, cfg, check_d=True): + if cfg["skip_ref"]: + return + + d_ref, dprob_ref = gemm_dsrelu_reference(inputs, cfg) + torch.testing.assert_close(outputs["dprob_tensor"].float(), dprob_ref.float(), atol=0.12, rtol=0.02) + + if check_d: + torch.testing.assert_close(outputs["d_tensor"].float(), d_ref.float(), atol=0.12, rtol=0.02) + if outputs["amax_tensor"] is not None: + amax_ref = torch.tensor( + [compute_reference_amax(d_ref)], + dtype=torch.float32, + device=outputs["amax_tensor"].device, + ) + torch.testing.assert_close(outputs["amax_tensor"], amax_ref, atol=0.12, rtol=0.02) diff --git a/test/python/fe_api/test_gemm_srelu.py b/test/python/fe_api/test_gemm_srelu.py new file mode 100644 index 00000000..55ad46b8 --- /dev/null +++ b/test/python/fe_api/test_gemm_srelu.py @@ -0,0 +1,393 @@ +import pytest +import torch + +import cudnn +from test_utils import torch_fork_set_rng + +from fe_api.test_gemm_srelu_utils import ( + allocate_gemm_srelu_outputs, + allocate_gemm_srelu_tensors, + check_ref_gemm_srelu, + gemm_srelu_init, + with_gemm_srelu_params_fp4, +) + + +def _run_class_api(cfg, inputs, outputs): + op = cudnn.GemmSreluSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=outputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_prob=inputs["prob_tensor"], + sample_sfd=outputs["sfd_tensor"], + sample_amax=outputs["amax_tensor"], + sample_norm_const=outputs["norm_const_tensor"], + alpha=cfg["alpha"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + try: + assert op.check_support() + except (ValueError, NotImplementedError, RuntimeError) as e: + pytest.skip(f"Unsupported testcase: {e}") + op.compile() + op.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=outputs["c_tensor"], + d_tensor=outputs["d_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + sfd_tensor=outputs["sfd_tensor"], + amax_tensor=outputs["amax_tensor"], + norm_const_tensor=outputs["norm_const_tensor"], + alpha=cfg["alpha"], + ) + torch.cuda.synchronize() + + +def _run_wrapper_api(cfg, inputs): + try: + result = None + for _ in range(2): + result = cudnn.gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + c_major=cfg["c_major"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=( + None if cfg["d_dtype"] not in {torch.float8_e4m3fn, torch.float8_e5m2} else torch.tensor([1.0], dtype=torch.float32, device="cuda") + ), + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + except (ValueError, NotImplementedError, RuntimeError) as e: + pytest.skip(f"Unsupported testcase: {e}") + torch.cuda.synchronize() + return { + "c_tensor": result["c_tensor"], + "d_tensor": result["d_tensor"], + "amax_tensor": result["amax_tensor"], + "sfd_tensor": result["sfd_tensor"], + "norm_const_tensor": None, + } + + +def _make_dense_srelu_cfg(request, m: int, n: int = 256, k: int = 512, l: int = 2): + cfg = gemm_srelu_init( + request, + "k", + "k", + "n", + torch.float4_e2m1fn_x2, + torch.bfloat16, + torch.bfloat16, + torch.float32, + (256, 256), + (2, 1), + 16, + torch.float8_e8m0fnu, + False, + ) + cfg["m"] = m + cfg["n"] = n + cfg["k"] = k + cfg["l"] = l + return cfg + + +def _test_gemm_srelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m): + try: + from cudnn import gemm_srelu_wrapper_sm100 + from cudnn.gemm_srelu import api as gemm_srelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + if use_dynamic_m: + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_M", "1") + else: + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + compile_count = {"value": 0} + + def counted_compile(self): + compile_count["value"] += 1 + + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "check_support", lambda self: True) + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "compile", counted_compile) + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "execute", lambda self, **kwargs: None) + + try: + for m in (256, 384): + cfg = _make_dense_srelu_cfg(request, m) + inputs = allocate_gemm_srelu_tensors(cfg) + gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + c_major=cfg["c_major"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + finally: + cache_entries = len(gemm_srelu_api._cache_of_GemmSreluSm100Objects) + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + + return compile_count["value"], cache_entries + + +def _test_gemm_srelu_wrapper_full_dynamic_cache_behavior(request, monkeypatch): + try: + from cudnn import gemm_srelu_wrapper_sm100 + from cudnn.gemm_srelu import api as gemm_srelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_MNKL", "1") + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + compile_count = {"value": 0} + + def counted_compile(self): + compile_count["value"] += 1 + + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "check_support", lambda self: True) + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "compile", counted_compile) + monkeypatch.setattr(gemm_srelu_api.GemmSreluSm100, "execute", lambda self, **kwargs: None) + + try: + for mnkl in ((256, 256, 512, 2), (384, 384, 640, 3)): + cfg = _make_dense_srelu_cfg(request, *mnkl) + inputs = allocate_gemm_srelu_tensors(cfg) + gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + c_major=cfg["c_major"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + finally: + cache_entries = len(gemm_srelu_api._cache_of_GemmSreluSm100Objects) + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + + return compile_count["value"], cache_entries + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_srelu_params_fp4 +def test_gemm_srelu_compile_execute_fp4( + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + cfg = gemm_srelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + ) + inputs = allocate_gemm_srelu_tensors(cfg) + outputs = allocate_gemm_srelu_outputs(cfg) + _run_class_api(cfg, inputs, outputs) + check_ref_gemm_srelu(inputs, outputs, cfg, check_d=True) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=1) +@with_gemm_srelu_params_fp4 +def test_gemm_srelu_wrapper_fp4( + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + cfg = gemm_srelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + ) + inputs = allocate_gemm_srelu_tensors(cfg) + outputs = _run_wrapper_api(cfg, inputs) + check_ref_gemm_srelu(inputs, outputs, cfg, check_d=True) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=2) +def test_gemm_srelu_wrapper_cache_static_m_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_srelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m=False) + + assert compile_count == 2 + assert cache_entries == 2 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=3) +def test_gemm_srelu_wrapper_cache_dynamic_m_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_srelu_wrapper_dynamic_m_cache_behavior(request, monkeypatch, use_dynamic_m=True) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=4) +def test_gemm_srelu_wrapper_dynamic_m_fp4(request, monkeypatch): + try: + import cudnn + from cudnn.gemm_srelu import api as gemm_srelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_M", "1") + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + + try: + for m in (256, 384): + cfg = _make_dense_srelu_cfg(request, m) + inputs = allocate_gemm_srelu_tensors(cfg) + outputs = cudnn.gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + c_major=cfg["c_major"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + check_ref_gemm_srelu(inputs, outputs, cfg, check_d=True) + + assert len(gemm_srelu_api._cache_of_GemmSreluSm100Objects) == 1 + finally: + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=5) +def test_gemm_srelu_wrapper_cache_full_dynamic_smoke(request, monkeypatch): + compile_count, cache_entries = _test_gemm_srelu_wrapper_full_dynamic_cache_behavior(request, monkeypatch) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=6) +def test_gemm_srelu_wrapper_full_dynamic_fp4(request, monkeypatch): + try: + import cudnn + from cudnn.gemm_srelu import api as gemm_srelu_api + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + monkeypatch.setenv("CUDNN_FE_GEMM_DYNAMIC_MNKL", "1") + monkeypatch.delenv("CUDNN_FE_GEMM_DYNAMIC_M", raising=False) + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() + + try: + for mnkl in ((256, 256, 512, 2), (384, 384, 640, 3)): + cfg = _make_dense_srelu_cfg(request, *mnkl) + inputs = allocate_gemm_srelu_tensors(cfg) + outputs = cudnn.gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + prob_tensor=inputs["prob_tensor"], + alpha=cfg["alpha"], + c_major=cfg["c_major"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + norm_const_tensor=None, + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + ) + check_ref_gemm_srelu(inputs, outputs, cfg, check_d=True) + + assert len(gemm_srelu_api._cache_of_GemmSreluSm100Objects) == 1 + finally: + gemm_srelu_api._cache_of_GemmSreluSm100Objects.clear() diff --git a/test/python/fe_api/test_gemm_srelu_utils.py b/test/python/fe_api/test_gemm_srelu_utils.py new file mode 100644 index 00000000..1a47bc08 --- /dev/null +++ b/test/python/fe_api/test_gemm_srelu_utils.py @@ -0,0 +1,149 @@ +import pytest +import torch + +from test_fe_api_utils import ( + compute_reference_amax, + create_and_permute_tensor, + create_scale_factor_tensor, +) + +GEMM_SRELU_PARAM_MARKS_FP4 = [ + pytest.mark.parametrize("a_major", ["k"]), + pytest.mark.parametrize("b_major", ["k"]), + pytest.mark.parametrize("c_major", ["n"]), + pytest.mark.parametrize("ab_dtype", [torch.float4_e2m1fn_x2]), + pytest.mark.parametrize("c_dtype", [torch.bfloat16]), + pytest.mark.parametrize("d_dtype", [torch.bfloat16]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("mma_tiler_mn", [(256, 256), (128, 256)]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1), (1, 1)]), + pytest.mark.parametrize("sf_vec_size", [16]), + pytest.mark.parametrize("sf_dtype", [torch.float8_e8m0fnu]), + pytest.mark.parametrize("vector_f32", [True, False]), +] + + +def with_gemm_srelu_params_fp4(func): + for mark in reversed(GEMM_SRELU_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def gemm_srelu_init( + request, + a_major, + b_major, + c_major, + ab_dtype, + c_dtype, + d_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, +): + major, minor = torch.cuda.get_device_capability() + if major * 10 + minor < 100: + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") + + mnkl_str = request.config.getoption("--gemm-srelu-mnkl", default=None) + if mnkl_str is not None: + m, n, k, l = [int(x.strip()) for x in mnkl_str.split(",")] + else: + m, n, k, l = 256, 256, 512, 2 + + return { + "m": m, + "n": n, + "k": k, + "l": l, + "a_major": a_major, + "b_major": b_major, + "c_major": c_major, + "ab_dtype": ab_dtype, + "c_dtype": c_dtype, + "d_dtype": d_dtype, + "acc_dtype": acc_dtype, + "mma_tiler_mn": mma_tiler_mn, + "cluster_shape_mn": cluster_shape_mn, + "sf_vec_size": sf_vec_size, + "sf_dtype": sf_dtype, + "vector_f32": vector_f32, + "alpha": 1.0, + "skip_ref": request.config.getoption("--skip-ref", default=False), + } + + +def allocate_gemm_srelu_tensors(cfg): + a_ref, a_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["k"], cfg["a_major"] == "m", cfg["ab_dtype"]) + b_ref, b_tensor = create_and_permute_tensor(cfg["l"], cfg["n"], cfg["k"], cfg["b_major"] == "n", cfg["ab_dtype"]) + sfa_ref, sfa_tensor = create_scale_factor_tensor(cfg["l"], cfg["m"], cfg["k"], cfg["sf_vec_size"], cfg["sf_dtype"]) + sfb_ref, sfb_tensor = create_scale_factor_tensor(cfg["l"], cfg["n"], cfg["k"], cfg["sf_vec_size"], cfg["sf_dtype"]) + prob_ref, prob_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], 1, cfg["a_major"] == "m", torch.float32) + + return { + "a_ref": a_ref, + "a_tensor": a_tensor, + "b_ref": b_ref, + "b_tensor": b_tensor, + "sfa_ref": sfa_ref, + "sfa_tensor": sfa_tensor, + "sfb_ref": sfb_ref, + "sfb_tensor": sfb_tensor, + "prob_ref": prob_ref, + "prob_tensor": prob_tensor, + } + + +def allocate_gemm_srelu_outputs(cfg): + _, c_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["c_major"] == "m", cfg["c_dtype"]) + _, d_tensor = create_and_permute_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["c_major"] == "m", cfg["d_dtype"]) + + sfd_tensor = None + if cfg["d_dtype"] in {torch.float8_e4m3fn, torch.float8_e5m2}: + _, sfd_tensor = create_scale_factor_tensor(cfg["l"], cfg["m"], cfg["n"], cfg["sf_vec_size"], cfg["sf_dtype"]) + + amax_tensor = None + if cfg["ab_dtype"] in {torch.float4_e2m1fn_x2, torch.uint8} and cfg["d_dtype"] in {torch.bfloat16, torch.float16, torch.float32}: + amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device="cuda") + + norm_const_tensor = None + if sfd_tensor is not None: + norm_const_tensor = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + return { + "c_tensor": c_tensor, + "d_tensor": d_tensor, + "sfd_tensor": sfd_tensor, + "amax_tensor": amax_tensor, + "norm_const_tensor": norm_const_tensor, + } + + +def gemm_srelu_reference(inputs, cfg): + res_a = torch.einsum("mkl,mkl->mkl", inputs["a_ref"], inputs["sfa_ref"]) + res_b = torch.einsum("nkl,nkl->nkl", inputs["b_ref"], inputs["sfb_ref"]) + c_ref = cfg["alpha"] * torch.einsum("mkl,nkl->mnl", res_a, res_b) + d_ref = torch.relu(c_ref) ** 2 + d_ref = d_ref * inputs["prob_ref"].expand(-1, cfg["n"], -1) + return c_ref, d_ref + + +def check_ref_gemm_srelu(inputs, outputs, cfg, check_d=True): + if cfg["skip_ref"]: + return + + c_ref, d_ref = gemm_srelu_reference(inputs, cfg) + torch.testing.assert_close(outputs["c_tensor"].float(), c_ref.float(), atol=0.12, rtol=0.02) + + if check_d: + torch.testing.assert_close(outputs["d_tensor"].float(), d_ref.float(), atol=0.12, rtol=0.02) + if outputs["amax_tensor"] is not None: + amax_ref = torch.tensor( + [compute_reference_amax(d_ref)], + dtype=torch.float32, + device=outputs["amax_tensor"].device, + ) + torch.testing.assert_close(outputs["amax_tensor"], amax_ref, atol=0.12, rtol=0.02) diff --git a/test/python/fe_api/test_grouped_gemm_dglu.py b/test/python/fe_api/test_grouped_gemm_dglu.py index 4e2d0cf8..c222f726 100644 --- a/test/python/fe_api/test_grouped_gemm_dglu.py +++ b/test/python/fe_api/test_grouped_gemm_dglu.py @@ -11,7 +11,7 @@ from fe_api.test_fe_api_utils import DYNAMIC_SHAPES_M_VALUES from fe_api.test_grouped_gemm_swiglu_utils import ( grouped_gemm_swiglu_init, - allocate_grouped_gemm_input_tensors, + allocate_grouped_gemm_input_tensors as allocate_grouped_gemm_input_tensors_base, ) from fe_api.test_grouped_gemm_dswiglu_utils import ( with_grouped_gemm_dswiglu_params_fp4, @@ -35,6 +35,23 @@ ) +def allocate_grouped_gemm_input_tensors(*args, **kwargs): + """Restore the upstream dGLU test-input range for backward kernels.""" + + tensors = allocate_grouped_gemm_input_tensors_base(*args, **kwargs) + + alpha_tensor = tensors["alpha_tensor"] + tensors["alpha_tensor"] = torch.randint(1, 2, alpha_tensor.shape, dtype=torch.float32, device=alpha_tensor.device) + + beta_tensor = tensors["beta_tensor"] + tensors["beta_tensor"] = torch.randint(1, 2, beta_tensor.shape, dtype=torch.float32, device=beta_tensor.device) + + prob_tensor = tensors["prob_tensor"] + tensors["prob_tensor"] = torch.randint(1, 2, prob_tensor.shape, dtype=torch.float32, device=prob_tensor.device) + + return tensors + + def _apply_grouped_gemm_cfg_overrides(cfg, cfg_overrides=None): if cfg_overrides is None: return cfg diff --git a/test/python/fe_api/test_grouped_gemm_dsrelu.py b/test/python/fe_api/test_grouped_gemm_dsrelu.py new file mode 100644 index 00000000..10211d5e --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_dsrelu.py @@ -0,0 +1,1080 @@ +""" +Tests for Grouped GEMM dSReLU Backward Kernel (SM100+) + +This module tests the contiguous grouped block-scaled GEMM backward pass +with dSReLU activation gradient for MoE (Mixture of Experts) workloads. +""" + +import torch +import pytest +from test_utils import torch_fork_set_rng +from fe_api.test_grouped_gemm_dsrelu_utils import ( + with_grouped_gemm_dsrelu_params_fp4, + with_grouped_gemm_dsrelu_params_fp8, + allocate_grouped_gemm_dsrelu_tensors, + allocate_grouped_gemm_input_tensors, + check_ref_grouped_gemm_dsrelu, + grouped_gemm_dsrelu_init, +) +from fe_api.test_discrete_grouped_gemm_swiglu_utils import ( + allocate_discrete_input_tensors, + discrete_grouped_gemm_init, +) + +GROUPED_GEMM_DSRELU_DYNAMIC_SHAPES_M_VALUES = [64, 320, 576, 832, 1088, 1344, 1600, 1856, 2112, 2368] + +DISCRETE_GROUPED_GEMM_DSRELU_SUPPORTED_CONFIGS = [ + pytest.param(torch.float4_e2m1fn_x2, torch.bfloat16, torch.bfloat16, "k", id="fp4-k-major"), + pytest.param(torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn, "k", id="fp8-k-major"), + pytest.param(torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn, "n", id="fp8-n-major"), +] + + +def _dense_ref_inputs_from_discrete(inputs): + ref_inputs = dict(inputs) + ref_inputs["b_ref"] = torch.cat(inputs["b_ref_list"], dim=2) + ref_inputs["sfb_ref"] = torch.cat(inputs["sfb_ref_list"], dim=2) + return ref_inputs + + +def _prepare_discrete_dsrelu_inputs(inputs): + inputs["alpha_tensor"] = torch.ones_like(inputs["alpha_tensor"]) + inputs["prob_tensor"] = torch.ones_like(inputs["prob_tensor"]) + return inputs + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_dsrelu_params_fp4 +def test_grouped_gemm_dsrelu_compile_execute_fp4( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_dsrelu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + b_major=b_major, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_dsrelu_params_fp8 +def test_grouped_gemm_dsrelu_compile_execute_fp8( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_dsrelu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + b_major=b_major, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_dsrelu_params_fp4 +def test_grouped_gemm_dsrelu_wrapper_fp4( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_dsrelu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + b_major=b_major, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_dsrelu_params_fp8 +def test_grouped_gemm_dsrelu_wrapper_fp8( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_dsrelu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + b_major=b_major, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_partial_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_dynamic_shape_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=False, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_full_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_dynamic_shape_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=True, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_zero_m_after_compile_partial_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_zero_m_after_compile_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=False, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_zero_m_after_compile_full_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_zero_m_after_compile_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=True, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_zero_m_before_compile_partial_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_zero_m_before_compile_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=False, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_dsrelu_wrapper_cache_zero_m_before_compile_full_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_dsrelu_wrapper_zero_m_before_compile_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=True, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=7) +def test_grouped_gemm_dsrelu_wrapper_uint8_raw_fp4_smoke(request): + try: + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_dsrelu_init( + request=request, + ab_dtype=torch.uint8, + c_dtype=torch.bfloat16, + d_dtype=torch.bfloat16, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=16, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=True, + discrete_col_sfd=False, + b_major="k", + ) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + b_major=cfg["b_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + inputs, _ = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + outputs = grouped_gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + acc_dtype=cfg["acc_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + assert torch.isfinite(outputs["d_row_tensor"].float()).all() + assert torch.isfinite(outputs["dprob_tensor"].float()).all() + assert torch.count_nonzero(outputs["d_row_tensor"]).item() > 0 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=13) +@pytest.mark.parametrize("ab_dtype,c_dtype,d_dtype,b_major", DISCRETE_GROUPED_GEMM_DSRELU_SUPPORTED_CONFIGS) +def test_grouped_gemm_dsrelu_discrete_compile_execute(request, ab_dtype, c_dtype, d_dtype, b_major): + try: + from cudnn import GroupedGemmDsreluSm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = discrete_grouped_gemm_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=False, + b_major=b_major, + ) + + inputs = _prepare_discrete_dsrelu_inputs( + allocate_discrete_input_tensors( + n=cfg["n"], + k=cfg["k"], + num_experts=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + ) + ) + inputs, outputs = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + api = GroupedGemmDsreluSm100( + sample_a=inputs["a_tensor"], + sample_c=inputs["c_tensor"], + sample_d_row=outputs["d_row_tensor"], + sample_d_col=outputs["d_col_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_padded_offsets=inputs["padded_offsets_tensor"], + sample_alpha=inputs["alpha_tensor"], + sample_prob=inputs["prob_tensor"], + sample_dprob=outputs["dprob_tensor"], + num_experts=cfg["l"], + b_shape=(cfg["n"], cfg["k"]), + b_dtype=inputs["b_list"][0].dtype, + sample_amax=outputs.get("amax_tensor"), + sample_sfd_row=outputs.get("sfd_row_tensor"), + sample_sfd_col=outputs.get("sfd_col_tensor"), + sample_norm_const=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + b_major=cfg["b_major"], + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + api.compile() + api.execute( + a_tensor=inputs["a_tensor"], + b_ptrs=inputs["b_ptrs_tensor"], + sfb_ptrs=inputs["sfb_ptrs_tensor"], + c_tensor=inputs["c_tensor"], + d_row_tensor=outputs["d_row_tensor"], + d_col_tensor=outputs["d_col_tensor"], + sfa_tensor=inputs["sfa_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + dprob_tensor=outputs["dprob_tensor"], + sfd_row_tensor=outputs.get("sfd_row_tensor"), + sfd_col_tensor=outputs.get("sfd_col_tensor"), + norm_const_tensor=inputs.get("norm_const_tensor"), + amax_tensor=outputs.get("amax_tensor"), + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + check_ref_grouped_gemm_dsrelu( + _dense_ref_inputs_from_discrete(inputs), + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=13) +@pytest.mark.parametrize("ab_dtype,c_dtype,d_dtype,b_major", DISCRETE_GROUPED_GEMM_DSRELU_SUPPORTED_CONFIGS) +def test_grouped_gemm_dsrelu_discrete_wrapper(request, ab_dtype, c_dtype, d_dtype, b_major): + try: + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = discrete_grouped_gemm_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=False, + b_major=b_major, + ) + + inputs = _prepare_discrete_dsrelu_inputs( + allocate_discrete_input_tensors( + n=cfg["n"], + k=cfg["k"], + num_experts=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + ) + ) + inputs, _ = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + outputs = grouped_gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + b_ptrs=inputs["b_ptrs_tensor"], + sfb_ptrs=inputs["sfb_ptrs_tensor"], + n=cfg["n"], + b_dtype=inputs["b_list"][0].dtype, + b_major=cfg["b_major"], + norm_const_tensor=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + wrapper_outputs = { + "d_row_tensor": outputs["d_row_tensor"], + "d_col_tensor": outputs["d_col_tensor"], + "dprob_tensor": outputs["dprob_tensor"], + "dbias_tensor": outputs["dbias_tensor"], + "amax_tensor": outputs["amax_tensor"], + "sfd_row_tensor": outputs["sfd_row_tensor"], + "sfd_col_tensor": outputs["sfd_col_tensor"], + } + check_ref_grouped_gemm_dsrelu( + _dense_ref_inputs_from_discrete(inputs), + wrapper_outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +""" +GroupedGemmDsrelu API with explicit check_support, compile, and execute paths. +Use this method when running one static configuration for each GroupedGemmDsrelu object. +""" + + +def _test_grouped_gemm_dsrelu_compile_execute( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import GroupedGemmDsreluSm100 + from cuda.bindings import driver as cuda + except ImportError as e: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_dsrelu_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + b_major=b_major, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + b_major=cfg["b_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + inputs, outputs = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + api = GroupedGemmDsreluSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=inputs["c_tensor"], + sample_d_row=outputs["d_row_tensor"], + sample_d_col=outputs["d_col_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_padded_offsets=inputs["padded_offsets_tensor"], + sample_alpha=inputs["alpha_tensor"], + sample_prob=inputs["prob_tensor"], + sample_dprob=outputs["dprob_tensor"], + sample_amax=outputs.get("amax_tensor"), + sample_sfd_row=outputs.get("sfd_row_tensor"), + sample_sfd_col=outputs.get("sfd_col_tensor"), + sample_norm_const=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + api.compile() + api.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + d_row_tensor=outputs["d_row_tensor"], + d_col_tensor=outputs["d_col_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + dprob_tensor=outputs["dprob_tensor"], + sfd_row_tensor=outputs.get("sfd_row_tensor"), + sfd_col_tensor=outputs.get("sfd_col_tensor"), + norm_const_tensor=inputs.get("norm_const_tensor"), + amax_tensor=outputs.get("amax_tensor"), + current_stream=stream, + ) + + torch.cuda.synchronize() + check_ref_grouped_gemm_dsrelu( + inputs, + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +""" +GroupedGemmDsrelu API with grouped_gemm_dsrelu_wrapper: +Use the wrapper to directly call GroupedGemmDsrelu without explicit setup and compilation. +""" + + +def _test_grouped_gemm_dsrelu_wrapper( + ab_dtype, + c_dtype, + d_dtype, + b_major, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError as e: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_dsrelu_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + b_major=b_major, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + b_major=cfg["b_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + inputs, _ = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + try: + for _ in range(2): # Run twice to test caching path + wrapper_outputs = grouped_gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + torch.cuda.synchronize() + check_ref_grouped_gemm_dsrelu( + inputs, + wrapper_outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +def _test_grouped_gemm_dsrelu_wrapper_dynamic_shape_cache_behavior( + request, + monkeypatch, + use_full_dynamic, + ab_dtype, +): + try: + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 + from cudnn.grouped_gemm.grouped_gemm_dsrelu import api as grouped_gemm_dsrelu_api + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + if use_full_dynamic: + monkeypatch.setenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + else: + monkeypatch.delenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", raising=False) + + grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects.clear() + + compile_count = {"value": 0} + original_compile = grouped_gemm_dsrelu_api.GroupedGemmDsreluSm100.compile + + def counted_compile(self): + compile_count["value"] += 1 + return original_compile(self) + + monkeypatch.setattr(grouped_gemm_dsrelu_api.GroupedGemmDsreluSm100, "compile", counted_compile) + + d_dtype = torch.float8_e4m3fn if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16 + + cfg = grouped_gemm_dsrelu_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=torch.bfloat16, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2], + b_major="k", + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + try: + for group_m in GROUPED_GEMM_DSRELU_DYNAMIC_SHAPES_M_VALUES: + group_m_list = [group_m] * cfg["l"] + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=group_m_list, + ab_dtype=cfg["ab_dtype"], + b_major=cfg["b_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + inputs, _ = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + wrapper_outputs = grouped_gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + torch.cuda.synchronize() + # check_ref_grouped_gemm_dsrelu( + # inputs, + # wrapper_outputs, + # cfg, + # skip_ref=cfg["skip_ref"], + # ) + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + finally: + cache_entries = len(grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects) + grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects.clear() + + return compile_count["value"], cache_entries + + +def _test_grouped_gemm_dsrelu_wrapper_zero_m_after_compile_cache_behavior( + request, + monkeypatch, + use_full_dynamic, + ab_dtype, +): + return _test_grouped_gemm_dsrelu_wrapper_zero_m_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=use_full_dynamic, + ab_dtype=ab_dtype, + group_m_values=[512, 0], + ) + + +def _test_grouped_gemm_dsrelu_wrapper_zero_m_before_compile_cache_behavior( + request, + monkeypatch, + use_full_dynamic, + ab_dtype, +): + return _test_grouped_gemm_dsrelu_wrapper_zero_m_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=use_full_dynamic, + ab_dtype=ab_dtype, + group_m_values=[0, 512], + ) + + +def _test_grouped_gemm_dsrelu_wrapper_zero_m_cache_behavior( + request, + monkeypatch, + use_full_dynamic, + ab_dtype, + group_m_values, +): + try: + from cudnn import grouped_gemm_dsrelu_wrapper_sm100 + from cudnn.grouped_gemm.grouped_gemm_dsrelu import api as grouped_gemm_dsrelu_api + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + if use_full_dynamic: + monkeypatch.setenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + else: + monkeypatch.delenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", raising=False) + + grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects.clear() + + compile_count = {"value": 0} + original_compile = grouped_gemm_dsrelu_api.GroupedGemmDsreluSm100.compile + + def counted_compile(self): + compile_count["value"] += 1 + return original_compile(self) + + monkeypatch.setattr(grouped_gemm_dsrelu_api.GroupedGemmDsreluSm100, "compile", counted_compile) + + d_dtype = torch.float8_e4m3fn if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16 + + cfg = grouped_gemm_dsrelu_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=torch.bfloat16, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2], + b_major="k", + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + try: + for group_m in group_m_values: + group_m_list = [group_m] * cfg["l"] + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=group_m_list, + ab_dtype=cfg["ab_dtype"], + b_major=cfg["b_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + inputs, _ = allocate_grouped_gemm_dsrelu_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + input_tensors=inputs, + ) + + grouped_gemm_dsrelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=inputs["c_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + acc_dtype=cfg["acc_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + torch.cuda.synchronize() + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + finally: + cache_entries = len(grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects) + grouped_gemm_dsrelu_api._cache_of_GroupedGemmDsreluSm100Objects.clear() + + return compile_count["value"], cache_entries diff --git a/test/python/fe_api/test_grouped_gemm_dsrelu_utils.py b/test/python/fe_api/test_grouped_gemm_dsrelu_utils.py new file mode 100644 index 00000000..d5d48acf --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_dsrelu_utils.py @@ -0,0 +1,506 @@ +""" +Utilities and parameterization for Grouped GEMM dSReLU backward tests. +Contains test configuration fixtures, tensor creation, and reference implementations. +""" + +import torch +import pytest +from typing import Optional, Tuple, List, Dict, Any +from test_fe_api_utils import ( + ceil_div, + compute_reference_amax, + create_and_permute_tensor, + create_scale_factor_tensor, + create_sf_layout_tensor, + cvt_sf_MKL_to_M32x4xrm_K4xrk_L, +) +from test_grouped_gemm_swiglu_utils import ( + allocate_grouped_gemm_input_tensors as allocate_grouped_gemm_input_tensors_base, + get_dtype_rcp_limits as get_grouped_gemm_dtype_rcp_limits, + grouped_gemm_swiglu_init as grouped_gemm_dsrelu_init, +) + +# ============================================================================= +# Parameterization Marks +# ============================================================================= + +GROUPED_GEMM_DSRELU_COMMON_MARKS = [ + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("mma_tiler_mn", [(256, 256), (128, 256)]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1), (1, 1)]), + pytest.mark.parametrize("vector_f32", [True, False]), +] + +GROUPED_GEMM_DSRELU_FP8_TYPE_MARKS = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float8_e4m3fn, + ], + ), + pytest.mark.parametrize("c_dtype", [torch.bfloat16]), + pytest.mark.parametrize( + "d_dtype", + [ + torch.float8_e4m3fn, + ], + ), + pytest.mark.parametrize("b_major", ["k"]), +] + +GROUPED_GEMM_DSRELU_FP4_TYPE_MARKS = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.uint8, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + # torch.float16, + torch.bfloat16, + # torch.float32, + ], + ), + pytest.mark.parametrize( + "d_dtype", + [ + # torch.float16, + torch.bfloat16, + torch.float32, + ], + ), + pytest.mark.parametrize("b_major", ["k"]), +] + +GROUPED_GEMM_DSRELU_PARAM_MARKS_FP8 = GROUPED_GEMM_DSRELU_FP8_TYPE_MARKS + [ + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("mma_tiler_mn", [(256, 256)]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1)]), + pytest.mark.parametrize("vector_f32", [False]), + pytest.mark.parametrize("sf_vec_size,sf_dtype", [(32, torch.float8_e8m0fnu)]), + pytest.mark.parametrize("discrete_col_sfd", [True]), +] + +GROUPED_GEMM_DSRELU_PARAM_MARKS_FP4 = ( + GROUPED_GEMM_DSRELU_FP4_TYPE_MARKS + + GROUPED_GEMM_DSRELU_COMMON_MARKS + + [ + pytest.mark.parametrize( + "sf_vec_size,sf_dtype", + [ + (16, torch.float8_e8m0fnu), + (16, torch.float8_e4m3fn), + (32, torch.float8_e8m0fnu), + (32, torch.float8_e4m3fn), + ], + ), + pytest.mark.parametrize("discrete_col_sfd", [False]), + ] +) + +GROUPED_GEMM_DSRELU_PARAM_MARKS_DBIAS_FP4 = ( + GROUPED_GEMM_DSRELU_FP4_TYPE_MARKS + + GROUPED_GEMM_DSRELU_COMMON_MARKS + + [ + pytest.mark.parametrize( + "sf_vec_size,sf_dtype", + [ + (16, torch.float8_e8m0fnu), + (16, torch.float8_e4m3fn), + (32, torch.float8_e8m0fnu), + ], + ), + pytest.mark.parametrize("discrete_col_sfd", [False]), + ] +) + + +def with_grouped_gemm_dsrelu_params_fp4(func): + """Decorator to apply grouped GEMM dSReLU FP4 test parameters.""" + for mark in reversed(GROUPED_GEMM_DSRELU_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def with_grouped_gemm_dsrelu_params_fp8(func): + """Decorator to apply grouped GEMM dSReLU FP8 test parameters.""" + for mark in reversed(GROUPED_GEMM_DSRELU_PARAM_MARKS_FP8): + func = mark(func) + return func + + +def with_grouped_gemm_dsrelu_params_dbias_fp4(func): + """Decorator to apply grouped GEMM dSReLU dense dbias FP4 test parameters.""" + for mark in reversed(GROUPED_GEMM_DSRELU_PARAM_MARKS_DBIAS_FP4): + func = mark(func) + return func + + +# ============================================================================= +# Tensor Allocation +# ============================================================================= +def allocate_grouped_gemm_input_tensors(*args, **kwargs) -> Dict[str, Any]: + """Allocate grouped GEMM inputs and restore the dsReLU-specific alpha/prob range. + + The shared grouped GEMM utility initializes alpha/prob from ``(-2, 2)`` to + match the forward kernel inputs. The upstream grouped dsReLU kernel uses + ``torch.randint(1, 2, ...)`` instead, which effectively produces constant ones. + Keep that behavior for the backward tests here. + """ + + tensors = allocate_grouped_gemm_input_tensors_base(*args, **kwargs) + + alpha_tensor = tensors["alpha_tensor"] + tensors["alpha_tensor"] = torch.randint(1, 2, alpha_tensor.shape, dtype=torch.float32, device=alpha_tensor.device) + + prob_tensor = tensors["prob_tensor"] + tensors["prob_tensor"] = torch.randint(1, 2, prob_tensor.shape, dtype=torch.float32, device=prob_tensor.device) + + return tensors + + +def allocate_grouped_gemm_dsrelu_tensors( + tensor_m: int, + n: int, + l: int, + ab_dtype: torch.dtype, + c_dtype: torch.dtype, + d_dtype: torch.dtype, + cd_major: str, + sf_dtype: torch.dtype, + sf_vec_size: int = 16, + generate_dbias: bool = False, + input_tensors: Optional[Dict] = None, + output_tensors: Optional[Dict] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Allocate backward tensors for grouped GEMM dSReLU backward. + + :return: Newly allocated input and output tensors. Modifies input_tensors and output_tensors dictionaries in place if provided. + """ + + # D has same shape as C - contains interleaved ab and dsrelu in 32-column blocks + c_ref, c_tensor = create_and_permute_tensor( + 1, tensor_m, n, cd_major == "m", c_dtype + ) # Note: c_tensor is an input tensor rather than an output tensor but is being kept as an output tensor to eventually merge with the forward allocation + _, d_row_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", d_dtype) + _, d_col_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", d_dtype) + dprob_tensor = torch.zeros((tensor_m, 1, 1), dtype=torch.float32).cuda() + + _input_tensors = { + "c_ref": c_ref, + "c_tensor": c_tensor, + } + _output_tensors = { + "d_row_tensor": d_row_tensor, + "d_col_tensor": d_col_tensor, + "dprob_tensor": dprob_tensor, + "dbias_tensor": None, + "sfd_row_tensor": None, + "sfd_col_tensor": None, + "amax_tensor": None, + } + + if d_dtype in [torch.bfloat16, torch.float16]: + _output_tensors["amax_tensor"] = torch.full((l, 1), float("-inf"), dtype=torch.float32).cuda() + + if generate_dbias: + _output_tensors["dbias_tensor"] = torch.zeros((l, n, 1), dtype=torch.bfloat16).cuda() + + if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sf_dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ]: # generate_sfd + sfd_row_ref, sfd_row_tensor = create_scale_factor_tensor(1, tensor_m, n, sf_vec_size, sf_dtype) + _output_tensors["sfd_row_tensor"] = sfd_row_tensor + _output_tensors["sfd_row_ref"] = sfd_row_ref + + sfd_col_ref, sfd_col_tensor = create_scale_factor_tensor(1, n, tensor_m, sf_vec_size, sf_dtype) + _output_tensors["sfd_col_tensor"] = sfd_col_tensor + _output_tensors["sfd_col_ref"] = sfd_col_ref + + if input_tensors is not None: + input_tensors.update(_input_tensors) + _input_tensors = input_tensors + if output_tensors is not None: + output_tensors.update(_output_tensors) + _output_tensors = output_tensors + return _input_tensors, _output_tensors + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def compute_reference_row_quant(src, d_dtype, sf_dtype, vec_size, norm_const) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute reference quantized value on CPU. + + Args: + src: torch.Tensor, source tensor + dst_type: Type[cutlass.Numeric], destination type + vec_size: int, vector size + + Returns: + torch.Tensor: quantized reference tensor + torch.Tensor: scale factor tensor + """ + + try: + from cutlass.cute.runtime import from_dlpack + import cutlass.cute as cute + from cudnn.datatypes import _convert_to_cutlass_data_type + except ImportError: + pytest.skip("CUTLASS not available for scale factor conversion") + + m = src.shape[0] + n = src.shape[1] + l = src.shape[2] + + # 1. Compute reference SFD (m, sfn, l) in fp32 + sfn = ceil_div(n, vec_size) + n_aligned = ceil_div(n, 128) * 128 + sfn_aligned = n_aligned // vec_size + sfm = ceil_div(m, 128) * 128 + if sfn_aligned != sfn: + zeros = torch.zeros( + src.shape[0], + n_aligned - n, + src.shape[2], + dtype=src.dtype, + device=src.device, + ) + src_sf = torch.cat([src, zeros], dim=1) + src_reshaped = src_sf.permute(2, 0, 1).contiguous() + else: + src_reshaped = src.permute(2, 0, 1).contiguous() + src_reshaped = src_reshaped.view(l, sfm, sfn_aligned, vec_size) + # Take abs max over vec_size dimension + src_reshaped, _ = torch.abs(src_reshaped).max(dim=3) # (l, m, sfn) + # Multiply by norm_const and rcp_limits + src_sfd_f32 = src_reshaped * norm_const * get_grouped_gemm_dtype_rcp_limits(d_dtype) + # Permute to (m, sfn, l) + src_sfd_f32 = src_sfd_f32.permute(1, 2, 0) + # Convert fp32 -> f8 -> fp32 for src_sfd_f32 + src_sfd_f8_torch = torch.empty(*(l, sfm, sfn_aligned), dtype=torch.uint8, device=src.device).permute(1, 2, 0) + src_sfd_f8 = from_dlpack(src_sfd_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + src_sfd_f8.element_type = _convert_to_cutlass_data_type(sf_dtype) + src_sfd_f32_device = src_sfd_f32.to(src.device) + ref_sfd_f32_tensor = from_dlpack(src_sfd_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + + # 2. Convert sfd from fp32 to scale factor + cute.testing.convert(ref_sfd_f32_tensor, src_sfd_f8) + cute.testing.convert(src_sfd_f8, ref_sfd_f32_tensor) + src_sfd_f32 = src_sfd_f32_device.cpu() + # ref_sfd_f32 for fp32 reference check + ref_sfd_f32, _ = create_sf_layout_tensor(l, sfm, n, vec_size) + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(src_sfd_f32), + from_dlpack(ref_sfd_f32), + ) + + # 3. Quantized output with scale factor + # Compute reciprocal of src_sfd_f32 and multiply by norm_const + src_sfd_f32_rcp = norm_const * src_sfd_f32.to(src.device).reciprocal() + # Expand the sfn dimension by repeating each value sf_vec_size times + # src_sfd_f32_rcp: (m, sfn, l) -> (m, sfn, sf_vec_size, l) -> (m, n, l) + src_sfd_f32_rcp_expanded = src_sfd_f32_rcp.unsqueeze(2).expand(sfm, sfn_aligned, vec_size, l) + src_sfd_f32_rcp_expanded = src_sfd_f32_rcp_expanded.reshape(sfm, sfn_aligned * vec_size, l) + # Trim to exact n dimension if needed + src_sfd_f32_rcp_expanded = src_sfd_f32_rcp_expanded[:, :n, :] + # Apply scale to reference output: ref = ref * src_sfd_f32_rcp + src_d_f32_torch = torch.einsum("mnl,mnl->mnl", src, src_sfd_f32_rcp_expanded) + # Convert to d_dtype, then convert back to fp32 for reference check + src_d_f8_torch = torch.empty(*(l, m, n), dtype=torch.uint8, device=src.device).permute(1, 2, 0) + src_d_f8 = from_dlpack(src_d_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + src_d_f8.element_type = _convert_to_cutlass_data_type(d_dtype) + src_d_f32_torch = src_d_f32_torch.to(src.device) + src_d_f32 = from_dlpack(src_d_f32_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(src_d_f32, src_d_f8) + cute.testing.convert(src_d_f8, src_d_f32) + + return (ref_sfd_f32, src_d_f32_torch) + + +def run_grouped_gemm_dsrelu_ref( + a_ref: torch.Tensor, + b_ref: torch.Tensor, + c_ref: torch.Tensor, # C tensor (intermediate from forward pass) in float32 + sfa_ref: torch.Tensor, + sfb_ref: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + aligned_group_m_list: List[int], + valid_m: int, + generate_dbias: bool = False, + generate_amax: bool = False, + generate_sfd: bool = False, + norm_const_tensor: Optional[torch.Tensor] = None, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.float32, + sf_vec_size: int = 16, + sf_dtype: torch.dtype = torch.float8_e8m0fnu, +) -> Dict[str, torch.Tensor]: + """Run reference implementation for grouped GEMM dSReLU backward. + + Based on the reference in contiguous_blockscaled_grouped_gemm_dsrelu_quant_fusion.py + + The dSReLU backward pass computes: + 1. GEMM: ref = alpha^2 * (SFA * A) @ (SFB * B)^T per group + 2. dprob = sum over N of (relu(ref)^2 * C) + 3. D = 2 * relu(ref) * C * prob + + :param a_ref: A tensor (tensor_m, k, 1) in float32 + :param b_ref: B tensor (n, k, l) in float32 + :param c_ref: C tensor (tensor_m, n, 1) from forward pass in float32 + :param sfa_ref: Scale factor A tensor (tensor_m, k, 1) in float32 + :param sfb_ref: Scale factor B tensor (n, k, l) in float32 + :param alpha_tensor: Per-group alpha scaling (l,) + :param prob_tensor: Per-row probability scaling (tensor_m, 1, 1) + :param aligned_group_m_list: Aligned M values per group + :param valid_m: Total valid M dimension + :param generate_amax: Generate AMAX tensor + :param generate_sfd: Generate SFD tensor + :param norm_const_tensor: Normalization constant tensor (1,) + :param c_dtype: Intermediate C tensor dtype + :param d_dtype: Output D tensor dtype + :param sf_vec_size: Scale factor vector size + :param sf_dtype: Scale factor dtype + :return: Dictionary of reference tensors + """ + n, k, l = b_ref.shape + ref_tensors = {} + + # Step 1: Compute GEMM per group with scale factors + ref = torch.empty((1, valid_m, n), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + res_a = torch.einsum( + "mk,mk->mk", + a_ref[start:end, :, 0], + sfa_ref[start:end, :, 0], + ) + res_b = torch.einsum("nk,nk->nk", b_ref[:, :, i], sfb_ref[:, :, i]) + ref[0, start:end, :] = torch.einsum("mk,nk->mn", res_a * alpha_tensor[i].item(), res_b * alpha_tensor[i].item()) + start = end + ref = ref.permute((1, 2, 0)) # shape [M, N, 1] + + # Step 2: Apply dsquared-ReLU backward elementwise + c_full = c_ref.clone() + relu_ref = torch.relu(ref) + ref_dprob = relu_ref**2 * c_full + chunk_sums = [torch.sum(chunk, dim=1, keepdim=True) for chunk in torch.split(ref_dprob, 32, dim=1)] + ref_dprob = torch.sum(torch.cat(chunk_sums, dim=1), dim=1, keepdim=True) + ref_tensors["dprob_ref"] = ref_dprob + + # Step 3: Compute dSReLU formula + prob = prob_tensor.expand(-1, n, -1) + ref_d = 2 * relu_ref * c_full * prob + + ref_tensors["d_ref"] = ref_d.clone() + + if generate_dbias: + ref_dbias = torch.zeros((l, n, 1), dtype=torch.bfloat16, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref_dbias[i, :, 0] = ref_d[start:end, :, 0].sum(dim=0).to(torch.bfloat16) + start = end + ref_tensors["dbias_ref"] = ref_dbias + + # Step 6: Generate amax for FP4/BF16 output + if generate_amax: + ref_amax = torch.empty((l, 1), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref_amax[i, 0] = torch.tensor(compute_reference_amax(ref_d[start:end, :, 0].clone())) + start = end + ref_tensors["amax_ref"] = ref_amax + + # Step 7: Generate SFD for FP8 output + if generate_sfd: + norm_const = norm_const_tensor[0].item() + sfd_row_ref_f32, d_ref_f32 = compute_reference_row_quant(ref_d, d_dtype, sf_dtype, sf_vec_size, norm_const) + ref_tensors["sfd_row_ref"] = sfd_row_ref_f32.clone() + ref_tensors["d_ref"] = d_ref_f32.clone() + + ref_d_col = ref_d.permute(2, 1, 0).contiguous().permute(1, 2, 0) + sfd_col_ref_f32, d_col_ref_f32 = compute_reference_row_quant(ref_d_col, d_dtype, sf_dtype, sf_vec_size, norm_const) + ref_tensors["sfd_col_ref"] = sfd_col_ref_f32.clone() + ref_tensors["d_col_ref"] = d_col_ref_f32.clone() + + return ref_tensors + + +# ============================================================================= +# Reference Checking +# ============================================================================= + + +def check_ref_grouped_gemm_dsrelu( + inputs: Dict[str, Any], + outputs: Dict[str, Any], + cfg: Dict[str, Any], + atol: float = 1e-1, + rtol: float = 1e-2, + skip_ref: bool = False, +) -> None: + if skip_ref: + return + + torch.cuda.synchronize() + ref_tensors = run_grouped_gemm_dsrelu_ref( + a_ref=inputs["a_ref"], + b_ref=inputs["b_ref"], + c_ref=inputs["c_ref"], + sfa_ref=inputs["sfa_ref"], + sfb_ref=inputs["sfb_ref"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + aligned_group_m_list=inputs["aligned_group_m_list"], + valid_m=inputs["valid_m"], + generate_dbias=outputs.get("dbias_tensor") is not None, + generate_amax=outputs.get("amax_tensor") is not None, + generate_sfd=outputs.get("sfd_row_tensor") is not None and outputs.get("sfd_col_tensor") is not None, + norm_const_tensor=inputs.get("norm_const_tensor"), + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + sf_vec_size=cfg["sf_vec_size"], + sf_dtype=cfg["sf_dtype"], + ) + + torch.testing.assert_close(outputs["dprob_tensor"].float(), ref_tensors["dprob_ref"].float(), atol=atol, rtol=rtol) + torch.testing.assert_close(outputs["d_row_tensor"].float(), ref_tensors["d_ref"].float(), atol=atol, rtol=rtol) + + if "d_col_ref" in ref_tensors: + torch.testing.assert_close( + outputs["d_col_tensor"].float().permute(1, 0, 2), + ref_tensors["d_col_ref"].float(), + atol=atol, + rtol=rtol, + ) + + if outputs.get("amax_tensor") is not None and "amax_ref" in ref_tensors: + torch.testing.assert_close(outputs["amax_tensor"].float(), ref_tensors["amax_ref"].float(), atol=atol, rtol=rtol) + + if outputs.get("sfd_row_tensor") is not None and "sfd_row_ref" in ref_tensors: + torch.testing.assert_close( + outputs["sfd_row_tensor"].float(), + ref_tensors["sfd_row_ref"].to(outputs["sfd_row_tensor"].device).float(), + atol=atol, + rtol=rtol, + ) + + if outputs.get("sfd_col_tensor") is not None and "sfd_col_ref" in ref_tensors: + torch.testing.assert_close( + outputs["sfd_col_tensor"].float(), + ref_tensors["sfd_col_ref"].to(outputs["sfd_col_tensor"].device).float(), + atol=atol, + rtol=rtol, + ) diff --git a/test/python/fe_api/test_grouped_gemm_dswiglu.py b/test/python/fe_api/test_grouped_gemm_dswiglu.py index e4cf1eb8..49f0fe3f 100644 --- a/test/python/fe_api/test_grouped_gemm_dswiglu.py +++ b/test/python/fe_api/test_grouped_gemm_dswiglu.py @@ -10,7 +10,7 @@ from test_utils import torch_fork_set_rng from fe_api.test_grouped_gemm_swiglu_utils import ( grouped_gemm_swiglu_init, - allocate_grouped_gemm_input_tensors, + allocate_grouped_gemm_input_tensors as allocate_grouped_gemm_input_tensors_base, ) from fe_api.test_grouped_gemm_dswiglu_utils import ( with_grouped_gemm_dswiglu_params_fp4, @@ -22,6 +22,23 @@ GROUPED_GEMM_DSWIGLU_DYNAMIC_SHAPES_M_VALUES = [64, 320, 576, 832, 1088, 1344, 1600, 1856, 2112, 2368] +def allocate_grouped_gemm_input_tensors(*args, **kwargs): + """Restore the upstream dSwiGLU test-input range for backward kernels.""" + + tensors = allocate_grouped_gemm_input_tensors_base(*args, **kwargs) + + alpha_tensor = tensors["alpha_tensor"] + tensors["alpha_tensor"] = torch.randint(1, 2, alpha_tensor.shape, dtype=torch.float32, device=alpha_tensor.device) + + beta_tensor = tensors["beta_tensor"] + tensors["beta_tensor"] = torch.randint(1, 2, beta_tensor.shape, dtype=torch.float32, device=beta_tensor.device) + + prob_tensor = tensors["prob_tensor"] + tensors["prob_tensor"] = torch.randint(1, 2, prob_tensor.shape, dtype=torch.float32, device=prob_tensor.device) + + return tensors + + @pytest.mark.L0 @torch_fork_set_rng(seed=0) @with_grouped_gemm_dswiglu_params_fp4 diff --git a/test/python/fe_api/test_grouped_gemm_glu_hadamard.py b/test/python/fe_api/test_grouped_gemm_glu_hadamard.py new file mode 100644 index 00000000..00ae469d --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_glu_hadamard.py @@ -0,0 +1,349 @@ +"""Tests for grouped GEMM GLU + Hadamard forward fusion (SM100+).""" + +from typing import Dict + +import pytest +import torch + +from test_utils import torch_fork_set_rng +from fe_api.test_fe_api_utils import DYNAMIC_SHAPES_M_VALUES, compute_reference_amax +from fe_api.test_grouped_gemm_swiglu_utils import allocate_grouped_gemm_input_tensors, grouped_gemm_swiglu_init + +FP4_EXECUTION_CASES = [ + (torch.float4_e2m1fn_x2, torch.float8_e8m0fnu, 16), + (torch.float4_e2m1fn_x2, torch.float8_e4m3fn, 16), + (torch.float4_e2m1fn_x2, torch.float8_e8m0fnu, 32), + (torch.uint8, torch.float8_e8m0fnu, 16), +] + + +def _make_cfg(request, *, ab_dtype, sf_dtype, sf_vec_size, enable_bias=False) -> Dict: + return grouped_gemm_swiglu_init( + request, + ab_dtype=ab_dtype, + c_dtype=torch.bfloat16, + d_dtype=torch.bfloat16, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=False, + discrete_col_sfd=False, + enable_bias=enable_bias, + ) + + +def _apply_hadamard(d_ref: torch.Tensor) -> torch.Tensor: + from cudnn.grouped_gemm.grouped_gemm_glu_hadamard.hadamard_utils import HADAMARD_SIZE, hadamard_matrix + + valid_m, n_out, _ = d_ref.shape + hadamard = hadamard_matrix(HADAMARD_SIZE, dtype=torch.float32, device=d_ref.device) + ref_view = d_ref.squeeze(-1).to(torch.bfloat16).to(torch.float32).view(valid_m, n_out // HADAMARD_SIZE, HADAMARD_SIZE) + return (ref_view @ hadamard).view(valid_m, n_out, 1) + + +def _run_grouped_gemm_glu_ref(inputs: Dict, act_func: str) -> Dict: + n, _, l = inputs["b_ref"].shape + n_out = n // 2 + valid_m = inputs["valid_m"] + aligned_group_m_list = inputs["aligned_group_m_list"] + + ref = torch.empty((1, valid_m, n), dtype=torch.float32, device=inputs["a_ref"].device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + res_a = torch.einsum("mk,mk->mk", inputs["a_ref"][start:end, :, 0].to(torch.float32), inputs["sfa_ref"][start:end, :, 0].to(torch.float32)) + res_b = torch.einsum("nk,nk->nk", inputs["b_ref"][:, :, i].to(torch.float32), inputs["sfb_ref"][:, :, i].to(torch.float32)) + ref[0, start:end, :] = torch.einsum("mk,nk->mn", res_a, res_b) + start = end + ref = ref.permute((1, 2, 0)) + + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref[start:end, :, 0] = ref[start:end, :, 0] * inputs["alpha_tensor"][i].item() + start = end + + if inputs.get("bias_tensor") is not None: + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref[start:end, :, 0] = ref[start:end, :, 0] + inputs["bias_tensor"][:, i].unsqueeze(0).to(torch.float32) + start = end + + group = 32 + assert n % group == 0, "N must be divisible by 32 for GLU block grouping" + num_blocks = n // group + assert num_blocks % 2 == 0, "Number of 32-col blocks must be even" + + cols = torch.arange(n, device=ref.device, dtype=torch.long) + block_cols = cols.view(num_blocks, group) + gate_idx = block_cols[0::2].reshape(-1) + up_idx = block_cols[1::2].reshape(-1) + ref_gate = ref.index_select(1, gate_idx) + ref_up = ref.index_select(1, up_idx) + + if act_func == "swiglu": + ref_after_glu = ref_up * (ref_gate * torch.sigmoid(ref_gate)) + elif act_func == "geglu": + ref_gate = torch.clamp(ref_gate, max=7.0) + ref_up = torch.clamp(ref_up, min=-7.0, max=7.0) + ref_after_glu = (ref_up + 1.0) * ref_gate * torch.sigmoid(1.702 * ref_gate) + else: + raise ValueError(f"Unsupported act_func {act_func}") + + ref_after_glu = ref_after_glu * inputs["prob_tensor"].expand(-1, n_out, -1) + return {"c_ref": ref.clone(), "d_ref": ref_after_glu} + + +def _check_reference(inputs: Dict, outputs: Dict, cfg: Dict, *, act_func: str) -> None: + ref_tensors = _run_grouped_gemm_glu_ref(inputs, act_func) + + torch.testing.assert_close( + outputs["c_tensor"][: inputs["valid_m"]].cpu().float(), + ref_tensors["c_ref"].cpu().to(cfg["c_dtype"]).to(torch.float32), + atol=1e-1, + rtol=1e-2, + ) + + d_hadamard_ref = _apply_hadamard(ref_tensors["d_ref"]) + + torch.testing.assert_close( + outputs["d_tensor"][: inputs["valid_m"]].cpu().float(), + d_hadamard_ref.cpu().to(cfg["d_dtype"]).to(torch.float32), + atol=1e-1, + rtol=1e-2, + ) + + if outputs["amax_tensor"] is not None: + amax_ref = torch.empty((cfg["l"],), dtype=torch.float32) + start = 0 + for i, group_m in enumerate(inputs["aligned_group_m_list"]): + end = start + group_m + amax_ref[i] = compute_reference_amax(d_hadamard_ref[start:end, :, 0].clone()) + start = end + torch.testing.assert_close( + outputs["amax_tensor"].cpu().reshape(-1), + amax_ref, + atol=1e-1, + rtol=1e-2, + ) + + +def _allocate_outputs(inputs: Dict, cfg: Dict) -> Dict: + valid_m = inputs["valid_m"] + n = cfg["n"] + n_out = n // 2 + l = cfg["l"] + device = inputs["a_tensor"].device + + return { + "c_tensor": torch.empty_strided((valid_m, n, 1), (n, 1, valid_m * n), dtype=cfg["c_dtype"], device=device), + "d_tensor": torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=cfg["d_dtype"], device=device), + "amax_tensor": torch.full((l, 1), float("-inf"), dtype=torch.float32, device=device), + } + + +def _run_compile_execute(request, *, ab_dtype, sf_dtype, sf_vec_size, act_func="swiglu", enable_bias=False): + cfg = _make_cfg(request, ab_dtype=ab_dtype, sf_dtype=sf_dtype, sf_vec_size=sf_vec_size, enable_bias=enable_bias) + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + enable_bias=enable_bias, + ) + outputs = _allocate_outputs(inputs, cfg) + + from cudnn import GroupedGemmGluHadamardSm100 + + api = GroupedGemmGluHadamardSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=outputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_padded_offsets=inputs["padded_offsets_tensor"], + sample_alpha=inputs["alpha_tensor"], + sample_prob=inputs["prob_tensor"], + sample_amax=outputs["amax_tensor"], + sample_bias=inputs["bias_tensor"], + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + act_func=act_func, + ) + api.check_support() + api.compile() + api.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=outputs["c_tensor"], + d_tensor=outputs["d_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + amax_tensor=outputs["amax_tensor"], + bias_tensor=inputs["bias_tensor"], + ) + + _check_reference(inputs, outputs, cfg, act_func=act_func) + + +def _run_wrapper(request, *, ab_dtype, sf_dtype, sf_vec_size, act_func="swiglu", enable_bias=False): + cfg = _make_cfg(request, ab_dtype=ab_dtype, sf_dtype=sf_dtype, sf_vec_size=sf_vec_size, enable_bias=enable_bias) + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + enable_bias=enable_bias, + ) + + from cudnn import grouped_gemm_glu_hadamard_wrapper_sm100 + + outputs = grouped_gemm_glu_hadamard_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + bias_tensor=inputs["bias_tensor"], + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + act_func=act_func, + ) + + _check_reference(inputs, outputs, cfg, act_func=act_func) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize("ab_dtype,sf_dtype,sf_vec_size", FP4_EXECUTION_CASES) +def test_grouped_gemm_glu_hadamard_compile_execute_fp4(request, ab_dtype, sf_dtype, sf_vec_size): + _run_compile_execute( + request, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize("ab_dtype,sf_dtype,sf_vec_size", FP4_EXECUTION_CASES) +@pytest.mark.parametrize("act_func", ["swiglu", "geglu"]) +def test_grouped_gemm_glu_hadamard_wrapper_fp4(request, ab_dtype, sf_dtype, sf_vec_size, act_func): + _run_wrapper( + request, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + act_func=act_func, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +def test_grouped_gemm_glu_hadamard_wrapper_with_bias(request): + _run_wrapper( + request, + ab_dtype=torch.float4_e2m1fn_x2, + sf_dtype=torch.float8_e8m0fnu, + sf_vec_size=16, + act_func="swiglu", + enable_bias=True, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize("group_m_list", [[256, 256, 256, 256], DYNAMIC_SHAPES_M_VALUES]) +def test_grouped_gemm_glu_hadamard_wrapper_cache_dynamic_m_smoke(request, monkeypatch, group_m_list): + from cudnn import grouped_gemm_glu_hadamard_wrapper_sm100 + from cudnn.grouped_gemm.grouped_gemm_glu_hadamard import api as grouped_gemm_glu_hadamard_api + + grouped_gemm_glu_hadamard_api._cache_of_GroupedGemmGluHadamardSm100Objects.clear() + + compile_count = {"value": 0} + + def counted_compile(self): + compile_count["value"] += 1 + return None + + monkeypatch.setattr(grouped_gemm_glu_hadamard_api.GroupedGemmGluHadamardSm100, "compile", counted_compile) + monkeypatch.setattr(grouped_gemm_glu_hadamard_api.GroupedGemmGluHadamardSm100, "check_support", lambda self: True) + monkeypatch.setattr(grouped_gemm_glu_hadamard_api.GroupedGemmGluHadamardSm100, "execute", lambda self, **kwargs: None) + + cfg = _make_cfg( + request, + ab_dtype=torch.float4_e2m1fn_x2, + sf_dtype=torch.float8_e8m0fnu, + sf_vec_size=16, + ) + cfg["group_m_list"] = list(group_m_list) + cfg["l"] = len(group_m_list) + + for _ in range(2): + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + enable_bias=False, + ) + grouped_gemm_glu_hadamard_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + act_func="swiglu", + ) + + assert compile_count["value"] == 1 + assert len(grouped_gemm_glu_hadamard_api._cache_of_GroupedGemmGluHadamardSm100Objects) == 1 + grouped_gemm_glu_hadamard_api._cache_of_GroupedGemmGluHadamardSm100Objects.clear() diff --git a/test/python/fe_api/test_grouped_gemm_quant.py b/test/python/fe_api/test_grouped_gemm_quant.py index bd9d59f8..2c133b42 100644 --- a/test/python/fe_api/test_grouped_gemm_quant.py +++ b/test/python/fe_api/test_grouped_gemm_quant.py @@ -447,7 +447,6 @@ def test_grouped_gemm_quant_wrapper_requires_prob_tensor(request): norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=None, acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -505,7 +504,6 @@ def test_grouped_gemm_quant_wrapper_requires_norm_const_tensor_for_fp8(request): norm_const_tensor=None, prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -572,7 +570,6 @@ def test_grouped_gemm_quant_wrapper_with_bias_sm100(use_dynamic_sched, request): norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -665,7 +662,6 @@ def counted_compile(self): norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -750,7 +746,6 @@ def counted_compile(self): norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -839,7 +834,6 @@ def counted_compile(self): norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -942,7 +936,6 @@ def _test_grouped_gemm_quant_compile_execute( sample_amax=outputs.get("amax_tensor"), sample_norm_const=inputs.get("norm_const_tensor"), sample_prob=inputs["prob_tensor"], - sample_c=outputs["c_tensor"], acc_dtype=cfg["acc_dtype"], mma_tiler_mn=cfg["mma_tiler_mn"], cluster_shape_mn=cfg["cluster_shape_mn"], @@ -962,7 +955,6 @@ def _test_grouped_gemm_quant_compile_execute( api.execute( a_tensor=inputs["a_tensor"], b_tensor=inputs["b_tensor"], - c_tensor=outputs["c_tensor"], d_tensor=outputs["d_tensor"], sfa_tensor=inputs["sfa_tensor"], sfb_tensor=inputs["sfb_tensor"], @@ -1054,7 +1046,6 @@ def _test_grouped_gemm_quant_wrapper( norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], @@ -1157,7 +1148,6 @@ def _test_grouped_gemm_quant_discrete_compile_execute( sample_amax=outputs.get("amax_tensor"), sample_norm_const=inputs.get("norm_const_tensor"), sample_prob=inputs["prob_tensor"], - sample_c=outputs["c_tensor"], acc_dtype=cfg["acc_dtype"], mma_tiler_mn=cfg["mma_tiler_mn"], cluster_shape_mn=cfg["cluster_shape_mn"], @@ -1183,7 +1173,6 @@ def _test_grouped_gemm_quant_discrete_compile_execute( d_tensor=outputs["d_tensor"], b_ptrs=inputs["b_ptrs_tensor"], sfb_ptrs=inputs["sfb_ptrs_tensor"], - c_tensor=outputs["c_tensor"], d_col_tensor=outputs["d_col_tensor"], sfd_row_tensor=outputs.get("sfd_row_tensor"), sfd_col_tensor=outputs.get("sfd_col_tensor"), @@ -1269,7 +1258,6 @@ def _test_grouped_gemm_quant_discrete_wrapper( norm_const_tensor=inputs.get("norm_const_tensor"), prob_tensor=inputs["prob_tensor"], acc_dtype=cfg["acc_dtype"], - c_dtype=cfg["c_dtype"], d_dtype=cfg["d_dtype"], cd_major=cfg["cd_major"], mma_tiler_mn=cfg["mma_tiler_mn"], diff --git a/test/python/fe_api/test_grouped_gemm_quant_utils.py b/test/python/fe_api/test_grouped_gemm_quant_utils.py index 4392487e..c5a89141 100644 --- a/test/python/fe_api/test_grouped_gemm_quant_utils.py +++ b/test/python/fe_api/test_grouped_gemm_quant_utils.py @@ -233,13 +233,10 @@ def allocate_grouped_gemm_quant_output_tensors( :return: Dictionary containing all output tensors """ - # C tensor is internal placeholder (generate_c=False), but needed for kernel compilation - _, c_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", c_dtype) _, d_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", d_dtype) _, d_col_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", d_dtype) result = { - "c_tensor": c_tensor, "d_tensor": d_tensor, "d_col_tensor": d_col_tensor, "sfd_row_tensor": None, diff --git a/test/python/fe_api/test_grouped_gemm_srelu.py b/test/python/fe_api/test_grouped_gemm_srelu.py new file mode 100644 index 00000000..0804c230 --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_srelu.py @@ -0,0 +1,765 @@ +""" +Tests for Grouped GEMM SReLU Forward Kernel (SM100+) + +This module tests the contiguous grouped block-scaled GEMM with SReLU activation +for MoE (Mixture of Experts) workloads. + +Reference: continugous_blockscaled_grouped_gemm_srelu_quant_fusion.py +""" + +import torch +import pytest +from test_utils import torch_fork_set_rng +from fe_api.test_grouped_gemm_srelu_utils import ( + grouped_gemm_srelu_init, + with_grouped_gemm_srelu_params_fp4, + with_grouped_gemm_srelu_params_fp8, + allocate_grouped_gemm_input_tensors, + allocate_grouped_gemm_output_tensors, + check_ref_grouped_gemm_srelu, +) +from fe_api.test_discrete_grouped_gemm_swiglu_utils import ( + allocate_discrete_input_tensors, + discrete_grouped_gemm_init, +) + +GROUPED_GEMM_SWIGLU_DYNAMIC_SHAPES_M_VALUES = [64, 320, 576, 832, 1088, 1344, 1600, 1856, 2112, 2368] + +DISCRETE_GROUPED_GEMM_SRELU_SUPPORTED_CONFIGS = [ + pytest.param(torch.float4_e2m1fn_x2, torch.bfloat16, torch.bfloat16, "k", id="fp4-k-major"), + pytest.param(torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn, "k", id="fp8-k-major"), + pytest.param(torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn, "n", id="fp8-n-major"), +] + + +def _dense_ref_inputs_from_discrete(inputs): + ref_inputs = dict(inputs) + ref_inputs["b_ref"] = torch.cat(inputs["b_ref_list"], dim=2) + ref_inputs["sfb_ref"] = torch.cat(inputs["sfb_ref_list"], dim=2) + return ref_inputs + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_srelu_params_fp4 +def test_grouped_gemm_srelu_compile_execute_fp4( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_srelu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_srelu_params_fp8 +def test_grouped_gemm_srelu_compile_execute_fp8( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_srelu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_srelu_params_fp4 +def test_grouped_gemm_srelu_wrapper_fp4( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_srelu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_srelu_params_fp8 +def test_grouped_gemm_srelu_wrapper_fp8( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_srelu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_srelu_wrapper_cache_partial_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_srelu_wrapper_dynamic_shape_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=False, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=1) +@pytest.mark.parametrize( + "ab_dtype", + [ + pytest.param(torch.float4_e2m1fn_x2, id="fp4"), + pytest.param(torch.float8_e4m3fn, id="fp8"), + ], +) +def test_grouped_gemm_srelu_wrapper_cache_full_dynamic_smoke(request, monkeypatch, ab_dtype): + compile_count, cache_entries = _test_grouped_gemm_srelu_wrapper_dynamic_shape_cache_behavior( + request=request, + monkeypatch=monkeypatch, + use_full_dynamic=True, + ab_dtype=ab_dtype, + ) + + assert compile_count == 1 + assert cache_entries == 1 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=7) +def test_grouped_gemm_srelu_wrapper_uint8_raw_fp4_smoke(request): + try: + from cudnn import grouped_gemm_srelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_srelu_init( + request=request, + ab_dtype=torch.uint8, + c_dtype=torch.bfloat16, + d_dtype=torch.bfloat16, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=16, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=True, + discrete_col_sfd=False, + ) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + outputs = grouped_gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + assert torch.isfinite(outputs["c_tensor"].float()).all() + assert torch.isfinite(outputs["d_tensor"].float()).all() + assert torch.count_nonzero(outputs["d_tensor"]).item() > 0 + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=11) +@pytest.mark.parametrize("ab_dtype,c_dtype,d_dtype,b_major", DISCRETE_GROUPED_GEMM_SRELU_SUPPORTED_CONFIGS) +def test_grouped_gemm_srelu_discrete_compile_execute(request, ab_dtype, c_dtype, d_dtype, b_major): + try: + from cudnn import GroupedGemmSreluSm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = discrete_grouped_gemm_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=False, + b_major=b_major, + ) + + inputs = allocate_discrete_input_tensors( + n=cfg["n"], + k=cfg["k"], + num_experts=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + ) + outputs = allocate_grouped_gemm_output_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + ) + + api = GroupedGemmSreluSm100( + sample_a=inputs["a_tensor"], + sample_c=outputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_padded_offsets=inputs["padded_offsets_tensor"], + sample_alpha=inputs["alpha_tensor"], + sample_d_col=outputs["d_col_tensor"], + num_experts=cfg["l"], + b_shape=(cfg["n"], cfg["k"]), + b_dtype=inputs["b_list"][0].dtype, + sample_amax=outputs.get("amax_tensor"), + sample_sfd_row=outputs.get("sfd_row_tensor"), + sample_sfd_col=outputs.get("sfd_col_tensor"), + sample_norm_const=inputs.get("norm_const_tensor"), + sample_prob=inputs.get("prob_tensor"), + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + b_major=cfg["b_major"], + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + api.compile() + api.execute( + a_tensor=inputs["a_tensor"], + b_ptrs=inputs["b_ptrs_tensor"], + sfb_ptrs=inputs["sfb_ptrs_tensor"], + c_tensor=outputs["c_tensor"], + d_tensor=outputs["d_tensor"], + d_col_tensor=outputs["d_col_tensor"], + sfa_tensor=inputs["sfa_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + sfd_row_tensor=outputs.get("sfd_row_tensor"), + sfd_col_tensor=outputs.get("sfd_col_tensor"), + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + amax_tensor=outputs.get("amax_tensor"), + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + check_ref_grouped_gemm_srelu( + _dense_ref_inputs_from_discrete(inputs), + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=11) +@pytest.mark.parametrize("ab_dtype,c_dtype,d_dtype,b_major", DISCRETE_GROUPED_GEMM_SRELU_SUPPORTED_CONFIGS) +def test_grouped_gemm_srelu_discrete_wrapper(request, ab_dtype, c_dtype, d_dtype, b_major): + try: + from cudnn import grouped_gemm_srelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = discrete_grouped_gemm_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=False, + b_major=b_major, + ) + + inputs = allocate_discrete_input_tensors( + n=cfg["n"], + k=cfg["k"], + num_experts=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + b_major=cfg["b_major"], + ) + + outputs = grouped_gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + sfa_tensor=inputs["sfa_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + b_ptrs=inputs["b_ptrs_tensor"], + sfb_ptrs=inputs["sfb_ptrs_tensor"], + n=cfg["n"], + b_dtype=inputs["b_list"][0].dtype, + b_major=cfg["b_major"], + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs["prob_tensor"], + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + + torch.cuda.synchronize() + check_ref_grouped_gemm_srelu( + _dense_ref_inputs_from_discrete(inputs), + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +""" +GroupedGemmSrelu API with explicit check_support, compile, and execute paths. +Use this method when running one static configuration for each GroupedGemmSrelu object. +""" + + +def _test_grouped_gemm_srelu_compile_execute( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import GroupedGemmSreluSm100 + from cuda.bindings import driver as cuda + except ImportError as e: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_srelu_init( + request, + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + outputs = allocate_grouped_gemm_output_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + ) + + api = GroupedGemmSreluSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=outputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_padded_offsets=inputs["padded_offsets_tensor"], + sample_alpha=inputs["alpha_tensor"], + sample_amax=outputs.get("amax_tensor"), + sample_d_col=outputs["d_col_tensor"], + sample_sfd_row=outputs.get("sfd_row_tensor"), + sample_sfd_col=outputs.get("sfd_col_tensor"), + sample_norm_const=inputs.get("norm_const_tensor"), + sample_prob=inputs.get("prob_tensor"), + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + api.compile() + api.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=outputs["c_tensor"], + d_tensor=outputs["d_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + d_col_tensor=outputs["d_col_tensor"], + sfd_row_tensor=outputs.get("sfd_row_tensor"), + sfd_col_tensor=outputs.get("sfd_col_tensor"), + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + amax_tensor=outputs.get("amax_tensor"), + current_stream=stream, + ) + + check_ref_grouped_gemm_srelu( + inputs, + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +""" +GroupedGemmSrelu API with grouped_gemm_srelu_wrapper: +Use the wrapper to directly call GroupedGemmSrelu without explicit setup and compilation. +""" + + +def _test_grouped_gemm_srelu_wrapper( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import grouped_gemm_srelu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError as e: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_srelu_init( + request, + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + try: + for _ in range(2): # Run twice to test caching path + outputs = grouped_gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + check_ref_grouped_gemm_srelu( + inputs, + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +def _test_grouped_gemm_srelu_wrapper_dynamic_shape_cache_behavior( + request, + monkeypatch, + use_full_dynamic, + ab_dtype, +): + try: + from cudnn import grouped_gemm_srelu_wrapper_sm100 + from cudnn.grouped_gemm.grouped_gemm_srelu import api as grouped_gemm_srelu_api + from cuda.bindings import driver as cuda + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + if use_full_dynamic: + monkeypatch.setenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") + else: + monkeypatch.delenv("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", raising=False) + + grouped_gemm_srelu_api._cache_of_GroupedGemmSreluSm100Objects.clear() + + compile_count = {"value": 0} + original_compile = grouped_gemm_srelu_api.GroupedGemmSreluSm100.compile + + def counted_compile(self): + compile_count["value"] += 1 + return original_compile(self) + + monkeypatch.setattr(grouped_gemm_srelu_api.GroupedGemmSreluSm100, "compile", counted_compile) + + d_dtype = torch.float8_e4m3fn if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16 + + cfg = grouped_gemm_srelu_init( + request=request, + ab_dtype=ab_dtype, + c_dtype=torch.bfloat16, + d_dtype=d_dtype, + cd_major="n", + acc_dtype=torch.float32, + mma_tiler_mn=(256, 256), + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, + vector_f32=False, + discrete_col_sfd=ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2], + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + try: + for group_m in GROUPED_GEMM_SWIGLU_DYNAMIC_SHAPES_M_VALUES: + group_m_list = [group_m] * cfg["l"] + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=group_m_list, + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + ) + + wrapper_outputs = grouped_gemm_srelu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + padded_offsets=inputs["padded_offsets_tensor"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + torch.cuda.synchronize() + + # check_ref_grouped_gemm_srelu( + # inputs, + # wrapper_outputs, + # cfg, + # skip_ref=cfg["skip_ref"], + # ) + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + finally: + cache_entries = len(grouped_gemm_srelu_api._cache_of_GroupedGemmSreluSm100Objects) + grouped_gemm_srelu_api._cache_of_GroupedGemmSreluSm100Objects.clear() + + return compile_count["value"], cache_entries diff --git a/test/python/fe_api/test_grouped_gemm_srelu_utils.py b/test/python/fe_api/test_grouped_gemm_srelu_utils.py new file mode 100644 index 00000000..63a24d29 --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_srelu_utils.py @@ -0,0 +1,737 @@ +""" +Utilities and parameterization for Grouped GEMM SReLU tests. +Contains test configuration fixtures, tensor creation, and reference implementations. + +Reference: continugous_blockscaled_grouped_gemm_srelu_quant_fusion.py (lines 3518-4825) +""" + +import torch +import pytest +from typing import Optional, Tuple, List, Dict, Any +from test_fe_api_utils import ( + ceil_div, + compute_reference_amax, + create_and_permute_tensor, + create_scale_factor_tensor, + create_sf_layout_tensor, + cvt_sf_MKL_to_M32x4xrm_K4xrk_L, +) + +# ============================================================================= +# Parameterization Marks +# ============================================================================= + +GROUPED_GEMM_SWIGLU_COMMON_MARKS = [ + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1), (1, 1)]), + pytest.mark.parametrize("vector_f32", [True, False]), +] + +GROUPED_GEMM_SWIGLU_FP8_TYPE_MARKS = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float8_e4m3fn, + ], + ), + pytest.mark.parametrize("c_dtype", [torch.bfloat16]), + pytest.mark.parametrize( + "d_dtype", + [ + torch.float8_e4m3fn, + ], + ), +] + +GROUPED_GEMM_SWIGLU_FP4_TYPE_MARKS = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.uint8, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + # torch.float16, + torch.bfloat16, + ], + ), + pytest.mark.parametrize( + "d_dtype", + [ + # torch.float16, + torch.bfloat16, + torch.float32, + ], + ), +] + +GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP8 = GROUPED_GEMM_SWIGLU_FP8_TYPE_MARKS + [ + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("mma_tiler_mn", [(256, 256)]), + pytest.mark.parametrize("cluster_shape_mn", [(2, 1)]), + pytest.mark.parametrize("vector_f32", [False]), + pytest.mark.parametrize("sf_vec_size,sf_dtype", [(32, torch.float8_e8m0fnu)]), + pytest.mark.parametrize("discrete_col_sfd", [True]), +] + +GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP4 = ( + GROUPED_GEMM_SWIGLU_FP4_TYPE_MARKS + + GROUPED_GEMM_SWIGLU_COMMON_MARKS + + [ + pytest.mark.parametrize("mma_tiler_mn", [(256, 256), (128, 256)]), + pytest.mark.parametrize( + "sf_vec_size,sf_dtype", + [ + (16, torch.float8_e8m0fnu), + (16, torch.float8_e4m3fn), + (32, torch.float8_e8m0fnu), + (32, torch.float8_e4m3fn), + ], + ), + pytest.mark.parametrize("discrete_col_sfd", [False]), + ] +) + +GROUPED_GEMM_SWIGLU_PARAM_MARKS_BIAS_FP4 = ( + GROUPED_GEMM_SWIGLU_FP4_TYPE_MARKS + + GROUPED_GEMM_SWIGLU_COMMON_MARKS + + [ + pytest.mark.parametrize("mma_tiler_mn", [(128, 256), (256, 256)]), + pytest.mark.parametrize( + "sf_vec_size,sf_dtype", + [ + (16, torch.float8_e8m0fnu), + (16, torch.float8_e4m3fn), + (32, torch.float8_e8m0fnu), + ], + ), + pytest.mark.parametrize("discrete_col_sfd", [False]), + ] +) + + +def with_grouped_gemm_srelu_params_fp4(func): + """Decorator to apply grouped GEMM SReLU FP4 test parameters.""" + for mark in reversed(GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def with_grouped_gemm_srelu_params_fp8(func): + """Decorator to apply grouped GEMM SReLU FP8 test parameters.""" + for mark in reversed(GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP8): + func = mark(func) + return func + + +def with_grouped_gemm_srelu_params_bias_fp4(func): + """Decorator to apply grouped GEMM SReLU dense bias FP4 test parameters.""" + for mark in reversed(GROUPED_GEMM_SWIGLU_PARAM_MARKS_BIAS_FP4): + func = mark(func) + return func + + +# ============================================================================= +# Configuration Initialization +# ============================================================================= + + +def grouped_gemm_srelu_init( + request, + ab_dtype: torch.dtype, + c_dtype: torch.dtype, + d_dtype: torch.dtype, + cd_major: str, + acc_dtype: torch.dtype, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sf_vec_size: int, + sf_dtype: torch.dtype, + vector_f32: bool = False, + discrete_col_sfd: bool = False, + b_major: str = "k", + enable_bias: bool = False, +) -> Dict[str, Any]: + """Initialize configuration for Grouped GEMM SReLU tests. + + :param request: pytest request object + :param ab_dtype: Data type for A and B tensors + :param c_dtype: Data type for intermediate C tensor (always bfloat16) + :param d_dtype: Data type for output D tensor (fp8 when ab is fp8, bf16 when ab is fp4) + :param cd_major: Major dimension for output C and D tensors + :param acc_dtype: Accumulator data type + :param mma_tiler_mn: MMA tiler shape + :param cluster_shape_mn: Cluster shape + :param sf_vec_size: Scale factor vector size + :param sf_dtype: Scale factor data type + :param vector_f32: Use vectorized f32 operations + :param discrete_col_sfd: Generate discrete col-major scale factor tensor + :param b_major: Major dimension for B tensor. + :param enable_bias: Allocate dense bias tensor for fused bias tests + :return: Configuration dictionary + """ + major, minor = torch.cuda.get_device_capability() + compute_capability = major * 10 + minor + if compute_capability < 100: + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") + + # Parse CLI options + nkl_str = request.config.getoption("--grouped-gemm-nkl", default=None) + group_m_str = request.config.getoption("--grouped-gemm-group-m", default=None) + skip_ref = request.config.getoption("--skip-ref", default=False) + + # Default values + if nkl_str is not None: + n, k, l = [int(x.strip()) for x in nkl_str.split(",")] + else: + n, k, l = 512, 512, 4 + + if group_m_str is not None: + group_m_list = [int(x.strip()) for x in group_m_str.split(",")] + else: + # Default: equal M values per group + group_m_list = [256] * l + + config = { + "n": n, + "k": k, + "l": l, + "group_m_list": group_m_list, + "m_aligned": 256, + "mma_tiler_mn": mma_tiler_mn, + "cluster_shape_mn": cluster_shape_mn, + "ab_dtype": ab_dtype, + "c_dtype": c_dtype, + "d_dtype": d_dtype, + "b_major": b_major, + "cd_major": cd_major, + "acc_dtype": acc_dtype, + "sf_vec_size": sf_vec_size, + "sf_dtype": sf_dtype, + "vector_f32": vector_f32, + "skip_ref": skip_ref, + "discrete_col_sfd": discrete_col_sfd, + "enable_bias": enable_bias, + } + + return config + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_dtype_rcp_limits(dtype: torch.dtype) -> float: + """Get reciprocal of max value for quantization.""" + if dtype == torch.float8_e5m2: + return 1 / 128.0 + elif dtype == torch.float8_e4m3fn: + return 1 / 448.0 + elif dtype in {torch.float4_e2m1fn_x2, torch.uint8}: + return 1 / 6.0 + return 1.0 + + +def create_mask( + group_m_list: List[int], + m_aligned: int = 256, + permuted_m: Optional[int] = None, +) -> Tuple[int, List[int], torch.Tensor]: + """Create padded_offsets tensor from group_m_list. + + :param group_m_list: List of M values for each group (will be aligned to m_aligned) + :param m_aligned: Alignment requirement for group M dimension. MUST equal + the grouped GEMM kernel FIX_PAD_SIZE (256) + :param permuted_m: Optional padded M dimension for CUDA graph support. If provided, + padded_offsets will be padded to include this size. + The kernel determines valid tiles from padded_offsets[-1]. + + :return: Tuple of (valid_m, aligned_group_m_list, padded_offsets_tensor) + """ + valid_m = 0 + aligned_group_m_list = [] + padded_offsets = [] + + for group_m in group_m_list: + aligned_group_m = ((group_m + m_aligned - 1) // m_aligned) * m_aligned + valid_m += aligned_group_m + aligned_group_m_list.append(aligned_group_m) + + # padded_offsets[i] = cumulative sum up to and including expert i + padded_offsets.append(valid_m) + + # Apply padding if requested (for cuda_graph support) + if permuted_m is not None: + if permuted_m < valid_m: + raise ValueError(f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). " f"Cannot pad to a smaller size.") + # Note: permuted_m padding is handled by the caller creating A/D tensors with larger M + # padded_offsets[-1] still equals valid_m (not permuted_m) + + # Convert to tensor + padded_offsets_tensor = torch.tensor(padded_offsets, dtype=torch.int32).cuda() + + return ( + valid_m, + aligned_group_m_list, + padded_offsets_tensor, + ) + + +# ============================================================================= +# Tensor Allocation +# ============================================================================= + + +def allocate_grouped_gemm_input_tensors( + n: int, + k: int, + l: int, + group_m_list: List[int], + ab_dtype: torch.dtype, + sf_dtype: torch.dtype, + sf_vec_size: int, + m_aligned: int, + permuted_m: Optional[int] = None, + norm_const: float = 0.01, + b_major: str = "k", + enable_bias: bool = False, + device: str = "cuda", +) -> Dict[str, Any]: + """Allocate input tensors for grouped GEMM SReLU. + + :param permuted_m: Optional padded M dimension for cuda_graph support. If provided, + A matrix, D matrix, and scale factor A will be padded to this size. + The kernel calculates valid tiles from padded_offsets[-1]. + + :return: Dictionary containing all input tensors and metadata + """ + + valid_m, aligned_group_m_list, padded_offsets_tensor = create_mask(group_m_list, m_aligned, permuted_m) + + tensor_m = permuted_m if permuted_m is not None else valid_m + + # Standalone grouped kernels use raw-byte tensors for FP4 payloads with the + # full logical K still present in the visible tensor shape. + if ab_dtype == torch.uint8: + try: + import cutlass + import cutlass.torch as cutlass_torch + except ImportError: + pytest.skip("CUTLASS is not installed; skipping grouped uint8 raw-FP4 tests.") + + a_ref = cutlass_torch.matrix(1, tensor_m, k, False, cutlass.Float32).cuda() + b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32).cuda() + _, a_tensor = cutlass_torch.cute_tensor_like( + a_ref, + cutlass.Float4E2M1FN, + is_dynamic_layout=True, + assumed_align=16, + ) + _, b_tensor = cutlass_torch.cute_tensor_like( + b_ref, + cutlass.Float4E2M1FN, + is_dynamic_layout=True, + assumed_align=16, + ) + a_tensor = a_tensor.view(torch.uint8) + b_tensor = b_tensor.view(torch.uint8) + else: + # Note: b tensor can be n-major for mxfp8 dSrelu; otherwise, a and b tensors are always k-major + a_ref, a_tensor = create_and_permute_tensor(1, tensor_m, k, False, ab_dtype) + b_ref, b_tensor = create_and_permute_tensor(l, n, k, b_major == "n", ab_dtype) + + sfa_ref, sfa_tensor = create_scale_factor_tensor(1, tensor_m, k, sf_vec_size, sf_dtype) + sfb_ref, sfb_tensor = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) + + alpha_tensor = torch.randint(-2, 2, (l,), dtype=torch.float32, device=device).float() + beta_tensor = torch.randint(-2, 2, (l,), dtype=torch.float32, device=device).float() # dSrelu only + + prob_tensor = torch.randint(-2, 2, (tensor_m, 1, 1), dtype=torch.float32, device=device).float() + + result = { + "a_tensor": a_tensor, + "a_ref": a_ref, + "b_tensor": b_tensor, + "b_ref": b_ref, + "sfa_tensor": sfa_tensor, + "sfa_ref": sfa_ref, + "sfb_tensor": sfb_tensor, + "sfb_ref": sfb_ref, + "alpha_tensor": alpha_tensor, + "beta_tensor": beta_tensor, + "prob_tensor": prob_tensor, + "bias_tensor": None, + "padded_offsets_tensor": padded_offsets_tensor, + "aligned_group_m_list": aligned_group_m_list, + "valid_m": valid_m, + "tensor_m": tensor_m, + "norm_const_tensor": None, + } + + # Norm constant tensor + if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sf_dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ]: + result["norm_const_tensor"] = torch.tensor([norm_const], dtype=torch.float32, device=device) + + if enable_bias: + result["bias_tensor"] = torch.empty((l, n), dtype=torch.bfloat16, device=device).uniform_(-2.0, 2.0).transpose(0, 1) + + return result + + +def allocate_grouped_gemm_output_tensors( + tensor_m: int, + n: int, + l: int, + ab_dtype: torch.dtype, + c_dtype: torch.dtype, + d_dtype: torch.dtype, + cd_major: str, + sf_dtype: torch.dtype, + sf_vec_size: int = 16, + device: str = "cuda", +) -> Dict[str, Any]: + """Allocate output tensors for grouped GEMM SReLU. + + :return: Dictionary containing all output tensors + """ + n_out = n # After SReLU + + _, c_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", c_dtype) + _, d_tensor = create_and_permute_tensor(1, tensor_m, n_out, cd_major == "m", d_dtype) + _, d_col_tensor = create_and_permute_tensor(1, tensor_m, n_out, cd_major == "m", d_dtype) + + result = { + "c_tensor": c_tensor, + "d_tensor": d_tensor, + "d_col_tensor": d_col_tensor, + "sfd_row_tensor": None, + "sfd_col_tensor": None, + } + + if d_dtype in [torch.bfloat16, torch.float16]: + result["amax_tensor"] = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=device) + + if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sf_dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ]: # generate_sfd + sfd_row_ref, sfd_row_tensor = create_scale_factor_tensor(1, tensor_m, n_out, sf_vec_size, sf_dtype) + result["sfd_row_tensor"] = sfd_row_tensor + result["sfd_row_ref"] = sfd_row_ref + + sfd_col_ref, sfd_col_tensor = create_scale_factor_tensor(1, n_out, tensor_m, sf_vec_size, sf_dtype) + result["sfd_col_tensor"] = sfd_col_tensor + result["sfd_col_ref"] = sfd_col_ref + + return result + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def run_grouped_gemm_srelu_ref( + a_ref: torch.Tensor, + b_ref: torch.Tensor, + sfa_ref: torch.Tensor, + sfb_ref: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + aligned_group_m_list: List[int], + valid_m: int, + bias_tensor: Optional[torch.Tensor] = None, + generate_amax: bool = False, + generate_sfd: bool = False, + norm_const_tensor: Optional[torch.Tensor] = None, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.float32, + sf_vec_size: int = 16, + sf_dtype: torch.dtype = torch.float8_e8m0fnu, +) -> torch.Tensor: + """Run reference implementation for grouped GEMM SReLU. + + Matches the reference checking in continugous_blockscaled_grouped_gemm_srelu_quant_fusion.py + (lines 4113-4179) + + :param a_ref: A tensor (tensor_m, k, 1) in float32 + :param b_ref: B tensor (n, k, l) in float32 + :param sfa_ref: Scale factor A tensor (tensor_m, k, 1) in float32 + :param sfb_ref: Scale factor B tensor (n, k, l) in float32 + :param alpha_tensor: Per-group alpha scaling (l,) + :param prob_tensor: Per-row probability scaling (tensor_m, 1, 1) + :param aligned_group_m_list: Aligned M values per group + :param valid_m: Total valid M dimension + :param generate_amax: Generate AMAX tensor + :param generate_sfd: Generate SFD tensor + :param norm_const_tensor: Normalization constant tensor (1,) + :param c_dtype: Intermediate C tensor dtype (always bfloat16) + :param d_dtype: Output D tensor dtype + :param sf_vec_size: Scale factor vector size + :param sf_dtype: Scale factor dtype + :return: Reference output tensor (valid_m, n_out, 1) + """ + n, k, l = b_ref.shape + n_out = n + ref_tensors = {} + + # Step 1: Compute GEMM per group with scale factors + ref = torch.empty((1, valid_m, n), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + res_a = torch.einsum("mk,mk->mk", a_ref[start:end, :, 0], sfa_ref[start:end, :, 0]) + res_b = torch.einsum("nk,nk->nk", b_ref[:, :, i], sfb_ref[:, :, i]) + ref[0, start:end, :] = torch.einsum("mk,nk->mn", res_a, res_b) + start = end + ref = ref.permute((1, 2, 0)) + + # Step 2: Apply alpha per group + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref[start:end, :, 0] = ref[start:end, :, 0] * alpha_tensor[i].item() + start = end + + if bias_tensor is not None: + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref[start:end, :, 0] = ref[start:end, :, 0] + bias_tensor[:, i].unsqueeze(0).to(torch.float32) + start = end + + ref_tensors["c_ref"] = ref.clone() + + # Step 3: Apply squared-ReLU and probability gating elementwise + ref_after_srelu = torch.relu(ref) ** 2 + ref_after_srelu = ref_after_srelu * prob_tensor.expand(-1, n_out, -1) + ref_tensors["d_ref"] = ref_after_srelu.clone() + + if generate_amax: + amax_ref = torch.empty((l, 1), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + amax_ref[i, 0] = compute_reference_amax(ref_after_srelu[start:end, :, 0].clone()) + start = end + ref_tensors["amax_ref"] = amax_ref + + if generate_sfd: + try: + from cutlass.cute.runtime import from_dlpack + import cutlass.cute as cute + from cudnn.datatypes import _convert_to_cutlass_data_type + except ImportError: + pytest.skip("CUTLASS not available for scale factor conversion") + + norm_const = norm_const_tensor[0].item() + + n_out_aligned = ceil_div(n_out, 128) * 128 + if n_out_aligned != n_out: + zeros = torch.zeros( + ref_after_srelu.shape[0], + n_out_aligned - n_out, + ref_after_srelu.shape[2], + dtype=ref_after_srelu.dtype, + device=ref_after_srelu.device, + ) + ref_after_srelu_sf = torch.cat([ref_after_srelu, zeros], dim=1) + else: + ref_after_srelu_sf = ref_after_srelu + + # 1. Compute reference SFDRow (m, sfn, l) in fp32 + sfn = ceil_div(n_out_aligned, sf_vec_size) + # Resahpe ref to (l, m, sfn, sf_vec_size) + ref_for_sf = ref_after_srelu_sf.permute(2, 0, 1).contiguous() # (l, m, n) + # l is involved in valid_m + ref_for_sf = ref_for_sf.view(1, valid_m, sfn, sf_vec_size) + # Take abs max over sf_vec_size dimension + ref_for_sf, _ = torch.abs(ref_for_sf).max(dim=3) # (l, m, sfn) + # Multiply by norm_const and rcp_limits + ref_sfd_row_f32 = ref_for_sf * norm_const * get_dtype_rcp_limits(d_dtype) + # Permute to (m, sfn, l) + ref_sfd_row_f32 = ref_sfd_row_f32.permute(1, 2, 0) + + # Convert fp32 -> f8 -> fp32 for ref_sfd_row_f32 + ref_sfd_row_f8_torch = torch.empty(*(1, valid_m, sfn), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_sfd_row_f8 = from_dlpack(ref_sfd_row_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_sfd_row_f8.element_type = _convert_to_cutlass_data_type(sf_dtype) + ref_sfd_row_f32_device = ref_sfd_row_f32.cuda() + ref_sfd_row_f32_tensor = from_dlpack(ref_sfd_row_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_sfd_row_f32_tensor, ref_sfd_row_f8) + cute.testing.convert(ref_sfd_row_f8, ref_sfd_row_f32_tensor) + ref_sfd_row_f32 = ref_sfd_row_f32_device.cpu() + + # 2. Convert ref_sfd_row_f32 to scale factor layout and compare with kernel sfd tensor + ref_sfd_row_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor(1, valid_m, n_out, sf_vec_size) + + # convert ref_after_srelu f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_sfd_row_f32), + from_dlpack(ref_sfd_row_f32_cute_torch_tensor_cpu), + ) + ref_sfd_row_f32 = ref_sfd_row_f32.cuda() + ref_tensors["sfd_row_ref"] = ref_sfd_row_f32_cute_torch_tensor_cpu.clone() + + # 3. Quantized output with scale factor + # Compute reciprocal of ref_sfd_row_f32 and multiply by norm_const + ref_sfd_row_rcp = norm_const * ref_sfd_row_f32.reciprocal() + ref_sfd_row_rcp = torch.clamp(ref_sfd_row_rcp, max=3.40282346638528859812e38) + # Expand the sfn dimension by repeating each value sf_vec_size times + # ref_sfd_row_rcp: (m, sfn, l) -> (m, sfn, sf_vec_size, l) -> (m, n, l) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp[:valid_m, :, :].unsqueeze(2).expand(valid_m, sfn, sf_vec_size, 1) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded.reshape(valid_m, sfn * sf_vec_size, 1) + # Trim to exact n dimension if needed + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded[:, :n_out, :] + + # Apply scale to reference output: ref = ref * ref_sfd_row_rcp + ref_after_row_quant = torch.einsum("mnl,mnl->mnl", ref_after_srelu, ref_sfd_row_rcp_expanded) + ref_tensors["d_ref"] = ref_after_row_quant.cuda().to(d_dtype).to(torch.float32).clone() + + ref_d_col = ref_after_srelu.permute(2, 1, 0).contiguous().permute(1, 2, 0) + ref_col_sf = ref_after_srelu_sf.permute(2, 1, 0).contiguous().permute(1, 2, 0) + n_col = ref_d_col.shape[1] + sfn_col = ceil_div(n_col, sf_vec_size) + valid_m_col = ref_d_col.shape[0] + valid_m_col_aligned = ceil_div(valid_m_col, 128) * 128 + ref_for_sf_col = ref_col_sf.permute(2, 0, 1).contiguous() + ref_for_sf_col = ref_for_sf_col.view(1, valid_m_col_aligned, sfn_col, sf_vec_size) + ref_for_sf_col, _ = torch.abs(ref_for_sf_col).max(dim=3) + ref_sfd_col_f32 = ref_for_sf_col * norm_const * get_dtype_rcp_limits(d_dtype) + ref_sfd_col_f32 = ref_sfd_col_f32.permute(1, 2, 0) + + ref_sfd_col_f8_torch = torch.empty(*(1, valid_m_col_aligned, sfn_col), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_sfd_col_f8 = from_dlpack(ref_sfd_col_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_sfd_col_f8.element_type = _convert_to_cutlass_data_type(sf_dtype) + ref_sfd_col_f32_device = ref_sfd_col_f32.cuda() + ref_sfd_col_f32_tensor = from_dlpack(ref_sfd_col_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_sfd_col_f32_tensor, ref_sfd_col_f8) + cute.testing.convert(ref_sfd_col_f8, ref_sfd_col_f32_tensor) + ref_sfd_col_f32 = ref_sfd_col_f32_device.cpu() + + ref_sfd_col_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor(1, valid_m_col_aligned, n_col, sf_vec_size) + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_sfd_col_f32), + from_dlpack(ref_sfd_col_f32_cute_torch_tensor_cpu), + ) + ref_sfd_col_f32 = ref_sfd_col_f32.cuda() + ref_tensors["sfd_col_ref"] = ref_sfd_col_f32_cute_torch_tensor_cpu.clone() + + ref_sfd_col_rcp = norm_const * ref_sfd_col_f32.reciprocal() + ref_sfd_col_rcp = torch.clamp(ref_sfd_col_rcp, max=3.40282346638528859812e38) + ref_sfd_col_rcp_expanded = ref_sfd_col_rcp[:valid_m_col, :, :].unsqueeze(2).expand(valid_m_col, sfn_col, sf_vec_size, 1) + ref_sfd_col_rcp_expanded = ref_sfd_col_rcp_expanded.reshape(valid_m_col, sfn_col * sf_vec_size, 1) + ref_sfd_col_rcp_expanded = ref_sfd_col_rcp_expanded[:, :n_col, :] + + ref_after_col_quant = torch.einsum("mnl,mnl->mnl", ref_d_col, ref_sfd_col_rcp_expanded) + + ref_col_f8_torch = torch.empty(*(1, valid_m_col, n_col), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_col_f8 = from_dlpack(ref_col_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_col_f8.element_type = _convert_to_cutlass_data_type(d_dtype) + ref_col_device = ref_after_col_quant.cuda() + ref_col_tensor = from_dlpack(ref_col_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_col_tensor, ref_col_f8) + cute.testing.convert(ref_col_f8, ref_col_tensor) + + ref_tensors["d_col_ref"] = ref_col_device.clone().permute(1, 0, 2) + + return ref_tensors + + +# ============================================================================= +# Reference Checking +# ============================================================================= + + +def check_ref_grouped_gemm_srelu( + inputs: Dict[str, Any], + outputs: Dict[str, Any], + cfg: Dict[str, Any], + atol: float = 1e-1, + rtol: float = 1e-2, + skip_ref: bool = False, +) -> None: + if skip_ref: + return + + torch.cuda.synchronize() + ref_tensors = run_grouped_gemm_srelu_ref( + a_ref=inputs["a_ref"], + b_ref=inputs["b_ref"], + sfa_ref=inputs["sfa_ref"], + sfb_ref=inputs["sfb_ref"], + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + aligned_group_m_list=inputs["aligned_group_m_list"], + valid_m=inputs["valid_m"], + bias_tensor=inputs.get("bias_tensor"), + generate_amax=outputs.get("amax_tensor") is not None, + generate_sfd=outputs.get("sfd_row_tensor") is not None and outputs.get("sfd_col_tensor") is not None, + norm_const_tensor=inputs.get("norm_const_tensor"), + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + sf_vec_size=cfg["sf_vec_size"], + sf_dtype=cfg["sf_dtype"], + ) + + torch.testing.assert_close(outputs["c_tensor"].float(), ref_tensors["c_ref"].float(), atol=atol, rtol=rtol) + torch.testing.assert_close(outputs["d_tensor"].float(), ref_tensors["d_ref"].float(), atol=atol, rtol=rtol) + + if "d_col_ref" in ref_tensors: + torch.testing.assert_close(outputs["d_col_tensor"].float(), ref_tensors["d_col_ref"].float(), atol=atol, rtol=rtol) + + if outputs.get("amax_tensor") is not None and "amax_ref" in ref_tensors: + torch.testing.assert_close(outputs["amax_tensor"].float(), ref_tensors["amax_ref"].float(), atol=atol, rtol=rtol) + + if outputs.get("sfd_row_tensor") is not None and "sfd_row_ref" in ref_tensors: + torch.testing.assert_close( + outputs["sfd_row_tensor"].float(), + ref_tensors["sfd_row_ref"].to(outputs["sfd_row_tensor"].device).float(), + atol=atol, + rtol=rtol, + ) + + if outputs.get("sfd_col_tensor") is not None and "sfd_col_ref" in ref_tensors: + sfd_col_tensor = outputs["sfd_col_tensor"].float() + sfd_col_ref = ref_tensors["sfd_col_ref"].to(outputs["sfd_col_tensor"].device).float() + if cfg.get("discrete_col_sfd", False): + # Mirror the original standalone discrete-col verification, which + # remaps packed tiles rather than comparing the whole buffer directly. + group_n_tile_list = [group // 128 for group in inputs["aligned_group_m_list"]] + m_tile = sfd_col_ref.shape[2] + res_real_idx = 0 + cumsum_n = 0 + total_n = sum(group_n_tile_list) + + for n_tile in group_n_tile_list: + for m_idx in range(m_tile): + for n_idx in range(n_tile): + res_real_m_idx = res_real_idx // total_n + res_real_n_idx = res_real_idx % total_n + ref_real_n_idx = n_idx + cumsum_n + + ref_slice = sfd_col_ref[:, :, m_idx, :, ref_real_n_idx, :] + res_slice = sfd_col_tensor[:, :, res_real_m_idx, :, res_real_n_idx, :] + torch.testing.assert_close( + res_slice, + ref_slice, + atol=atol, + rtol=rtol, + ) + res_real_idx += 1 + cumsum_n += n_tile + else: + torch.testing.assert_close( + sfd_col_tensor, + sfd_col_ref, + atol=atol, + rtol=rtol, + ) diff --git a/test/python/fe_api/test_grouped_gemm_swiglu_utils.py b/test/python/fe_api/test_grouped_gemm_swiglu_utils.py index 5e0535d8..28757a12 100644 --- a/test/python/fe_api/test_grouped_gemm_swiglu_utils.py +++ b/test/python/fe_api/test_grouped_gemm_swiglu_utils.py @@ -345,9 +345,35 @@ def allocate_grouped_gemm_input_tensors( tensor_m = permuted_m if permuted_m is not None else valid_m - # Note: b tensor can be n-major for mxfp8 dSwiglu; otherwise, a and b tensors are always k-major - a_ref, a_tensor = create_and_permute_tensor(1, tensor_m, k, False, ab_dtype) - b_ref, b_tensor = create_and_permute_tensor(l, n, k, b_major == "n", ab_dtype) + # Standalone grouped kernels use raw-byte tensors for FP4 payloads with the + # full logical K still present in the visible tensor shape. + if ab_dtype == torch.uint8: + try: + import cutlass + import cutlass.torch as cutlass_torch + except ImportError: + pytest.skip("CUTLASS is not installed; skipping grouped uint8 raw-FP4 tests.") + + a_ref = cutlass_torch.matrix(1, tensor_m, k, False, cutlass.Float32).cuda() + b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32).cuda() + _, a_tensor = cutlass_torch.cute_tensor_like( + a_ref, + cutlass.Float4E2M1FN, + is_dynamic_layout=True, + assumed_align=16, + ) + _, b_tensor = cutlass_torch.cute_tensor_like( + b_ref, + cutlass.Float4E2M1FN, + is_dynamic_layout=True, + assumed_align=16, + ) + a_tensor = a_tensor.view(torch.uint8) + b_tensor = b_tensor.view(torch.uint8) + else: + # Note: b tensor can be n-major for mxfp8 dSwiglu; otherwise, a and b tensors are always k-major + a_ref, a_tensor = create_and_permute_tensor(1, tensor_m, k, False, ab_dtype) + b_ref, b_tensor = create_and_permute_tensor(l, n, k, b_major == "n", ab_dtype) sfa_ref, sfa_tensor = create_scale_factor_tensor(1, tensor_m, k, sf_vec_size, sf_dtype) sfb_ref, sfb_tensor = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) diff --git a/test/python/fe_api/test_grouped_gemm_wgrad.py b/test/python/fe_api/test_grouped_gemm_wgrad.py index 4a80d526..304a1208 100644 --- a/test/python/fe_api/test_grouped_gemm_wgrad.py +++ b/test/python/fe_api/test_grouped_gemm_wgrad.py @@ -557,9 +557,7 @@ def _make_wgrad_wrapper_cache_inputs(group_k_list, sf_vec_size=16): expert_cnt = len(group_k_list) tokens_sum = sum(group_k_list) - scale_cols = 0 - for group_k in group_k_list: - scale_cols += ((group_k + sf_vec_size - 1) // sf_vec_size + 3) // 4 * 4 + scale_cols = ((tokens_sum + sf_vec_size - 1) // sf_vec_size + 3) // 4 * 4 return { "a_tensor": torch.empty((hidden, tokens_sum), dtype=torch.bfloat16), diff --git a/test/python/fe_api/test_rmsnorm_rht_amax.py b/test/python/fe_api/test_rmsnorm_rht_amax.py new file mode 100644 index 00000000..c00d04b0 --- /dev/null +++ b/test/python/fe_api/test_rmsnorm_rht_amax.py @@ -0,0 +1,124 @@ +"""Tests for the FE-OSS RMSNorm + RHT + amax API.""" + +import math + +import pytest +import torch + +from test_utils import torch_fork_set_rng + +SUPPORTED_N_NUM_THREADS = [ + (2048, 128), + (4096, 256), + (7168, 128), + (8192, 512), + (16384, 1024), + (32768, 512), +] + + +def _hadamard_matrix(n: int, *, device: torch.device) -> torch.Tensor: + matrix = torch.tensor([[1.0]], device=device, dtype=torch.float32) + while matrix.shape[0] < n: + top = torch.cat((matrix, matrix), dim=1) + bottom = torch.cat((matrix, -matrix), dim=1) + matrix = torch.cat((top, bottom), dim=0) + return matrix + + +def _rmsnorm_rht_amax_ref(x: torch.Tensor, w: torch.Tensor, eps: float, rows_per_cta: int): + m, n = x.shape + x_f32 = x.float() + rms = torch.sqrt((x_f32 * x_f32).mean(dim=-1, keepdim=True) + eps) + y = x_f32 / rms * w.float().unsqueeze(0) + + had_block = 16 + hadamard = _hadamard_matrix(had_block, device=x.device) / math.sqrt(had_block) + y = y.view(m, n // had_block, had_block) + y = torch.matmul(y, hadamard).view(m, n) + + num_ctas = m // rows_per_cta + amax = y.abs().view(num_ctas, rows_per_cta, n).amax(dim=(1, 2)) + return y.to(torch.bfloat16), amax.to(torch.float32) + + +def _make_inputs(*, m: int, n: int): + x = torch.randn((m, n), dtype=torch.bfloat16, device="cuda") + w = torch.randn((n,), dtype=torch.bfloat16, device="cuda") + return x.contiguous(), w.contiguous() + + +def _assert_ref_close(x, w, o, amax, *, eps: float, rows_per_cta: int, skip_ref: bool = False): + if skip_ref: + return + o_ref, amax_ref = _rmsnorm_rht_amax_ref(x, w, eps, rows_per_cta) + torch.testing.assert_close(o.float().cpu(), o_ref.float().cpu(), atol=4e-2, rtol=1e-2) + torch.testing.assert_close(amax.cpu(), amax_ref.cpu(), atol=2e-3, rtol=1e-3) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize("n,num_threads", SUPPORTED_N_NUM_THREADS) +def test_rmsnorm_rht_amax_compile_execute(n, num_threads, request): + try: + from cudnn import RmsNormRhtAmaxSm100 + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + skip_ref = request.config.getoption("--skip-ref", default=False) + eps = 1e-5 + m = 256 + rows_per_cta = 2 + x, w = _make_inputs(m=m, n=n) + o = torch.empty_like(x) + amax = torch.full((m // rows_per_cta,), float("-inf"), dtype=torch.float32, device="cuda") + + api = RmsNormRhtAmaxSm100( + sample_x=x, + sample_w=w, + sample_o=o, + sample_amax=amax, + eps=eps, + num_threads=num_threads, + rows_per_cta=rows_per_cta, + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, RuntimeError) as exc: + pytest.skip(f"Unsupported testcase: {exc}") + + api.compile() + api.execute(x_tensor=x, w_tensor=w, o_tensor=o, amax_tensor=amax) + _assert_ref_close(x, w, o, amax, eps=eps, rows_per_cta=rows_per_cta, skip_ref=skip_ref) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@pytest.mark.parametrize("n,num_threads", SUPPORTED_N_NUM_THREADS) +@pytest.mark.parametrize("rows_per_cta", [2, 4, 8]) +def test_rmsnorm_rht_amax_wrapper(n, num_threads, rows_per_cta, request): + try: + from cudnn import rmsnorm_rht_amax_wrapper_sm100 + except ImportError: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + skip_ref = request.config.getoption("--skip-ref", default=False) + eps = 1e-5 + m = 256 + x, w = _make_inputs(m=m, n=n) + + try: + outputs = rmsnorm_rht_amax_wrapper_sm100( + x_tensor=x, + w_tensor=w, + eps=eps, + num_threads=num_threads, + rows_per_cta=rows_per_cta, + ) + except (ValueError, RuntimeError) as exc: + pytest.skip(f"Unsupported testcase: {exc}") + + assert outputs["o_tensor"].shape == (m, n) + assert outputs["amax_tensor"].shape == (m // rows_per_cta,) + _assert_ref_close(x, w, outputs["o_tensor"], outputs["amax_tensor"], eps=eps, rows_per_cta=rows_per_cta, skip_ref=skip_ref) diff --git a/test/python/fe_api/test_sdpa_bwd.py b/test/python/fe_api/test_sdpa_bwd.py index a9d7e19f..6685cc09 100644 --- a/test/python/fe_api/test_sdpa_bwd.py +++ b/test/python/fe_api/test_sdpa_bwd.py @@ -1,5 +1,11 @@ """ Tests for SDPA backward SM100 API and wrapper. + +These tests are marked ``gpu_exclusive`` because the 2-CTA cluster kernel +uses Blackwell TMEM which produces NaN when GPU kernels from other CUDA +contexts (i.e. other pytest-xdist workers) execute concurrently on the +same device. Run with ``-m gpu_exclusive -n 0`` in CI. +See: https://gitlab-master.nvidia.com/cudnn/cudnn_frontend/-/jobs/299938128 """ import pytest @@ -14,6 +20,8 @@ with_sdpa_bwd_params, ) +pytestmark = [pytest.mark.gpu_exclusive, pytest.mark.xdist_group(name="gpu_exclusive")] + @pytest.mark.L0 @torch_fork_set_rng(seed=0) diff --git a/test/python/pytest.ini b/test/python/pytest.ini index b87edc3c..412f6d38 100644 --- a/test/python/pytest.ini +++ b/test/python/pytest.ini @@ -1,9 +1,10 @@ [pytest] -markers = +markers = L0: specifies L0 level (use -m L0) L1: specifies L1 level (use -m L1) L2: specifies L2 level (use -m L2) L3: specifies L3 level (use -m L3) L4: specifies L4 level (use -m L4) + gpu_exclusive: tests that require exclusive GPU access (no concurrent kernels from other processes) addopts = -m L0 --tb=short --no-header diff --git a/test/python/sdpa/mxfp8.py b/test/python/sdpa/mxfp8.py index 7aabf2c2..98f8a57f 100644 --- a/test/python/sdpa/mxfp8.py +++ b/test/python/sdpa/mxfp8.py @@ -193,7 +193,8 @@ def generate_graph_fwd(b, h_q, h_k, h_v, cudnn_itype=cudnn.data_type.FP8_E4M3, cudnn_otype=cudnn.data_type.HALF, left_bound=None, right_bound=None, diag_align=None, - with_sink_token=False): + with_sink_token=False, + with_unfuse_fma=False): # Compute padded dimensions for F8_128x4 scale factors s_q_padded = ceil_div(s_qo, 128) * 128 s_kv_padded = ceil_div(s_kv, 128) * 128 @@ -278,6 +279,7 @@ def generate_graph_fwd(b, h_q, h_k, h_v, diagonal_band_left_bound=left_bound, diagonal_band_right_bound=right_bound, sink_token=sink_token, + unfuse_fma=with_unfuse_fma, ) # Set output tensor properties @@ -513,7 +515,8 @@ def exec_sdpa_mxfp8(cfg, request, cudnn_handle): right_bound = getattr(cfg, 'right_bound', None) diag_align = getattr(cfg, 'diag_align', None) with_sink_token = getattr(cfg, 'with_sink_token', False) - rescale_threshold = getattr(cfg, 'rescale_threshold', 4.0) + with_unfuse_fma = getattr(cfg, 'with_unfuse_fma', False) + rescale_threshold = cfg.rescale_threshold if hasattr(cfg, 'rescale_threshold') and cfg.rescale_threshold is not None else 4.0 # Get input/output types from config torch_itype = cfg.data_type if hasattr(cfg, 'data_type') and cfg.data_type else torch.float8_e4m3fn @@ -537,6 +540,7 @@ def exec_sdpa_mxfp8(cfg, request, cudnn_handle): cudnn_itype, cudnn_otype, left_bound=left_bound, right_bound=right_bound, diag_align=diag_align, with_sink_token=with_sink_token, + with_unfuse_fma=with_unfuse_fma, ) graph_fwd.validate() graph_fwd.build_operation_graph() diff --git a/test/python/test_mhas.py b/test/python/test_mhas.py index 0e13f69a..0fbd85b3 100644 --- a/test/python/test_mhas.py +++ b/test/python/test_mhas.py @@ -1538,5 +1538,9 @@ def test_sdpa_backward( ) if is_bias: torch.testing.assert_close( - dBias_ref, dBias_gpu, check_dtype=False, atol=2e-2, rtol=2e-2 + dBias_ref, + dBias_gpu, + check_dtype=False, + atol=2e-2 if input_type != torch.bfloat16 else 7e-2, + rtol=2e-2, ) diff --git a/test/python/test_mhas_v2.py b/test/python/test_mhas_v2.py index 618fed64..838c9ec0 100644 --- a/test/python/test_mhas_v2.py +++ b/test/python/test_mhas_v2.py @@ -935,7 +935,6 @@ def test_sdpa_mxfp8_fwd_L0(env_info, test_no, request, cudnn_handle): test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_mxfp8 = True - test.showConfig(test_no, request) # Randomly enable unfuse_fma via environment variable for SM100 unfuse_fma = rng.choice([True, False]) @@ -953,6 +952,8 @@ def test_sdpa_mxfp8_fwd_L0(env_info, test_no, request, cudnn_handle): test.cfg.rescale_threshold = rescale_threshold os.environ["CUDNN_RESCALE_THRESHOLD"] = str(test.cfg.rescale_threshold) + test.showConfig(test_no, request) + if request.node.name in test.blocked_tests: pytest.skip(f"blocked test: {request.node.name}") try: @@ -996,11 +997,12 @@ def test_sdpa_mxfp8_bwd_L0(env_info, test_no, request, cudnn_handle): test.cfg.is_mxfp8 = True test.cfg.is_infer = False - test.showConfig(test_no, request) test.cfg.rescale_threshold = 0.0 os.environ["CUDNN_RESCALE_THRESHOLD"] = str(test.cfg.rescale_threshold) + test.showConfig(test_no, request) + if request.node.name in test.blocked_tests: pytest.skip(f"blocked test: {request.node.name}") try: diff --git a/tools/cudnn_repro/README.md b/tools/cudnn_repro/README.md index bbc56bb4..e7c3fc09 100644 --- a/tools/cudnn_repro/README.md +++ b/tools/cudnn_repro/README.md @@ -65,8 +65,8 @@ The tool auto-detects SDPA operation tags and routes to the appropriate handler: - `SDPA_BWD` - `SDPA_FP8_FWD` - `SDPA_FP8_BWD` - -Non-MXFP8 FP8 forward and backward repro are supported. MXFP8 repro is not yet implemented. +- `SDPA_MXFP8_FWD` +- `SDPA_MXFP8_BWD` **Debug mode** (`CUDNN_DEBUG_REPRO=1`) writes: - `cudnn_repro_stage0.txt` - Raw log diff --git a/tools/cudnn_repro/cudnn_repro/routing.py b/tools/cudnn_repro/cudnn_repro/routing.py index abaf6722..8f65a8c0 100644 --- a/tools/cudnn_repro/cudnn_repro/routing.py +++ b/tools/cudnn_repro/cudnn_repro/routing.py @@ -14,9 +14,9 @@ def detect_operation_key(payload: dict) -> str: """Detect the operation key from the JSON payload.""" for node in payload.get("nodes", []): tag = node.get("tag", "") - if tag == "SDPA_FP8_FWD": + if tag in ("SDPA_FP8_FWD", "SDPA_MXFP8_FWD"): return "sdpa_fp8_fwd" - if tag == "SDPA_FP8_BWD": + if tag in ("SDPA_FP8_BWD", "SDPA_MXFP8_BWD"): return "sdpa_fp8_bwd" if tag == "SDPA_BWD": return "sdpa_bwd" diff --git a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py index e419970a..1ccd7f01 100644 --- a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py +++ b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_bwd.py @@ -12,7 +12,7 @@ def _find_bwd_node(payload: dict) -> dict: tag = candidate.get("tag") if tag == "SDPA_BWD": return candidate - if tag == "SDPA_FP8_BWD": + if tag in ("SDPA_FP8_BWD", "SDPA_MXFP8_BWD"): raise NotImplementedError("SDPA FP8 backward repro is not yet implemented") if payload.get("nodes"): return payload["nodes"][0] diff --git a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py index e2d1a292..fc005785 100644 --- a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py +++ b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_bwd.py @@ -8,17 +8,16 @@ def _find_node(payload: dict) -> dict: - node = utils.node_by_tag(payload, "SDPA_FP8_BWD") + node = utils.node_by_tag(payload, "SDPA_MXFP8_BWD", "SDPA_FP8_BWD") if node is None: - raise ValueError("SDPA FP8 backward node not found in log") + raise ValueError("SDPA FP8/MXFP8 backward node not found in log") return node def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build FP8 backward test configuration from JSON payload.""" node = _find_node(payload) - if utils.is_mxfp8_payload(payload, node): - raise NotImplementedError("MXFP8 repro is not yet implemented") + is_mxfp8 = utils.is_mxfp8_payload(payload, node) tensors = payload.get("tensors", {}) node_name = node.get("name") @@ -96,10 +95,12 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: cfg["is_ragged"] = is_ragged cfg["is_dropout"] = dropout_prob > 0.0 cfg["is_determin"] = bool(node.get("is_deterministic_algorithm", False)) - cfg["is_mxfp8"] = False + cfg["is_mxfp8"] = is_mxfp8 cfg["with_score_max"] = "Max" in outputs cfg["with_score_sum_exp"] = "Sum_exp" in outputs cfg["with_sink_token"] = "SINK_TOKEN" in inputs or "DSINK_TOKEN" in outputs + if "rescale_threshold" in node: + cfg["rescale_threshold"] = utils.parse_hex_float(node.get("rescale_threshold")) left_bound = utils.parse_optional_int(node.get("left_bound")) right_bound = utils.parse_optional_int(node.get("right_bound")) diff --git a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py index 63e0c4a4..52dfa1ab 100644 --- a/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py +++ b/tools/cudnn_repro/cudnn_repro/stage1_annotate_sdpa_fp8_fwd.py @@ -8,17 +8,16 @@ def _find_node(payload: dict) -> dict: - node = utils.node_by_tag(payload, "SDPA_FP8_FWD") + node = utils.node_by_tag(payload, "SDPA_MXFP8_FWD", "SDPA_FP8_FWD") if node is None: - raise ValueError("SDPA FP8 forward node not found in log") + raise ValueError("SDPA FP8/MXFP8 forward node not found in log") return node def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: """Build FP8 forward test configuration from JSON payload.""" node = _find_node(payload) - if utils.is_mxfp8_payload(payload, node): - raise NotImplementedError("MXFP8 repro is not yet implemented") + is_mxfp8 = utils.is_mxfp8_payload(payload, node) tensors = payload.get("tensors", {}) node_name = node.get("name") @@ -96,10 +95,14 @@ def build_cfg(raw_line: str, payload: dict, seed: Optional[int] = None) -> dict: cfg["is_ragged"] = is_ragged cfg["is_dropout"] = dropout_prob > 0.0 cfg["is_determin"] = None - cfg["is_mxfp8"] = False + cfg["is_mxfp8"] = is_mxfp8 cfg["with_score_max"] = "Max" in outputs cfg["with_score_sum_exp"] = "Sum_exp" in outputs cfg["with_sink_token"] = "SINK_TOKEN" in inputs + if is_mxfp8: + cfg["with_unfuse_fma"] = bool(node.get("unfuse_fma", False)) + if "rescale_threshold" in node: + cfg["rescale_threshold"] = utils.parse_hex_float(node.get("rescale_threshold")) left_bound = utils.parse_optional_int(node.get("left_bound")) right_bound = utils.parse_optional_int(node.get("right_bound")) diff --git a/tools/cudnn_repro/cudnn_repro/utils.py b/tools/cudnn_repro/cudnn_repro/utils.py index 7e8b0e9e..32b61443 100644 --- a/tools/cudnn_repro/cudnn_repro/utils.py +++ b/tools/cudnn_repro/cudnn_repro/utils.py @@ -28,7 +28,7 @@ def parse_hex_float(value: Any) -> Optional[float]: hex_str = hex_str[2:] if len(hex_str) == 8: try: - return struct.unpack("f", bytes.fromhex(hex_str))[0] except (ValueError, struct.error): pass try: @@ -206,7 +206,13 @@ def infer_block_size(page_table_entry: Optional[dict], seq_len_kv: list[int], k_ def is_mxfp8_payload(payload: dict, node: dict) -> bool: - """Detect MXFP8 payloads from scale-factor tensor metadata.""" + """Detect MXFP8 payloads from tags, explicit flags, or scale-factor metadata.""" + if (node.get("tag") or "").upper().startswith("SDPA_MXFP8_"): + return True + + if node.get("is_mxfp8") is True: + return True + tensors = payload.get("tensors", {}) node_name = node.get("name") for label, hint in node.get("inputs", {}).items(): diff --git a/tools/cudnn_repro/tests/test_cudnn_repro.py b/tools/cudnn_repro/tests/test_cudnn_repro.py deleted file mode 100644 index 2c9a2b7e..00000000 --- a/tools/cudnn_repro/tests/test_cudnn_repro.py +++ /dev/null @@ -1,189 +0,0 @@ -import os -import shlex -import subprocess -import sys -from pathlib import Path - -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[3] - - -def _run(cmd, env): - proc = subprocess.run( - cmd, - cwd=REPO_ROOT, - env=env, - capture_output=True, - text=True, - ) - if proc.returncode != 0: - raise AssertionError(f"Command failed: {' '.join(cmd)}\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}") - return proc - - -def _last_payload(log_path): - import cudnn_repro as repro - - lines = log_path.read_text().splitlines() - entries = list(repro._iter_context_entries(lines)) - assert entries, f"No context entries found in {log_path}" - return entries[-1] - - -def _target_tests(): - raw = os.environ.get("CUDNN_REPRO_TARGETS") - if raw is None: - raw = ",".join( - [ - *(f"test/python/test_mhas_v2.py::test_sdpa_random_fwd_L0[test{i}]" for i in range(1, 11)), - *(f"test/python/test_mhas_v2.py::test_sdpa_random_fwd_ragged_L0[test{i}]" for i in range(1, 11)), - ] - ) - return [item.strip() for item in raw.split(",") if item.strip()] - - -def _target_test_id(target): - return target.split("::", 1)[-1] - - -def _assert_reproducer_json_matches_target(tmp_path, target): - torch = pytest.importorskip("torch") - if not torch.cuda.is_available(): - pytest.skip("CUDA is required to generate SDPA logs") - - env_base = os.environ.copy() - env_base["CUDNN_FRONTEND_LOG_INFO"] = "1" - log_a = tmp_path / "initial.log" - log_b = tmp_path / "repro.log" - - cmd_test = [ - sys.executable, - "-m", - "pytest", - "-vv", - "-s", - "-rA", - target, - ] - env_first = env_base.copy() - env_first["CUDNN_FRONTEND_LOG_FILE"] = str(log_a) - _run(cmd_test, env_first) - - import cudnn_repro as repro - - raw_line, payload = _last_payload(log_a) - cfg = repro._build_cfg(raw_line, payload) - - repro_cmd = shlex.split(repro._build_command(cfg)) - env_second = env_base.copy() - env_second["CUDNN_FRONTEND_LOG_FILE"] = str(log_b) - _run(repro_cmd, env_second) - - _, repro_payload = _last_payload(log_b) - assert payload == repro_payload, f"Payload mismatch for target {target}" - - -@pytest.mark.parametrize("target", _target_tests(), ids=_target_test_id) -def test_reproducer_json_matches(tmp_path, target): - _assert_reproducer_json_matches_target(tmp_path, target) - - -def test_build_cfg_maps_causal_without_explicit_right_bound(): - import cudnn_repro as repro - - payload = { - "context": {"io_data_type": "FLOAT16"}, - "nodes": [ - { - "tag": "SDPA_FWD", - "name": "sdpa_fwd", - "inputs": {"Q": 1, "K": 2, "V": 3}, - "outputs": {"O": 4}, - "diagonal_alignment": "TOP_LEFT", - "causal_mask": True, - "left_bound": None, - "right_bound": None, - } - ], - "tensors": { - "1": {"uid": 1, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "2": {"uid": 2, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "3": {"uid": 3, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - "4": {"uid": 4, "dim": [2, 4, 128, 64], "stride": [32768, 8192, 64, 1]}, - }, - } - - cfg = repro._build_cfg("{}", payload) - assert cfg["left_bound"] is None - assert cfg["right_bound"] == 0 - assert cfg["diag_align"] == 0 - - -def test_build_cfg_preserves_logged_tensor_layout(): - import cudnn_repro as repro - - payload = { - "context": {"io_data_type": "FLOAT16"}, - "nodes": [ - { - "tag": "SDPA_FWD", - "name": "sdpa_fwd", - "inputs": {"Q": 1, "K": 2, "V": 3}, - "outputs": {"O": 4}, - "diagonal_alignment": "TOP_LEFT", - "left_bound": None, - "right_bound": None, - } - ], - # BSHD shape/stride: (b, s, h, d) - "tensors": { - "1": {"uid": 1, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "2": {"uid": 2, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "3": {"uid": 3, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - "4": {"uid": 4, "dim": [2, 128, 4, 64], "stride": [32768, 64, 8192, 1]}, - }, - } - - cfg = repro._build_cfg("{}", payload) - assert cfg["shape_q"] == (2, 128, 4, 64) - assert cfg["stride_q"] == (32768, 64, 8192, 1) - assert cfg["h_q"] == 128 - assert cfg["s_q"] == 4 - assert cfg["left_bound"] is None - assert cfg["right_bound"] is None - - -def test_build_command_normalizes_enum_fields(): - import cudnn_repro as repro - - cfg = { - "data_type": "torch.float16", - "rng_data_seed": 123, - "batches": 1, - "h_q": 2, - "h_k": 2, - "h_v": 2, - "s_q": 16, - "s_kv": 16, - "d_qk": 64, - "d_v": 64, - "shape_q": (1, 2, 16, 64), - "stride_q": (2048, 1024, 64, 1), - "shape_k": (1, 2, 16, 64), - "stride_k": (2048, 1024, 64, 1), - "shape_v": (1, 2, 16, 64), - "stride_v": (2048, 1024, 64, 1), - "shape_o": (1, 2, 16, 64), - "stride_o": (2048, 1024, 64, 1), - "seq_len_q": [], - "seq_len_kv": [], - "left_bound": None, - "right_bound": 0, - "diag_align": 0, - "implementation": "AUTO", - } - - command = repro._build_command(cfg) - assert "cudnn.diagonal_alignment.TOP_LEFT" in command - assert "cudnn.attention_implementation.AUTO" in command diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py b/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py index 3ec6a5db..c09743de 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_closed_loop.py @@ -1,3 +1,4 @@ +import json import os import shlex import subprocess @@ -31,6 +32,49 @@ def _last_payload(log_path): return entries[-1] +def _normalize_tensor_entry(entry): + normalized = {} + for key in ( + "data_type", + "dim", + "stride", + "is_virtual", + "is_pass_by_value", + "pass_by_value", + "reordering_type", + "ragged_offset_uid", + ): + if key in entry: + normalized[key] = entry[key] + return normalized + + +def _normalize_payload(payload): + tensors = payload.get("tensors", {}) + + def resolve(uid): + return _normalize_tensor_entry(tensors[str(uid)]) + + normalized = { + "context": payload.get("context"), + "nodes": [], + "tensors": sorted( + json.dumps(_normalize_tensor_entry(entry), sort_keys=True) for entry in tensors.values() + ), + } + for node in payload.get("nodes", []): + normalized_node = {} + for key, value in node.items(): + if key == "inputs": + normalized_node[key] = {label: resolve(uid) for label, uid in value.items()} + elif key == "outputs": + normalized_node[key] = {label: resolve(uid) for label, uid in value.items()} + else: + normalized_node[key] = value + normalized["nodes"].append(normalized_node) + return normalized + + def _target_tests(): raw = os.environ.get("CUDNN_REPRO_TARGETS") if raw is None: @@ -72,17 +116,20 @@ def _assert_reproducer_json_matches_target(tmp_path, target): _run(cmd_test, env_first) import cudnn_repro as repro + import cudnn_repro.routing as routing raw_line, payload = _last_payload(log_a) - cfg = repro._build_cfg(raw_line, payload) - - repro_cmd = shlex.split(repro._build_command(cfg, payload)) + stage1, stage2 = routing.select_stage_modules(payload) + stage1_json = stage1.extract_and_annotate(raw_line, payload, log_a.read_text()) + seed = stage1_json.get("repro_metadata", {}).get("rng_data_seed") + cfg = stage1.build_cfg(raw_line, stage1_json, seed) + repro_cmd = shlex.split(stage2.build_command(cfg)) env_second = env_base.copy() env_second["CUDNN_FRONTEND_LOG_FILE"] = str(log_b) _run(repro_cmd, env_second) _, repro_payload = _last_payload(log_b) - assert payload == repro_payload, f"Payload mismatch for target {target}" + assert _normalize_payload(payload) == _normalize_payload(repro_payload), f"Payload mismatch for target {target}" @pytest.mark.parametrize("target", _target_tests(), ids=_target_test_id) diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py b/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py index 80bb882b..2d107f20 100644 --- a/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py +++ b/tools/cudnn_repro/tests/test_cudnn_repro_fp8.py @@ -8,6 +8,8 @@ def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dtype="HALF", mxfp8=False): + if mxfp8 and tag == "SDPA_FP8_FWD": + tag = "SDPA_MXFP8_FWD" inputs = { "Q": 1, "K": 2, @@ -59,6 +61,8 @@ def _fp8_fwd_payload(*, tag="SDPA_FP8_FWD", ragged=False, paged=False, output_dt { "tag": tag, "name": "sdpa_fp8_fwd", + "is_mxfp8": mxfp8, + "unfuse_fma": mxfp8, "inputs": inputs, "outputs": outputs, "generate_stats": True, @@ -102,8 +106,9 @@ def _fp8_bwd_payload(*, ragged=False, output_dtype="HALF", mxfp8=False): "context": {"io_data_type": "FP8_E4M3"}, "nodes": [ { - "tag": "SDPA_FP8_BWD", + "tag": "SDPA_MXFP8_BWD" if mxfp8 else "SDPA_FP8_BWD", "name": "sdpa_fp8_bwd", + "is_mxfp8": mxfp8, "inputs": inputs, "outputs": {"dQ": 21, "dK": 22, "dV": 23, "Amax_dQ": 24, "Amax_dK": 25, "Amax_dV": 26, "Amax_d": 27}, "padding_mask": ragged, @@ -156,6 +161,8 @@ def test_routing_distinguishes_fp8_and_non_fp8_tags(): assert routing.detect_operation_key({"nodes": [{"tag": "SDPA_BWD"}]}) == "sdpa_bwd" assert routing.detect_operation_key({"nodes": [{"tag": "SDPA_FP8_FWD"}]}) == "sdpa_fp8_fwd" assert routing.detect_operation_key({"nodes": [{"tag": "SDPA_FP8_BWD"}]}) == "sdpa_fp8_bwd" + assert routing.detect_operation_key({"nodes": [{"tag": "SDPA_MXFP8_FWD"}]}) == "sdpa_fp8_fwd" + assert routing.detect_operation_key({"nodes": [{"tag": "SDPA_MXFP8_BWD"}]}) == "sdpa_fp8_bwd" def test_build_fp8_fwd_cfg_extracts_output_type_and_stats(): @@ -183,9 +190,21 @@ def test_build_fp8_fwd_cfg_infers_paged_block_size(): assert cfg["stride_v"] is None -def test_build_fp8_fwd_cfg_rejects_mxfp8(): - with pytest.raises(NotImplementedError, match="MXFP8"): - stage1_fp8_fwd.build_cfg("{}", _fp8_fwd_payload(mxfp8=True), seed=123) +def test_build_mxfp8_fwd_cfg_extracts_mxfp8_fields(): + payload = _fp8_fwd_payload(output_dtype="BFLOAT16", mxfp8=True) + payload["nodes"][0]["rescale_threshold"] = 2.0 + cfg = stage1_fp8_fwd.build_cfg("{}", payload, seed=123) + assert cfg["is_mxfp8"] is True + assert cfg["output_type"] == "torch.bfloat16" + assert cfg["with_unfuse_fma"] is True + assert cfg["rescale_threshold"] == 2.0 + + +def test_build_mxfp8_fwd_cfg_uses_mxfp8_tag_without_flag(): + payload = _fp8_fwd_payload(tag="SDPA_MXFP8_FWD", mxfp8=False) + payload["nodes"][0].pop("is_mxfp8") + cfg = stage1_fp8_fwd.build_cfg("{}", payload, seed=123) + assert cfg["is_mxfp8"] is True def test_build_fp8_bwd_cfg_extracts_output_type_and_determinism(): @@ -206,9 +225,21 @@ def test_build_fp8_bwd_cfg_extracts_output_type_and_determinism(): assert cfg["seq_len_kv"] == [15, 11] -def test_build_fp8_bwd_cfg_rejects_mxfp8(): - with pytest.raises(NotImplementedError, match="MXFP8"): - stage1_fp8_bwd.build_cfg("{}", _fp8_bwd_payload(mxfp8=True), seed=123) +def test_build_mxfp8_bwd_cfg_extracts_mxfp8_fields(): + payload = _fp8_bwd_payload(output_dtype="BFLOAT16", mxfp8=True) + payload["nodes"][0]["rescale_threshold"] = 0.0 + cfg = stage1_fp8_bwd.build_cfg("{}", payload, seed=123) + assert cfg["is_mxfp8"] is True + assert cfg["output_type"] == "torch.bfloat16" + assert cfg["rescale_threshold"] == 0.0 + + +def test_build_mxfp8_bwd_cfg_uses_mxfp8_tag_without_flag(): + payload = _fp8_bwd_payload(mxfp8=False) + payload["nodes"][0]["tag"] = "SDPA_MXFP8_BWD" + payload["nodes"][0].pop("is_mxfp8") + cfg = stage1_fp8_bwd.build_cfg("{}", payload, seed=123) + assert cfg["is_mxfp8"] is True def test_build_fp8_forward_command_preserves_fp8_fields(): @@ -227,3 +258,13 @@ def test_build_fp8_backward_command_preserves_bwd_fields(): assert "torch.float8_e4m3fn" in command assert "'is_mxfp8': False" in command assert "cudnn.diagonal_alignment.BOTTOM_RIGHT" in command + + +def test_build_mxfp8_forward_command_preserves_mxfp8_fields(): + payload = _fp8_fwd_payload(output_dtype="BFLOAT16", mxfp8=True) + payload["nodes"][0]["rescale_threshold"] = 4.0 + command = stage2_fp8_fwd.build_command(stage1_fp8_fwd.build_cfg("{}", payload, seed=7)) + assert "'is_mxfp8': True" in command + assert "'with_unfuse_fma': True" in command + assert "'rescale_threshold': 4.0" in command + assert "torch.bfloat16" in command diff --git a/tools/cudnn_repro/tests/test_cudnn_repro_mxfp8_closed_loop.py b/tools/cudnn_repro/tests/test_cudnn_repro_mxfp8_closed_loop.py new file mode 100644 index 00000000..1312c7aa --- /dev/null +++ b/tools/cudnn_repro/tests/test_cudnn_repro_mxfp8_closed_loop.py @@ -0,0 +1,23 @@ +import pytest +from looseversion import LooseVersion + +from .test_cudnn_repro_closed_loop import _assert_reproducer_json_matches_target + + +@pytest.mark.parametrize( + "target", + [ + "test/python/test_mhas_v2.py::test_sdpa_mxfp8_fwd_L0[test1]", + "test/python/test_mhas_v2.py::test_sdpa_mxfp8_bwd_L0[test1]", + ], +) +def test_mxfp8_reproducer_json_matches(tmp_path, target): + cudnn = pytest.importorskip("cudnn") + torch = pytest.importorskip("torch") + + if LooseVersion(cudnn.backend_version_string()) < "9.21.0": + pytest.skip("MXFP8 repro requires cuDNN 9.21.0 or higher") + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("MXFP8 repro requires Blackwell or higher") + + _assert_reproducer_json_matches_target(tmp_path, target)