diff --git a/CMakeLists.txt b/CMakeLists.txt
index e4deecbe..61359a0d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.23)
-project(cudnn_frontend VERSION 1.22.1)
+project(cudnn_frontend VERSION 1.23.0)
option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)
diff --git a/README.md b/README.md
index ba010c00..e93898f1 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,17 @@
-# cuDNN FrontEnd(FE)
+# cuDNN Frontend (FE)
-**cuDNN FE** is the modern, open-source entry point to the NVIDIA cuDNN library and high performance open-source kernels. It provides a C++ header-only library and a Python interface to access the powerful cuDNN Graph API and open-source kernels.
+[](https://pypi.org/project/nvidia-cudnn-frontend/)
+[](https://pypi.org/project/nvidia-cudnn-frontend/)
+[](https://pypi.org/project/nvidia-cudnn-frontend/)
+[](LICENSE)
+[](https://nvidia.github.io/cudnn-frontend/)
+
+**cuDNN Frontend** is NVIDIA's modern, open-source entry point to the cuDNN library and a growing collection of high-performance open-source kernels — scaled dot-product attention (**SDPA / Flash Attention**), grouped GEMM fusions for **Mixture-of-Experts (MoE)** training, fused normalization + activation, and more.
+
+It provides a **header-only C++ API** and a **Python interface** (with native PyTorch integration) to the cuDNN Graph API, targeting NVIDIA **Hopper** (H100/H200) and **Blackwell** (B200/GB200/GB300) GPUs across FP16, BF16, FP8, and **MXFP8** precision.
+
+**Links:** [Documentation](https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/) · [Blog & Deep Dives](https://nvidia.github.io/cudnn-frontend/) · [PyPI](https://pypi.org/project/nvidia-cudnn-frontend/) · [Release Notes](https://github.com/NVIDIA/cudnn-frontend/releases) · [Samples](samples/)
## 🚀 Latest news:
@@ -11,10 +21,15 @@ We are now shipping **OSS kernels**, allowing you to inspect, modify, and contri
* **[GEMM + Amax](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_amax):** Optimized FP8 matrix multiplication with absolute maximum calculation.
* **[GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_swiglu):** High-performance implementation of the SwiGLU activation fused with GEMM.
+* **[GEMM + sReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_srelu):** High-performance implementation of squared-ReLU fused with GEMM.
+* **[GEMM + dsReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/gemm_dsrelu):** High-performance implementation of dsquared-ReLU fused with GEMM.
* **[Grouped GEMM + GLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_glu):** Unified grouped GEMM GLU API supporting dense and discrete MoE weight layouts.
+* **[Grouped GEMM + GLU + Hadamard](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard):** Dense grouped GEMM GLU forward fusion with a fused Hadamard transform and per-expert AMAX reduction.
* **[Grouped GEMM + dGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dglu):** Unified grouped GEMM dGLU backward API supporting dense and discrete MoE weight layouts.
* **[Grouped GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu):** SwiGLU activation fused with Grouped GEMM.
* **[Grouped GEMM + dSwiglu](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dswiglu):** dSwiglu activation fused with Grouped GEMM.
+* **[Grouped GEMM + sReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_srelu):** Contiguous grouped squared-ReLU GEMM for MoE workloads.
+* **[Grouped GEMM + dsReLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_dsrelu):** Contiguous grouped dsquared-ReLU GEMM for MoE workloads.
* **[Discrete Grouped GEMM + SwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu):** Per-expert-pointer SwiGLU grouped GEMM for MoE workloads without weight packing.
* **[Discrete Grouped GEMM + dSwiGLU](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu):** Per-expert-pointer dSwiGLU backward grouped GEMM for MoE workloads without weight packing.
* **[Grouped GEMM + Quant](https://github.com/NVIDIA/cudnn-frontend/tree/main/python/cudnn/grouped_gemm/grouped_gemm_quant):** Legacy dense-only grouped GEMM quant API for MoE FC2/dFC1 workloads.
@@ -30,13 +45,13 @@ We are now shipping **OSS kernels**, allowing you to inspect, modify, and contri
#### Llama 3.1 style Forward and Bprop with causal masking (GB300)
-
+
#### Deepseek v3 style Forward and Bprop with causal masking (GB300)
-
+
## Key Features
@@ -56,8 +71,9 @@ pip install nvidia-cudnn-frontend
```
**Requirements:**
-* Python 3.8+
+* Python 3.9+
* NVIDIA driver and CUDA Toolkit
+* NVIDIA cuDNN (minimum 8.5.0)
### ⚙️ C++ (Header Only)
@@ -93,9 +109,12 @@ cmake --build . -j16
## Documentation & Examples
-* **Developer Guide:** [Official NVIDIA Documentation](https://docs.nvidia.com/deeplearning/cudnn/frontend/v1.9.0/developer/overview.html)
-* **C++ Samples:** See `samples/cpp` for comprehensive usage examples.
-* **Python Samples:** See `samples/python` for pythonic implementations.
+* **Developer Guide:** [Official NVIDIA Documentation (latest)](https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/)
+* **Blog & Deep Dives:** [nvidia.github.io/cudnn-frontend](https://nvidia.github.io/cudnn-frontend/) — release notes, installation guides, and technical deep-dives (MXFP8 attention, FP8 scale layouts, etc.)
+* **C++ Samples:** See [`samples/cpp`](samples/cpp) for end-to-end examples covering convolution, matmul, SDPA / Flash Attention, normalization, and more.
+* **Python Samples:** See [`samples/python`](samples/python) for Jupyter notebooks and PyTorch integration patterns.
+* **OSS Kernels:** See [`python/cudnn/`](python/cudnn/) for source of SDPA, grouped GEMM + SwiGLU/GLU, RMSNorm + SiLU, Native Sparse Attention, and other open-sourced kernels.
+* **PyTorch Custom Ops:** See [`python/cudnn/experimental/ops`](python/cudnn/experimental/ops) for `torch.compile`-compatible wrappers around cuDNN kernels.
## 🤝 Contributing
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv
new file mode 100644
index 00000000..abbcaaec
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_20260424_101009.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,fwd,False,50.456,1743.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,False,212.546,1076.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,True,210.687,1086.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,fwd,False,111.031,1584.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,False,421.194,1086.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,True,411.727,1111.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,fwd,False,35.942,2447.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,False,123.181,1857.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,True,122.033,1874.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,75.166,2340.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,236.705,1932.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,234.246,1953.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,fwd,False,37.773,2329.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,False,149.238,1532.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,True,152.079,1504.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,81.379,2162.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,293.095,1561.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,289.395,1581.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,fwd,False,13.103,1678.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,False,50.424,1134.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,True,50.136,1140.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,fwd,False,24.270,1812.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,False,105.115,1088.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,True,102.798,1112.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,fwd,False,8.845,2486.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,False,29.050,1968.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,True,29.760,1921.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,16.984,2590.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,57.481,1989.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,57.615,1985.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,fwd,False,9.337,2355.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,False,36.835,1552.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,True,36.960,1547.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,17.648,2492.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,69.820,1638.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,71.538,1598.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,fwd,False,3.470,1585.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,False,11.984,1193.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,True,13.023,1098.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,fwd,False,6.249,1759.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,False,25.403,1125.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,True,25.571,1118.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,fwd,False,2.294,2397.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,False,7.296,1959.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,True,7.530,1898.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,4.258,2582.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,13.472,2122.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,13.256,2157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,fwd,False,2.478,2219.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,False,9.567,1494.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,True,9.205,1553.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,4.567,2408.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,17.014,1680.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,16.251,1759.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,fwd,False,0.977,1406.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,False,3.383,1057.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,True,3.301,1083.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.650,1666.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,False,6.085,1175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,True,6.043,1183.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.656,2095.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.112,1692.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.108,1696.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.092,2518.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,3.450,2071.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,3.449,2072.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.714,1925.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.620,1364.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.621,1364.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.187,2316.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,4.267,1675.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,4.272,1673.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,fwd,False,0.321,1072.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,False,1.056,846.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,True,1.015,881.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.461,1492.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.706,1048.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.633,1094.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.207,1659.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.681,1312.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.682,1311.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.316,2175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,0.989,1807.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,0.988,1808.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.224,1535.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.854,1047.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.854,1046.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.341,2018.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.188,1504.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.206,1482.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png
new file mode 100644
index 00000000..b9b11572
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png
new file mode 100644
index 00000000..c2aa14f0
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png
new file mode 100644
index 00000000..f6c3b22c
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png
new file mode 100644
index 00000000..04365b83
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb200/dsv3_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv
new file mode 100644
index 00000000..c01eb2bb
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_20260424_101002.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,fwd,False,44.189,1991.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,False,177.494,1289.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,32768,32768,128,128,192,128,bwd,True,175.497,1303.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,fwd,False,91.788,1917.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,False,366.969,1246.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,32768,32768,128,128,192,128,bwd,True,353.235,1295.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,fwd,False,26.146,3364.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,False,95.057,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,32768,32768,128,128,192,128,bwd,True,94.596,2418.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,52.327,3362.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,185.459,2466.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,187.470,2440.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,fwd,False,29.731,2959.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,False,122.768,1863.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,32768,32768,128,128,192,128,bwd,True,121.830,1877.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,fwd,False,57.427,3063.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,False,234.862,1948.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,32768,32768,128,128,192,128,bwd,True,237.119,1929.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,fwd,False,11.178,1967.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,False,43.911,1302.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,16384,16384,128,128,192,128,bwd,True,41.569,1375.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,fwd,False,21.640,2032.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,False,91.189,1254.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,16384,16384,128,128,192,128,bwd,True,86.044,1329.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,fwd,False,6.653,3305.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,False,23.257,2459.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,16384,16384,128,128,192,128,bwd,True,23.761,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,12.794,3438.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,44.344,2579.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,43.948,2602.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,fwd,False,7.336,2998.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,False,31.002,1844.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,16384,16384,128,128,192,128,bwd,True,30.575,1870.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,fwd,False,14.110,3117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,False,56.977,2007.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,16384,16384,128,128,192,128,bwd,True,56.908,2009.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,fwd,False,3.013,1825.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,False,11.074,1291.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,8192,8192,128,128,192,128,bwd,True,10.887,1313.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,fwd,False,5.439,2021.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,False,21.973,1301.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,8192,8192,128,128,192,128,bwd,True,20.598,1388.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,fwd,False,1.758,3128.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,False,6.188,2310.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,8192,8192,128,128,192,128,bwd,True,6.181,2313.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,3.233,3400.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,11.395,2509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,11.339,2521.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,fwd,False,1.944,2829.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,False,8.030,1780.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,8192,8192,128,128,192,128,bwd,True,8.027,1781.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,fwd,False,3.538,3108.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,False,14.104,2027.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,8192,8192,128,128,192,128,bwd,True,14.079,2030.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,fwd,False,0.864,1590.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,False,3.127,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,4096,4096,128,128,192,128,bwd,True,3.017,1185.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,fwd,False,1.419,1937.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,False,5.641,1267.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,4096,4096,128,128,192,128,bwd,True,5.603,1276.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.493,2788.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,False,1.803,1982.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,4096,4096,128,128,192,128,bwd,True,1.800,1985.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,0.832,3304.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,2.971,2406.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,2.969,2407.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,fwd,False,0.546,2517.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,False,2.348,1522.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,4096,4096,128,128,192,128,bwd,True,2.351,1520.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,fwd,False,0.912,3015.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,False,3.720,1921.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,4096,4096,128,128,192,128,bwd,True,3.721,1921.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,fwd,False,0.268,1285.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,False,0.977,914.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,top_left,2,2048,2048,128,128,192,128,bwd,True,0.935,956.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.388,1770.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.600,1117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,bfloat16,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.553,1150.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.150,2290.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.591,1511.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.592,1509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.225,3058.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,0.851,2099.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,fp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,0.852,2098.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,fwd,False,0.166,2069.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,False,0.772,1158.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,top_left,2,2048,2048,128,128,192,128,bwd,True,0.772,1158.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,fwd,False,0.249,2758.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,False,1.073,1665.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+dsv3,dsv3,cudnn,mxfp8,no_mask,2,2048,2048,128,128,192,128,bwd,True,1.074,1663.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png
new file mode 100644
index 00000000..dec37f4e
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png
new file mode 100644
index 00000000..fea637e8
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png
new file mode 100644
index 00000000..a921070c
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png
new file mode 100644
index 00000000..7c38409c
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/dsv3/gb300/dsv3_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv
new file mode 100644
index 00000000..ec97be3a
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_20260424_100011.csv
@@ -0,0 +1,46 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,fwd,False,2.634,104.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,False,3.692,186.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,True,4.691,146.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.529,179.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,False,90.348,8.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,True,90.348,8.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.637,168.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,False,108.370,6.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,True,108.368,6.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,fwd,False,0.822,167.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,False,1.818,188.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,True,2.354,145.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.766,179.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,False,23.231,15.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,True,23.234,15.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.826,166.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,False,27.940,12.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,True,27.910,12.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,fwd,False,0.414,165.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,False,0.911,187.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,True,1.185,144.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.385,177.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,False,6.176,28.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,True,6.177,28.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.417,164.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,False,7.427,23.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,True,7.426,23.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,fwd,False,0.210,161.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,False,0.465,182.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,True,0.602,140.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.195,173.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.731,49.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.732,49.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.210,161.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,False,2.080,41.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,True,2.081,41.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,fwd,False,0.109,153.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,False,0.237,176.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,True,0.311,134.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.101,165.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.536,78.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.536,78.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.109,153.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.642,65.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.642,65.000,0.000,10,128,True,,NVIDIA GB200,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png
new file mode 100644
index 00000000..27eb0ed0
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png
new file mode 100644
index 00000000..a1becf02
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb200/gpt_oss_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv
new file mode 100644
index 00000000..60792295
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_20260424_100022.csv
@@ -0,0 +1,46 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,fwd,False,2.530,108.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,False,3.450,199.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,32768,32768,128,128,64,64,bwd,True,4.376,157.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.296,212.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,False,74.763,9.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,32768,32768,128,128,64,64,bwd,True,74.760,9.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,fwd,False,1.345,204.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,False,96.913,7.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,32768,32768,128,128,64,64,bwd,True,96.937,7.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,fwd,False,0.790,173.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,False,1.712,200.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,16384,16384,128,128,64,64,bwd,True,2.194,156.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.650,211.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,False,19.354,18.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,16384,16384,128,128,64,64,bwd,True,19.347,18.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,fwd,False,0.676,203.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,False,24.996,14.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,16384,16384,128,128,64,64,bwd,True,24.995,14.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,fwd,False,0.397,172.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,False,0.862,198.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,8192,8192,128,128,64,64,bwd,True,1.106,154.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.327,208.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,False,5.189,33.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,8192,8192,128,128,64,64,bwd,True,5.189,33.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,fwd,False,0.342,200.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,False,6.648,26.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,8192,8192,128,128,64,64,bwd,True,6.649,26.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,fwd,False,0.202,167.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,False,0.442,191.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,4096,4096,128,128,64,64,bwd,True,0.565,150.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.167,203.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.480,57.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.481,57.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,fwd,False,0.175,194.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,False,1.871,45.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,4096,4096,128,128,64,64,bwd,True,1.871,45.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,fwd,False,0.104,160.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,False,0.227,184.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,bfloat16,top_left,2,2048,2048,128,128,64,64,bwd,True,0.291,143.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.086,193.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.467,89.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,fp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.467,89.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,fwd,False,0.092,182.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,False,0.581,72.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
+gpt_oss,gpt_oss,cudnn,mxfp8,top_left,2,2048,2048,128,128,64,64,bwd,True,0.580,72.000,0.000,10,128,True,,NVIDIA GB300,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png
new file mode 100644
index 00000000..dcc9f6f0
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png
new file mode 100644
index 00000000..7c88aa95
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gpt_oss/gb300/gpt_oss_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv
new file mode 100644
index 00000000..30b06508
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_20260424_100953.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,fwd,False,25.720,1710.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,False,105.445,1084.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,True,102.196,1119.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,fwd,False,51.776,1699.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,False,213.066,1073.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,True,202.309,1130.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,fwd,False,18.194,2417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,False,58.126,1967.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,True,60.971,1876.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,36.262,2426.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,113.717,2011.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,112.404,2035.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,fwd,False,18.417,2388.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,False,72.533,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,True,70.153,1630.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,37.733,2331.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,141.358,1618.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,142.283,1607.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,fwd,False,6.398,1719.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,False,24.921,1147.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,True,25.278,1131.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,fwd,False,12.457,1765.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,False,50.963,1122.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,True,47.797,1196.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,fwd,False,4.326,2542.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,False,14.195,2014.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,True,13.830,2067.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,8.242,2668.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,27.238,2099.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,27.125,2108.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,fwd,False,4.667,2356.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,False,17.427,1641.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,True,17.224,1660.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,9.089,2419.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,34.587,1653.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,32.626,1752.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,fwd,False,1.733,1586.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,False,6.005,1190.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,True,5.967,1198.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,fwd,False,3.087,1781.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,False,11.719,1220.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,True,12.095,1182.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,fwd,False,1.139,2414.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.689,1937.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.697,1933.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.155,2551.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,6.535,2187.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,6.532,2188.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,fwd,False,1.254,2193.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,False,4.546,1572.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,True,4.540,1575.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.272,2420.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,8.032,1780.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,7.874,1815.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,fwd,False,0.518,1328.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,False,1.698,1053.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,True,1.644,1087.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.841,1635.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,False,3.009,1187.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,True,3.004,1189.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.338,2034.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.082,1652.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.082,1651.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.575,2389.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.723,2074.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.750,2042.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.363,1891.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.331,1343.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.337,1336.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.625,2200.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,2.120,1686.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,2.125,1681.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,fwd,False,0.167,1029.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,False,0.536,834.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,True,0.526,850.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.237,1448.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.879,1016.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.863,1036.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.110,1557.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.355,1260.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.356,1256.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.165,2077.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.512,1743.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.512,1746.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.119,1447.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.443,1009.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.443,1008.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.178,1927.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.631,1417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.631,1416.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png
new file mode 100644
index 00000000..db3418af
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png
new file mode 100644
index 00000000..e495d05f
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png
new file mode 100644
index 00000000..59f87bae
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png
new file mode 100644
index 00000000..8a973840
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb200/kimiK26_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv
new file mode 100644
index 00000000..0966bc87
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_20260424_100915.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,fwd,False,21.498,2046.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,False,88.821,1287.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,32768,32768,64,64,192,128,bwd,True,86.626,1320.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,fwd,False,43.698,2013.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,False,182.491,1253.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,32768,32768,64,64,192,128,bwd,True,175.560,1303.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,fwd,False,12.849,3423.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,False,45.480,2514.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,32768,32768,64,64,192,128,bwd,True,45.333,2523.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,25.416,3461.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,90.921,2515.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,90.997,2513.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,fwd,False,14.100,3119.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,False,59.247,1930.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,32768,32768,64,64,192,128,bwd,True,58.968,1939.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,fwd,False,28.687,3066.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,False,115.014,1988.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,32768,32768,64,64,192,128,bwd,True,114.422,1999.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,fwd,False,5.442,2021.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,False,21.253,1345.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,16384,16384,64,64,192,128,bwd,True,20.838,1372.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,fwd,False,10.590,2077.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,False,43.165,1325.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,16384,16384,64,64,192,128,bwd,True,40.794,1402.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,fwd,False,3.266,3366.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,False,11.276,2535.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,16384,16384,64,64,192,128,bwd,True,11.578,2469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,6.268,3509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,21.662,2639.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,21.863,2615.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,fwd,False,3.613,3044.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,False,14.766,1936.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,16384,16384,64,64,192,128,bwd,True,14.824,1929.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,fwd,False,6.924,3176.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,False,27.533,2077.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,16384,16384,64,64,192,128,bwd,True,27.085,2111.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,fwd,False,1.466,1875.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,False,5.466,1308.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,8192,8192,64,64,192,128,bwd,True,5.271,1356.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,fwd,False,2.632,2089.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,False,10.654,1342.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,8192,8192,64,64,192,128,bwd,True,10.128,1411.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,fwd,False,0.865,3178.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.064,2333.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.064,2333.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,1.584,3471.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,5.463,2617.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,5.481,2608.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,fwd,False,0.957,2872.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,False,3.991,1791.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,8192,8192,64,64,192,128,bwd,True,3.990,1791.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,fwd,False,1.739,3160.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,False,6.904,2070.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,8192,8192,64,64,192,128,bwd,True,6.872,2080.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,fwd,False,0.421,1632.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,False,1.548,1154.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,4096,4096,64,64,192,128,bwd,True,1.507,1186.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.701,1960.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,False,2.809,1272.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,4096,4096,64,64,192,128,bwd,True,2.802,1275.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.244,2822.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,False,0.904,1976.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,4096,4096,64,64,192,128,bwd,True,0.906,1973.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.418,3289.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.472,2428.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.474,2424.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,fwd,False,0.269,2555.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,False,1.185,1508.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,4096,4096,64,64,192,128,bwd,True,1.184,1509.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,fwd,False,0.462,2973.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,False,1.854,1928.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,4096,4096,64,64,192,128,bwd,True,1.854,1928.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,fwd,False,0.136,1265.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,False,0.492,909.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,top_left,2,2048,2048,64,64,192,128,bwd,True,0.477,937.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.190,1809.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.819,1090.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,bfloat16,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.799,1117.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.078,2193.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.304,1470.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.304,1470.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.114,3026.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.443,2019.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,fp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.441,2027.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,fwd,False,0.086,1991.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,False,0.397,1125.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,top_left,2,2048,2048,64,64,192,128,bwd,True,0.396,1127.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,fwd,False,0.127,2701.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,False,0.557,1603.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+kimiK26,kimiK26,cudnn,mxfp8,no_mask,2,2048,2048,64,64,192,128,bwd,True,0.558,1601.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png
new file mode 100644
index 00000000..1c04da20
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png
new file mode 100644
index 00000000..6cb653e3
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png
new file mode 100644
index 00000000..0f005ab9
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png
new file mode 100644
index 00000000..6de346a4
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/kimiK26/gb300/kimiK26_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv
new file mode 100644
index 00000000..987e0f0c
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_20260424_100750.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,fwd,False,21.692,1622.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,False,71.756,1226.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,True,82.343,1068.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,fwd,False,43.056,1634.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,False,139.364,1262.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,True,159.976,1100.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,fwd,False,16.012,2197.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,False,52.155,1687.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,True,51.781,1699.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,31.655,2223.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,96.384,1825.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,97.941,1796.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,fwd,False,17.201,2046.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,False,61.716,1425.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,True,61.437,1432.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,33.996,2070.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,118.844,1480.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,118.852,1480.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,fwd,False,5.150,1708.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,False,16.554,1328.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,True,19.007,1157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,fwd,False,10.369,1697.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,False,32.014,1374.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,True,38.480,1143.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,fwd,False,4.084,2154.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,False,13.476,1632.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,True,13.439,1636.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,7.773,2263.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,23.960,1836.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,23.444,1876.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,fwd,False,4.400,1999.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,False,15.315,1436.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,True,15.408,1427.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,8.365,2103.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,28.309,1554.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,28.551,1540.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,fwd,False,1.397,1574.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,False,4.127,1332.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,True,4.754,1157.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.483,1771.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.831,1404.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,True,9.306,1182.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,fwd,False,1.074,2047.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.781,1454.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.777,1456.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.962,2241.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,6.206,1772.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,6.209,1771.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,fwd,False,1.156,1902.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,False,4.272,1287.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,True,4.270,1288.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.115,2080.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.352,1495.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,7.353,1495.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,fwd,False,0.391,1408.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,False,1.198,1148.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,True,1.352,1017.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.687,1599.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,False,2.037,1350.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.339,1175.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.305,1805.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.167,1178.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.164,1181.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.525,2093.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.743,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.743,1577.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.326,1685.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.305,1054.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.312,1048.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.565,1944.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,2.046,1344.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.056,1337.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,fwd,False,0.125,1097.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,False,0.389,885.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,True,0.431,797.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.192,1433.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.611,1124.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.686,1001.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.097,1420.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.414,831.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.413,832.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.146,1881.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.566,1214.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.565,1216.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.103,1330.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.461,745.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.461,746.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.157,1754.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.643,1069.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.642,1071.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png
new file mode 100644
index 00000000..c1fbe9b6
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png
new file mode 100644
index 00000000..f35d16e1
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png
new file mode 100644
index 00000000..caf946b0
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png
new file mode 100644
index 00000000..3561cc30
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb200/llama3.1_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv
new file mode 100644
index 00000000..655b74f1
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_20260424_100757.csv
@@ -0,0 +1,91 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,fwd,False,18.014,1953.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,False,59.024,1490.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,32768,32768,64,8,128,128,bwd,True,68.488,1284.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,fwd,False,38.317,1836.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,False,123.272,1427.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,32768,32768,64,8,128,128,bwd,True,142.081,1238.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,fwd,False,11.319,3108.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,False,41.328,2128.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,32768,32768,64,8,128,128,bwd,True,42.165,2086.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,23.072,3050.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,80.754,2178.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,81.705,2153.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,fwd,False,12.632,2785.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,False,55.340,1590.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,32768,32768,64,8,128,128,bwd,True,55.204,1593.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,fwd,False,25.005,2814.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,False,106.358,1654.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,32768,32768,64,8,128,128,bwd,True,105.826,1662.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,fwd,False,4.593,1915.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,False,14.593,1507.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,16384,16384,64,8,128,128,bwd,True,16.462,1336.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,fwd,False,8.734,2014.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,False,28.638,1536.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,16384,16384,64,8,128,128,bwd,True,32.449,1355.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,fwd,False,2.880,3054.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,False,10.805,2035.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,16384,16384,64,8,128,128,bwd,True,10.823,2032.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,5.529,3182.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,20.203,2177.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,20.296,2167.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,fwd,False,3.225,2727.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,False,14.072,1563.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,16384,16384,64,8,128,128,bwd,True,14.074,1563.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,fwd,False,6.089,2889.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,False,26.041,1689.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,16384,16384,64,8,128,128,bwd,True,25.719,1710.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,fwd,False,1.185,1856.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,False,3.832,1435.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,8192,8192,64,8,128,128,bwd,True,4.298,1279.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,fwd,False,2.181,2017.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,False,7.181,1531.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,8192,8192,64,8,128,128,bwd,True,7.932,1386.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,fwd,False,0.759,2897.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.087,1781.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.088,1780.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.392,3160.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,5.209,2111.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,5.198,2115.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,fwd,False,0.856,2570.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,False,3.917,1404.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,8192,8192,64,8,128,128,bwd,True,3.915,1404.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,fwd,False,1.532,2870.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,False,6.723,1636.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,8192,8192,64,8,128,128,bwd,True,6.716,1637.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,fwd,False,0.323,1703.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,False,1.101,1248.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,4096,4096,64,8,128,128,bwd,True,1.220,1127.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.579,1900.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.924,1429.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,4096,4096,64,8,128,128,bwd,True,2.195,1252.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.217,2532.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,False,0.985,1396.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,4096,4096,64,8,128,128,bwd,True,0.986,1395.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.368,2988.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.495,1839.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.494,1840.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,fwd,False,0.242,2272.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,False,1.203,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,4096,4096,64,8,128,128,bwd,True,1.203,1143.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,fwd,False,0.408,2693.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,False,1.871,1469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,4096,4096,64,8,128,128,bwd,True,1.871,1469.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,fwd,False,0.100,1373.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,False,0.356,967.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,top_left,2,2048,2048,64,8,128,128,bwd,True,0.388,885.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.152,1805.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.567,1213.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,bfloat16,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.634,1083.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.070,1954.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.364,944.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.365,942.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.101,2728.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.494,1390.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,fp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.495,1387.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,fwd,False,0.079,1751.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,False,0.424,811.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,top_left,2,2048,2048,64,8,128,128,bwd,True,0.424,810.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,fwd,False,0.113,2443.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,False,0.591,1164.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+llama3.1,llama3.1,cudnn,mxfp8,no_mask,2,2048,2048,64,8,128,128,bwd,True,0.591,1163.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png
new file mode 100644
index 00000000..dcad4c50
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png
new file mode 100644
index 00000000..28da7dce
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png
new file mode 100644
index 00000000..2746fa0d
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png
new file mode 100644
index 00000000..ca1b3057
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/llama3.1/gb300/llama3.1_top_left_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv
new file mode 100644
index 00000000..c53bcf74
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_20260424_095758.csv
@@ -0,0 +1,16 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,fwd,False,0.437,1415.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,False,1.199,1290.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,True,1.338,1156.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,fwd,False,1.844,1589.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,False,5.225,1403.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,True,6.410,1143.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,fwd,False,2.959,1706.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,False,8.952,1410.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,True,11.175,1130.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,fwd,False,8.657,1731.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,False,28.505,1314.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,True,35.620,1052.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,fwd,False,14.403,1611.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,False,43.187,1343.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,True,53.136,1092.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png
new file mode 100644
index 00000000..8d386e6d
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png
new file mode 100644
index 00000000..e31f8f26
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb200/ltx2_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv
new file mode 100644
index 00000000..e4325010
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_20260424_095719.csv
@@ -0,0 +1,16 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,fwd,False,0.363,1704.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,False,1.089,1420.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,6144,6144,32,32,128,128,bwd,True,1.213,1275.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,fwd,False,1.551,1890.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,False,4.796,1528.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,13376,13376,32,32,128,128,bwd,True,5.838,1255.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,fwd,False,2.569,1966.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,False,8.245,1531.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,17556,17556,32,32,128,128,bwd,True,10.058,1255.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,fwd,False,7.623,1965.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,False,25.013,1497.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,30240,30240,32,32,128,128,bwd,True,30.447,1230.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,fwd,False,12.044,1926.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,False,38.088,1523.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+ltx2,ltx2,cudnn,bfloat16,no_mask,1,37632,37632,32,32,128,128,bwd,True,42.066,1379.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png
new file mode 100644
index 00000000..1992f9d0
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png
new file mode 100644
index 00000000..c0b0fbd3
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/ltx2/gb300/ltx2_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv
new file mode 100644
index 00000000..cc7f3cff
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_20260424_095249.csv
@@ -0,0 +1,6 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+qwen35,qwen35,cudnn,bfloat16,top_left,1,32768,32768,32,2,256,256,fwd,False,10.148,1734.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,16384,16384,32,2,256,256,fwd,False,2.438,1804.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,8192,8192,32,2,256,256,fwd,False,0.644,1708.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,4096,4096,32,2,256,256,fwd,False,0.191,1441.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,2048,2048,32,2,256,256,fwd,False,0.061,1126.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png
new file mode 100644
index 00000000..37e45c80
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/qwen35/gb200/qwen35_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv
new file mode 100644
index 00000000..63464f51
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_20260424_095247.csv
@@ -0,0 +1,6 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+qwen35,qwen35,cudnn,bfloat16,top_left,1,32768,32768,32,2,256,256,fwd,False,8.394,2096.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,16384,16384,32,2,256,256,fwd,False,2.132,2063.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,8192,8192,32,2,256,256,fwd,False,0.560,1965.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,4096,4096,32,2,256,256,fwd,False,0.158,1743.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
+qwen35,qwen35,cudnn,bfloat16,top_left,1,2048,2048,32,2,256,256,fwd,False,0.053,1307.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200.000
diff --git a/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png
new file mode 100644
index 00000000..bdab2e02
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/qwen35/gb300/qwen35_top_left.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv
new file mode 100644
index 00000000..2e064ca9
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_20260424_095743.csv
@@ -0,0 +1,16 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,fwd,False,0.803,1552.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,False,2.236,1393.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,True,2.785,1118.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,fwd,False,3.383,1783.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,False,10.638,1417.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,True,13.730,1098.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,fwd,False,13.191,1666.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,False,41.653,1319.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,True,52.261,1051.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,fwd,False,28.902,1657.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,False,93.715,1278.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,True,117.248,1021.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,fwd,False,76.372,1533.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,False,233.302,1254.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,True,287.864,1017.000,0.000,10,,True,,NVIDIA GB200,1.21.1,92200
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png
new file mode 100644
index 00000000..a77ce366
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png
new file mode 100644
index 00000000..ad2a883b
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb200/wan22_no_mask_det_overhead.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv
new file mode 100644
index 00000000..4b37e548
--- /dev/null
+++ b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_20260424_095741.csv
@@ -0,0 +1,16 @@
+config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,time_ms,tflops,max_diff,num_iterations,sliding_window_size,success,error_message,gpu_name,cudnn_version,cudnn_backend_version
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,fwd,False,0.678,1836.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,False,2.139,1456.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,7800,7800,40,40,128,128,bwd,True,2.631,1184.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,fwd,False,3.002,2009.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,False,9.883,1525.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,17160,17160,40,40,128,128,bwd,True,12.048,1251.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,fwd,False,11.213,1960.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,False,37.545,1464.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,32760,32760,40,40,128,128,bwd,True,45.146,1217.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,fwd,False,25.686,1865.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,False,88.194,1358.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,48360,48360,40,40,128,128,bwd,True,102.918,1163.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,fwd,False,64.212,1823.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,False,214.474,1364.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
+wan22,wan22_a14b,cudnn,bfloat16,no_mask,1,75600,75600,40,40,128,128,bwd,True,258.728,1131.000,0.000,10,,True,,NVIDIA GB300,1.21.1,92200
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png
new file mode 100644
index 00000000..80014b77
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask.png differ
diff --git a/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png
new file mode 100644
index 00000000..44df0932
Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/wan22/gb300/wan22_no_mask_det_overhead.png differ
diff --git a/cmake/cuDNN.cmake b/cmake/cuDNN.cmake
index 61a3cba9..a6998046 100644
--- a/cmake/cuDNN.cmake
+++ b/cmake/cuDNN.cmake
@@ -110,6 +110,7 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_adv OPTIONAL)
find_cudnn_library(cudnn_engines_precompiled OPTIONAL)
find_cudnn_library(cudnn_heuristic OPTIONAL)
+ find_cudnn_library(cudnn_ext OPTIONAL)
target_link_libraries(
CUDNN::cudnn_all
diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h
index 06121ddb..e20c9599 100644
--- a/include/cudnn_frontend/cudnn_interface.h
+++ b/include/cudnn_frontend/cudnn_interface.h
@@ -69,6 +69,17 @@ create_cudnn_tensor(
tensor_builder.setVectorCountAndDimension(props->get_vector_count(), props->get_vector_dimension());
}
+ // Set compile-time constant value before build (if present)
+ if (props->get_has_compile_time_constant()) {
+ auto const_value = props->get_compile_time_constant();
+ if (const_value.has_value()) {
+ std::visit([&tensor_builder](auto&& val) { tensor_builder.setConstValue(val); }, *const_value);
+
+ CUDNN_FE_LOG_LABEL_ENDL("INFO: Compile-time constant value set for tensor '" << props->get_name()
+ << "'");
+ }
+ }
+
if (auto ragged_offset_props = props->get_ragged_offset()) {
CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids));
tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid()));
diff --git a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h
index 175986cb..ee3664a8 100644
--- a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h
+++ b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d128_fprop_kernel.h
@@ -69,8 +69,6 @@ inline __device__ void fastDivMod(const FastDivisor_t &d, uint32_t val,
mod = val - div * d.val;
}
-__device__ __inline__ void cfence() {}
-
inline __device__ char *get_smem_loc_epilogue_swizzle_128b(
char *smem_addr, int local_block_id, int tid, int local_row, int column,
size_t element_size, int block_size, int row_per_tile) {
@@ -1765,7 +1763,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
#pragma unroll
for (int i = 0; i < BMM1_TILE_N; i += 2) {
- cfence();
if (i - kConvertPipeCount == 32) {
sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]);
}
@@ -1780,12 +1777,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
®_12_0[((i - kConvertPipeCount) / 2) % 32],
reinterpret_cast(®_8_0[i - kConvertPipeCount]));
}
- cfence();
reinterpret_cast(reg_8_0[i + 0]) =
exp2f(reinterpret_cast(reg_8_0[i + 0]));
- cfence();
if (i + kFmaPipeCount < BMM1_TILE_N) {
float2 in = make_float2(
@@ -1841,7 +1836,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
named_barrier_arrive(SOFTMAX_BARRIER + 1, 256);
named_barrier_wait(SOFTMAX_BARRIER, 256);
}
- cfence();
named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128);
}
bmm_mbar_state ^= 1;
@@ -1947,7 +1941,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
#pragma unroll
for (int i = 0; i < BMM1_TILE_N; i += 2) {
- cfence();
if (i - kConvertPipeCount == 32) {
sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]);
}
@@ -1962,12 +1955,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
®_12_0[((i - kConvertPipeCount) / 2) % 32],
reinterpret_cast(®_8_0[i - kConvertPipeCount]));
}
- cfence();
reinterpret_cast(reg_8_0[i + 0]) =
exp2f(reinterpret_cast(reg_8_0[i + 0]));
- cfence();
if (i + kFmaPipeCount < BMM1_TILE_N) {
float2 in = make_float2(
@@ -2023,7 +2014,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
named_barrier_arrive(SOFTMAX_BARRIER + 1, 256);
named_barrier_wait(SOFTMAX_BARRIER, 256);
}
- cfence();
named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128);
}
bmm_mbar_state ^= 1;
@@ -2232,12 +2222,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
sttm_step * 8,
&fp32_O[8 * sttm_step]);
}
- cfence();
}
}
fence_view_async_tmem_store();
arrive_barrier(cast_smem_ptr_to_uint(bmm_ready_mbar));
- cfence();
}
stat_mbar_state ^= 1;
bmm_mbar_state ^= 1;
@@ -2372,7 +2360,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
sts_128(cast_smem_ptr_to_uint(smem_loc),
reinterpret_cast(®_O[i * 4]));
}
- cfence();
uint64_t *tma_o_full_mbar =
sub_tile_id == 0 ? &(shared_storage.tma_o_0_full_mbar[block])
@@ -2380,7 +2367,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
fence_view_async_shared();
arrive_barrier(cast_smem_ptr_to_uint(tma_o_full_mbar));
}
- cfence();
}
stat_mbar_state ^= 1;
bmm_mbar_state ^= 1;
diff --git a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h
index c1aa5278..33b9ef2d 100644
--- a/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h
+++ b/include/cudnn_frontend/generated/sdpa/sm100/prefill/full_seqlens/d64_fprop_kernel.h
@@ -69,8 +69,6 @@ inline __device__ void fastDivMod(const FastDivisor_t &d, uint32_t val,
mod = val - div * d.val;
}
-__device__ __inline__ void cfence() {}
-
inline __device__ char *get_smem_loc_epilogue_swizzle_128b(
char *smem_addr, int local_block_id, int tid, int local_row, int column,
size_t element_size, int block_size, int row_per_tile) {
@@ -1766,7 +1764,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
#pragma unroll
for (int i = 0; i < BMM1_TILE_N; i += 2) {
- cfence();
if (i - kConvertPipeCount == 32) {
sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]);
}
@@ -1781,12 +1778,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
®_12_0[((i - kConvertPipeCount) / 2) % 32],
reinterpret_cast(®_8_0[i - kConvertPipeCount]));
}
- cfence();
reinterpret_cast(reg_8_0[i + 0]) =
exp2f(reinterpret_cast(reg_8_0[i + 0]));
- cfence();
if (i + kFmaPipeCount < BMM1_TILE_N) {
float2 in = make_float2(
@@ -1842,7 +1837,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
named_barrier_arrive(SOFTMAX_BARRIER + 1, 256);
named_barrier_wait(SOFTMAX_BARRIER, 256);
}
- cfence();
named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128);
}
bmm_mbar_state ^= 1;
@@ -1948,7 +1942,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
#pragma unroll
for (int i = 0; i < BMM1_TILE_N; i += 2) {
- cfence();
if (i - kConvertPipeCount == 32) {
sttm_32dp32bit_x16(tmem_fp16_S, ®_12_0[0]);
}
@@ -1963,12 +1956,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
®_12_0[((i - kConvertPipeCount) / 2) % 32],
reinterpret_cast(®_8_0[i - kConvertPipeCount]));
}
- cfence();
reinterpret_cast(reg_8_0[i + 0]) =
exp2f(reinterpret_cast(reg_8_0[i + 0]));
- cfence();
if (i + kFmaPipeCount < BMM1_TILE_N) {
float2 in = make_float2(
@@ -2024,7 +2015,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
named_barrier_arrive(SOFTMAX_BARRIER + 1, 256);
named_barrier_wait(SOFTMAX_BARRIER, 256);
}
- cfence();
named_barrier_wait(SOFTMAX_BARRIER + 2 + softmax_gid, 128);
}
bmm_mbar_state ^= 1;
@@ -2234,12 +2224,10 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
sttm_step * 8,
&fp32_O[8 * sttm_step]);
}
- cfence();
}
}
fence_view_async_tmem_store();
arrive_barrier(cast_smem_ptr_to_uint(bmm_ready_mbar));
- cfence();
}
stat_mbar_state ^= 1;
bmm_mbar_state ^= 1;
@@ -2374,7 +2362,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
sts_128(cast_smem_ptr_to_uint(smem_loc),
reinterpret_cast(®_O[i * 4]));
}
- cfence();
uint64_t *tma_o_full_mbar =
sub_tile_id == 0 ? &(shared_storage.tma_o_0_full_mbar[block])
@@ -2382,7 +2369,6 @@ __launch_bounds__(512, 1) void cudnn_generated_oss_sdpa_sm100_flash_fprop_f16_kn
fence_view_async_shared();
arrive_barrier(cast_smem_ptr_to_uint(tma_o_full_mbar));
}
- cfence();
}
stat_mbar_state ^= 1;
bmm_mbar_state ^= 1;
diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h
index 1a0e3d7e..8ce19101 100644
--- a/include/cudnn_frontend/graph_interface.h
+++ b/include/cudnn_frontend/graph_interface.h
@@ -22,6 +22,8 @@
#include "node/resample.h"
#include "node/reshape.h"
#include "node/slice.h"
+#include "node/transpose.h"
+// #include "node/scaled_dot_product_attention.h"
#include "node/scaled_dot_product_flash_attention.h"
#include "node/sdpa_fp8_bwd.h"
#include "node/block_scale_quantize.h"
@@ -194,6 +196,8 @@ class Graph : public ICudnn, public INode {
tensor_to_pointer_map.emplace(uid, int64_t_value_ptr);
} else if (float *float_value_ptr = std::get_if(&value)) {
tensor_to_pointer_map.emplace(uid, float_value_ptr);
+ } else if (double *double_value_ptr = std::get_if(&value)) {
+ tensor_to_pointer_map.emplace(uid, double_value_ptr);
} else {
RETURN_CUDNN_FRONTEND_ERROR_IF(
true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor.");
@@ -1450,12 +1454,30 @@ class Graph : public ICudnn, public INode {
std::unordered_map tensor_to_pass_by_value;
CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value));
- j["pass_by_values"] = tensor_to_pass_by_value;
+
+ // Convert pass_by_values to JSON (unordered_map with numeric keys needs manual conversion)
+ json pass_by_values_json = json::object();
+ for (const auto &[uid, variant_value] : tensor_to_pass_by_value) {
+ json variant_json;
+ variant_json = variant_value;
+ pass_by_values_json[std::to_string(uid)] = variant_json;
+ }
+ j["pass_by_values"] = pass_by_values_json;
std::unordered_map>> workspace_modifications;
int64_t workspace_offset = 0;
CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset));
- j["workspace_modifications"] = workspace_modifications;
+
+ // Convert workspace_modifications to JSON (nlohmann::json doesn't support std::tuple directly)
+ json workspace_modifications_json = json::object();
+ for (const auto &[uid, tuple_value] : workspace_modifications) {
+ json tuple_json = json::array();
+ tuple_json.push_back(std::get<0>(tuple_value));
+ tuple_json.push_back(std::get<1>(tuple_value));
+ tuple_json.push_back(std::get<2>(tuple_value));
+ workspace_modifications_json[std::to_string(uid)] = tuple_json;
+ }
+ j["workspace_modifications"] = workspace_modifications_json;
j["variant_pack_replacements"] = variant_pack_replacements;
@@ -1500,9 +1522,28 @@ class Graph : public ICudnn, public INode {
variant_pack_uids = j["variant_pack_uids"].get>();
- deserialized_pass_by_value = j["pass_by_values"];
+ // Deserialize pass_by_values from JSON
+ if (j.contains("pass_by_values")) {
+ auto pass_by_values_json = j["pass_by_values"];
+ for (auto it = pass_by_values_json.begin(); it != pass_by_values_json.end(); ++it) {
+ uid_t uid = std::stoll(it.key());
+ pass_by_values_t value = it.value().get();
+ deserialized_pass_by_value[uid] = value;
+ }
+ }
- deserialized_workspace_modifications = j["workspace_modifications"];
+ // Deserialize workspace_modifications from JSON
+ if (j.contains("workspace_modifications")) {
+ auto workspace_modifications_json = j["workspace_modifications"];
+ for (auto it = workspace_modifications_json.begin(); it != workspace_modifications_json.end(); ++it) {
+ uid_t uid = std::stoll(it.key());
+ auto tuple_json = it.value();
+ auto tuple_value = std::make_tuple(tuple_json[0].get(),
+ tuple_json[1].get(),
+ tuple_json[2].get>());
+ deserialized_workspace_modifications[uid] = tuple_value;
+ }
+ }
variant_pack_replacements = j["variant_pack_replacements"];
@@ -1572,6 +1613,25 @@ class Graph : public ICudnn, public INode {
std::shared_ptr
tensor(Tensor_attributes const &tensor);
+ // Overloaded tensor() methods for compile-time constants
+ std::shared_ptr
+ tensor(float const &scalar, ScalarType scalar_type);
+
+ std::shared_ptr
+ tensor(half const &scalar, ScalarType scalar_type);
+
+ std::shared_ptr
+ tensor(nv_bfloat16 const &scalar, ScalarType scalar_type);
+
+ std::shared_ptr
+ tensor(int32_t const &scalar, ScalarType scalar_type);
+
+ std::shared_ptr
+ tensor(int64_t const &scalar, ScalarType scalar_type);
+
+ std::shared_ptr
+ tensor(double const &scalar, ScalarType scalar_type);
+
std::shared_ptr
tensor_like(std::shared_ptr const &tensor, std::string const &name = std::string{});
@@ -1736,6 +1796,8 @@ class Graph : public ICudnn, public INode {
std::shared_ptr slice(std::shared_ptr, Slice_attributes);
+ std::shared_ptr transpose(std::shared_ptr, Transpose_attributes);
+
std::array, 2> block_scale_quantize(std::shared_ptr,
Block_scale_quantize_attributes);
@@ -2353,6 +2415,12 @@ class Graph : public ICudnn, public INode {
CHECK_TENSORS(slice_attributes);
FILL_GLOBAL_IO_TENSOR_MAP(slice_attributes);
sub_nodes.emplace_back(std::make_unique(std::move(slice_attributes), context));
+ } else if (tag == "TRANSPOSE") {
+ auto transpose_attributes = j_sub_node.get();
+ CHECK_TENSORS(transpose_attributes);
+ FILL_GLOBAL_IO_TENSOR_MAP(transpose_attributes);
+ sub_nodes.emplace_back(
+ std::make_unique(std::move(transpose_attributes), context));
} else if (tag == "RESAMPLE") {
auto resample_attributes = j_sub_node.get();
CHECK_TENSORS(resample_attributes);
@@ -2686,6 +2754,49 @@ Graph::tensor(Tensor_attributes const &tensor) {
return tensor_ptr;
}
+// Overloaded tensor() methods for compile-time constants
+inline std::shared_ptr
+Graph::tensor(float const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
+inline std::shared_ptr
+Graph::tensor(half const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
+inline std::shared_ptr
+Graph::tensor(nv_bfloat16 const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
+inline std::shared_ptr
+Graph::tensor(int32_t const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
+inline std::shared_ptr
+Graph::tensor(int64_t const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
+inline std::shared_ptr
+Graph::tensor(double const &scalar, ScalarType scalar_type) {
+ auto tensor_ptr = std::make_shared(scalar, scalar_type);
+ full_graph_inputs.emplace(tensor_ptr);
+ return tensor_ptr;
+}
+
inline error_t
Graph::query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const {
for (auto const &o_tensor : full_graph_outputs) {
@@ -3363,6 +3474,15 @@ Graph::slice(std::shared_ptr input, Slice_attributes attribut
return Y;
}
+inline std::shared_ptr
+Graph::transpose(std::shared_ptr input, Transpose_attributes attributes) {
+ attributes.inputs[Transpose_attributes::input_names::X] = input;
+ auto Y = attributes.outputs[Transpose_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
+
+ sub_nodes.emplace_back(std::make_unique(std::move(attributes), context));
+ return Y;
+}
+
inline std::array, 2>
Graph::block_scale_quantize(std::shared_ptr x, Block_scale_quantize_attributes attributes) {
// Set outputs
diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h
index dee506a2..71ae8962 100644
--- a/include/cudnn_frontend/graph_properties.h
+++ b/include/cudnn_frontend/graph_properties.h
@@ -1,6 +1,8 @@
#pragma once
+#include
+#include
#include
#include
#include
@@ -17,8 +19,31 @@ namespace cudnn_frontend {
namespace graph {
+inline std::optional
+get_rescale_threshold_from_env() {
+ auto const* env_value = std::getenv("CUDNN_RESCALE_THRESHOLD");
+ if (env_value == nullptr || env_value[0] == '\0') {
+ return std::nullopt;
+ }
+
+ errno = 0;
+ char* end = nullptr;
+ auto value = std::strtof(env_value, &end);
+ if (env_value == end || (end != nullptr && *end != '\0') || errno == ERANGE) {
+ return std::nullopt;
+ }
+
+ return value;
+}
+
using managed_backend_descriptor_t = std::vector;
+// Enum to distinguish between runtime parameters and compile-time constants
+enum class ScalarType {
+ RUNTIME_PARAM, // Value provided at execution time (can change)
+ COMPILE_TIME_CONST // Value baked into graph (fixed, optimizable)
+};
+
// simple structure to hold all properties of a tensor.
// Each property has a getter setter.
class Tensor_attributes {
@@ -31,7 +56,7 @@ class Tensor_attributes {
// In approach 1, users provide a value to embed into the graph.
// In approach 2, users set is_pass_by_value boolean and then pass a pointer to scalar value with execute() API.
// A closed set of types that are allowed to be passed by value.
- using pass_by_values_t = std::variant;
+ using pass_by_values_t = std::variant;
error_t
validate() const {
@@ -52,6 +77,22 @@ class Tensor_attributes {
error_code_t::ATTRIBUTE_NOT_SET,
"Tensor '" + name + "' can't be a fused scalar and not a pass_by_value tensor at the same time.");
+ // Validate compile-time constant constraints
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ has_compile_time_constant && !is_pass_by_value,
+ error_code_t::ATTRIBUTE_NOT_SET,
+ "Tensor '" + name + "' with compile-time constant must have is_pass_by_value=true.");
+
+ RETURN_CUDNN_FRONTEND_ERROR_IF(has_compile_time_constant && is_virtual,
+ error_code_t::ATTRIBUTE_NOT_SET,
+ "Tensor '" + name + "' can't be both compile-time constant and virtual.");
+
+ // Can't have both compile-time constant and runtime parameter
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ has_compile_time_constant && pass_by_value.has_value(),
+ error_code_t::ATTRIBUTE_NOT_SET,
+ "Tensor '" + name + "' can't have both compile-time constant and runtime pass_by_value.");
+
return {error_code_t::OK, ""};
}
@@ -68,6 +109,10 @@ class Tensor_attributes {
std::optional pass_by_value = std::nullopt;
bool is_pass_by_value = false;
+ // Compile-time constant support (distinct from pass_by_value, which is for runtime parameters)
+ bool has_compile_time_constant = false;
+ std::optional compile_time_constant_value = std::nullopt;
+
TensorReordering_t reordering_type = TensorReordering_t::NONE;
uid_t uid = 0;
bool uid_assigned = false;
@@ -135,6 +180,86 @@ class Tensor_attributes {
data_type = DataType_t::INT64;
}
+ Tensor_attributes(double const& scalar) {
+ pass_by_value = scalar;
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::DOUBLE;
+ }
+
+ // Constructors with ScalarType for compile-time constant or runtime parameter control
+ Tensor_attributes(float const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::FLOAT;
+ }
+
+ Tensor_attributes(half const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::HALF;
+ }
+
+ Tensor_attributes(nv_bfloat16 const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::BFLOAT16;
+ }
+
+ Tensor_attributes(int32_t const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::INT32;
+ }
+
+ Tensor_attributes(int64_t const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::INT64;
+ }
+
+ Tensor_attributes(double const& scalar, ScalarType scalar_type) {
+ if (scalar_type == ScalarType::COMPILE_TIME_CONST) {
+ compile_time_constant_value = scalar;
+ has_compile_time_constant = true;
+ } else {
+ pass_by_value = scalar;
+ }
+ is_pass_by_value = true;
+ dim = stride = {1};
+ data_type = DataType_t::DOUBLE;
+ }
+
std::string
get_name() const {
return name;
@@ -225,6 +350,35 @@ class Tensor_attributes {
return *this;
}
+ // Compile-time constant accessors
+ bool
+ get_has_compile_time_constant() const {
+ return has_compile_time_constant;
+ }
+
+ std::optional
+ get_compile_time_constant() const {
+ return compile_time_constant_value;
+ }
+
+ auto
+ set_compile_time_constant(pass_by_values_t const& value) -> Tensor_attributes& {
+ compile_time_constant_value = value;
+ has_compile_time_constant = true;
+ if (!is_pass_by_value) {
+ is_pass_by_value = true;
+ }
+ return *this;
+ }
+
+ auto
+ set_as_runtime_parameter() -> Tensor_attributes& {
+ is_pass_by_value = true;
+ has_compile_time_constant = false;
+ compile_time_constant_value = std::nullopt;
+ return *this;
+ }
+
TensorReordering_t
get_reordering_type() const {
return reordering_type;
@@ -315,9 +469,15 @@ class Attributes {
get_non_virtual_uids() const {
std::vector non_virtual_uids;
auto derived = static_cast(this);
+
+ // Compile-time constants are excluded from the variant pack
+ auto should_be_in_variant_pack = [](std::shared_ptr const& tensor) {
+ return tensor && tensor->get_is_virtual() == false && !tensor->get_has_compile_time_constant();
+ };
+
if constexpr (std::is_same_v) {
for (auto tensor : derived->inputs) {
- if (tensor && tensor->get_is_virtual() == false) {
+ if (should_be_in_variant_pack(tensor)) {
non_virtual_uids.push_back(tensor->get_uid());
if (auto ragged_offset = tensor->get_ragged_offset()) {
non_virtual_uids.push_back(ragged_offset->get_uid());
@@ -327,7 +487,7 @@ class Attributes {
} else {
for (auto& [name, tensor] : derived->inputs) {
(void)name;
- if (tensor && tensor->get_is_virtual() == false) {
+ if (should_be_in_variant_pack(tensor)) {
non_virtual_uids.push_back(tensor->get_uid());
if (auto ragged_offset = tensor->get_ragged_offset()) {
non_virtual_uids.push_back(ragged_offset->get_uid());
@@ -338,7 +498,7 @@ class Attributes {
for (auto& [name, tensor] : derived->outputs) {
(void)name;
- if (tensor && tensor->get_is_virtual() == false) {
+ if (should_be_in_variant_pack(tensor)) {
non_virtual_uids.push_back(tensor->get_uid());
if (auto ragged_offset = tensor->get_ragged_offset()) {
non_virtual_uids.push_back(ragged_offset->get_uid());
@@ -350,7 +510,7 @@ class Attributes {
if constexpr (std::is_same_v ||
std::is_same_v) {
for (auto& tensor : derived->peer_stats) {
- if (tensor && tensor->get_is_virtual() == false) {
+ if (should_be_in_variant_pack(tensor)) {
non_virtual_uids.push_back(tensor->get_uid());
if (auto ragged_offset = tensor->get_ragged_offset()) {
non_virtual_uids.push_back(ragged_offset->get_uid());
@@ -423,13 +583,16 @@ class Attributes {
set_compute_data_type(context.get_compute_data_type());
}
- // Handle shape and stride inferencing for fused scalars.
- // Pick number of dimensions from anyone of non-fused-scalar input/output tensors
- // In case, all tensors are fused scalars, just keep them 1D.
+ // Infer shape and stride for fused scalars (runtime params and compile-time constants).
+ // Fused scalars expand to match the dimensionality of non-scalar tensors; if all are scalars, keep 1D.
+ auto is_fused_scalar = [](std::shared_ptr const& tensor) {
+ return tensor && (tensor->get_pass_by_value().has_value() || tensor->get_has_compile_time_constant());
+ };
+
int64_t number_of_dims = 1;
if constexpr (std::is_same_v) {
for (auto tensor : derived->inputs) {
- if (tensor && (tensor->get_pass_by_value().has_value() == false)) {
+ if (tensor && !is_fused_scalar(tensor)) {
number_of_dims = tensor->get_dim().size();
break;
}
@@ -437,7 +600,7 @@ class Attributes {
} else {
for (auto [name, tensor] : derived->inputs) {
(void)name;
- if (tensor && (tensor->get_pass_by_value().has_value() == false)) {
+ if (tensor && !is_fused_scalar(tensor)) {
number_of_dims = tensor->get_dim().size();
break;
}
@@ -448,16 +611,17 @@ class Attributes {
if (number_of_dims == 1) {
for (auto [name, tensor] : derived->outputs) {
(void)name;
- if (tensor && (tensor->get_pass_by_value().has_value() == false)) {
+ if (tensor && !is_fused_scalar(tensor)) {
number_of_dims = tensor->get_dim().size();
break;
}
}
}
+ // Expand fused scalar dimensions to match the number of dims of non-scalar tensors
if constexpr (std::is_same_v) {
for (auto tensor : derived->inputs) {
- if (tensor && tensor->get_pass_by_value().has_value()) {
+ if (is_fused_scalar(tensor)) {
tensor->set_dim(std::vector(number_of_dims, 1));
tensor->set_stride(std::vector(number_of_dims, 1));
}
@@ -465,7 +629,7 @@ class Attributes {
} else {
for (auto [name, tensor] : derived->inputs) {
(void)name;
- if (tensor && tensor->get_pass_by_value().has_value()) {
+ if (is_fused_scalar(tensor)) {
tensor->set_dim(std::vector(number_of_dims, 1));
tensor->set_stride(std::vector(number_of_dims, 1));
}
@@ -1419,13 +1583,21 @@ class Reshape_attributes : public Attributes {
std::vector dim = {};
std::vector stride = {};
+ ReshapeMode_t reshape_mode = ReshapeMode_t::VIEW_ONLY;
public:
enum class input_names { X };
std::unordered_map> inputs;
enum class output_names { Y };
std::unordered_map> outputs;
- NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, name, compute_data_type, inputs, outputs, dim, stride)
+ NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes,
+ name,
+ compute_data_type,
+ inputs,
+ outputs,
+ dim,
+ stride,
+ reshape_mode)
std::vector
get_dim() const {
@@ -1448,6 +1620,54 @@ class Reshape_attributes : public Attributes {
stride = value;
return *this;
}
+
+ ReshapeMode_t
+ get_reshape_mode() const {
+ return reshape_mode;
+ }
+
+ auto
+ set_reshape_mode(ReshapeMode_t const& value) -> Reshape_attributes& {
+ reshape_mode = value;
+ return *this;
+ }
+};
+
+class Transpose_attributes : public Attributes {
+ friend class Attributes;
+ friend class TransposeNode;
+ friend class Graph;
+
+ std::vector permutation;
+
+ public:
+ std::string
+ get_name() const {
+ return name;
+ }
+
+ auto
+ set_name(std::string const& value) -> Transpose_attributes& {
+ name = value;
+ return *this;
+ }
+
+ enum class input_names { X };
+ std::unordered_map> inputs;
+ enum class output_names { Y };
+ std::unordered_map> outputs;
+ NLOHMANN_DEFINE_TYPE_INTRUSIVE(Transpose_attributes, name, compute_data_type, inputs, outputs, permutation)
+
+ std::vector
+ get_permutation() const {
+ return permutation;
+ }
+
+ auto
+ set_permutation(std::vector const& value) -> Transpose_attributes& {
+ permutation = value;
+ return *this;
+ }
};
class Rmsnorm_attributes : public Attributes {
@@ -2562,13 +2782,14 @@ class Slice_attributes : public Attributes {
friend class INode;
std::vector> slices;
+ std::vector slice_strides = {1};
public:
enum class input_names { X };
std::unordered_map> inputs;
enum class output_names { Y };
std::unordered_map> outputs;
- NLOHMANN_DEFINE_TYPE_INTRUSIVE(Slice_attributes, name, compute_data_type, inputs, outputs, slices)
+ NLOHMANN_DEFINE_TYPE_INTRUSIVE(Slice_attributes, name, compute_data_type, inputs, outputs, slices, slice_strides)
Slice_attributes&
set_slices(std::vector> const value) {
@@ -2576,6 +2797,12 @@ class Slice_attributes : public Attributes {
return *this;
}
+ Slice_attributes&
+ set_strides(std::vector const value) {
+ slice_strides = value;
+ return *this;
+ }
+
int64_t
get_offset() const {
auto& input = inputs.at(input_names::X);
diff --git a/include/cudnn_frontend/node/concatenate.h b/include/cudnn_frontend/node/concatenate.h
index 0c051186..5f62589b 100644
--- a/include/cudnn_frontend/node/concatenate.h
+++ b/include/cudnn_frontend/node/concatenate.h
@@ -29,9 +29,6 @@ class ConcatenateNode : public NodeCRTP {
RETURN_CUDNN_FRONTEND_ERROR_IF(!attributes.axis.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "Axis not set\n");
- RETURN_CUDNN_FRONTEND_ERROR_IF(
- !attributes.in_place_index.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "In-place index not set\n");
-
auto X = attributes.inputs;
RETURN_CUDNN_FRONTEND_ERROR_IF(
@@ -116,12 +113,14 @@ class ConcatenateNode : public NodeCRTP {
1,
&axis));
- int64_t in_place_index = attributes.in_place_index.value();
- _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(),
- CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX,
- CUDNN_TYPE_INT64,
- 1,
- &in_place_index));
+ if (attributes.in_place_index.has_value()) {
+ int64_t in_place_index = attributes.in_place_index.value();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX,
+ CUDNN_TYPE_INT64,
+ 1,
+ &in_place_index));
+ }
_CUDNN_CHECK_CUDNN_ERROR(detail::finalize(concatenate_operation->get_backend_descriptor()));
diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h
index 39b08c79..e127f65e 100644
--- a/include/cudnn_frontend/node/reshape.h
+++ b/include/cudnn_frontend/node/reshape.h
@@ -48,6 +48,19 @@ class ReshapeNode : public NodeCRTP {
return {error_code_t::SHAPE_DEDUCTION_FAILED, "Reshape node output shape deduction failed"};
}
+ CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reshape_attributes::input_names::X);
+ auto const& input_data_type = X->second->get_data_type();
+ if (y_tensor->get_data_type() == DataType_t::NOT_SET) {
+ y_tensor->set_data_type(input_data_type);
+ } else if (attributes.get_reshape_mode() == ReshapeMode_t::LOGICAL) {
+ // Lexicographic reshape preserves element type; reject inconsistent metadata.
+ // VIEW_ONLY paths (e.g. SDPA backward) may set Y dtype after reshape to match a consumer.
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ y_tensor->get_data_type() != input_data_type,
+ error_code_t::INVALID_VALUE,
+ "Output and input tensor data types must match for LOGICAL reshape operation.");
+ }
+
return {error_code_t::OK, ""};
}
@@ -75,7 +88,16 @@ class ReshapeNode : public NodeCRTP {
CUDNN_TYPE_BACKEND_DESCRIPTOR,
1,
&x_desc));
-
+#if (CUDNN_VERSION >= 92200)
+ // Set reshape mode
+ cudnnBackendReshapeMode_t cudnn_reshape_mode;
+ _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_reshape_mode(), cudnn_reshape_mode));
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(),
+ CUDNN_ATTR_OPERATION_RESHAPE_MODE,
+ CUDNN_TYPE_RESHAPE_MODE,
+ 1,
+ &cudnn_reshape_mode));
+#endif
// Set output tensor Y
CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reshape_attributes::output_names::Y);
auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc();
diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
index 71edf64d..ee1d0b27 100644
--- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
@@ -512,8 +512,16 @@ class SDPANodeBase : public NodeCRTP {
#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB
virtual void
serialize(json& j) const override final {
- j = attributes;
- if (attributes.mma_core_mode == DataType_t::FP8_E4M3 || attributes.mma_core_mode == DataType_t::FP8_E5M2) {
+ j = attributes;
+ j["is_mxfp8"] = is_mxfp8_scaling();
+ j["unfuse_fma"] = attributes.unfuse_fma;
+ if (auto const rescale_threshold = get_rescale_threshold_from_env(); rescale_threshold.has_value()) {
+ j["rescale_threshold"] = rescale_threshold.value();
+ }
+ if (is_mxfp8_scaling()) {
+ j.update(R"({"tag": "SDPA_MXFP8_FWD"})"_json);
+ } else if (attributes.mma_core_mode == DataType_t::FP8_E4M3 ||
+ attributes.mma_core_mode == DataType_t::FP8_E5M2) {
j.update(R"({"tag": "SDPA_FP8_FWD"})"_json);
} else {
j.update(R"({"tag": "SDPA"})"_json);
@@ -1296,11 +1304,6 @@ class CompositeSDPABackwardNode : public NodeCRTP {
error_code_t::GRAPH_NOT_SUPPORTED,
"For cuDNN version below 9.6.0, group-query attention with raggged offset is not supported");
- // TODO add version check once fixed
- RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_rng,
- error_code_t::GRAPH_NOT_SUPPORTED,
- "Dropout RNG dump is not supported for SM Major version 10");
-
// TODO add version check once fixed
RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_ragged && is_dbias,
error_code_t::GRAPH_NOT_SUPPORTED,
@@ -1952,10 +1955,32 @@ class CompositeSDPABackwardNode : public NodeCRTP {
std::pair>
override_heuristics_query() const {
int32_t const sm_version = context.get_sm_version();
+ bool const use_new_knobs = detail::get_backend_version() >= 92300;
+ // {128,128} bprop: tileM=3, tileN=2, kernelCfg=2(bprop warp), streamK=0, cgaM=0
if (sm_version > 103 && is_deterministic_algorithm_supported_on_blackwell) {
- return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ if (use_new_knobs) {
+ return {17,
+ {{KnobType_t::TILE_M, 3},
+ {KnobType_t::TILE_N, 2},
+ {KnobType_t::KERNEL_CFG, 2},
+ {KnobType_t::STREAM_K, 0},
+ {KnobType_t::TILE_CGA_M, 0},
+ {KnobType_t::STAGES, 2}}};
+ } else {
+ return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ }
} else if (is_deterministic_algorithm_supported_on_blackwell) {
- return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ if (use_new_knobs) {
+ return {5,
+ {{KnobType_t::TILE_M, 3},
+ {KnobType_t::TILE_N, 2},
+ {KnobType_t::KERNEL_CFG, 2},
+ {KnobType_t::STREAM_K, 0},
+ {KnobType_t::TILE_CGA_M, 0},
+ {KnobType_t::STAGES, 2}}};
+ } else {
+ return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ }
} else {
return {-1, {}};
}
diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h
index 12b390c7..29b8f200 100644
--- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h
+++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h
@@ -1089,10 +1089,32 @@ class SDPAFP8BackwardNode : public NodeCRTP {
std::pair>
override_heuristics_query() const {
int32_t const sm_version = context.get_sm_version();
+ bool const use_new_knobs = detail::get_backend_version() >= 92300;
+ // {128,128} bprop: tileM=3, tileN=2, kernelCfg=2(bprop warp), streamK=0, cgaM=0
if (sm_version > 103 && is_deterministic_algorithm_supported_on_blackwell) {
- return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ if (use_new_knobs) {
+ return {17,
+ {{KnobType_t::TILE_M, 3},
+ {KnobType_t::TILE_N, 2},
+ {KnobType_t::KERNEL_CFG, 2},
+ {KnobType_t::STREAM_K, 0},
+ {KnobType_t::TILE_CGA_M, 0},
+ {KnobType_t::STAGES, 2}}};
+ } else {
+ return {17, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ }
} else if (is_deterministic_algorithm_supported_on_blackwell) {
- return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ if (use_new_knobs) {
+ return {5,
+ {{KnobType_t::TILE_M, 3},
+ {KnobType_t::TILE_N, 2},
+ {KnobType_t::KERNEL_CFG, 2},
+ {KnobType_t::STREAM_K, 0},
+ {KnobType_t::TILE_CGA_M, 0},
+ {KnobType_t::STAGES, 2}}};
+ } else {
+ return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}};
+ }
} else {
return {-1, {}};
}
@@ -1102,7 +1124,15 @@ class SDPAFP8BackwardNode : public NodeCRTP {
virtual void
serialize(json& j) const override final {
j = attributes;
- j.update(R"({"tag": "SDPA_FP8_BWD"})"_json);
+ j["is_mxfp8"] = is_mxfp8_scaling();
+ if (auto const rescale_threshold = get_rescale_threshold_from_env(); rescale_threshold.has_value()) {
+ j["rescale_threshold"] = rescale_threshold.value();
+ }
+ if (is_mxfp8_scaling()) {
+ j.update(R"({"tag": "SDPA_MXFP8_BWD"})"_json);
+ } else {
+ j.update(R"({"tag": "SDPA_FP8_BWD"})"_json);
+ }
}
#endif
};
diff --git a/include/cudnn_frontend/node/sdpa_support_surface.h b/include/cudnn_frontend/node/sdpa_support_surface.h
index 42e40476..223539b3 100644
--- a/include/cudnn_frontend/node/sdpa_support_surface.h
+++ b/include/cudnn_frontend/node/sdpa_support_surface.h
@@ -350,10 +350,6 @@ SDPA_attributes::validate_sdpa_support_surface(const detail::Context& context,
"THD (ragged offset) is only supported in Hopper and above : " +
std::to_string(context.get_sm_version()));
}
- // TODO add version check once fixed
- RETURN_CUDNN_FRONTEND_ERROR_IF(prop_major == 10 && is_rng,
- error_code_t::GRAPH_NOT_SUPPORTED,
- "dropout RNG dump is not supported for Blackwell architecture");
} else {
RETURN_CUDNN_FRONTEND_ERROR_IF(true, error_code_t::GRAPH_NOT_SUPPORTED, "Unsupported mma core mode");
}
@@ -464,8 +460,13 @@ SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Co
"Diagonal alignment for unified SDPA node requires cuDNN 9.21.0 or above"};
}
- if (dropout_probability.has_value() && effective_cudnn_ver < 92100) {
- return {error_code_t::GRAPH_NOT_SUPPORTED, "Dropout for unified SDPA node requires cuDNN 9.21.0"};
+ if (dropout_probability.has_value() && effective_cudnn_ver < 92200) {
+ return {error_code_t::GRAPH_NOT_SUPPORTED, "Dropout for unified SDPA node requires cuDNN 9.22.0"};
+ }
+
+ if (dropout_probability.has_value() && generate_stats.value_or(false)) {
+ return {error_code_t::GRAPH_NOT_SUPPORTED,
+ "Dropout for unified SDPA node with generated stats is not supported"};
}
// Unified engine in cuDNN < 9.15 can't meaningfully support max sequence length,
@@ -496,4 +497,4 @@ SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Co
return {error_code_t::OK, ""};
}
-} // namespace cudnn_frontend::graph
\ No newline at end of file
+} // namespace cudnn_frontend::graph
diff --git a/include/cudnn_frontend/node/slice.h b/include/cudnn_frontend/node/slice.h
index e40f5c53..ebcad6a0 100644
--- a/include/cudnn_frontend/node/slice.h
+++ b/include/cudnn_frontend/node/slice.h
@@ -1,5 +1,8 @@
#pragma once
+#include "../graph_helpers.h"
+#include "../node_interface.h"
+
namespace cudnn_frontend::graph {
class SliceNode : public NodeCRTP {
@@ -21,12 +24,28 @@ class SliceNode : public NodeCRTP {
attributes.fill_from_context(context);
+ for (size_t i = 0; i < attributes.slice_strides.size(); ++i) {
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ attributes.slice_strides[i] <= 0,
+ error_code_t::INVALID_VALUE,
+ "Slice slice_strides[" + std::to_string(i) + "] must be strictly positive (got " +
+ std::to_string(attributes.slice_strides[i]) +
+ "). Non-positive strides break output dimension calculation and can cause division by zero.");
+ }
+
auto output = attributes.outputs.at(Slice_attributes::output_names::Y);
auto output_dim = output->get_dim();
if (output_dim.empty()) {
for (size_t i = 0; i < attributes.slices.size(); ++i) {
- output_dim.push_back(attributes.slices[i].second - attributes.slices[i].first);
+ int64_t start = attributes.slices[i].first;
+ int64_t limit = attributes.slices[i].second;
+ int64_t stride = (!attributes.slice_strides.empty() && i < attributes.slice_strides.size())
+ ? attributes.slice_strides[i]
+ : 1;
+ // Output dimension = ceil((limit - start) / stride)
+ int64_t dim = (limit - start + stride - 1) / stride;
+ output_dim.push_back(dim);
}
output->set_dim(output_dim);
}
@@ -44,10 +63,15 @@ class SliceNode : public NodeCRTP {
auto const input_stride = input->get_stride();
if (output->get_stride().empty()) {
- // For simple slicing without changing the step, the stride remains the same
- // std::vector stride_order =
- // detail::generate_stride_order_preserving_format(input_stride, output_dim.size());
- output->set_stride(input_stride);
+ // When slice strides > 1, output strides need to be multiplied accordingly
+ std::vector output_stride;
+ for (size_t i = 0; i < input_stride.size(); ++i) {
+ int64_t stride = (!attributes.slice_strides.empty() && i < attributes.slice_strides.size())
+ ? attributes.slice_strides[i]
+ : 1;
+ output_stride.push_back(input_stride[i] * stride);
+ }
+ output->set_stride(output_stride);
}
return {error_code_t::OK, ""};
@@ -57,16 +81,27 @@ class SliceNode : public NodeCRTP {
create_cudnn_tensors_node(std::unordered_map>& tensors,
int64_t& potential_uid,
std::unordered_set const& used_uids) const override final {
- // Do not make input tensor for backend.
- // But assign it a uid
- auto const input = attributes.inputs.at(Slice_attributes::input_names::X);
+ getLogger() << "[cudnn_frontend] INFO: Creating cudnn tensors for SliceNode " << attributes.name << std::endl;
+
+ auto const input = attributes.inputs.at(Slice_attributes::input_names::X);
+ auto const output = attributes.outputs.at(Slice_attributes::output_names::Y);
+
+#if (CUDNN_VERSION >= 92200)
+ // For cuDNN >= 9.22.0: Use native slice operation, create both input and output tensors
+ CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(input, tensors, potential_uid, used_uids));
+ output->set_is_virtual(false);
+ CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids));
+#else
+ // For cuDNN < 9.22.0: Fallback to pointer arithmetic approach
+ // Only assign UID to input, don't create backend tensor
if (input->has_uid() == false) {
detail::assign_uid(input.get(), potential_uid, used_uids);
}
- auto const output = attributes.outputs.at(Slice_attributes::output_names::Y);
+ // Create output tensor
output->set_is_virtual(false);
CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids));
+#endif
return {error_code_t::OK, ""};
}
@@ -74,20 +109,98 @@ class SliceNode : public NodeCRTP {
error_t
create_cudnn_operations(
std::unordered_set& uids_involved_in_operations,
- std::vector>&,
+ std::vector>& operations,
managed_backend_descriptor_t& raw_operations,
- std::unordered_map>&) const override final {
+ std::unordered_map>& tensors) const override final {
+ getLogger() << "[cudnn_frontend] INFO: " << "Building SliceNode operations " << attributes.name << std::endl;
+
+#if (CUDNN_VERSION >= 92200)
+ // cuDNN >= 9.22.0: Use native backend slice operation
+ auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Slice requires cuDNN v9.22.0"};
+ NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnn_ver_error);
+ CUDNN_FRONTEND_UNUSED(operations);
+
+ auto slice_operation = make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR);
+
+ // Set input tensor
+ auto X = attributes.inputs.at(Slice_attributes::input_names::X);
+ auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_SLICE_XDESC,
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
+ 1,
+ &backend_x));
+
+ // Set output tensor
+ auto Y = attributes.outputs.at(Slice_attributes::output_names::Y);
+ auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_SLICE_YDESC,
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
+ 1,
+ &backend_y));
+
+ // Extract start and limit indices from slices
+ std::vector start_indices;
+ std::vector limit_indices;
+
+ for (const auto& slice : attributes.slices) {
+ start_indices.push_back(slice.first);
+ limit_indices.push_back(slice.second);
+ }
+
+ // Per-dimension strides: use user slice_strides[i] when set, else 1 (preserves partial configuration)
+ std::vector strides(attributes.slices.size());
+ for (size_t i = 0; i < strides.size(); ++i) {
+ strides[i] = (i < attributes.slice_strides.size()) ? attributes.slice_strides[i] : 1;
+ }
+
+ // Set start indices
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_SLICE_START_INDICES,
+ CUDNN_TYPE_INT64,
+ start_indices.size(),
+ start_indices.data()));
+
+ // Set limit indices
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_SLICE_LIMIT_INDICES,
+ CUDNN_TYPE_INT64,
+ limit_indices.size(),
+ limit_indices.data()));
+
+ // Set strides
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(slice_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_SLICE_STRIDES,
+ CUDNN_TYPE_INT64,
+ strides.size(),
+ strides.data()));
+
+ _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(slice_operation->get_backend_descriptor()));
+
+ raw_operations.push_back(slice_operation);
+
+ auto const& non_virtual_uids = attributes.get_non_virtual_uids();
+ uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end());
+#else
+ // cuDNN < 9.22.0: Fallback to pointer arithmetic (no backend operation needed)
+ // The collect_variant_pack_replacements_node method handles the pointer offset mapping
+ CUDNN_FRONTEND_UNUSED(operations);
CUDNN_FRONTEND_UNUSED(raw_operations);
- // No corresponding backend operation
+ CUDNN_FRONTEND_UNUSED(tensors);
- auto const virutal_output = attributes.outputs.at(Slice_attributes::output_names::Y);
- if (virutal_output && virutal_output->get_is_virtual() == false) {
- uids_involved_in_operations.insert(virutal_output->get_uid());
- if (auto ragged_offset = virutal_output->get_ragged_offset()) {
+ getLogger() << "[cudnn_frontend] INFO: " << "Using pointer arithmetic fallback for slice on cuDNN < 9.22.0"
+ << std::endl;
+
+ // Register the output tensor as involved (input is handled via variant pack replacement)
+ auto const output = attributes.outputs.at(Slice_attributes::output_names::Y);
+ if (output && output->get_is_virtual() == false) {
+ uids_involved_in_operations.insert(output->get_uid());
+ if (auto ragged_offset = output->get_ragged_offset()) {
uids_involved_in_operations.insert(ragged_offset->get_uid());
}
}
-
+#endif
return {error_code_t::OK, ""};
}
diff --git a/include/cudnn_frontend/node/transpose.h b/include/cudnn_frontend/node/transpose.h
new file mode 100644
index 00000000..7f764be7
--- /dev/null
+++ b/include/cudnn_frontend/node/transpose.h
@@ -0,0 +1,164 @@
+#pragma once
+
+#include "../../cudnn_frontend_Logging.h"
+
+#include "../graph_helpers.h"
+#include "../node_interface.h"
+
+namespace cudnn_frontend::graph {
+
+class TransposeNode : public NodeCRTP {
+ public:
+ Transpose_attributes attributes;
+
+ TransposeNode(Transpose_attributes&& attributes_, detail::Context const& context)
+ : NodeCRTP(context), attributes(std::move(attributes_)) {}
+
+ Type
+ getType() override final {
+ return Type::TRANSPOSE;
+ }
+
+ error_t
+ infer_properties_node() override final {
+ getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for transpose node " << attributes.name
+ << std::endl;
+
+ attributes.fill_from_context(context);
+
+ auto const& X_tensor = attributes.inputs[Transpose_attributes::input_names::X];
+ auto Y_tensor = attributes.outputs[Transpose_attributes::output_names::Y];
+
+ // Get input properties
+ auto const& input_dim = X_tensor->get_dim();
+ auto const& input_stride = X_tensor->get_stride();
+ auto const& input_data_type = X_tensor->get_data_type();
+ auto const& permutation = attributes.get_permutation();
+
+ // Validate permutation
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ permutation.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Permutation must be set for transpose operation.");
+
+ RETURN_CUDNN_FRONTEND_ERROR_IF(permutation.size() != input_dim.size(),
+ error_code_t::INVALID_VALUE,
+ "Permutation size must match input tensor dimensionality.");
+
+ // Check that permutation is a valid permutation (contains each index 0 to n-1 exactly once)
+ std::vector seen(permutation.size(), false);
+ for (auto idx : permutation) {
+ RETURN_CUDNN_FRONTEND_ERROR_IF(idx < 0 || idx >= static_cast(permutation.size()),
+ error_code_t::INVALID_VALUE,
+ "Permutation indices must be in range [0, n-1].");
+ RETURN_CUDNN_FRONTEND_ERROR_IF(
+ seen[idx], error_code_t::INVALID_VALUE, "Permutation indices must be unique.");
+ seen[idx] = true;
+ }
+
+ // Infer output dimensions by permuting input dimensions
+ std::vector output_dim(input_dim.size());
+ for (size_t i = 0; i < permutation.size(); ++i) {
+ output_dim[i] = input_dim[permutation[i]];
+ }
+
+ // Infer output strides by permuting input strides
+ std::vector output_stride(input_stride.size());
+ for (size_t i = 0; i < permutation.size(); ++i) {
+ output_stride[i] = input_stride[permutation[i]];
+ }
+
+ // Set output tensor properties
+ if (Y_tensor->get_dim().empty()) {
+ Y_tensor->set_dim(output_dim);
+ }
+ if (Y_tensor->get_stride().empty()) {
+ Y_tensor->set_stride(output_stride);
+ }
+ if (Y_tensor->get_data_type() == DataType_t::NOT_SET) {
+ Y_tensor->set_data_type(input_data_type);
+ }
+
+ RETURN_CUDNN_FRONTEND_ERROR_IF(Y_tensor->get_data_type() != input_data_type,
+ error_code_t::INVALID_VALUE,
+ "Output and input tensor data types must match for transpose operation.");
+
+ return {error_code_t::OK, ""};
+ }
+
+ error_t
+ create_cudnn_operations(
+ std::unordered_set& uids_involved_in_operations,
+ std::vector>& operations,
+ managed_backend_descriptor_t& raw_operations,
+ std::unordered_map>& tensors) const override final {
+ getLogger() << "[cudnn_frontend] INFO: " << "Building TransposeNode operations " << attributes.name
+ << std::endl;
+
+#if (CUDNN_VERSION >= 92200)
+ // cuDNN >= 9.22.0: Use native backend transpose operation
+ auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Transpose requires cuDNN v9.22.0"};
+ NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnn_ver_error);
+ CUDNN_FRONTEND_UNUSED(operations);
+
+ auto transpose_operation = make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR);
+
+ // Set input tensor
+ auto X = attributes.inputs.at(Transpose_attributes::input_names::X);
+ auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_TRANSPOSE_XDESC,
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
+ 1,
+ &backend_x));
+
+ // Set output tensor
+ auto Y = attributes.outputs.at(Transpose_attributes::output_names::Y);
+ auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_TRANSPOSE_YDESC,
+ CUDNN_TYPE_BACKEND_DESCRIPTOR,
+ 1,
+ &backend_y));
+
+ // Set permutation
+ auto permutation = attributes.get_permutation();
+ _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(transpose_operation->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_TRANSPOSE_PERMUTATION,
+ CUDNN_TYPE_INT64,
+ permutation.size(),
+ permutation.data()));
+
+ _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(transpose_operation->get_backend_descriptor()));
+
+ raw_operations.push_back(transpose_operation);
+
+ auto const& non_virtual_uids = attributes.get_non_virtual_uids();
+ uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end());
+#else
+ CUDNN_FRONTEND_UNUSED(operations);
+ CUDNN_FRONTEND_UNUSED(raw_operations);
+ CUDNN_FRONTEND_UNUSED(tensors);
+ CUDNN_FRONTEND_UNUSED(uids_involved_in_operations);
+ return {error_code_t::GRAPH_NOT_SUPPORTED, "Transpose operation requires cuDNN version >= 9.22.0"};
+#endif
+ return {error_code_t::OK, ""};
+ }
+
+#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB
+ virtual void
+ serialize(json& j) const override final {
+ j = attributes;
+ j.update(R"( {"tag": "TRANSPOSE"})"_json);
+ }
+#endif
+};
+
+inline std::shared_ptr
+INode::transpose(std::shared_ptr input, Transpose_attributes attributes) {
+ attributes.inputs[Transpose_attributes::input_names::X] = input;
+ auto Y = attributes.outputs[Transpose_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
+
+ sub_nodes.emplace_back(std::make_unique(std::move(attributes), context));
+ return Y;
+}
+
+} // namespace cudnn_frontend::graph
diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h
index e9109ce3..7bc922ad 100644
--- a/include/cudnn_frontend/node_interface.h
+++ b/include/cudnn_frontend/node_interface.h
@@ -38,6 +38,8 @@ class CompositeSoftmaxNode;
class UnifiedSoftmaxNode;
class MoeGroupedMatmulNode;
class UnifiedDiagonalBandMaskNode;
+class TransposeNode;
+class SliceNode;
// Interface for all nodes to follow.
class INode {
@@ -138,6 +140,7 @@ class INode {
RMSNORM,
RNG,
SLICE,
+ TRANSPOSE,
WGRAD,
PAGED_CACHE_LOAD,
BLOCK_SCALE_QUANTIZE,
@@ -253,6 +256,15 @@ class INode {
Moe_grouped_matmul_bwd_attributes attributes,
std::shared_ptr dweight);
+ void
+ transpose(std::shared_ptr input,
+ Transpose_attributes attributes,
+ std::shared_ptr output);
+
+ void
+ slice(std::shared_ptr input,
+ Slice_attributes attributes,
+ std::shared_ptr output);
error_t
validate_subtree() {
// pre validate to catch errors early
@@ -392,6 +404,8 @@ class INode {
std::shared_ptr reduction(std::shared_ptr, Reduction_attributes);
std::array, 2> resample(std::shared_ptr, Resample_attributes);
std::shared_ptr reshape(std::shared_ptr, Reshape_attributes);
+ std::shared_ptr transpose(std::shared_ptr, Transpose_attributes);
+ std::shared_ptr slice(std::shared_ptr, Slice_attributes);
std::shared_ptr rng(std::shared_ptr,
std::shared_ptr,
diff --git a/include/cudnn_frontend/utils/attn_score_modifiers.h b/include/cudnn_frontend/utils/attn_score_modifiers.h
index ff0a3877..e0597907 100644
--- a/include/cudnn_frontend/utils/attn_score_modifiers.h
+++ b/include/cudnn_frontend/utils/attn_score_modifiers.h
@@ -253,7 +253,7 @@ sliding_window_mask(std::shared_ptr graph,
[[maybe_unused]] inline error_t
build_operation_subgraph(std::shared_ptr graph) {
return graph->build_operation_graph(/*handle=*/nullptr);
-};
+}
class Softcap {
private:
diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h
index 6cc4590a..1277c5b8 100644
--- a/include/cudnn_frontend/utils/serialize.h
+++ b/include/cudnn_frontend/utils/serialize.h
@@ -298,6 +298,13 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Reshape_attributes::input_names,
NLOHMANN_JSON_SERIALIZE_ENUM(Reshape_attributes::output_names, {{Reshape_attributes::output_names::Y, "Y"}})
+NLOHMANN_JSON_SERIALIZE_ENUM(Transpose_attributes::input_names,
+ {
+ {Transpose_attributes::input_names::X, "X"},
+ })
+
+NLOHMANN_JSON_SERIALIZE_ENUM(Transpose_attributes::output_names, {{Transpose_attributes::output_names::Y, "Y"}})
+
NLOHMANN_JSON_SERIALIZE_ENUM(Rmsnorm_attributes::input_names,
{
{Rmsnorm_attributes::input_names::X, "X"},
@@ -457,6 +464,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::input_names,
{SDPA_fp8_backward_attributes::input_names::Scale_dV, "Scale_dV"},
{SDPA_fp8_backward_attributes::input_names::Scale_S, "Scale_S"},
{SDPA_fp8_backward_attributes::input_names::Scale_dP, "Scale_dP"},
+ {SDPA_fp8_backward_attributes::input_names::SINK_TOKEN, "SINK_TOKEN"},
})
NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::output_names,
@@ -468,6 +476,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_backward_attributes::output_names,
{SDPA_fp8_backward_attributes::output_names::Amax_dK, "Amax_dK"},
{SDPA_fp8_backward_attributes::output_names::Amax_dV, "Amax_dV"},
{SDPA_fp8_backward_attributes::output_names::Amax_dP, "Amax_d"},
+ {SDPA_fp8_backward_attributes::output_names::DSINK_TOKEN, "DSINK_TOKEN"},
})
NLOHMANN_JSON_SERIALIZE_ENUM(Block_scale_quantize_attributes::input_names,
diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h
index 6bf01019..5be62537 100644
--- a/include/cudnn_frontend_Operation.h
+++ b/include/cudnn_frontend_Operation.h
@@ -236,6 +236,7 @@ class Operation_v8 : public BackendDescriptor {
NormFwdPhase_t norm_fwd_phase;
NormMode_t norm_mode;
+ ReshapeMode_t reshape_mode = ReshapeMode_t::VIEW_ONLY;
float alpha_s = 1.0f, beta_s = .0f, alpha2_s = 1.0f;
double alpha_d = 1.0, beta_d = 0.0, alpha2_d = 1.0;
@@ -1790,6 +1791,29 @@ class OperationBuilder_v8 {
"CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_RESHAPE_YDESC Failed");
return std::move(m_operation);
}
+
+#if (CUDNN_VERSION >= 92200)
+ // Set reshape mode if it's not NOT_SET
+ if (m_operation.reshape_mode != ReshapeMode_t::NOT_SET) {
+ cudnnBackendReshapeMode_t cudnn_reshape_mode;
+ status = detail::convert_to_cudnn_type(m_operation.reshape_mode, cudnn_reshape_mode);
+ if (status == CUDNN_STATUS_SUCCESS) {
+ status = detail::set_attribute(m_operation.pointer->get_backend_descriptor(),
+ CUDNN_ATTR_OPERATION_RESHAPE_MODE,
+ CUDNN_TYPE_RESHAPE_MODE,
+ 1,
+ &cudnn_reshape_mode);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ set_error_and_throw_exception(
+ &m_operation,
+ status,
+ "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_RESHAPE_MODE Failed");
+ return std::move(m_operation);
+ }
+ }
+ }
+#endif
+
status = detail::finalize(m_operation.pointer->get_backend_descriptor());
if (status != CUDNN_STATUS_SUCCESS) {
set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
@@ -2561,6 +2585,21 @@ class OperationBuilder_v8 {
return *this;
}
+ auto
+ setReshapeMode(ReshapeMode_t mode) -> OperationBuilder_v8 & {
+ m_operation.reshape_mode = mode;
+ return *this;
+ }
+
+#if (CUDNN_VERSION >= 92200)
+ // To be deprecated. Please use setReshapeMode(cudnn_frontend::ReshapeMode_t mode) instead.
+ auto
+ setReshapeMode(cudnnBackendReshapeMode_t mode) -> OperationBuilder_v8 & {
+ detail::convert_from_cudnn_type(mode, m_operation.reshape_mode);
+ return *this;
+ }
+#endif
+
auto
setNormalizationMode(NormMode_t mode) -> OperationBuilder_v8 & {
m_operation.norm_mode = mode;
diff --git a/include/cudnn_frontend_Tensor.h b/include/cudnn_frontend_Tensor.h
index 91585003..ba538844 100644
--- a/include/cudnn_frontend_Tensor.h
+++ b/include/cudnn_frontend_Tensor.h
@@ -24,6 +24,7 @@
#include
#include
+#include
#include
#include
#include
@@ -179,6 +180,10 @@ class Tensor_v8 : public BackendDescriptor {
int64_t vectorCount = 1; //! What is the vectorization count (4 or 32)
bool isVirtual = false; //! Whether it is an intermediate tensor of an op graph
bool isByValue = false; //! Whether the tensor is in host memory that needs to be passed to the kernel by value
+
+ bool hasConstValue = false; //! Whether this tensor has a compile-time constant value
+ std::vector constValueBytes; //! Raw bytes of the compile-time constant value
+
cudnn_frontend::TensorReordering_t reorder_type =
cudnn_frontend::TensorReordering_t::NONE; //! Type of reordering in the tensor
std::shared_ptr raggedOffset; //! Ragged offsets for ragged tensors
@@ -242,6 +247,18 @@ class TensorBuilder_v8 {
m_tensor.isByValue = isByValue_;
return *this;
}
+
+ //! Set compile-time constant value for this tensor
+ template
+ auto
+ setConstValue(T const &value) -> TensorBuilder_v8 & {
+ m_tensor.hasConstValue = true;
+ m_tensor.constValueBytes.resize(sizeof(T));
+ std::memcpy(m_tensor.constValueBytes.data(), &value, sizeof(T));
+ m_tensor.isByValue = true;
+ return *this;
+ }
+
auto
setVectorCountAndDimension(int64_t vectorCount_, int64_t vectorDimension_) -> TensorBuilder_v8 & {
m_tensor.vectorCount = vectorCount_;
@@ -506,6 +523,38 @@ class TensorBuilder_v8 {
return std::move(m_tensor);
}
}
+
+ // Set compile-time constant value (if present); CUDNN_ATTR_TENSOR_CONSTANT_VALUE is cuDNN 9.22.0+
+#if (CUDNN_VERSION >= 92200)
+ NV_CUDNN_FE_DYNAMIC_CHECK_BACKEND_DESCRIPTOR(92200,
+ m_tensor,
+ "CUDNN_BACKEND_TENSOR_DESCRIPTOR: SetAttribute "
+ "CUDNN_ATTR_TENSOR_CONSTANT_VALUE requires cudnn version 9.22.0");
+ if (m_tensor.hasConstValue) {
+ void *value_ptr = m_tensor.constValueBytes.data();
+ status = detail::set_attribute(m_tensor.pointer->get_backend_descriptor(),
+ CUDNN_ATTR_TENSOR_CONSTANT_VALUE,
+ CUDNN_TYPE_VOID_PTR,
+ 1,
+ &value_ptr);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ set_error_and_throw_exception(
+ &m_tensor,
+ status,
+ "CUDNN_BACKEND_TENSOR_DESCRIPTOR: SetAttribute CUDNN_ATTR_TENSOR_CONSTANT_VALUE Failed");
+ return std::move(m_tensor);
+ }
+ }
+#else
+ if (m_tensor.hasConstValue) {
+ set_error_and_throw_exception(
+ &m_tensor,
+ CUDNN_STATUS_INVALID_VALUE,
+ "CUDNN_BACKEND_TENSOR_DESCRIPTOR: CUDNN_ATTR_TENSOR_CONSTANT_VALUE requires cudnn version 9.22.0");
+ return std::move(m_tensor);
+ }
+#endif // CUDNN_VERSION >= 92200
+
// Finalizing the descriptor
status = detail::finalize(m_tensor.pointer->get_backend_descriptor());
if (status != CUDNN_STATUS_SUCCESS) {
diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h
index e07ca484..ea838d33 100644
--- a/include/cudnn_frontend_shim.h
+++ b/include/cudnn_frontend_shim.h
@@ -427,6 +427,70 @@ destroy_handle(cudnnHandle_t handle) {
NV_FE_CALL_TO_BACKEND(destroy_handle, cudnnDestroy, handle);
}
+#if CUDNN_VERSION >= 92200
+inline cudnnStatus_t
+causal_conv1d_forward(cudaStream_t stream,
+ const void *x,
+ const void *weight,
+ const void *bias,
+ void *y,
+ int batch,
+ int dim,
+ int seq_len,
+ int kernel_size,
+ cudnnDataType_t data_type,
+ cudnnCausalConv1dActivation_t activation) {
+ NV_FE_CALL_TO_BACKEND(causal_conv1d_forward,
+ cudnnCausalConv1dForward,
+ stream,
+ x,
+ weight,
+ bias,
+ y,
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ data_type,
+ activation);
+}
+
+inline cudnnStatus_t
+causal_conv1d_backward(cudaStream_t stream,
+ const void *x,
+ const void *weight,
+ const void *bias,
+ const void *dy,
+ void *dx,
+ void *dweight,
+ void *dbias,
+ int batch,
+ int dim,
+ int seq_len,
+ int kernel_size,
+ cudnnDataType_t data_type,
+ cudnnDataType_t dw_data_type,
+ cudnnCausalConv1dActivation_t activation) {
+ NV_FE_CALL_TO_BACKEND(causal_conv1d_backward,
+ cudnnCausalConv1dBackward,
+ stream,
+ x,
+ weight,
+ bias,
+ dy,
+ dx,
+ dweight,
+ dbias,
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ data_type,
+ dw_data_type,
+ activation);
+}
+#endif
+
inline size_t
get_backend_version(void) {
#if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h
index 603b9cf1..33e7c3a5 100644
--- a/include/cudnn_frontend_utils.h
+++ b/include/cudnn_frontend_utils.h
@@ -96,21 +96,21 @@ struct nlohmann::adl_serializer {
};
template <>
-struct nlohmann::adl_serializer> {
+struct nlohmann::adl_serializer> {
static void
- to_json(nlohmann::json& j, const std::variant& data) {
+ to_json(nlohmann::json& j, const std::variant& data) {
std::visit([&](const auto& v) { j = {{"index", data.index()}, {"value", v}}; }, data);
}
static void
- from_json(const nlohmann::json& j, std::variant& data) {
+ from_json(const nlohmann::json& j, std::variant& data) {
if (!j.is_object() || !j.contains("index") || !j.contains("value")) {
return;
}
size_t type_index = j.at("index").get();
if (type_index == 0) {
- data = j.at("value").get();
+ data = j.at("value").get();
} else if (type_index == 1) {
data = j.at("value").get();
} else if (type_index == 2) {
@@ -118,6 +118,8 @@ struct nlohmann::adl_serializer();
} else if (type_index == 4) {
+ data = j.at("value").get();
+ } else if (type_index == 5) {
data = j.at("value").get();
} else {
return;
@@ -386,6 +388,20 @@ enum class PaddingMode_t {
ZERO_PAD
};
+enum class ReshapeMode_t {
+ NOT_SET,
+
+ VIEW_ONLY,
+ LOGICAL
+};
+
+NLOHMANN_JSON_SERIALIZE_ENUM(ReshapeMode_t,
+ {
+ {ReshapeMode_t::NOT_SET, nullptr},
+ {ReshapeMode_t::VIEW_ONLY, "VIEW_ONLY"},
+ {ReshapeMode_t::LOGICAL, "LOGICAL"},
+ })
+
enum class ConvolutionMode_t {
NOT_SET,
@@ -480,6 +496,8 @@ enum class DescriptorType_t {
OPERATION_CONCATENATE_DESCRIPTOR,
OPERATION_MOE_GROUPED_MATMUL_DESCRIPTOR,
OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR,
+ OPERATION_TRANSPOSE_DESCRIPTOR,
+ OPERATION_SLICE_DESCRIPTOR
};
enum class NormMode_t {
@@ -962,6 +980,12 @@ operator<<(std::ostream& os, const DescriptorType_t& mode) {
case DescriptorType_t::OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR:
os << "OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR";
break;
+ case DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR:
+ os << "OPERATION_TRANSPOSE_DESCRIPTOR";
+ break;
+ case DescriptorType_t::OPERATION_SLICE_DESCRIPTOR:
+ os << "OPERATION_SLICE_DESCRIPTOR";
+ break;
case DescriptorType_t::NOT_SET:
os << "NOT_SET";
break;
@@ -1684,6 +1708,22 @@ convert_to_cudnn_type(cudnn_frontend::DescriptorType_t const mode, cudnnBackendD
#else
return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
#endif
+ case DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR:
+#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900)
+ NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE);
+ cudnn_mode = CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR;
+ return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
+#else
+ return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
+#endif
+ case DescriptorType_t::OPERATION_SLICE_DESCRIPTOR:
+#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900)
+ NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(92200, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE);
+ cudnn_mode = CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR;
+ return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
+#else
+ return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
+#endif
#ifndef NO_DEFAULT_IN_SWITCH
default:
@@ -1799,6 +1839,26 @@ convert_to_cudnn_type(cudnn_frontend::NormFwdPhase_t const mode, cudnnBackendNor
return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
}
+#if (CUDNN_VERSION >= 92200)
+static inline cudnnStatus_t
+convert_to_cudnn_type(cudnn_frontend::ReshapeMode_t const mode, cudnnBackendReshapeMode_t& cudnn_mode) {
+ switch (mode) {
+ case ReshapeMode_t::VIEW_ONLY:
+ cudnn_mode = CUDNN_RESHAPE_VIEW_ONLY;
+ return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
+ case ReshapeMode_t::LOGICAL:
+ cudnn_mode = CUDNN_RESHAPE_LOGICAL;
+ return cudnnStatus_t::CUDNN_STATUS_SUCCESS;
+
+#ifndef NO_DEFAULT_IN_SWITCH
+ default:
+ return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
+#endif
+ }
+ return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE;
+}
+#endif
+
// To be deprecated. Only exists as setResampleMode(cudnnPaddingMode_t) requires it.
static inline void
convert_from_cudnn_type(cudnnPaddingMode_t const cudnn_mode, cudnn_frontend::PaddingMode_t& mode) {
@@ -1930,6 +1990,27 @@ convert_from_cudnn_type(cudnnBackendNormFwdPhase_t const cudnn_mode, cudnn_front
}
}
+#if (CUDNN_VERSION >= 92200)
+// To be deprecated. Only exists as setReshapeMode(cudnnBackendReshapeMode_t) requires it.
+static inline void
+convert_from_cudnn_type(cudnnBackendReshapeMode_t const cudnn_mode, cudnn_frontend::ReshapeMode_t& mode) {
+ mode = ReshapeMode_t::NOT_SET;
+ switch (cudnn_mode) {
+ case CUDNN_RESHAPE_VIEW_ONLY:
+ mode = ReshapeMode_t::VIEW_ONLY;
+ break;
+ case CUDNN_RESHAPE_LOGICAL:
+ mode = ReshapeMode_t::LOGICAL;
+ break;
+
+#ifndef NO_DEFAULT_IN_SWITCH
+ default:
+ break;
+#endif
+ }
+}
+#endif
+
static inline cudnnStatus_t
convert_to_cudnn_type(cudnn_frontend::TensorReordering_t const mode, cudnnBackendTensorReordering_t& cudnn_mode) {
switch (mode) {
@@ -2094,6 +2175,13 @@ convert_from_cudnn_type(cudnnBackendDescriptorType_t const cudnn_mode) {
return DescriptorType_t::OPERATION_MOE_GROUPED_MATMUL_BWD_DESCRIPTOR;
#endif
+#if (CUDNN_VERSION >= 92200) && (CUDNN_VERSION < 99900)
+ case CUDNN_BACKEND_OPERATION_TRANSPOSE_DESCRIPTOR:
+ return DescriptorType_t::OPERATION_TRANSPOSE_DESCRIPTOR;
+ case CUDNN_BACKEND_OPERATION_SLICE_DESCRIPTOR:
+ return DescriptorType_t::OPERATION_SLICE_DESCRIPTOR;
+#endif
+
#ifndef NO_DEFAULT_IN_SWITCH
default:
return DescriptorType_t::NOT_SET;
diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h
index b3f38800..bfbaa99f 100644
--- a/include/cudnn_frontend_version.h
+++ b/include/cudnn_frontend_version.h
@@ -23,7 +23,7 @@
#pragma once
#define CUDNN_FRONTEND_MAJOR_VERSION 1
-#define CUDNN_FRONTEND_MINOR_VERSION 22
-#define CUDNN_FRONTEND_PATCH_VERSION 1
+#define CUDNN_FRONTEND_MINOR_VERSION 23
+#define CUDNN_FRONTEND_PATCH_VERSION 0
#define CUDNN_FRONTEND_VERSION \
((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION)
diff --git a/pyproject.toml b/pyproject.toml
index 0b1d55ad..1ec44cd1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,12 +5,48 @@ build-backend = "setuptools.build_meta"
[project]
name = "nvidia-cudnn-frontend"
dynamic = ["version"]
-description = "CUDNN FrontEnd python library"
+description = "NVIDIA cuDNN Frontend — Python and C++ Graph API with SOTA attention (SDPA / Flash Attention), MoE grouped GEMM fusions, and FP8/MXFP8 kernels for Hopper and Blackwell GPUs."
readme = "README.md"
requires-python = ">=3.9"
-license = {text = "NVIDIA Proprietary Software"}
+license = {text = "MIT"}
+keywords = [
+ "cudnn",
+ "cuda",
+ "gpu",
+ "nvidia",
+ "deep-learning",
+ "attention",
+ "sdpa",
+ "flash-attention",
+ "transformer",
+ "moe",
+ "mixture-of-experts",
+ "grouped-gemm",
+ "fp8",
+ "mxfp8",
+ "blackwell",
+ "hopper",
+ "pytorch",
+ "kernel",
+ "graph-api",
+]
classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: POSIX :: Linux",
+ "Operating System :: Microsoft :: Windows",
+ "Programming Language :: C++",
"Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development :: Libraries :: Python Modules",
]
[tool.setuptools]
@@ -19,8 +55,12 @@ package-dir = {"" = "python", "include" = "include"}
include-package-data = true
[project.urls]
-"Homepage" = "https://github.com/nvidia/cudnn-frontend"
-"Bug Tracker" = "https://github.com/nvidia/cudnn-frontend/issues"
+"Homepage" = "https://github.com/NVIDIA/cudnn-frontend"
+"Documentation" = "https://docs.nvidia.com/deeplearning/cudnn/frontend/latest/"
+"Blog" = "https://nvidia.github.io/cudnn-frontend/"
+"Repository" = "https://github.com/NVIDIA/cudnn-frontend"
+"Bug Tracker" = "https://github.com/NVIDIA/cudnn-frontend/issues"
+"Release Notes" = "https://github.com/NVIDIA/cudnn-frontend/releases"
[tool.setuptools.dynamic]
version = {attr = "cudnn.__version__"}
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index 80836a14..f0de8773 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -37,6 +37,7 @@ find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
include(${PROJECT_SOURCE_DIR}/cmake/cuDNN.cmake)
option(CUDNN_FRONTEND_FETCH_PYBINDS_IN_CMAKE "Whether cmake build system should fetch pybinds." ON)
+set(PYBIND11_FINDPYTHON ON)
if(CUDNN_FRONTEND_FETCH_PYBINDS_IN_CMAKE)
diff --git a/python/cudnn/README.md b/python/cudnn/README.md
index 1d766455..f574cd66 100644
--- a/python/cudnn/README.md
+++ b/python/cudnn/README.md
@@ -38,11 +38,17 @@ To add a new frontend-only API, follow these steps:
**Currently implemented frontend-only APIs**:
- `GEMM + Amax`
+- `RMSNorm + RHT + Amax`
- `GEMM + SwiGLU`
+- `GEMM + sReLU`
+- `GEMM + dsReLU`
- `Grouped Gemm + GLU (Unified)`
+- `Grouped Gemm + GLU + Hadamard`
- `Grouped Gemm + dGLU (Unified)`
- `Grouped Gemm + SwiGLU (Legacy, Contiguous-only)`
- `Grouped Gemm + dSwiglu (Legacy, Contiguous-only)`
+- `Grouped Gemm + sReLU (Contiguous-only)`
+- `Grouped Gemm + dsReLU (Contiguous-only)`
- `Discrete Grouped Gemm + SwiGLU`
- `Discrete Grouped Gemm + dSwiglu`
- `Grouped Gemm + Quant (Legacy, Dense-only)`
diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py
index a1125b2d..b9b9d39d 100644
--- a/python/cudnn/__init__.py
+++ b/python/cudnn/__init__.py
@@ -40,14 +40,20 @@ def is_windows():
"diagonal_alignment",
"attention_implementation",
"moe_grouped_matmul_mode",
+ "scalar_type",
+ "reshape_mode",
]
for symbol_name in symbols_to_import:
globals()[symbol_name] = getattr(_pybind_module, symbol_name)
+for _optional_symbol in ["causal_conv1d_forward", "causal_conv1d_backward"]:
+ if hasattr(_pybind_module, _optional_symbol):
+ globals()[_optional_symbol] = getattr(_pybind_module, _optional_symbol)
+
from .datatypes import _library_type, _is_torch_tensor
-__version__ = "1.22.1"
+__version__ = "1.23.0"
def _tensor(
@@ -230,287 +236,83 @@ def _dlopen_cudnn():
from typing import Any
+_OPTIONAL_DEPENDENCY_INSTALL_HINT = "Install with 'pip install nvidia-cudnn-frontend[cutedsl]'"
+
+_LAZY_OPTIONAL_IMPORTS = {
+ "NSA": (".native_sparse_attention", "NSA"),
+ "GemmSwigluSm100": (".gemm_swiglu", "GemmSwigluSm100"),
+ "gemm_swiglu_wrapper_sm100": (".gemm_swiglu", "gemm_swiglu_wrapper_sm100"),
+ "GemmSreluSm100": (".gemm_srelu", "GemmSreluSm100"),
+ "gemm_srelu_wrapper_sm100": (".gemm_srelu", "gemm_srelu_wrapper_sm100"),
+ "GemmDsreluSm100": (".gemm_dsrelu", "GemmDsreluSm100"),
+ "gemm_dsrelu_wrapper_sm100": (".gemm_dsrelu", "gemm_dsrelu_wrapper_sm100"),
+ "GemmAmaxSm100": (".gemm_amax", "GemmAmaxSm100"),
+ "gemm_amax_wrapper_sm100": (".gemm_amax", "gemm_amax_wrapper_sm100"),
+ "RmsNormRhtAmaxSm100": (".rmsnorm_rht_amax", "RmsNormRhtAmaxSm100"),
+ "rmsnorm_rht_amax_wrapper_sm100": (".rmsnorm_rht_amax", "rmsnorm_rht_amax_wrapper_sm100"),
+ "grouped_gemm": (".grouped_gemm", None),
+ "GroupedGemmSwigluSm100": (".grouped_gemm", "GroupedGemmSwigluSm100"),
+ "grouped_gemm_swiglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_swiglu_wrapper_sm100"),
+ "GroupedGemmDswigluSm100": (".grouped_gemm", "GroupedGemmDswigluSm100"),
+ "grouped_gemm_dswiglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dswiglu_wrapper_sm100"),
+ "GroupedGemmSreluSm100": (".grouped_gemm", "GroupedGemmSreluSm100"),
+ "grouped_gemm_srelu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_srelu_wrapper_sm100"),
+ "GroupedGemmDsreluSm100": (".grouped_gemm", "GroupedGemmDsreluSm100"),
+ "grouped_gemm_dsrelu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dsrelu_wrapper_sm100"),
+ "SdpafwdSm100D256": (".sdpa", "SdpafwdSm100D256"),
+ "sdpa_fwd_wrapper_sm100_d256": (".sdpa", "sdpa_fwd_wrapper_sm100_d256"),
+ "SdpabwdSm100D256": (".sdpa", "SdpabwdSm100D256"),
+ "sdpa_bwd_wrapper_sm100_d256": (".sdpa", "sdpa_bwd_wrapper_sm100_d256"),
+ "GroupedGemmQuantSm100": (".grouped_gemm", "GroupedGemmQuantSm100"),
+ "grouped_gemm_quant_wrapper_sm100": (".grouped_gemm", "grouped_gemm_quant_wrapper_sm100"),
+ "GroupedGemmGluSm100": (".grouped_gemm", "GroupedGemmGluSm100"),
+ "grouped_gemm_glu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_glu_wrapper_sm100"),
+ "GroupedGemmGluHadamardSm100": (".grouped_gemm", "GroupedGemmGluHadamardSm100"),
+ "grouped_gemm_glu_hadamard_wrapper_sm100": (".grouped_gemm", "grouped_gemm_glu_hadamard_wrapper_sm100"),
+ "GroupedGemmDgluSm100": (".grouped_gemm", "GroupedGemmDgluSm100"),
+ "grouped_gemm_dglu_wrapper_sm100": (".grouped_gemm", "grouped_gemm_dglu_wrapper_sm100"),
+ "GroupedGemmWgradSm100": (".grouped_gemm", "GroupedGemmWgradSm100"),
+ "grouped_gemm_wgrad_wrapper_sm100": (".grouped_gemm", "grouped_gemm_wgrad_wrapper_sm100"),
+ "discrete_grouped_gemm": (".discrete_grouped_gemm", None),
+ "DiscreteGroupedGemmSwigluSm100": (".discrete_grouped_gemm", "DiscreteGroupedGemmSwigluSm100"),
+ "discrete_grouped_gemm_swiglu_wrapper_sm100": (".discrete_grouped_gemm", "discrete_grouped_gemm_swiglu_wrapper_sm100"),
+ "DiscreteGroupedGemmDswigluSm100": (".discrete_grouped_gemm", "DiscreteGroupedGemmDswigluSm100"),
+ "discrete_grouped_gemm_dswiglu_wrapper_sm100": (".discrete_grouped_gemm", "discrete_grouped_gemm_dswiglu_wrapper_sm100"),
+}
+
+
+def _load_optional_symbol(name: str) -> Any:
+ module_name, attr_name = _LAZY_OPTIONAL_IMPORTS[name]
+ try:
+ module = importlib.import_module(module_name, package=__name__)
+ value = module if attr_name is None else getattr(module, attr_name)
+ except Exception as e:
+ raise ImportError(f"{name} requires optional dependencies. {_OPTIONAL_DEPENDENCY_INSTALL_HINT}: {e}") from e
+
+ globals()[name] = value
+ return value
-def __getattr__(name: str) -> Any:
- if name == "NSA":
- try:
- from .native_sparse_attention import NSA as _NSA
-
- return _NSA
- except Exception as e:
- raise ImportError(f"NSA requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "GemmSwigluSm100":
- try:
- from .gemm_swiglu import GemmSwigluSm100 as _GemmSwigluSm100
-
- return _GemmSwigluSm100
- except Exception as e:
- raise ImportError(f"GemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "gemm_swiglu_wrapper_sm100":
- try:
- from .gemm_swiglu import (
- gemm_swiglu_wrapper_sm100 as _gemm_swiglu_wrapper_sm100,
- )
-
- return _gemm_swiglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "GemmAmaxSm100":
- try:
- from .gemm_amax import GemmAmaxSm100 as _GemmAmaxSm100
-
- return _GemmAmaxSm100
- except Exception as e:
- raise ImportError(f"GemmAmaxSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "gemm_amax_wrapper_sm100":
- try:
- from .gemm_amax import (
- gemm_amax_wrapper_sm100 as _gemm_amax_wrapper_sm100,
- )
-
- return _gemm_amax_wrapper_sm100
- except Exception as e:
- raise ImportError(f"gemm_amax_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- # Grouped GEMM module
- elif name == "grouped_gemm":
- try:
- from . import grouped_gemm as _grouped_gemm
-
- return _grouped_gemm
- except Exception as e:
- raise ImportError(f"grouped_gemm requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "GroupedGemmSwigluSm100":
- try:
- from .grouped_gemm import GroupedGemmSwigluSm100 as _GroupedGemmSwigluSm100
-
- return _GroupedGemmSwigluSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_swiglu_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_swiglu_wrapper_sm100 as _grouped_gemm_swiglu_wrapper_sm100,
- )
-
- return _grouped_gemm_swiglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "GroupedGemmDswigluSm100":
- try:
- from .grouped_gemm import (
- GroupedGemmDswigluSm100 as _GroupedGemmDswigluSm100,
- )
-
- return _GroupedGemmDswigluSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmDswigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_dswiglu_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_dswiglu_wrapper_sm100 as _grouped_gemm_dswiglu_wrapper_sm100,
- )
-
- return _grouped_gemm_dswiglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_dswiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
- elif name == "SdpafwdSm100D256":
- try:
- from .sdpa import SdpafwdSm100D256 as _SdpafwdSm100D256
-
- return _SdpafwdSm100D256
- except Exception as e:
- raise ImportError(f"SdpafwdSm100D256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "sdpa_fwd_wrapper_sm100_d256":
- try:
- from .sdpa import (
- sdpa_fwd_wrapper_sm100_d256 as _sdpa_fwd_wrapper_sm100_d256,
- )
-
- return _sdpa_fwd_wrapper_sm100_d256
- except Exception as e:
- raise ImportError(
- f"sdpa_fwd_wrapper_sm100_d256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "SdpabwdSm100D256":
- try:
- from .sdpa import SdpabwdSm100D256 as _SdpabwdSm100D256
-
- return _SdpabwdSm100D256
- except Exception as e:
- raise ImportError(f"SdpabwdSm100D256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "sdpa_bwd_wrapper_sm100_d256":
- try:
- from .sdpa import (
- sdpa_bwd_wrapper_sm100_d256 as _sdpa_bwd_wrapper_sm100_d256,
- )
-
- return _sdpa_bwd_wrapper_sm100_d256
- except Exception as e:
- raise ImportError(
- f"sdpa_bwd_wrapper_sm100_d256 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "GroupedGemmQuantSm100":
- try:
- from .grouped_gemm import GroupedGemmQuantSm100 as _GroupedGemmQuantSm100
-
- return _GroupedGemmQuantSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmQuantSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_quant_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_quant_wrapper_sm100 as _grouped_gemm_quant_wrapper_sm100,
- )
-
- return _grouped_gemm_quant_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_quant_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- # Unified Grouped GEMM GLU (forward)
- elif name == "GroupedGemmGluSm100":
- try:
- from .grouped_gemm import GroupedGemmGluSm100 as _GroupedGemmGluSm100
-
- return _GroupedGemmGluSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmGluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_glu_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_glu_wrapper_sm100 as _grouped_gemm_glu_wrapper_sm100,
- )
-
- return _grouped_gemm_glu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_glu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- # Unified Grouped GEMM dGLU (backward)
- elif name == "GroupedGemmDgluSm100":
- try:
- from .grouped_gemm import GroupedGemmDgluSm100 as _GroupedGemmDgluSm100
-
- return _GroupedGemmDgluSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmDgluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_dglu_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_dglu_wrapper_sm100 as _grouped_gemm_dglu_wrapper_sm100,
- )
- return _grouped_gemm_dglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_dglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "GroupedGemmWgradSm100":
- try:
- from .grouped_gemm import GroupedGemmWgradSm100 as _GroupedGemmWgradSm100
-
- return _GroupedGemmWgradSm100
- except Exception as e:
- raise ImportError(f"GroupedGemmWgradSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "grouped_gemm_wgrad_wrapper_sm100":
- try:
- from .grouped_gemm import (
- grouped_gemm_wgrad_wrapper_sm100 as _grouped_gemm_wgrad_wrapper_sm100,
- )
-
- return _grouped_gemm_wgrad_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"grouped_gemm_wgrad_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- # Discrete-weight Grouped GEMM GLU module
- elif name == "discrete_grouped_gemm":
- try:
- from . import discrete_grouped_gemm as _discrete_grouped_gemm
-
- return _discrete_grouped_gemm
- except Exception as e:
- raise ImportError(f"discrete_grouped_gemm requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e
-
- elif name == "DiscreteGroupedGemmSwigluSm100":
- try:
- from .discrete_grouped_gemm import (
- DiscreteGroupedGemmSwigluSm100 as _DiscreteGroupedGemmSwigluSm100,
- )
-
- return _DiscreteGroupedGemmSwigluSm100
- except Exception as e:
- raise ImportError(
- f"DiscreteGroupedGemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "discrete_grouped_gemm_swiglu_wrapper_sm100":
- try:
- from .discrete_grouped_gemm import (
- discrete_grouped_gemm_swiglu_wrapper_sm100 as _discrete_grouped_gemm_swiglu_wrapper_sm100,
- )
-
- return _discrete_grouped_gemm_swiglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"discrete_grouped_gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "DiscreteGroupedGemmDswigluSm100":
- try:
- from .discrete_grouped_gemm import (
- DiscreteGroupedGemmDswigluSm100 as _DiscreteGroupedGemmDswigluSm100,
- )
-
- return _DiscreteGroupedGemmDswigluSm100
- except Exception as e:
- raise ImportError(
- f"DiscreteGroupedGemmDswigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "discrete_grouped_gemm_dswiglu_wrapper_sm100":
- try:
- from .discrete_grouped_gemm import (
- discrete_grouped_gemm_dswiglu_wrapper_sm100 as _discrete_grouped_gemm_dswiglu_wrapper_sm100,
- )
-
- return _discrete_grouped_gemm_dswiglu_wrapper_sm100
- except Exception as e:
- raise ImportError(
- f"discrete_grouped_gemm_dswiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}"
- ) from e
-
- elif name == "experimental":
+def __getattr__(name: str) -> Any:
+ if name == "ops":
+ # Use importlib rather than "from . import ops" to avoid infinite
+ # recursion. The cycle:
+ # 1. cudnn.ops accessed → __getattr__("ops") fires
+ # 2. "from . import ops" → _handle_fromlist(cudnn, ["ops"], ...)
+ # 3. _handle_fromlist calls hasattr(cudnn, "ops")
+ # 4. "ops" not in __dict__ yet → __getattr__("ops") again → goto 1
+ # importlib.import_module bypasses _handle_fromlist entirely.
+ _ops = importlib.import_module(".ops", __name__)
+ globals()["ops"] = _ops
+ return _ops
+
+ if name == "experimental":
from . import experimental as _experimental
globals()["experimental"] = _experimental
return _experimental
- else:
- raise AttributeError(name)
+
+ if name in _LAZY_OPTIONAL_IMPORTS:
+ return _load_optional_symbol(name)
+
+ raise AttributeError(name)
diff --git a/python/cudnn/api_base.py b/python/cudnn/api_base.py
index 82fb161f..e60a52e1 100644
--- a/python/cudnn/api_base.py
+++ b/python/cudnn/api_base.py
@@ -14,6 +14,7 @@
from dataclasses import dataclass, field
from typing import Any, List, Tuple, Optional
import logging
+import threading
import cuda.bindings.driver as cuda
import cutlass
import torch
@@ -31,6 +32,26 @@ def is_power_of_2(n: int) -> bool:
return n > 0 and (n & (n - 1)) == 0
+_experimental_api_warnings_emitted = set()
+_experimental_api_warnings_lock = threading.Lock()
+
+
+def warn_experimental_api_once(logger: logging.Logger, api_name: str) -> None:
+ """Emit the experimental API warning once per API class per process."""
+ with _experimental_api_warnings_lock:
+ if api_name in _experimental_api_warnings_emitted:
+ return
+ _experimental_api_warnings_emitted.add(api_name)
+
+ logger.warning("%s is an experimental API", api_name)
+
+
+def _reset_experimental_api_warning_registry() -> None:
+ """Reset experimental API warning state for tests."""
+ with _experimental_api_warnings_lock:
+ _experimental_api_warnings_emitted.clear()
+
+
@dataclass(frozen=True)
class TensorDesc:
"""Metadata needed to validate/compile tensor signatures without storage."""
@@ -40,6 +61,7 @@ class TensorDesc:
stride: Tuple[int, ...]
stride_order: Tuple[int, ...]
device: torch.device
+ interpret_uint8_as_fp4x2: bool = False
ndim: int = field(init=False)
name: str = ""
@@ -156,6 +178,7 @@ def _with_layout(self, shape: Tuple[int, ...], stride: Tuple[int, ...]) -> "Tens
stride=stride,
stride_order=self._compute_stride_order(shape, stride),
device=self.device,
+ interpret_uint8_as_fp4x2=self.interpret_uint8_as_fp4x2,
name=self.name,
)
@@ -363,6 +386,9 @@ def __init__(self):
self._interpret_uint8_as_fp4x2 = False
self._logger = logging.getLogger(self.__class__.__name__)
+ def _warn_experimental_api(self) -> None:
+ warn_experimental_api_once(self._logger, self.__class__.__name__)
+
@abstractmethod
def check_support(self) -> bool:
"""Check if the current configuration is supported by the kernel.
@@ -571,8 +597,16 @@ def _is_fp4x2(self, tensor_or_dtype: torch.Tensor | torch.dtype | TensorDesc) ->
"""
if tensor_or_dtype is None:
return False
- dtype = tensor_or_dtype.dtype if isinstance(tensor_or_dtype, (torch.Tensor, TensorDesc)) else tensor_or_dtype
- return (dtype == torch.float4_e2m1fn_x2) or (self._interpret_uint8_as_fp4x2 and dtype == torch.uint8)
+ if isinstance(tensor_or_dtype, TensorDesc):
+ dtype = tensor_or_dtype.dtype
+ interpret_uint8_as_fp4x2 = tensor_or_dtype.interpret_uint8_as_fp4x2
+ elif isinstance(tensor_or_dtype, torch.Tensor):
+ dtype = tensor_or_dtype.dtype
+ interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2
+ else:
+ dtype = tensor_or_dtype
+ interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2
+ return (dtype == torch.float4_e2m1fn_x2) or (interpret_uint8_as_fp4x2 and dtype == torch.uint8)
def _is_fp8(self, tensor_or_dtype: torch.Tensor | torch.dtype | TensorDesc) -> bool:
"""Check if tensor or dtype is an FP8 datatype.
@@ -875,12 +909,24 @@ def _make_fake_cute_tensor_like(
assumed_align=assumed_align,
)
- def _make_tensor_desc(self, tensor: Optional[torch.Tensor], name: str = "") -> Optional[TensorDesc]:
+ def _make_tensor_desc(
+ self,
+ tensor: Optional[torch.Tensor],
+ name: str = "",
+ interpret_uint8_as_fp4x2: Optional[bool] = None,
+ ) -> Optional[TensorDesc]:
"""Capture logical tensor metadata that is sufficient for validation/compile."""
if tensor is None:
return None
- tensor_shape = self._tensor_shape(tensor, name=name)
- tensor_stride = self._tensor_stride(tensor, name=name)
+ if interpret_uint8_as_fp4x2 is None:
+ interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2
+ prev_interpret = self._interpret_uint8_as_fp4x2
+ self._interpret_uint8_as_fp4x2 = interpret_uint8_as_fp4x2
+ try:
+ tensor_shape = self._tensor_shape(tensor, name=name)
+ tensor_stride = self._tensor_stride(tensor, name=name)
+ finally:
+ self._interpret_uint8_as_fp4x2 = prev_interpret
tensor_stride_order = tuple(i for i, s in sorted(enumerate(tensor_stride), key=lambda x: (x[1], tensor_shape[x[0]])))
return TensorDesc(
dtype=tensor.dtype,
@@ -888,6 +934,7 @@ def _make_tensor_desc(self, tensor: Optional[torch.Tensor], name: str = "") -> O
stride=tensor_stride,
stride_order=tensor_stride_order,
device=tensor.device,
+ interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2,
name=name,
)
@@ -904,6 +951,7 @@ def _make_fake_cute_tensor_from_desc(
shape=tensor_desc.shape,
stride=tensor_desc.stride,
assumed_align=assumed_align,
+ interpret_uint8_as_fp4x2=tensor_desc.interpret_uint8_as_fp4x2,
)
def _make_fake_cute_tensor(
@@ -912,6 +960,7 @@ def _make_fake_cute_tensor(
shape: Tuple[int, ...],
stride: Tuple[int, ...],
assumed_align: int = 16,
+ interpret_uint8_as_fp4x2: Optional[bool] = None,
) -> cute.Pointer:
"""Make a fake tensor.
@@ -926,8 +975,10 @@ def _make_fake_cute_tensor(
:return: A fake tensor
:rtype: cute.Pointer
"""
+ if interpret_uint8_as_fp4x2 is None:
+ interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2
return cute.runtime.make_fake_tensor(
- dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2),
+ dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2),
shape=shape,
stride=stride,
assumed_align=assumed_align,
@@ -941,6 +992,7 @@ def _make_fake_cute_compact_tensor(
assumed_align: int = 16,
dynamic_mode: Optional[int] = None,
divisibility: int = 16,
+ interpret_uint8_as_fp4x2: Optional[bool] = None,
) -> cute.Pointer:
"""Make a fake compact tensor.
:param dtype: The dtype of the tensor
@@ -958,8 +1010,10 @@ def _make_fake_cute_compact_tensor(
dynamic_dim = cute.sym_int(divisibility=divisibility)
shape = shape[:dynamic_mode] + (dynamic_dim,) + shape[dynamic_mode + 1 :]
+ if interpret_uint8_as_fp4x2 is None:
+ interpret_uint8_as_fp4x2 = self._interpret_uint8_as_fp4x2
return cute.runtime.make_fake_compact_tensor(
- dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2),
+ dtype=_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2),
shape=shape,
stride_order=stride_order,
assumed_align=assumed_align,
diff --git a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
index 4c110e17..dc52c846 100644
--- a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
+++ b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_dswiglu/api.py
@@ -133,7 +133,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("DiscreteGroupedGemmDswigluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self._value_error_if(num_experts == 0, "num_experts must be > 0")
@@ -196,7 +196,7 @@ def __init__(
self._kernel = BlockScaledDiscreteWeightDgluDbiasGroupedGemmKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._workspace = None
self._logger.debug("__init__ completed")
diff --git a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
index d3e736ca..a3d3c0fd 100644
--- a/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
+++ b/python/cudnn/discrete_grouped_gemm/discrete_grouped_gemm_swiglu/api.py
@@ -141,7 +141,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("DiscreteGroupedGemmSwigluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self._value_error_if(num_experts == 0, "num_experts must be > 0")
@@ -194,7 +194,7 @@ def __init__(
self._kernel = BlockScaledDiscreteWeightGroupedGemmBiasKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._workspace = None
diff --git a/python/cudnn/gemm_amax/api.py b/python/cudnn/gemm_amax/api.py
index eea7d816..62fc475b 100644
--- a/python/cudnn/gemm_amax/api.py
+++ b/python/cudnn/gemm_amax/api.py
@@ -31,7 +31,7 @@ def __init__(
):
super().__init__()
- self._logger.warning("GemmAmaxSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self.a_desc = self._make_tensor_desc(sample_a, name="sample_a")
@@ -47,7 +47,7 @@ def __init__(
self.sf_vec_size = sf_vec_size
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
# used to reshape sfa/sfb tensors to atom layout
self.atom_m = (32, 4)
diff --git a/python/cudnn/gemm_dsrelu/__init__.py b/python/cudnn/gemm_dsrelu/__init__.py
new file mode 100644
index 00000000..ade25458
--- /dev/null
+++ b/python/cudnn/gemm_dsrelu/__init__.py
@@ -0,0 +1,9 @@
+from .api import (
+ GemmDsreluSm100,
+ gemm_dsrelu_wrapper_sm100,
+)
+
+__all__ = [
+ "GemmDsreluSm100",
+ "gemm_dsrelu_wrapper_sm100",
+]
diff --git a/python/cudnn/gemm_dsrelu/api.py b/python/cudnn/gemm_dsrelu/api.py
new file mode 100644
index 00000000..fb02fa02
--- /dev/null
+++ b/python/cudnn/gemm_dsrelu/api.py
@@ -0,0 +1,612 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from __future__ import annotations
+
+import logging
+import os
+from typing import Optional, Tuple
+
+import cutlass
+import cutlass.cute as cute
+import torch
+from cuda.bindings import driver as cuda
+from cutlass.cute.runtime import make_fake_stream
+
+from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2
+from cudnn.datatypes import _convert_to_cutlass_data_type
+
+from .dense_blockscaled_gemm_persistent_dsrelu_quant import (
+ Sm100BlockScaledPersistentDenseGemmKernel,
+)
+
+
+def _major_from_stride_order(stride_order: Tuple[int, ...], mode0_label: str, mode1_label: str) -> str:
+ if stride_order == (0, 1, 2):
+ return mode0_label
+ if stride_order == (1, 0, 2):
+ return mode1_label
+ raise ValueError(f"Unsupported stride order {stride_order}")
+
+
+class GemmDsreluSm100(APIBase):
+ def __init__(
+ self,
+ sample_a: torch.Tensor,
+ sample_b: torch.Tensor,
+ sample_c: torch.Tensor,
+ sample_d: torch.Tensor,
+ sample_dprob: torch.Tensor,
+ sample_sfa: torch.Tensor,
+ sample_sfb: torch.Tensor,
+ sample_prob: torch.Tensor,
+ sample_sfd: Optional[torch.Tensor] = None,
+ sample_amax: Optional[torch.Tensor] = None,
+ sample_norm_const: Optional[torch.Tensor] = None,
+ alpha: float = 1.0,
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ ):
+ super().__init__()
+
+ self._warn_experimental_api()
+
+ self.a_desc = self._make_tensor_desc(sample_a, name="sample_a")
+ self.b_desc = self._make_tensor_desc(sample_b, name="sample_b")
+ self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
+ self.d_desc = self._make_tensor_desc(sample_d, name="sample_d")
+ self.dprob_desc = self._make_tensor_desc(sample_dprob, name="sample_dprob")
+ self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa")
+ self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
+ self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
+ self.sfd_desc = self._make_tensor_desc(sample_sfd, name="sample_sfd")
+ self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "amax")
+ self.norm_const_desc = self._unpad_tensor_to_ndim(
+ self._make_tensor_desc(sample_norm_const, name="sample_norm_const"),
+ 1,
+ "norm_const",
+ )
+
+ self.alpha = alpha
+ self.acc_dtype = acc_dtype
+ self.mma_tiler_mn = mma_tiler_mn
+ self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if mma_tiler_mn[0] == 256 else (1, 1))
+ self.sf_vec_size = sf_vec_size
+ self.vector_f32 = vector_f32
+ self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
+
+ self._interpret_uint8_as_fp4x2 = True
+ self._kernel = Sm100BlockScaledPersistentDenseGemmKernel
+
+ def check_support(self) -> bool:
+ m, k, l = self._tensor_shape(self.a_desc, name="sample_a")
+ n, b_k, b_l = self._tensor_shape(self.b_desc, name="sample_b")
+
+ self._value_error_if((b_k, b_l) != (k, l), f"B shape mismatch: expected (*, {k}, {l}), got {(n, b_k, b_l)}")
+ self._check_tensor_shape(self.c_desc, (m, n, l), "C")
+ self._check_tensor_shape(self.d_desc, (m, n, l), "D")
+ self._check_tensor_shape(self.prob_desc, (m, 1, l), "prob")
+ self._check_tensor_shape(self.dprob_desc, (m, 1, l), "dprob")
+
+ rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA")
+ self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB")
+
+ if self.sfd_desc is not None:
+ rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfd_desc, (32, 4, ceil_div(m, 128), 4, rest_n, l), "SFD")
+
+ self._check_tensor_shape(self.amax_desc, (1,), "amax")
+ self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const")
+
+ self.ab_dtype = self._check_dtype(
+ self.a_desc,
+ dtype=[torch.float4_e2m1fn_x2, torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="A",
+ )
+ self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="A and B must have the same dtype")
+ self.c_dtype = self._check_dtype(
+ self.c_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32],
+ name="C",
+ )
+ self.d_dtype = self._check_dtype(
+ self.d_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="D",
+ )
+ self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob")
+ self._check_dtype(self.dprob_desc, dtype=torch.float32, name="dprob")
+
+ self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA")
+ self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must have the same dtype as SFA")
+ self._check_dtype(self.sfd_desc, dtype=self.sf_dtype, name="SFD", extra_error_msg="SFD must have the same dtype as SFA")
+
+ self._check_dtype(self.acc_dtype, dtype=torch.float32, name="Accumulator")
+
+ self._value_error_if(self.sf_vec_size not in {16, 32}, f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}")
+ self._value_error_if(
+ self._is_fp8(self.d_desc) and (self.sfd_desc is None or self.norm_const_desc is None), "sfd and norm_const are required when D is FP8"
+ )
+ self._value_error_if(
+ self._is_fp4x2(self.ab_dtype) and self.d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}, "FP4 input with FP8 output is not supported"
+ )
+
+ a_major = _major_from_stride_order(self.a_desc.stride_order, "m", "k")
+ b_major = _major_from_stride_order(self.b_desc.stride_order, "n", "k")
+ c_major = _major_from_stride_order(self.c_desc.stride_order, "m", "n")
+ d_major = _major_from_stride_order(self.d_desc.stride_order, "m", "n")
+ self._value_error_if(c_major != d_major, f"C and D must share the same layout, got {c_major} and {d_major}")
+
+ self._value_error_if(
+ self.mma_tiler_mn[0] not in {128, 256} or self.mma_tiler_mn[1] not in {64, 128, 192, 256},
+ f"Unsupported mma_tiler_mn {self.mma_tiler_mn}",
+ )
+ self._value_error_if(
+ not (
+ self.cluster_shape_mn[0] > 0
+ and self.cluster_shape_mn[1] > 0
+ and self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16
+ and is_power_of_2(self.cluster_shape_mn[0])
+ and is_power_of_2(self.cluster_shape_mn[1])
+ ),
+ f"Invalid cluster shape {self.cluster_shape_mn}",
+ )
+
+ self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available")
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
+ self._runtime_error_if(major * 10 + minor < 100, f"GemmDsreluSm100 requires SM100+, found SM{major}{minor}")
+
+ self._value_error_if(
+ not self._kernel.can_implement(
+ ab_dtype=_convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2),
+ sf_dtype=_convert_to_cutlass_data_type(self.sf_dtype),
+ sf_vec_size=self.sf_vec_size,
+ d_dtype=_convert_to_cutlass_data_type(self.d_dtype),
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ m=m,
+ n=n,
+ k=k,
+ l=l,
+ a_major=a_major,
+ b_major=b_major,
+ d_major=d_major,
+ ),
+ "Unsupported configuration for dense dsReLU kernel",
+ )
+
+ self._is_supported = True
+ return True
+
+ def compile(self) -> None:
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ return
+
+ gemm = self._kernel(
+ sf_vec_size=self.sf_vec_size,
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ vector_f32=self.vector_f32,
+ )
+
+ hardware_info = cutlass.utils.HardwareInfo()
+ max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1])
+ max_active_clusters -= self.num_cluster_overlap_margin
+ self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin")
+
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+ epilogue_op = lambda x, y: cute.where(x > 0, x, cute.full_like(x, 0)) * 2 * y
+ use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None
+ use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None
+
+ if use_dynamic_m:
+ valid_m = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride_order=self.a_desc.stride_order,
+ )
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride_order=self.c_desc.stride_order,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, *self.d_desc.shape[1:]),
+ stride_order=self.d_desc.stride_order,
+ )
+ prob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride_order=self.prob_desc.stride_order,
+ )
+ dprob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, *self.dprob_desc.shape[1:]),
+ stride_order=self.dprob_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], self.sfa_desc.shape[5]),
+ stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
+ )
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+
+ sfd_cute_fake = None
+ if self.sfd_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_desc.shape[4], self.sfd_desc.shape[5]),
+ stride=(16, 4, self.sfd_desc.stride[2], 1, 512, stride_sfd_m),
+ )
+ elif use_full_dynamic:
+ valid_m = cute.sym_int()
+ n_sym = cute.sym_int()
+ k_sym = cute.sym_int()
+ l_sym = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, k_sym, l_sym),
+ stride_order=self.a_desc.stride_order,
+ dynamic_mode=self.a_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ b_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.b_desc.dtype,
+ shape=(n_sym, k_sym, l_sym),
+ stride_order=self.b_desc.stride_order,
+ dynamic_mode=self.b_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, n_sym, l_sym),
+ stride_order=self.c_desc.stride_order,
+ dynamic_mode=self.c_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, n_sym, l_sym),
+ stride_order=self.d_desc.stride_order,
+ dynamic_mode=self.d_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_desc.dtype) else 16,
+ )
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:-1], l_sym),
+ stride=(1, 1, valid_m),
+ )
+ dprob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, *self.dprob_desc.shape[1:-1], l_sym),
+ stride=(l_sym, l_sym, 1),
+ )
+
+ tensor_m_128 = cute.sym_int()
+ rest_k = cute.sym_int()
+ stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_shape[4] = rest_k
+ sfa_shape[5] = l_sym
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[2] = stride_rest_k
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ )
+
+ tensor_n_128 = cute.sym_int()
+ stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfb_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfb_desc.dtype,
+ shape=(32, 4, tensor_n_128, 4, rest_k, l_sym),
+ stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k),
+ )
+
+ sfd_cute_fake = None
+ if self.sfd_desc is not None:
+ rest_n = cute.sym_int()
+ stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_shape = list(self.sfd_desc.shape)
+ sfd_shape[2] = tensor_m_128
+ sfd_shape[4] = rest_n
+ sfd_shape[5] = l_sym
+ sfd_stride = list(self.sfd_desc.stride)
+ sfd_stride[2] = stride_sfd_rest_n
+ sfd_stride[5] = stride_sfd_tensor_m_128
+ sfd_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_desc.dtype,
+ shape=tuple(sfd_shape),
+ stride=tuple(sfd_stride),
+ )
+ else:
+ a_cute_fake = self._make_fake_cute_tensor_from_desc(self.a_desc, assumed_align=16)
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_tensor_from_desc(self.c_desc, assumed_align=16)
+ d_cute_fake = self._make_fake_cute_tensor_from_desc(self.d_desc, assumed_align=16)
+ prob_cute_fake = self._make_fake_cute_tensor_from_desc(self.prob_desc, assumed_align=16)
+ dprob_cute_fake = self._make_fake_cute_tensor_from_desc(self.dprob_desc, assumed_align=16)
+ sfa_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfa_desc, assumed_align=16)
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+ sfd_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfd_desc, assumed_align=16)
+
+ compiled = cute.compile(
+ gemm,
+ a_tensor=a_cute_fake,
+ b_tensor=b_cute_fake,
+ sfa_tensor=sfa_cute_fake,
+ sfb_tensor=sfb_cute_fake,
+ c_tensor=c_cute_fake,
+ d_tensor=d_cute_fake,
+ prob_tensor=prob_cute_fake,
+ dprob_tensor=dprob_cute_fake,
+ amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16),
+ sfd_tensor=sfd_cute_fake,
+ norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16),
+ alpha=self.alpha,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ epilogue_op=epilogue_op,
+ options="--enable-tvm-ffi",
+ )
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ dprob_tensor: torch.Tensor,
+ amax_tensor: Optional[torch.Tensor],
+ sfd_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ alpha: float,
+ stream: cuda.CUstream,
+ ) -> None:
+ compiled(
+ a_tensor,
+ b_tensor,
+ sfa_tensor,
+ sfb_tensor,
+ c_tensor,
+ d_tensor,
+ prob_tensor,
+ dprob_tensor,
+ self._unpad_tensor_to_ndim(amax_tensor, 1, "amax"),
+ sfd_tensor,
+ self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const"),
+ alpha,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ dprob_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ sfd_tensor: Optional[torch.Tensor] = None,
+ amax_tensor: Optional[torch.Tensor] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ self._runtime_error_if(self._compiled_kernel is None, "GemmDsreluSm100 kernel not compiled; call compile() first")
+ current_stream = self._get_default_stream(current_stream)
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ amax_tensor=amax_tensor,
+ sfd_tensor=sfd_tensor,
+ norm_const_tensor=norm_const_tensor,
+ alpha=self.alpha if alpha is None else alpha,
+ stream=current_stream,
+ )
+
+
+_logger = logging.getLogger(__name__)
+_cache_of_GemmDsreluSm100Objects = {}
+_DENSE_GEMM_DYNAMIC_M_ENV = "CUDNN_FE_GEMM_DYNAMIC_M"
+_DENSE_GEMM_DYNAMIC_MNKL_ENV = "CUDNN_FE_GEMM_DYNAMIC_MNKL"
+
+
+def _allocate_dense_output(shape: Tuple[int, int, int], major: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
+ m, n, l = shape
+ if major == "m":
+ return torch.empty_strided((m, n, l), (1, m, m * n), dtype=dtype, device=device)
+ if major == "n":
+ return torch.empty_strided((m, n, l), (n, 1, m * n), dtype=dtype, device=device)
+ raise ValueError(f"major must be 'm' or 'n', got {major}")
+
+
+def gemm_dsrelu_wrapper_sm100(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ alpha: float = 1.0,
+ d_major: str = "n",
+ d_dtype: torch.dtype = torch.bfloat16,
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ m, k, l = a_tensor.shape
+ n, _, _ = b_tensor.shape
+
+ d_tensor = _allocate_dense_output((m, n, l), d_major, d_dtype, a_tensor.device)
+ dprob_tensor = torch.zeros((m, 1, l), dtype=torch.float32, device=a_tensor.device)
+
+ sfd_tensor = None
+ if d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
+ sf_k = ceil_div(n, sf_vec_size)
+ mma_shape = (
+ l,
+ ceil_div(m, 128),
+ ceil_div(sf_k, 4),
+ 32,
+ 4,
+ 4,
+ )
+ sfd_tensor = torch.empty(mma_shape, dtype=sfa_tensor.dtype, device=a_tensor.device).permute(3, 4, 1, 5, 2, 0)
+
+ amax_tensor = None
+ if a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} and d_dtype in {torch.bfloat16, torch.float16, torch.float32}:
+ amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+
+ use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None
+ use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None
+
+ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
+ return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
+
+ def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype
+
+ def dynamic_compact_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape[1:]), stride_order(tensor), tensor.dtype
+
+ def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return None, stride_order(tensor), tensor.dtype
+
+ def dynamic_m_tensor_signature(
+ tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = ()
+ ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
+ return static_shape_suffix, stride_signature, tensor.dtype
+
+ cache_key = (
+ use_full_dynamic,
+ use_dynamic_m,
+ *(dynamic_tensor_signature(a_tensor) if use_full_dynamic else dynamic_compact_signature(a_tensor) if use_dynamic_m else tensor_signature(a_tensor)),
+ *(dynamic_tensor_signature(b_tensor) if use_full_dynamic else tensor_signature(b_tensor)),
+ *(dynamic_tensor_signature(c_tensor) if use_full_dynamic else dynamic_compact_signature(c_tensor) if use_dynamic_m else tensor_signature(c_tensor)),
+ *(dynamic_tensor_signature(d_tensor) if use_full_dynamic else dynamic_compact_signature(d_tensor) if use_dynamic_m else tensor_signature(d_tensor)),
+ *(
+ dynamic_tensor_signature(dprob_tensor)
+ if use_full_dynamic
+ else dynamic_compact_signature(dprob_tensor) if use_dynamic_m else tensor_signature(dprob_tensor)
+ ),
+ d_dtype,
+ *(
+ dynamic_tensor_signature(sfa_tensor)
+ if use_full_dynamic
+ else (
+ dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], sfa_tensor.shape[5]), dynamic_stride_dims=(5,))
+ if use_dynamic_m
+ else tensor_signature(sfa_tensor)
+ )
+ ),
+ *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)),
+ *(
+ dynamic_tensor_signature(prob_tensor)
+ if use_full_dynamic
+ else dynamic_compact_signature(prob_tensor) if use_dynamic_m else tensor_signature(prob_tensor)
+ ),
+ norm_const_tensor.shape if norm_const_tensor is not None else None,
+ norm_const_tensor.stride() if norm_const_tensor is not None else None,
+ norm_const_tensor.dtype if norm_const_tensor is not None else None,
+ alpha,
+ acc_dtype,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ d_major,
+ sf_vec_size,
+ vector_f32,
+ )
+
+ op = _cache_of_GemmDsreluSm100Objects.get(cache_key)
+ if op is None:
+ op = GemmDsreluSm100(
+ sample_a=a_tensor,
+ sample_b=b_tensor,
+ sample_c=c_tensor,
+ sample_d=d_tensor,
+ sample_dprob=dprob_tensor,
+ sample_sfa=sfa_tensor,
+ sample_sfb=sfb_tensor,
+ sample_prob=prob_tensor,
+ sample_sfd=sfd_tensor,
+ sample_amax=amax_tensor,
+ sample_norm_const=norm_const_tensor,
+ alpha=alpha,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ )
+ assert op.check_support(), "Unsupported testcase"
+ op.compile()
+ _cache_of_GemmDsreluSm100Objects[cache_key] = op
+
+ op.execute(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ dprob_tensor=dprob_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ prob_tensor=prob_tensor,
+ sfd_tensor=sfd_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ alpha=alpha,
+ current_stream=stream,
+ )
+
+ return TupleDict(
+ d_tensor=d_tensor,
+ dprob_tensor=dprob_tensor,
+ amax_tensor=amax_tensor,
+ sfd_tensor=sfd_tensor,
+ )
diff --git a/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py b/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py
new file mode 100644
index 00000000..5b197014
--- /dev/null
+++ b/python/cudnn/gemm_dsrelu/dense_blockscaled_gemm_persistent_dsrelu_quant.py
@@ -0,0 +1,2174 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from typing import Type, Tuple, Union, Optional
+
+import cuda.bindings.driver as cuda
+import torch
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu import cpasync, tcgen05
+import cutlass.utils as utils
+import cutlass.pipeline as pipeline
+from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
+import cutlass.utils.blackwell_helpers as sm100_utils
+import cutlass.utils.blockscaled_layout as blockscaled_utils
+
+import cutlass.cute.math as math
+from cutlass.cute.typing import Float32
+from cutlass._mlir.dialects import llvm, nvvm
+from cutlass._mlir.dialects.nvvm import AtomicOpKind
+from cutlass.cutlass_dsl import T
+
+
+def atomic_add_float32(
+ ptr,
+ value: Float32,
+ *,
+ loc=None,
+ ip=None,
+) -> Float32:
+ old_value = nvvm.atomicrmw(
+ AtomicOpKind.FADD,
+ ptr,
+ value.ir_value(loc=loc, ip=ip),
+ loc=loc,
+ ip=ip,
+ )
+
+ return Float32(llvm.bitcast(T.f32(), old_value, loc=loc, ip=ip))
+
+
+class Sm100BlockScaledPersistentDenseGemmKernel:
+ """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
+ and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
+
+ :param sf_vec_size: Scalefactor vector size.
+ :type sf_vec_size: int
+ :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
+ :type cluster_shape_mn: Tuple[int, int]
+
+ :note: In current version, A and B tensor must have the same data type
+ - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
+
+ :note: Supported combinations of A/B data types, SF data typs and SF vector size:
+ - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32
+ - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16
+
+ :note: Supported accumulator data types:
+ - Float32
+
+ :note: Supported C data types:
+ - Float32
+ - Float16/BFloat16
+ - Float8E4M3FN/Float8E5M2
+ # {$nv-internal-release begin}
+ # Note: We don't have SFD generation support in this example for now, so Float4E2M1FN output is only for internal testing and will not be released.
+ - Float4E2M1FN
+ # {$nv-internal-release end}
+
+ :note: Constraints:
+ - MMA tiler M must be 128 or 256 (use_2cta_instrs)
+ - MMA tiler N must be 64/128/192/256
+ - Cluster shape M must be multiple of 2 if Mma tiler M is 256
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 16
+ - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
+
+ Example:
+ >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel(
+ ... sf_vec_size=16,
+ ... mma_tiler_mn=(256, 128),
+ ... cluster_shape_mn=(2, 1)
+ ... )
+ >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, d_tensor, prob_tensor, amax_tensor, sfd_tensor, norm_const_tensor, alpha, max_active_clusters, stream)
+ """
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vector_f32: bool,
+ ):
+ """Initializes the configuration for a Blackwell dense GEMM kernel.
+
+ This configuration includes several key aspects:
+
+ 1. MMA Instruction Settings (tcgen05):
+ - acc_dtype: Data types for MMA accumulator, always set to Float32
+ - sf_vec_size: Scalefactor A/B vector size.
+ - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
+
+ 2. Cluster Shape:
+ - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
+
+ :param sf_vec_size: Scalefactor vector size.
+ :type sf_vec_size: int
+ :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
+ :type cluster_shape_mn: Tuple[int, int]
+ """
+
+ self.acc_dtype = cutlass.Float32
+ self.sf_vec_size = sf_vec_size
+ self.use_2cta_instrs = mma_tiler_mn[0] == 256
+ self.cluster_shape_mn = cluster_shape_mn
+ # K dimension is deferred in _setup_attributes
+ self.mma_tiler = (*mma_tiler_mn, 1)
+
+ self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
+
+ self.occupancy = 1
+ # Set specialized warp ids
+ self.epilog_warp_id = (
+ 0,
+ 1,
+ 2,
+ 3,
+ )
+ self.mma_warp_id = 4
+ self.tma_warp_id = 5
+ self.epilog_load_tma_id = 6
+ self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, self.epilog_load_tma_id, *self.epilog_warp_id))
+ # Set barrier id for epilogue sync and tmem ptr sync
+ self.epilog_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=1,
+ num_threads=32 * len(self.epilog_warp_id),
+ )
+ self.tmem_alloc_barrier = pipeline.NamedBarrier(
+ barrier_id=2,
+ num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
+ )
+ self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
+ SM100_TMEM_CAPACITY_COLUMNS = 512
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
+
+ # Generate sfd output by 2xf32
+ self.vector_f32 = vector_f32
+
+ # Amax reduction configuration
+ self.num_epilog_warps = len(self.epilog_warp_id)
+
+ def _setup_attributes(self):
+ """Set up configurations that are dependent on GEMM inputs
+
+ This method configures various attributes based on the input tensor properties
+ (data types, leading dimensions) and kernel settings:
+ - Configuring tiled MMA
+ - Computing MMA/cluster/tile shapes
+ - Computing cluster layout
+ - Computing multicast CTAs for A/B/SFA/SFB
+ - Computing epilogue subtile
+ - Setting up A/B/SFA/SFB/C stage counts in shared memory
+ - Computing A/B/SFA/SFB/C shared memory layout
+ """
+ # Compute mma instruction shapes
+ # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
+ self.mma_inst_shape_mn = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ )
+ # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K)
+ self.mma_inst_shape_mn_sfb = (
+ self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
+ cute.round_up(self.mma_inst_shape_mn[1], 128),
+ )
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+
+ # Compute mma/cluster/tile shapes
+ mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
+ mma_inst_tile_k = 4
+ self.mma_tiler = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.mma_tiler_sfb = (
+ self.mma_inst_shape_mn_sfb[0],
+ self.mma_inst_shape_mn_sfb[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk = (
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler[1],
+ self.mma_tiler[2],
+ )
+
+ self.mma_tiler_c = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk_c = (
+ self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_c[1],
+ self.mma_tiler_c[2],
+ )
+
+ # Compute cluster layout
+ self.cluster_layout_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma.thr_id.shape,),
+ )
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma_sfb.thr_id.shape,),
+ )
+
+ # Compute number of multicast CTAs for A/B
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
+ self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
+ self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
+
+ # Always use subtile (128,32)
+ self.epi_tile = (cute.make_layout(128), cute.make_layout(32))
+ self.epi_tile_cnt = (
+ self.cta_tile_shape_mnk[0] // cute.size(self.epi_tile[0]),
+ self.cta_tile_shape_mnk[1] // cute.size(self.epi_tile[1]),
+ )
+ # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
+ self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_d_stage = self._compute_stages(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.b_dtype,
+ self.epi_tile,
+ self.c_dtype,
+ self.c_layout,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.d_dtype,
+ self.d_layout,
+ self.smem_capacity,
+ self.occupancy,
+ )
+
+ # Compute A/B/SFA/SFB/C shared memory layout
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.num_ab_stage,
+ )
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ self.mma_tiler,
+ self.b_dtype,
+ self.num_ab_stage,
+ )
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.c_dtype,
+ self.c_layout,
+ self.epi_tile,
+ self.num_c_stage,
+ )
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.d_dtype,
+ self.d_layout,
+ self.epi_tile,
+ self.num_d_stage,
+ )
+
+ # Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case
+ self.overlapping_accum = False # self.num_acc_stage == 1
+
+ @cute.jit
+ def __call__(
+ self,
+ a_tensor: cute.Tensor,
+ b_tensor: cute.Tensor,
+ sfa_tensor: cute.Tensor,
+ sfb_tensor: cute.Tensor,
+ c_tensor: cute.Tensor,
+ d_tensor: cute.Tensor,
+ prob_tensor: cute.Tensor,
+ dprob_tensor: cute.Tensor,
+ amax_tensor: Optional[cute.Tensor],
+ sfd_tensor: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ alpha: cutlass.Float32,
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ epilogue_op: cutlass.Constexpr = lambda x, y: x,
+ ):
+ """Execute the GEMM operation in steps:
+ - Setup static attributes before smem/grid/tma computation
+ - Setup TMA load/store atoms and tensors
+ - Compute grid size with regard to hardware constraints
+ - Define shared storage for kernel
+ - Launch the kernel synchronously
+
+ :param a_tensor: Input tensor A
+ :type a_tensor: cute.Tensor
+ :param b_tensor: Input tensor B
+ :type b_tensor: cute.Tensor
+ :param sfa_tensor: Scale factor tensor A
+ :type sfa_tensor: cute.Tensor
+ :param sfb_tensor: Scale factor tensor B
+ :type sfb_tensor: cute.Tensor
+ :param c_tensor: Output tensor C
+ :type c_tensor: cute.Tensor
+ :param d_tensor: Input tensor D
+ :type d_tensor: cute.Tensor
+ :param prob_tensor: Probability tensor
+ :type prob_tensor: cute.Tensor
+ :param dprob_tensor: Derivative of probability tensor
+ :type dprob_tensor: cute.Tensor
+ :param sfd_tensor: Scale factor tensor C
+ :type sfd_tensor: cute.Tensor
+ :param norm_const_tensor: Normalization constant tensor for quantization
+ :type norm_const_tensor: cute.Tensor
+ :param amax_tensor: Output tensor for absolute maximum value
+ :type amax_tensor: cute.Tensor
+ :param max_active_clusters: Maximum number of active clusters
+ :type max_active_clusters: cutlass.Constexpr
+ :param stream: CUDA stream for asynchronous execution
+ :type stream: cuda.CUstream
+ :param epilogue_op: Optional elementwise lambda function to apply to the output tensor
+ :type epilogue_op: cutlass.Constexpr
+ :raises TypeError: If input data types are incompatible with the MMA instruction.
+ """
+ # Setup static attributes before smem/grid/tma computation
+ self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type
+ self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type
+ self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type
+ self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type
+ self.d_dtype: Type[cutlass.Numeric] = d_tensor.element_type
+ self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode()
+ self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode()
+ self.c_layout = utils.LayoutEnum.from_tensor(c_tensor)
+ self.d_layout = utils.LayoutEnum.from_tensor(d_tensor)
+
+ # Check if input data types are compatible with MMA instruction
+ if cutlass.const_expr(self.a_dtype != self.b_dtype):
+ raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
+
+ # Setup attributes that dependent on gemm inputs
+ self._setup_attributes()
+
+ # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
+ # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size)
+ sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout)
+
+ # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size)
+ sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout)
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+
+ # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release}
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
+
+ # Setup TMA load for A
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
+ a_op,
+ a_tensor,
+ a_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # Setup TMA load for B
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
+ b_op,
+ b_tensor,
+ b_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # Setup TMA load for SFA
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
+ sfa_op,
+ sfa_tensor,
+ sfa_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ # Setup TMA load for SFB
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_op,
+ sfb_tensor,
+ sfb_smem_layout,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb,
+ self.cluster_layout_sfb_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ # {$nv-internal-release begin}
+ # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF))
+ # logical blocks for SFB when cta_tile_shape_n=192.
+ # {$nv-internal-release end}
+ if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192):
+ x = tma_tensor_sfb.stride[0][1]
+ y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
+
+ new_shape = (
+ (tma_tensor_sfb.shape[0][0], ((2, 2), y)),
+ tma_tensor_sfb.shape[1],
+ tma_tensor_sfb.shape[2],
+ )
+ # Use right multiplication for ScaledBasis (3 * x instead of x * 3)
+ x_times_3 = 3 * x
+ new_stride = (
+ (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
+ tma_tensor_sfb.stride[1],
+ tma_tensor_sfb.stride[2],
+ )
+ tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride)
+ tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout)
+
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
+
+ # Setup TMA load for C
+ epi_c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
+ self.tma_c_load_bytes = cute.size_in_bytes(self.c_dtype, epi_c_smem_layout)
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileG2SOp(),
+ c_tensor,
+ epi_c_smem_layout,
+ self.epi_tile,
+ )
+ epi_d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d_tensor,
+ epi_d_smem_layout,
+ self.epi_tile,
+ )
+
+ # Compute grid size
+ self.tile_sched_params, grid = self._compute_grid(
+ c_tensor,
+ self.cta_tile_shape_mnk_c,
+ self.cluster_shape_mn,
+ max_active_clusters,
+ )
+
+ self.buffer_align_bytes = 1024
+
+ self.generate_sfd = sfd_tensor is not None and norm_const_tensor is not None
+ if cutlass.const_expr(self.generate_sfd):
+ sfd_layout = blockscaled_utils.tile_atom_to_shape_SF(c_tensor.shape, self.sf_vec_size)
+ sfd_tensor = cute.make_tensor(sfd_tensor.iterator, sfd_layout)
+
+ self.generate_amax = amax_tensor is not None
+
+ # Define shared storage for kernel
+ @cute.struct
+ class SharedStorage:
+ ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
+ ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
+ acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
+ acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
+ c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage]
+ c_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage]
+ tmem_dealloc_mbar_ptr: cutlass.Int64
+ tmem_holding_buf: cutlass.Int32
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sC: cute.struct.Align[
+ cute.struct.MemRange[
+ self.c_dtype,
+ cute.cosize(self.c_smem_layout_staged.outer),
+ ],
+ self.buffer_align_bytes,
+ ]
+ sD: cute.struct.Align[
+ cute.struct.MemRange[
+ self.d_dtype,
+ cute.cosize(self.d_smem_layout_staged.outer),
+ ],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sA: cute.struct.Align[
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sB: cute.struct.Align[
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sSFA: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sSFB: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ # Amax reduction shared memory (one FP32 per epilogue warp)
+ # Use smaller alignment for amax since it's only 16 bytes
+ sAmax: cute.struct.Align[
+ cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps],
+ self.buffer_align_bytes,
+ ]
+
+ self.shared_storage = SharedStorage
+
+ # Launch the kernel synchronously
+ self.kernel(
+ tiled_mma,
+ tiled_mma_sfb,
+ tma_atom_a,
+ tma_tensor_a,
+ tma_atom_b,
+ tma_tensor_b,
+ tma_atom_sfa,
+ tma_tensor_sfa,
+ tma_atom_sfb,
+ tma_tensor_sfb,
+ tma_atom_c,
+ tma_tensor_c,
+ tma_atom_d,
+ tma_tensor_d,
+ prob_tensor,
+ dprob_tensor,
+ amax_tensor,
+ sfd_tensor,
+ norm_const_tensor,
+ self.cluster_layout_vmnk,
+ self.cluster_layout_sfb_vmnk,
+ self.a_smem_layout_staged,
+ self.b_smem_layout_staged,
+ self.sfa_smem_layout_staged,
+ self.sfb_smem_layout_staged,
+ self.c_smem_layout_staged,
+ self.d_smem_layout_staged,
+ self.epi_tile,
+ self.tile_sched_params,
+ epilogue_op,
+ alpha,
+ ).launch(
+ grid=grid,
+ block=[self.threads_per_cta, 1, 1],
+ cluster=(*self.cluster_shape_mn, 1),
+ stream=stream,
+ )
+ return
+
+ # GPU device kernel
+ @cute.kernel
+ def kernel(
+ self,
+ tiled_mma: cute.TiledMma,
+ tiled_mma_sfb: cute.TiledMma,
+ tma_atom_a: cute.CopyAtom,
+ mA_mkl: cute.Tensor,
+ tma_atom_b: cute.CopyAtom,
+ mB_nkl: cute.Tensor,
+ tma_atom_sfa: cute.CopyAtom,
+ mSFA_mkl: cute.Tensor,
+ tma_atom_sfb: cute.CopyAtom,
+ mSFB_nkl: cute.Tensor,
+ tma_atom_c: cute.CopyAtom,
+ mC_mnl: cute.Tensor,
+ tma_atom_d: cute.CopyAtom,
+ mD_mnl: cute.Tensor,
+ mProb_mnl: cute.Tensor,
+ mDProb_mnl: cute.Tensor,
+ mAmax_tensor: Optional[cute.Tensor],
+ mSFD_mnl: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ cluster_layout_vmnk: cute.Layout,
+ cluster_layout_sfb_vmnk: cute.Layout,
+ a_smem_layout_staged: cute.ComposedLayout,
+ b_smem_layout_staged: cute.ComposedLayout,
+ sfa_smem_layout_staged: cute.Layout,
+ sfb_smem_layout_staged: cute.Layout,
+ c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
+ epi_tile: cute.Tile,
+ tile_sched_params: utils.PersistentTileSchedulerParams,
+ epilogue_op: cutlass.Constexpr,
+ alpha: cutlass.Float32,
+ ):
+ """
+ GPU device kernel performing the Persistent batched GEMM computation.
+ """
+ warp_idx = cute.arch.warp_idx()
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
+
+ #
+ # Prefetch tma desc
+ #
+ if warp_idx == self.tma_warp_id:
+ cpasync.prefetch_descriptor(tma_atom_a)
+ cpasync.prefetch_descriptor(tma_atom_b)
+ cpasync.prefetch_descriptor(tma_atom_sfa)
+ cpasync.prefetch_descriptor(tma_atom_sfb)
+ cpasync.prefetch_descriptor(tma_atom_c)
+ cpasync.prefetch_descriptor(tma_atom_d)
+
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
+
+ #
+ # Setup cta/thread coordinates
+ #
+ # Coords inside cluster
+ bidx, bidy, bidz = cute.arch.block_idx()
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
+ is_leader_cta = mma_tile_coord_v == 0
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
+ # Coord inside cta
+ tidx, _, _ = cute.arch.thread_idx()
+
+ #
+ # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
+ #
+ smem = utils.SmemAllocator()
+ storage = smem.allocate(self.shared_storage)
+
+ # Initialize mainloop ab_pipeline (barrier) and states
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
+ barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_ab_stage,
+ producer_group=ab_pipeline_producer_group,
+ consumer_group=ab_pipeline_consumer_group,
+ tx_count=self.num_tma_load_bytes,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ defer_sync=True,
+ )
+
+ # Initialize acc_pipeline (barrier) and states
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_acc_stage,
+ producer_group=acc_pipeline_producer_group,
+ consumer_group=acc_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ defer_sync=True,
+ )
+ c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ c_consumer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ len(self.epilog_warp_id),
+ )
+ c_pipeline = pipeline.PipelineTmaAsync.create(
+ barrier_storage=storage.c_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_c_stage,
+ producer_group=c_producer_group,
+ consumer_group=c_consumer_group,
+ tx_count=self.tma_c_load_bytes,
+ )
+
+ # Tensor memory dealloc barrier init
+ tmem = utils.TmemAllocator(
+ storage.tmem_holding_buf,
+ barrier_for_retrieve=self.tmem_alloc_barrier,
+ allocator_warp_id=self.epilog_warp_id[0],
+ is_two_cta=use_2cta_instrs,
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
+ )
+
+ # Cluster arrive after barrier init
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
+
+ #
+ # Setup smem tensor A/B/SFA/SFB/C
+ #
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sC = storage.sC.get_tensor(
+ c_smem_layout_staged.outer,
+ swizzle=c_smem_layout_staged.inner,
+ dtype=self.c_dtype,
+ )
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sD = storage.sD.get_tensor(
+ d_smem_layout_staged.outer,
+ swizzle=d_smem_layout_staged.inner,
+ dtype=self.d_dtype,
+ )
+
+ # Shared memory for amax reduction (one FP32 per epilogue warp)
+ # Simple 1D layout. The allocation always here if no amax is generated,
+ # as the overhead is minimal and we want to keep the code simple.
+ amax_layout = cute.make_layout((self.num_epilog_warps,))
+ sAmax = storage.sAmax.get_tensor(amax_layout)
+
+ #
+ # Compute multicast mask for A/B/SFA/SFB buffer full
+ #
+ a_full_mcast_mask = None
+ b_full_mcast_mask = None
+ sfa_full_mcast_mask = None
+ sfb_full_mcast_mask = None
+ if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
+ a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1)
+ sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1)
+
+ #
+ # Local_tile partition global tensors
+ #
+ # (bM, bK, RestM, RestK, RestL)
+ gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ # (bN, bK, RestN, RestK, RestL)
+ gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
+ # (bM, bK, RestM, RestK, RestL)
+ gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ # (bN, bK, RestN, RestK, RestL)
+ gSFB_nkl = cute.local_tile(
+ mSFB_nkl,
+ cute.slice_(self.mma_tiler_sfb, (0, None, None)),
+ (None, None, None),
+ )
+ # (bM, bN, RestM, RestN, RestL)
+ gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None))
+ # (bM, bN, RestM, RestN, RestL)
+ gD_mnl = cute.local_tile(
+ mD_mnl,
+ cute.slice_(self.mma_tiler_c, (None, None, 0)),
+ (None, None, None),
+ )
+ k_tile_cnt = cute.size(gA_mkl, mode=[3])
+
+ #
+ # Partition global tensor for TiledMMA_A/B/C
+ #
+ thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
+ thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ # (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
+ tCgA = thr_mma.partition_A(gA_mkl)
+ # (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
+ tCgB = thr_mma.partition_B(gB_nkl)
+ # (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
+ tCgSFA = thr_mma.partition_A(gSFA_mkl)
+ # (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
+ tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
+ # (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
+ tCgC = thr_mma.partition_C(gC_mnl)
+ # (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
+ tCgD = thr_mma.partition_C(gD_mnl)
+
+ #
+ # Partition global/shared tensor for TMA load A/B
+ #
+ # TMA load A partition_S/D
+ a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestM, RestK, RestL)
+ tAsA, tAgA = cpasync.tma_partition(
+ tma_atom_a,
+ block_in_cluster_coord_vmnk[2],
+ a_cta_layout,
+ cute.group_modes(sA, 0, 3),
+ cute.group_modes(tCgA, 0, 3),
+ )
+ # TMA load B partition_S/D
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestN, RestK, RestL)
+ tBsB, tBgB = cpasync.tma_partition(
+ tma_atom_b,
+ block_in_cluster_coord_vmnk[1],
+ b_cta_layout,
+ cute.group_modes(sB, 0, 3),
+ cute.group_modes(tCgB, 0, 3),
+ )
+
+ # TMA load SFA partition_S/D
+ sfa_cta_layout = a_cta_layout
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestM, RestK, RestL)
+ tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
+ tma_atom_sfa,
+ block_in_cluster_coord_vmnk[2],
+ sfa_cta_layout,
+ cute.group_modes(sSFA, 0, 3),
+ cute.group_modes(tCgSFA, 0, 3),
+ )
+ tAsSFA = cute.filter_zeros(tAsSFA)
+ tAgSFA = cute.filter_zeros(tAgSFA)
+
+ # TMA load SFB partition_S/D
+ sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestN, RestK, RestL)
+ tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
+ tma_atom_sfb,
+ block_in_cluster_coord_sfb_vmnk[1],
+ sfb_cta_layout,
+ cute.group_modes(sSFB, 0, 3),
+ cute.group_modes(tCgSFB, 0, 3),
+ )
+ tBsSFB = cute.filter_zeros(tBsSFB)
+ tBgSFB = cute.filter_zeros(tBgSFB)
+
+ #
+ # Partition shared/tensor memory tensor for TiledMMA_A/B/C
+ #
+ # (MMA, MMA_M, MMA_K, STAGE)
+ tCrA = tiled_mma.make_fragment_A(sA)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ tCrB = tiled_mma.make_fragment_B(sB)
+ # (MMA, MMA_M, MMA_N)
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
+
+ #
+ # Cluster wait before tensor memory alloc
+ #
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
+
+ #
+ # Specialized TMA load warp
+ #
+ if warp_idx == self.tma_warp_id:
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ #
+ # Slice to per mma tile index
+ #
+ # ((atom_v, rest_v), RestK)
+ tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
+ # ((atom_v, rest_v), RestK)
+ tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
+
+ # ((atom_v, rest_v), RestK)
+ tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
+
+ # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release}
+ slice_n = mma_tile_coord_mnl[1]
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ slice_n = mma_tile_coord_mnl[1] // 2
+ # ((atom_v, rest_v), RestK)
+ tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])]
+
+ # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
+ ab_producer_state.reset_count()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ #
+ # Tma load loop
+ #
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ # Conditionally wait for AB buffer empty
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
+
+ # TMA load A/B/SFA/SFB
+ cute.copy(
+ tma_atom_a,
+ tAgA_slice[(None, ab_producer_state.count)],
+ tAsA[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=a_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_b,
+ tBgB_slice[(None, ab_producer_state.count)],
+ tBsB[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=b_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_sfa,
+ tAgSFA_slice[(None, ab_producer_state.count)],
+ tAsSFA[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=sfa_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_sfb,
+ tBgSFB_slice[(None, ab_producer_state.count)],
+ tBsSFB[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=sfb_full_mcast_mask,
+ )
+
+ # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
+ ab_producer_state.advance()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Wait A/B buffer empty
+ #
+ ab_pipeline.producer_tail(ab_producer_state)
+
+ #
+ # Specialized MMA warp
+ #
+ if warp_idx == self.mma_warp_id:
+ #
+ # Bar sync for retrieve tensor memory ptr from shared mem
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor
+ #
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ # Make accumulator tmem tensor
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ # Make SFA tmem tensor
+ sfa_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
+ dtype=self.sf_dtype,
+ )
+ # (MMA, MMA_M, MMA_K)
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
+
+ # Make SFB tmem tensor
+ sfb_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA),
+ dtype=self.sf_dtype,
+ )
+ # (MMA, MMA_N, MMA_K)
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
+ #
+ # Partition for S2T copy of SFA/SFB
+ #
+ (
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t,
+ tCtSFA_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
+ (
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t,
+ tCtSFB_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
+
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
+ acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ # Set tensor memory buffer for current tile
+ # (MMA, MMA_M, MMA_N)
+ tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
+
+ # Peek (try_wait) AB buffer full for k_tile = 0
+ ab_consumer_state.reset_count()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ #
+ # Wait for accumulator buffer empty
+ #
+ if is_leader_cta:
+ acc_pipeline.producer_acquire(acc_producer_state)
+
+ # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or cta_tile_shape_n=64 # {$nv-internal-release}
+ tCtSFB_mma = tCtSFB
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB)
+ offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+ elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ # Move in increments of 64 columns of SFB
+ offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+
+ #
+ # Reset the ACCUMULATE field for each tile
+ #
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
+
+ #
+ # Mma mainloop
+ #
+ for k_tile in range(k_tile_cnt):
+ if is_leader_cta:
+ # Conditionally wait for AB buffer full
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
+
+ # Copy SFA/SFB from smem to tmem
+ s2t_stage_coord = (
+ None,
+ None,
+ None,
+ None,
+ ab_consumer_state.index,
+ )
+ tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
+ tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
+ cute.copy(
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t_staged,
+ tCtSFA_compact_s2t,
+ )
+ cute.copy(
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t_staged,
+ tCtSFB_compact_s2t,
+ )
+
+ # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB
+ num_kblocks = cute.size(tCrA, mode=[2])
+ for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
+ kblock_coord = (
+ None,
+ None,
+ kblock_idx,
+ ab_consumer_state.index,
+ )
+
+ # Set SFA/SFB tensor to tiled_mma
+ sf_kblock_coord = (None, None, kblock_idx)
+ tiled_mma.set(
+ tcgen05.Field.SFA,
+ tCtSFA[sf_kblock_coord].iterator,
+ )
+ tiled_mma.set(
+ tcgen05.Field.SFB,
+ tCtSFB_mma[sf_kblock_coord].iterator,
+ )
+
+ cute.gemm(
+ tiled_mma,
+ tCtAcc,
+ tCrA[kblock_coord],
+ tCrB[kblock_coord],
+ tCtAcc,
+ )
+
+ # Enable accumulate on tCtAcc after first kblock
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
+
+ # Async arrive AB buffer empty
+ ab_pipeline.consumer_release(ab_consumer_state)
+
+ # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
+ ab_consumer_state.advance()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt:
+ if is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ #
+ # Async arrive accumulator buffer full
+ #
+ if is_leader_cta:
+ acc_pipeline.producer_commit(acc_producer_state)
+ acc_producer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Wait for accumulator buffer empty
+ #
+ acc_pipeline.producer_tail(acc_producer_state)
+ #
+ # Specialized epilogue warps
+ #
+ if warp_idx < self.mma_warp_id:
+ #
+ # Alloc tensor memory buffer
+ #
+ tmem.allocate(self.num_tmem_alloc_cols)
+
+ #
+ # Bar sync for retrieve tensor memory ptr from shared memory
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr and make accumulator tensor
+ #
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ #
+ # Partition for epilogue
+ #
+ epi_tidx = tidx
+ (
+ tiled_copy_t2r,
+ tTR_tAcc_base,
+ tTR_rAcc,
+ ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs)
+
+ tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
+ tiled_copy_s2r, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC)
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ _, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD)
+ (
+ bSG_sD,
+ bSG_gD_mnl,
+ ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD, epi_tile, sD)
+
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
+
+ # Threads/warps participating in tma store pipeline
+ d_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ 32 * len(self.epilog_warp_id),
+ )
+ d_pipeline = pipeline.PipelineTmaStore.create(
+ num_stages=self.num_d_stage,
+ producer_group=d_producer_group,
+ )
+ # Load C pipeline
+ c_pipeline_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_c_stage)
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ #
+ # Slice to per mma tile index
+ #
+ # ((ATOM_V, REST_V), EPI_M, EPI_N)
+ bSG_gD = bSG_gD_mnl[
+ (
+ None,
+ None,
+ None,
+ *mma_tile_coord_mnl,
+ )
+ ]
+ # Set tensor memory buffer for current tile
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
+
+ # Initialize thread-local amax accumulator for this tile
+ # Use 0.0 as initial value since we're computing absolute maximum
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = cutlass.Float32(0.0)
+
+ #
+ # Wait for accumulator buffer full
+ #
+ acc_pipeline.consumer_wait(acc_consumer_state)
+
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
+
+ #
+ # Store accumulator to global memory in subtiles
+ #
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8))
+ num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
+ #
+ # Get PROB
+ # Note, it always assumes T2R_M/EPI_M is 1, otherwise it will break the result.
+ #
+ mPosition = cur_tile_coord[0] * self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape) + tidx
+ mProb = mProb_mnl[mPosition, 0, cur_tile_coord[2]]
+ dProb = cutlass.Float32(0.0)
+ for subtile_idx in cutlass.range(0, subtile_cnt, 1):
+ #
+ # Load accumulator from tensor memory buffer to register
+ #
+ tTR_tAcc_subtile = tTR_tAcc[(None, None, None, subtile_idx)]
+ cute.copy(tiled_copy_t2r, tTR_tAcc_subtile, tTR_rAcc)
+
+ #
+ # Convert to C type
+ #
+ acc_vec = tiled_copy_s2r.retile(tTR_rAcc)
+
+ #
+ # Apply alpha
+ #
+ acc_vec_ = acc_vec.load()
+ acc_vec.store(acc_vec_ * alpha)
+
+ #
+ # Load C from shared memory
+ #
+ c_pipeline.consumer_wait(c_pipeline_consumer_state)
+ cute.copy(
+ tiled_copy_s2r,
+ tRS_sC[(None, None, None, c_pipeline_consumer_state.index)],
+ tRS_rC,
+ )
+ cute.arch.fence_proxy("async.shared", space="cta")
+ c_pipeline.consumer_release(c_pipeline_consumer_state)
+ c_pipeline_consumer_state.advance()
+ c_vec = tiled_copy_s2r.retile(tRS_rC)
+
+ #
+ # Generate dSquared ReLu
+ #
+ acc_values = acc_vec.load()
+ c_values = c_vec.load()
+ dsquared_relu = epilogue_op(acc_values, c_values * mProb)
+
+ #
+ # Generate dProb
+ #
+ acc_relu = cute.where(acc_values > 0, acc_values, cute.full_like(acc_values, 0))
+ dProb += (acc_relu**2 * c_values).reduce(
+ cute.ReductionOp.ADD,
+ cutlass.Float32(0.0),
+ 0,
+ )
+
+ # Write dsquared_relu values back for store
+ acc_vec.store(dsquared_relu)
+
+ #
+ # Generate amax
+ #
+ if cutlass.const_expr(self.generate_amax):
+ # Apply element-wise absolute value using math.absf (supports vectors)
+ dsquared_relu_ir = cutlass._mlir.dialects.math.absf(dsquared_relu.ir_value()) # operand (positional)
+ dsquared_relu_values = type(acc_values)(dsquared_relu_ir, acc_values.shape, acc_values.dtype)
+ subtile_amax = dsquared_relu_values.reduce(
+ cute.ReductionOp.MAX,
+ cutlass.Float32(0.0),
+ 0, # Use 0.0 as init for abs values
+ )
+ thread_tile_amax = cute.arch.fmax(thread_tile_amax, subtile_amax)
+
+ #
+ # Generate sfd
+ #
+ if cutlass.const_expr(self.generate_sfd):
+ cute.printf("SFD not implemented\n")
+ else:
+ #
+ # Convert to D type directly
+ #
+ acc_vec = tiled_copy_s2r.retile(acc_vec).load()
+ tRS_rD.store(acc_vec.to(self.d_dtype))
+
+ #
+ # Store D to shared memory
+ #
+ d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
+ cute.copy(
+ tiled_copy_s2r,
+ tRS_rD,
+ tRS_sD[(None, None, None, d_buffer)],
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ #
+ # TMA store D to global memory
+ #
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(
+ tma_atom_d,
+ bSG_sD[(None, d_buffer)],
+ bSG_gD[(None, subtile_idx)],
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ d_pipeline.producer_commit()
+ d_pipeline.producer_acquire()
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ # Perform amax reduction after all subtiles are processed
+ if cutlass.const_expr(self.generate_amax):
+ # Warp-level reduction using wrapper function
+ warp_amax = cute.arch.warp_redux_sync(
+ value=thread_tile_amax,
+ kind="fmax",
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ # Each epilogue warp's lane 0 writes warp amax to shared memory
+ if cute.arch.lane_idx() == 0:
+ sAmax[warp_idx] = cutlass.Float32(warp_amax)
+
+ # Ensure all epilogue warps complete their writes before block reduction
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ # Block-level reduction: only first epilogue warp's lane 0 handles this
+ if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0:
+ block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum
+ for i in cutlass.range(self.num_epilog_warps):
+ warp_amax_val = sAmax[i]
+ block_amax = cute.arch.fmax(block_amax, warp_amax_val)
+
+ # Global atomic max (accumulates across all tiles for final tensor amax)
+ # Since we compute absolute values, all values are non-negative
+ # Use wrapper function for atomic max operation
+ _ = cute.arch.atomic_max_float32(ptr=mAmax_tensor.iterator.llvm_ptr, value=block_amax)
+
+ # write dProb result to global memory
+ _ = atomic_add_float32(
+ ptr=mDProb_mnl[(mPosition, None, cur_tile_coord[2])].iterator.llvm_ptr,
+ value=dProb,
+ )
+
+ #
+ # Async arrive accumulator buffer empty
+ #
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Dealloc the tensor memory buffer
+ #
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier.arrive_and_wait()
+ tmem.free(acc_tmem_ptr)
+ #
+ # Wait for C store complete
+ #
+ d_pipeline.producer_tail()
+
+ #
+ # Specialized epilog load warp
+ #
+ if warp_idx == self.epilog_load_tma_id:
+ ## M 1024, N 512
+ ## tCgC (((32,128),1),1,8,4,2,1) : (((1@0,1@1),0),0,32@0,256@1,512@0,1@2)
+ ## bGS_gC_mnl (((32,128),1),1,8,EPI_M,EPI_N,L) : (((1@0,1@1),0),0,32@0,256@1,512@0,1@2)
+ ## bGS_sC ((4096, 1), (1, 4)) : ((1, 0), (0, 4096))
+ (
+ bGS_sC,
+ bGS_gC_mnl,
+ ) = self.epilog_gmem_copy_and_partition(tidx, tma_atom_c, tCgC, epi_tile, sC)
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ c_pipeline_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_c_stage)
+ is_reverse = True
+ while work_tile.is_valid_tile:
+ # if it needs to be reversed
+ if cutlass.const_expr(self.overlapping_accum):
+ reverse_subtile = is_reverse
+ is_reverse = not is_reverse
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+ bGS_gC = bGS_gC_mnl[
+ (
+ None,
+ None,
+ None,
+ *mma_tile_coord_mnl,
+ )
+ ]
+ bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
+ subtile_cnt = cute.size(bGS_gC.shape, mode=[1])
+ for subtile_idx in cutlass.range(subtile_cnt):
+ # Check real subtile index
+ real_subtile_idx = subtile_idx
+ if cutlass.const_expr(self.overlapping_accum):
+ if reverse_subtile:
+ # Subtile always iterates on N dimension as we only have 4x1DP tmem load pattern for cta_tile_m = 128 cases. # {$nv-internal-release}
+ real_subtile_idx = subtile_cnt - 1 - subtile_idx
+ # Load C from global memory to shared memory using TMALDG
+ c_pipeline.producer_acquire(c_pipeline_producer_state)
+ cute.copy(
+ tma_atom_c,
+ bGS_gC[(None, real_subtile_idx)],
+ bGS_sC[(None, c_pipeline_producer_state.index)],
+ tma_bar_ptr=c_pipeline.producer_get_barrier(c_pipeline_producer_state),
+ )
+ c_pipeline_producer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Wait C buffer empty
+ #
+ c_pipeline.producer_tail(c_pipeline_producer_state)
+
+ def mainloop_s2t_copy_and_partition(
+ self,
+ sSF: cute.Tensor,
+ tSF: cute.Tensor,
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination).
+
+ :param sSF: The scale factor tensor in smem
+ :type sSF: cute.Tensor
+ :param tSF: The scale factor tensor in tmem
+ :type tSF: cute.Tensor
+
+ :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where:
+ - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t)
+ - tCsSF_compact_s2t: The partitioned scale factor tensor in smem
+ - tSF_compact_s2t: The partitioned scale factor tensor in tmem
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
+ """
+ # (MMA, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact = cute.filter_zeros(sSF)
+ # (MMA, MMA_MN, MMA_K)
+ tCtSF_compact = cute.filter_zeros(tSF)
+
+ # Make S2T CopyAtom and tiledCopy
+ copy_atom_s2t = cute.make_copy_atom(
+ tcgen05.Cp4x32x128bOp(self.cta_group),
+ self.sf_dtype,
+ )
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
+
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
+
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
+
+ def epilog_tmem_copy_and_partition(
+ self,
+ tidx: cutlass.Int32,
+ tAcc: cute.Tensor,
+ gD_mnl: cute.Tensor,
+ epi_tile: cute.Tile,
+ use_2cta_instrs: Union[cutlass.Boolean, bool],
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
+
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param tAcc: The accumulator tensor to be copied and partitioned
+ :type tAcc: cute.Tensor
+ :param gD_mnl: The global tensor C
+ :type gD_mnl: cute.Tensor
+ :param epi_tile: The epilogue tiler
+ :type epi_tile: cute.Tile
+ :param use_2cta_instrs: Whether use_2cta_instrs is enabled
+ :type use_2cta_instrs: bool
+
+ :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
+ - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
+ - tTR_tAcc: The partitioned accumulator tensor
+ - tTR_rAcc: The partitioned accumulator tensor for acc up
+ - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]
+ """
+ # Make tiledCopy for tensor memory load
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
+ self.cta_tile_shape_mnk,
+ self.c_layout,
+ self.c_dtype,
+ self.acc_dtype,
+ epi_tile,
+ use_2cta_instrs,
+ )
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
+ tAcc_epi = cute.flat_divide(
+ tAcc[((None, None), 0, 0, None)],
+ epi_tile,
+ )
+ # (EPI_TILE_M, EPI_TILE_N)
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
+
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ tTR_gD = thr_copy_t2r.partition_D(gD_mnl_epi)
+ # (T2R, T2R_M, T2R_N)
+ tTR_rAcc = cute.make_rmem_tensor(tTR_gD[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
+
+ def epilog_smem_copy_and_partition(
+ self,
+ tiled_copy_t2r: cute.TiledCopy,
+ tTR_rC: cute.Tensor,
+ tidx: cutlass.Int32,
+ sC: cute.Tensor,
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
+
+ :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
+ :type tiled_copy_t2r: cute.TiledCopy
+ :param tTR_rC: The partitioned accumulator tensor
+ :type tTR_rC: cute.Tensor
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param sC: The shared memory tensor to be copied and partitioned
+ :type sC: cute.Tensor
+ :type sepi: cute.Tensor
+
+ :return: A tuple containing (tiled_copy_s2r, tRS_rC, tRS_sC) where:
+ - tiled_copy_s2r: The tiled copy operation for register to smem copy(r2s)
+ - tRS_rC: The partitioned tensor C (register source)
+ - tRS_sC: The partitioned tensor C (smem destination)
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
+ """
+ copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r)
+ tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
+ # (R2S, R2S_M, R2S_N, PIPE_D)
+ thr_copy_r2s = tiled_copy_s2r.get_slice(tidx)
+ tRS_sC = thr_copy_r2s.partition_D(sC)
+ # (R2S, R2S_M, R2S_N)
+ tRS_rC = tiled_copy_s2r.retile(tTR_rC)
+ return tiled_copy_s2r, tRS_rC, tRS_sC
+
+ def epilog_gmem_copy_and_partition(
+ self,
+ tidx: cutlass.Int32,
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
+ gC_mnl: cute.Tensor,
+ epi_tile: cute.Tile,
+ sC: cute.Tensor,
+ ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
+ """Make tiledCopy for global memory store, then use it to:
+ partition shared memory (source) and global memory (destination) for TMA store version.
+
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
+ :type atom: cute.CopyAtom or cute.TiledCopy
+ :param gC_mnl: The global tensor C
+ :type gC_mnl: cute.Tensor
+ :param epi_tile: The epilogue tiler
+ :type epi_tile: cute.Tile
+ :param sC: The shared memory tensor to be copied and partitioned
+ :type sC: cute.Tensor
+
+ :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where:
+ - tma_atom_c: The TMA copy atom
+ - bSG_sC: The partitioned shared memory tensor C
+ - bSG_gC: The partitioned global tensor C
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
+ """
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ # ((ATOM_V, REST_V), EPI_M, EPI_N)
+ # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
+ bSG_sC, bSG_gC = cpasync.tma_partition(
+ atom,
+ 0,
+ cute.make_layout(1),
+ cute.group_modes(sC, 0, 2),
+ cute.group_modes(gC_epi, 0, 2),
+ )
+ return bSG_sC, bSG_gC
+
+ @staticmethod
+ def _compute_stages(
+ tiled_mma: cute.TiledMma,
+ mma_tiler_mnk: Tuple[int, int, int],
+ a_dtype: Type[cutlass.Numeric],
+ b_dtype: Type[cutlass.Numeric],
+ epi_tile: cute.Tile,
+ c_dtype: Type[cutlass.Numeric],
+ c_layout: utils.LayoutEnum,
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ d_layout: utils.LayoutEnum,
+ smem_capacity: int,
+ occupancy: int,
+ ) -> Tuple[int, int, int]:
+ """Computes the number of stages for A/B/C operands based on heuristics.
+
+ :param tiled_mma: The tiled MMA object defining the core computation.
+ :type tiled_mma: cute.TiledMma
+ :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
+ :type mma_tiler_mnk: tuple[int, int, int]
+ :param a_dtype: Data type of operand A.
+ :type a_dtype: type[cutlass.Numeric]
+ :param b_dtype: Data type of operand B.
+ :type b_dtype: type[cutlass.Numeric]
+ :param epi_tile: The epilogue tile shape.
+ :type epi_tile: cute.Tile
+ :param c_dtype: Data type of operand C (output).
+ :type c_dtype: type[cutlass.Numeric]
+ :param c_layout: Layout enum of operand C.
+ :type c_layout: utils.LayoutEnum
+ :param sf_dtype: Data type of Scale factor.
+ :type sf_dtype: type[cutlass.Numeric]
+ :param sf_vec_size: Scale factor vector size.
+ :type sf_vec_size: int
+ :param d_dtype: Data type of operand D.
+ :type d_dtype: type[cutlass.Numeric]
+ :param d_layout: Layout enum of operand D.
+ :type d_layout: utils.LayoutEnum
+ :param smem_capacity: Total available shared memory capacity in bytes.
+ :type smem_capacity: int
+ :param occupancy: Target number of CTAs per SM (occupancy).
+ :type occupancy: int
+
+ :return: A tuple containing the computed number of stages for:
+ (ACC stages, A/B operand stages, C stages)
+ :rtype: tuple[int, int, int]
+ """
+ # ACC stages
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
+
+ # Default C stages
+ num_c_stage = 2 # mma_tiler_mnk[1] // cute.cosize(epi_tile[1])
+ num_d_stage = num_c_stage
+
+ # Calculate smem layout and size for one stage of A, B, SFA, SFB and C
+ a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ mma_tiler_mnk,
+ a_dtype,
+ 1, # a tmp 1 stage is provided
+ )
+ b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ mma_tiler_mnk,
+ b_dtype,
+ 1, # a tmp 1 stage is provided
+ )
+ sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ mma_tiler_mnk,
+ sf_vec_size,
+ 1, # a tmp 1 stage is provided
+ )
+ sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ mma_tiler_mnk,
+ sf_vec_size,
+ 1, # a tmp 1 stage is provided
+ )
+
+ c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
+ c_dtype,
+ c_layout,
+ epi_tile,
+ 1,
+ )
+
+ d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
+ d_dtype,
+ d_layout,
+ epi_tile,
+ 1,
+ )
+
+ ab_bytes_per_stage = (
+ cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
+ + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
+ )
+ mbar_helpers_bytes = 1024
+ c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
+ d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
+ amax_bytes = Sm100BlockScaledPersistentDenseGemmKernel.get_amax_smem_size()
+ epi_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage * num_d_stage + amax_bytes
+
+ # Calculate A/B/SFA/SFB stages:
+ # Start with total smem per CTA (capacity / occupancy)
+ # Subtract reserved bytes and initial C stages bytes
+ # Divide remaining by bytes needed per A/B/SFA/SFB stage
+ num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)) // ab_bytes_per_stage
+
+ return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage
+
+ @staticmethod
+ def _compute_grid(
+ c: cute.Tensor,
+ cta_tile_shape_mnk: Tuple[int, int, int],
+ cluster_shape_mn: Tuple[int, int],
+ max_active_clusters: cutlass.Constexpr,
+ ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
+ """Use persistent tile scheduler to compute the grid size for the output tensor C.
+
+ :param c: The output tensor C
+ :type c: cute.Tensor
+ :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
+ :type cta_tile_shape_mnk: tuple[int, int, int]
+ :param cluster_shape_mn: Shape of each cluster in M, N dimensions.
+ :type cluster_shape_mn: tuple[int, int]
+ :param max_active_clusters: Maximum number of active clusters.
+ :type max_active_clusters: cutlass.Constexpr
+
+ :return: A tuple containing:
+ - tile_sched_params: Parameters for the persistent tile scheduler.
+ - grid: Grid shape for kernel launch.
+ :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
+ """
+ c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
+ gc = cute.zipped_divide(c, tiler=c_shape)
+ num_ctas_mnl = gc[(0, (None, None, None))].shape
+ cluster_shape_mnl = (*cluster_shape_mn, 1)
+
+ tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
+ grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters)
+
+ return tile_sched_params, grid
+
+ @staticmethod
+ def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float:
+ """
+ Calculates the reciprocal of the maximum absolute value for a given data type.
+
+ :param dtype: Data type
+ :type dtype: Type[cutlass.Numeric]
+
+ :return: An float representing the reciprocal of the maximum absolute value
+ :rtype: float
+ """
+ if dtype == cutlass.Float4E2M1FN:
+ return 1 / 6.0
+ if dtype == cutlass.Float8E4M3FN:
+ return 1 / 448.0
+ if dtype == cutlass.Float8E5M2:
+ return 1 / 128.0
+ return 1.0
+
+ @staticmethod
+ def is_valid_dtypes_and_scale_factor_vec_size(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ ) -> bool:
+ """
+ Check if the dtypes and sf_vec_size are valid combinations
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param sf_dtype: The data type of the scale factor
+ :type sf_dtype: Type[cutlass.Numeric]
+ :param sf_vec_size: The vector size of the scale factor
+ :type sf_vec_size: int
+ :param d_dtype: The data type of the output tensor
+ :type d_dtype: Type[cutlass.Numeric]
+
+ :return: True if the dtypes and sf_vec_size are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ # Check valid ab_dtype
+ if ab_dtype not in {
+ cutlass.Float4E2M1FN,
+ cutlass.Float8E5M2,
+ cutlass.Float8E4M3FN,
+ }:
+ is_valid = False
+
+ if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
+ # Check valid sf_vec_size
+ if sf_vec_size not in {16, 32}:
+ is_valid = False
+ # Check valid sf_dtype
+ if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}:
+ is_valid = False
+ # Check valid sf_dtype and sf_vec_size combinations
+ if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32:
+ is_valid = False
+ if sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 16:
+ is_valid = False
+
+ # Check valid c_dtype
+ if d_dtype not in {
+ cutlass.Float32,
+ cutlass.Float16,
+ cutlass.BFloat16,
+ cutlass.Float8E5M2,
+ cutlass.Float8E4M3FN,
+ cutlass.Float4E2M1FN, # {$nv-internal-release}
+ }:
+ is_valid = False
+
+ return is_valid
+
+ @staticmethod
+ def is_valid_layouts(
+ ab_dtype: Type[cutlass.Numeric],
+ c_dtype: Type[cutlass.Numeric],
+ a_major: str,
+ b_major: str,
+ c_major: str,
+ ) -> bool:
+ """
+ Check if layouts and dtypes are valid combinations
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param c_dtype: The data type of the output tensor
+ :type c_dtype: Type[cutlass.Numeric]
+ :param a_major: The major dimension of the A tensor
+ :type a_major: str
+ :param b_major: The major dimension of the B tensor
+ :type b_major: str
+ :param c_major: The major dimension of the C tensor
+ :type c_major: str
+
+ :return: True if the layouts are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ # {$nv-internal-release begin}
+ if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
+ is_valid = False
+ # TODO: Currently we don't support m major output for Float4E2M1FN
+ if c_dtype is cutlass.Float4E2M1FN and c_major == "m":
+ is_valid = False
+ # {$nv-internal-release end}
+
+ return is_valid
+
+ @staticmethod
+ def is_valid_mma_tiler_and_cluster_shape(
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ ) -> bool:
+ """
+ Check if the mma tiler and cluster shape are valid
+
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
+ :type cluster_shape_mn: Tuple[int, int]
+
+ :return: True if the mma tiler and cluster shape are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+ # Skip invalid mma tile shape
+ if mma_tiler_mn[0] not in [128, 256]:
+ is_valid = False
+ if mma_tiler_mn[1] not in [64, 128, 192, 256]:
+ is_valid = False
+ # Skip illegal cluster shape
+ if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0:
+ is_valid = False
+ # Skip invalid cluster shape
+ is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
+ if (
+ cluster_shape_mn[0] * cluster_shape_mn[1] > 16
+ or cluster_shape_mn[0] <= 0
+ or cluster_shape_mn[1] <= 0
+ # Special cluster shape check for scale factor multicasts.
+ # Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
+ or cluster_shape_mn[0] > 4
+ or cluster_shape_mn[1] > 4
+ or not is_power_of_2(cluster_shape_mn[0])
+ or not is_power_of_2(cluster_shape_mn[1])
+ ):
+ is_valid = False
+ return is_valid
+
+ @staticmethod
+ def is_valid_tensor_alignment(
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ ab_dtype: Type[cutlass.Numeric],
+ c_dtype: Type[cutlass.Numeric],
+ a_major: str,
+ b_major: str,
+ c_major: str,
+ ) -> bool:
+ """
+ Check if the tensor alignment is valid
+
+ :param m: The number of rows in the A tensor
+ :type m: int
+ :param n: The number of columns in the B tensor
+ :type n: int
+ :param k: The number of columns in the A tensor
+ :type k: int
+ :param l: The number of columns in the C tensor
+ :type l: int
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param c_dtype: The data type of the output tensor
+ :type c_dtype: Type[cutlass.Numeric]
+ :param a_major: The major axis of the A tensor
+ :type a_major: str
+ :param b_major: The major axis of the B tensor
+ :type b_major: str
+ :param c_major: The major axis of the C tensor
+ :type c_major: str
+
+ :return: True if the problem shape is valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
+ major_mode_idx = 0 if is_mode0_major else 1
+ num_major_elements = tensor_shape[major_mode_idx]
+ num_contiguous_elements = 16 * 8 // dtype.width
+ return num_major_elements % num_contiguous_elements == 0
+
+ if (
+ not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
+ or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
+ or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
+ ):
+ is_valid = False
+ return is_valid
+
+ @staticmethod
+ def can_implement(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ a_major: str,
+ b_major: str,
+ d_major: str,
+ ) -> bool:
+ """
+ Check if the gemm can be implemented
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param sf_dtype: The data type of the scale factor tensor
+ :type sf_dtype: Type[cutlass.Numeric]
+ :param sf_vec_size: The vector size
+ :type sf_vec_size: int
+ :param d_dtype: The data type of the output tensor
+ :type d_dtype: Type[cutlass.Numeric]
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
+ :type cluster_shape_mn: Tuple[int, int]
+ :param m: The number of rows in the A tensor
+ :type m: int
+ :param n: The number of columns in the B tensor
+ :type n: int
+ :param k: The number of columns in the A tensor
+ :type k: int
+ :param l: The number of columns in the C tensor
+ :type l: int
+ :param a_major: The major axis of the A tensor
+ :type a_major: str
+ :param b_major: The major axis of the B tensor
+ :type b_major: str
+ :param d_major: The major axis of the C tensor
+ :type d_major: str
+
+ :return: True if the gemm can be implemented, False otherwise
+ :rtype: bool
+ """
+ can_implement = True
+ # Skip unsupported types
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, d_dtype):
+ can_implement = False
+ # Skip unsupported layouts
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(ab_dtype, d_dtype, a_major, b_major, d_major):
+ can_implement = False
+ # Skip invalid mma tile shape and cluster shape
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
+ can_implement = False
+ # Skip illegal problem shape for load/store alignment
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major):
+ can_implement = False
+ return can_implement
+
+ @staticmethod
+ def get_amax_smem_size():
+ # Note: 4 is hardcoded for num_epilog_warps
+ return 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,)))
diff --git a/python/cudnn/gemm_srelu/__init__.py b/python/cudnn/gemm_srelu/__init__.py
new file mode 100644
index 00000000..29324f36
--- /dev/null
+++ b/python/cudnn/gemm_srelu/__init__.py
@@ -0,0 +1,9 @@
+from .api import (
+ GemmSreluSm100,
+ gemm_srelu_wrapper_sm100,
+)
+
+__all__ = [
+ "GemmSreluSm100",
+ "gemm_srelu_wrapper_sm100",
+]
diff --git a/python/cudnn/gemm_srelu/api.py b/python/cudnn/gemm_srelu/api.py
new file mode 100644
index 00000000..3bf21f95
--- /dev/null
+++ b/python/cudnn/gemm_srelu/api.py
@@ -0,0 +1,587 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from __future__ import annotations
+
+import logging
+import os
+from typing import Optional, Tuple
+
+import cutlass
+import cutlass.cute as cute
+import torch
+from cuda.bindings import driver as cuda
+from cutlass.cute.runtime import make_fake_stream
+
+from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2
+from cudnn.datatypes import _convert_to_cutlass_data_type
+
+from .dense_blockscaled_gemm_persistent_srelu_quant import (
+ Sm100BlockScaledPersistentDenseGemmKernel,
+)
+
+
+def _major_from_stride_order(stride_order: Tuple[int, ...], mode0_label: str, mode1_label: str) -> str:
+ if stride_order == (0, 1, 2):
+ return mode0_label
+ if stride_order == (1, 0, 2):
+ return mode1_label
+ raise ValueError(f"Unsupported stride order {stride_order}")
+
+
+class GemmSreluSm100(APIBase):
+ def __init__(
+ self,
+ sample_a: torch.Tensor,
+ sample_b: torch.Tensor,
+ sample_c: torch.Tensor,
+ sample_d: torch.Tensor,
+ sample_sfa: torch.Tensor,
+ sample_sfb: torch.Tensor,
+ sample_prob: torch.Tensor,
+ sample_sfd: Optional[torch.Tensor] = None,
+ sample_amax: Optional[torch.Tensor] = None,
+ sample_norm_const: Optional[torch.Tensor] = None,
+ alpha: float = 1.0,
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ ):
+ super().__init__()
+
+ self._warn_experimental_api()
+
+ self.a_desc = self._make_tensor_desc(sample_a, name="sample_a")
+ self.b_desc = self._make_tensor_desc(sample_b, name="sample_b")
+ self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
+ self.d_desc = self._make_tensor_desc(sample_d, name="sample_d")
+ self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa")
+ self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
+ self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
+ self.sfd_desc = self._make_tensor_desc(sample_sfd, name="sample_sfd")
+ self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "amax")
+ self.norm_const_desc = self._unpad_tensor_to_ndim(
+ self._make_tensor_desc(sample_norm_const, name="sample_norm_const"),
+ 1,
+ "norm_const",
+ )
+
+ self.alpha = alpha
+ self.acc_dtype = acc_dtype
+ self.mma_tiler_mn = mma_tiler_mn
+ self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if mma_tiler_mn[0] == 256 else (1, 1))
+ self.sf_vec_size = sf_vec_size
+ self.vector_f32 = vector_f32
+ self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
+
+ self._interpret_uint8_as_fp4x2 = True
+ self._kernel = Sm100BlockScaledPersistentDenseGemmKernel
+
+ def check_support(self) -> bool:
+ m, k, l = self._tensor_shape(self.a_desc, name="sample_a")
+ n, b_k, b_l = self._tensor_shape(self.b_desc, name="sample_b")
+ c_m, c_n, c_l = self._tensor_shape(self.c_desc, name="sample_c")
+ d_m, d_n, d_l = self._tensor_shape(self.d_desc, name="sample_d")
+
+ self._value_error_if((b_k, b_l) != (k, l), f"B shape mismatch: expected (*, {k}, {l}), got {(n, b_k, b_l)}")
+ self._check_tensor_shape(self.c_desc, (m, n, l), "C")
+ self._check_tensor_shape(self.d_desc, (m, n, l), "D")
+ self._check_tensor_shape(self.prob_desc, (m, 1, l), "prob")
+
+ rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA")
+ self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB")
+
+ if self.sfd_desc is not None:
+ rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfd_desc, (32, 4, ceil_div(m, 128), 4, rest_n, l), "SFD")
+
+ self._check_tensor_shape(self.amax_desc, (1,), "amax")
+ self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const")
+
+ self.ab_dtype = self._check_dtype(
+ self.a_desc,
+ dtype=[torch.float4_e2m1fn_x2, torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="A",
+ )
+ self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="A and B must have the same dtype")
+ self.c_dtype = self._check_dtype(
+ self.c_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="C",
+ )
+ self.d_dtype = self._check_dtype(
+ self.d_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="D",
+ )
+ self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob")
+
+ self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA")
+ self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must have the same dtype as SFA")
+ self._check_dtype(self.sfd_desc, dtype=self.sf_dtype, name="SFD", extra_error_msg="SFD must have the same dtype as SFA")
+
+ self._check_dtype(self.acc_dtype, dtype=torch.float32, name="Accumulator")
+
+ self._value_error_if(self.sf_vec_size not in {16, 32}, f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}")
+ self._value_error_if(
+ self._is_fp8(self.d_desc) and (self.sfd_desc is None or self.norm_const_desc is None), "sfd and norm_const are required when D is FP8"
+ )
+ self._value_error_if(
+ self._is_fp4x2(self.ab_dtype) and self.d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}, "FP4 input with FP8 output is not supported"
+ )
+
+ a_major = _major_from_stride_order(self.a_desc.stride_order, "m", "k")
+ b_major = _major_from_stride_order(self.b_desc.stride_order, "n", "k")
+ c_major = _major_from_stride_order(self.c_desc.stride_order, "m", "n")
+ d_major = _major_from_stride_order(self.d_desc.stride_order, "m", "n")
+ self._value_error_if(c_major != d_major, f"C and D must share the same layout, got {c_major} and {d_major}")
+
+ self._value_error_if(
+ self.mma_tiler_mn[0] not in {128, 256} or self.mma_tiler_mn[1] not in {64, 128, 192, 256},
+ f"Unsupported mma_tiler_mn {self.mma_tiler_mn}",
+ )
+ self._value_error_if(
+ not (
+ self.cluster_shape_mn[0] > 0
+ and self.cluster_shape_mn[1] > 0
+ and self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16
+ and is_power_of_2(self.cluster_shape_mn[0])
+ and is_power_of_2(self.cluster_shape_mn[1])
+ ),
+ f"Invalid cluster shape {self.cluster_shape_mn}",
+ )
+
+ self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available")
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
+ self._runtime_error_if(major * 10 + minor < 100, f"GemmSreluSm100 requires SM100+, found SM{major}{minor}")
+
+ self._value_error_if(
+ not self._kernel.can_implement(
+ ab_dtype=_convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2),
+ sf_dtype=_convert_to_cutlass_data_type(self.sf_dtype),
+ sf_vec_size=self.sf_vec_size,
+ d_dtype=_convert_to_cutlass_data_type(self.d_dtype),
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ m=m,
+ n=n,
+ k=k,
+ l=l,
+ a_major=a_major,
+ b_major=b_major,
+ d_major=d_major,
+ ),
+ "Unsupported configuration for dense sReLU kernel",
+ )
+
+ self._is_supported = True
+ return True
+
+ def compile(self) -> None:
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ return
+
+ gemm = self._kernel(
+ sf_vec_size=self.sf_vec_size,
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ vector_f32=self.vector_f32,
+ )
+
+ hardware_info = cutlass.utils.HardwareInfo()
+ max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1])
+ max_active_clusters -= self.num_cluster_overlap_margin
+ self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin")
+
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+ epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) ** 2
+ use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None
+ use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None
+
+ if use_dynamic_m:
+ valid_m = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride_order=self.a_desc.stride_order,
+ )
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride_order=self.c_desc.stride_order,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, *self.d_desc.shape[1:]),
+ stride_order=self.d_desc.stride_order,
+ )
+ prob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride_order=self.prob_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], self.sfa_desc.shape[5]),
+ stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
+ )
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+
+ sfd_cute_fake = None
+ if self.sfd_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_desc.shape[4], self.sfd_desc.shape[5]),
+ stride=(16, 4, self.sfd_desc.stride[2], 1, 512, stride_sfd_m),
+ )
+ elif use_full_dynamic:
+ valid_m = cute.sym_int()
+ n_sym = cute.sym_int()
+ k_sym = cute.sym_int()
+ l_sym = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, k_sym, l_sym),
+ stride_order=self.a_desc.stride_order,
+ dynamic_mode=self.a_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ b_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.b_desc.dtype,
+ shape=(n_sym, k_sym, l_sym),
+ stride_order=self.b_desc.stride_order,
+ dynamic_mode=self.b_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, n_sym, l_sym),
+ stride_order=self.c_desc.stride_order,
+ dynamic_mode=self.c_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, n_sym, l_sym),
+ stride_order=self.d_desc.stride_order,
+ dynamic_mode=self.d_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_desc.dtype) else 16,
+ )
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:-1], l_sym),
+ stride=(1, 1, valid_m),
+ )
+
+ tensor_m_128 = cute.sym_int()
+ rest_k = cute.sym_int()
+ stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_shape[4] = rest_k
+ sfa_shape[5] = l_sym
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[2] = stride_rest_k
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ )
+
+ tensor_n_128 = cute.sym_int()
+ stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfb_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfb_desc.dtype,
+ shape=(32, 4, tensor_n_128, 4, rest_k, l_sym),
+ stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k),
+ )
+
+ sfd_cute_fake = None
+ if self.sfd_desc is not None:
+ rest_n = cute.sym_int()
+ stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_shape = list(self.sfd_desc.shape)
+ sfd_shape[2] = tensor_m_128
+ sfd_shape[4] = rest_n
+ sfd_shape[5] = l_sym
+ sfd_stride = list(self.sfd_desc.stride)
+ sfd_stride[2] = stride_sfd_rest_n
+ sfd_stride[5] = stride_sfd_tensor_m_128
+ sfd_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_desc.dtype,
+ shape=tuple(sfd_shape),
+ stride=tuple(sfd_stride),
+ )
+ else:
+ a_cute_fake = self._make_fake_cute_tensor_from_desc(self.a_desc, assumed_align=16)
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_tensor_from_desc(self.c_desc, assumed_align=16)
+ d_cute_fake = self._make_fake_cute_tensor_from_desc(self.d_desc, assumed_align=16)
+ prob_cute_fake = self._make_fake_cute_tensor_from_desc(self.prob_desc, assumed_align=16)
+ sfa_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfa_desc, assumed_align=16)
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+ sfd_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfd_desc, assumed_align=16)
+
+ compiled = cute.compile(
+ gemm,
+ a_tensor=a_cute_fake,
+ b_tensor=b_cute_fake,
+ sfa_tensor=sfa_cute_fake,
+ sfb_tensor=sfb_cute_fake,
+ c_tensor=c_cute_fake,
+ d_tensor=d_cute_fake,
+ prob_tensor=prob_cute_fake,
+ amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16),
+ sfd_tensor=sfd_cute_fake,
+ norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16),
+ alpha=self.alpha,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ epilogue_op=epilogue_op,
+ options="--enable-tvm-ffi",
+ )
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ amax_tensor: Optional[torch.Tensor],
+ sfd_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ alpha: float,
+ stream: cuda.CUstream,
+ ) -> None:
+ compiled(
+ a_tensor,
+ b_tensor,
+ sfa_tensor,
+ sfb_tensor,
+ c_tensor,
+ d_tensor,
+ prob_tensor,
+ self._unpad_tensor_to_ndim(amax_tensor, 1, "amax"),
+ sfd_tensor,
+ self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const"),
+ alpha,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ sfd_tensor: Optional[torch.Tensor] = None,
+ amax_tensor: Optional[torch.Tensor] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ self._runtime_error_if(self._compiled_kernel is None, "GemmSreluSm100 kernel not compiled; call compile() first")
+ current_stream = self._get_default_stream(current_stream)
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ prob_tensor=prob_tensor,
+ amax_tensor=amax_tensor,
+ sfd_tensor=sfd_tensor,
+ norm_const_tensor=norm_const_tensor,
+ alpha=self.alpha if alpha is None else alpha,
+ stream=current_stream,
+ )
+
+
+_logger = logging.getLogger(__name__)
+_cache_of_GemmSreluSm100Objects = {}
+_DENSE_GEMM_DYNAMIC_M_ENV = "CUDNN_FE_GEMM_DYNAMIC_M"
+_DENSE_GEMM_DYNAMIC_MNKL_ENV = "CUDNN_FE_GEMM_DYNAMIC_MNKL"
+
+
+def _allocate_dense_output(shape: Tuple[int, int, int], major: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
+ m, n, l = shape
+ if major == "m":
+ return torch.empty_strided((m, n, l), (1, m, m * n), dtype=dtype, device=device)
+ if major == "n":
+ return torch.empty_strided((m, n, l), (n, 1, m * n), dtype=dtype, device=device)
+ raise ValueError(f"major must be 'm' or 'n', got {major}")
+
+
+def gemm_srelu_wrapper_sm100(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ alpha: float = 1.0,
+ c_major: str = "n",
+ c_dtype: torch.dtype = torch.bfloat16,
+ d_dtype: torch.dtype = torch.bfloat16,
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ m, k, l = a_tensor.shape
+ n, _, _ = b_tensor.shape
+
+ c_tensor = _allocate_dense_output((m, n, l), c_major, c_dtype, a_tensor.device)
+ d_tensor = _allocate_dense_output((m, n, l), c_major, d_dtype, a_tensor.device)
+
+ sfd_tensor = None
+ if d_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
+ sf_k = ceil_div(n, sf_vec_size)
+ mma_shape = (
+ l,
+ ceil_div(m, 128),
+ ceil_div(sf_k, 4),
+ 32,
+ 4,
+ 4,
+ )
+ sfd_tensor = torch.empty(mma_shape, dtype=sfa_tensor.dtype, device=a_tensor.device).permute(3, 4, 1, 5, 2, 0)
+
+ amax_tensor = None
+ if a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} and d_dtype in {torch.bfloat16, torch.float16, torch.float32}:
+ amax_tensor = torch.full((1,), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+
+ use_full_dynamic = os.environ.get(_DENSE_GEMM_DYNAMIC_MNKL_ENV) is not None
+ use_dynamic_m = not use_full_dynamic and os.environ.get(_DENSE_GEMM_DYNAMIC_M_ENV) is not None
+
+ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
+ return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
+
+ def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype
+
+ def dynamic_compact_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape[1:]), stride_order(tensor), tensor.dtype
+
+ def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return None, stride_order(tensor), tensor.dtype
+
+ def dynamic_m_tensor_signature(
+ tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = ()
+ ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
+ return static_shape_suffix, stride_signature, tensor.dtype
+
+ cache_key = (
+ use_full_dynamic,
+ use_dynamic_m,
+ *(dynamic_tensor_signature(a_tensor) if use_full_dynamic else dynamic_compact_signature(a_tensor) if use_dynamic_m else tensor_signature(a_tensor)),
+ *(dynamic_tensor_signature(b_tensor) if use_full_dynamic else tensor_signature(b_tensor)),
+ *(dynamic_tensor_signature(c_tensor) if use_full_dynamic else dynamic_compact_signature(c_tensor) if use_dynamic_m else tensor_signature(c_tensor)),
+ c_dtype,
+ d_dtype,
+ *(
+ dynamic_tensor_signature(sfa_tensor)
+ if use_full_dynamic
+ else (
+ dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], sfa_tensor.shape[5]), dynamic_stride_dims=(5,))
+ if use_dynamic_m
+ else tensor_signature(sfa_tensor)
+ )
+ ),
+ *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)),
+ *(
+ dynamic_tensor_signature(prob_tensor)
+ if use_full_dynamic
+ else dynamic_compact_signature(prob_tensor) if use_dynamic_m else tensor_signature(prob_tensor)
+ ),
+ norm_const_tensor.shape if norm_const_tensor is not None else None,
+ norm_const_tensor.stride() if norm_const_tensor is not None else None,
+ norm_const_tensor.dtype if norm_const_tensor is not None else None,
+ alpha,
+ acc_dtype,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ c_major,
+ sf_vec_size,
+ vector_f32,
+ )
+
+ op = _cache_of_GemmSreluSm100Objects.get(cache_key)
+ if op is None:
+ op = GemmSreluSm100(
+ sample_a=a_tensor,
+ sample_b=b_tensor,
+ sample_c=c_tensor,
+ sample_d=d_tensor,
+ sample_sfa=sfa_tensor,
+ sample_sfb=sfb_tensor,
+ sample_prob=prob_tensor,
+ sample_sfd=sfd_tensor,
+ sample_amax=amax_tensor,
+ sample_norm_const=norm_const_tensor,
+ alpha=alpha,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ )
+ assert op.check_support(), "Unsupported testcase"
+ op.compile()
+ _cache_of_GemmSreluSm100Objects[cache_key] = op
+
+ op.execute(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ prob_tensor=prob_tensor,
+ sfd_tensor=sfd_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ alpha=alpha,
+ current_stream=stream,
+ )
+
+ return TupleDict(
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ amax_tensor=amax_tensor,
+ sfd_tensor=sfd_tensor,
+ )
diff --git a/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py b/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py
new file mode 100644
index 00000000..cac870d6
--- /dev/null
+++ b/python/cudnn/gemm_srelu/dense_blockscaled_gemm_persistent_srelu_quant.py
@@ -0,0 +1,2094 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+from typing import Type, Tuple, Union, Optional
+
+import cuda.bindings.driver as cuda
+import torch
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu import cpasync, tcgen05
+import cutlass.utils as utils
+import cutlass.pipeline as pipeline
+from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
+import cutlass.utils.blackwell_helpers as sm100_utils
+import cutlass.utils.blockscaled_layout as blockscaled_utils
+
+import cutlass.cute.math as math
+from cutlass.cute.typing import Float32
+
+
+def get_divisibility(dtype) -> int:
+ """
+ Get the divisibility value based on the data type.
+ """
+ if dtype in [cutlass.Float4E2M1FN]:
+ divisibility = 32
+ elif dtype in [cutlass.Float8E4M3FN, cutlass.Float8E5M2, cutlass.Float8E8M0FNU]:
+ divisibility = 16
+ elif dtype in [cutlass.Float16, cutlass.BFloat16]:
+ divisibility = 8
+ elif dtype == cutlass.Float32:
+ divisibility = 4
+ else:
+ raise ValueError(f"Unsupported data type: {dtype}")
+ return divisibility
+
+
+class Sm100BlockScaledPersistentDenseGemmKernel:
+ """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
+ and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
+
+ :param sf_vec_size: Scalefactor vector size.
+ :type sf_vec_size: int
+ :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
+ :type cluster_shape_mn: Tuple[int, int]
+
+ :note: In current version, A and B tensor must have the same data type
+ - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
+
+ :note: Supported combinations of A/B data types, SF data typs and SF vector size:
+ - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32
+ - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16
+
+ :note: Supported accumulator data types:
+ - Float32
+
+ :note: Supported C data types:
+ - Float32
+ - Float16/BFloat16
+ - Float8E4M3FN/Float8E5M2
+ # {$nv-internal-release begin}
+ # Note: We don't have SFD generation support in this example for now, so Float4E2M1FN output is only for internal testing and will not be released.
+ - Float4E2M1FN
+ # {$nv-internal-release end}
+
+ :note: Constraints:
+ - MMA tiler M must be 128 or 256 (use_2cta_instrs)
+ - MMA tiler N must be 64/128/192/256
+ - Cluster shape M must be multiple of 2 if Mma tiler M is 256
+ - Cluster shape M/N must be positive and power of 2, total cluster size <= 16
+ - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
+
+ Example:
+ >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel(
+ ... sf_vec_size=16,
+ ... mma_tiler_mn=(256, 128),
+ ... cluster_shape_mn=(2, 1)
+ ... )
+ >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, d_tensor, prob_tensor, amax_tensor, sfd_tensor, norm_const_tensor, alpha, max_active_clusters, stream)
+ """
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vector_f32: bool,
+ ):
+ """Initializes the configuration for a Blackwell dense GEMM kernel.
+
+ This configuration includes several key aspects:
+
+ 1. MMA Instruction Settings (tcgen05):
+ - acc_dtype: Data types for MMA accumulator, always set to Float32
+ - sf_vec_size: Scalefactor A/B vector size.
+ - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler.
+
+ 2. Cluster Shape:
+ - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster.
+
+ :param sf_vec_size: Scalefactor vector size.
+ :type sf_vec_size: int
+ :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster.
+ :type cluster_shape_mn: Tuple[int, int]
+ """
+
+ self.acc_dtype = cutlass.Float32
+ self.sf_vec_size = sf_vec_size
+ self.use_2cta_instrs = mma_tiler_mn[0] == 256
+ self.cluster_shape_mn = cluster_shape_mn
+ # K dimension is deferred in _setup_attributes
+ self.mma_tiler = (*mma_tiler_mn, 1)
+
+ self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
+
+ self.occupancy = 1
+ # Set specialized warp ids
+ self.epilog_warp_id = (
+ 0,
+ 1,
+ 2,
+ 3,
+ )
+ self.mma_warp_id = 4
+ self.tma_warp_id = 5
+ self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id))
+ # Set barrier id for epilogue sync and tmem ptr sync
+ self.epilog_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=1,
+ num_threads=32 * len(self.epilog_warp_id),
+ )
+ self.tmem_alloc_barrier = pipeline.NamedBarrier(
+ barrier_id=2,
+ num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
+ )
+ self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
+ SM100_TMEM_CAPACITY_COLUMNS = 512
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
+
+ # Generate sfd output by 2xf32
+ self.vector_f32 = vector_f32
+
+ # Amax reduction configuration
+ self.num_epilog_warps = len(self.epilog_warp_id)
+
+ def _setup_attributes(self):
+ """Set up configurations that are dependent on GEMM inputs
+
+ This method configures various attributes based on the input tensor properties
+ (data types, leading dimensions) and kernel settings:
+ - Configuring tiled MMA
+ - Computing MMA/cluster/tile shapes
+ - Computing cluster layout
+ - Computing multicast CTAs for A/B/SFA/SFB
+ - Computing epilogue subtile
+ - Setting up A/B/SFA/SFB/C stage counts in shared memory
+ - Computing A/B/SFA/SFB/C shared memory layout
+ """
+ # Compute mma instruction shapes
+ # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
+ self.mma_inst_shape_mn = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ )
+ # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K)
+ self.mma_inst_shape_mn_sfb = (
+ self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
+ cute.round_up(self.mma_inst_shape_mn[1], 128),
+ )
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+
+ # Compute mma/cluster/tile shapes
+ mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
+ mma_inst_tile_k = 4
+ self.mma_tiler = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.mma_tiler_sfb = (
+ self.mma_inst_shape_mn_sfb[0],
+ self.mma_inst_shape_mn_sfb[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk = (
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler[1],
+ self.mma_tiler[2],
+ )
+
+ self.mma_tiler_c = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk_c = (
+ self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_c[1],
+ self.mma_tiler_c[2],
+ )
+
+ # Compute cluster layout
+ self.cluster_layout_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma.thr_id.shape,),
+ )
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma_sfb.thr_id.shape,),
+ )
+
+ # Compute number of multicast CTAs for A/B
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
+ self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
+ self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
+
+ # Always use subtile (128,32)
+ self.epi_tile = (cute.make_layout(128), cute.make_layout(32))
+ self.epi_tile_cnt = (
+ self.cta_tile_shape_mnk[0] // cute.size(self.epi_tile[0]),
+ self.cta_tile_shape_mnk[1] // cute.size(self.epi_tile[1]),
+ )
+ # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory
+ self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_d_stage = self._compute_stages(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.b_dtype,
+ self.epi_tile,
+ self.c_dtype,
+ self.c_layout,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.d_dtype,
+ self.d_layout,
+ self.smem_capacity,
+ self.occupancy,
+ )
+
+ # Compute A/B/SFA/SFB/C shared memory layout
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.num_ab_stage,
+ )
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ self.mma_tiler,
+ self.b_dtype,
+ self.num_ab_stage,
+ )
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.c_dtype,
+ self.c_layout,
+ self.epi_tile,
+ self.num_c_stage,
+ )
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.d_dtype,
+ self.d_layout,
+ self.epi_tile,
+ self.num_d_stage,
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ a_tensor: cute.Tensor,
+ b_tensor: cute.Tensor,
+ sfa_tensor: cute.Tensor,
+ sfb_tensor: cute.Tensor,
+ c_tensor: cute.Tensor,
+ d_tensor: cute.Tensor,
+ prob_tensor: cute.Tensor,
+ amax_tensor: Optional[cute.Tensor],
+ sfd_tensor: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ alpha: cutlass.Float32,
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ epilogue_op: cutlass.Constexpr = lambda x: x,
+ ):
+ """Execute the GEMM operation in steps:
+ - Setup static attributes before smem/grid/tma computation
+ - Setup TMA load/store atoms and tensors
+ - Compute grid size with regard to hardware constraints
+ - Define shared storage for kernel
+ - Launch the kernel synchronously
+
+ :param a_tensor: Input tensor A
+ :type a_tensor: cute.Tensor
+ :param b_tensor: Input tensor B
+ :type b_tensor: cute.Tensor
+ :param sfa_tensor: Scale factor tensor A
+ :type sfa_tensor: cute.Tensor
+ :param sfb_tensor: Scale factor tensor B
+ :type sfb_tensor: cute.Tensor
+ :param c_tensor: Output tensor C
+ :type c_tensor: cute.Tensor
+ :param d_tensor: Input tensor D
+ :type d_tensor: cute.Tensor
+ :param sfd_tensor: Scale factor tensor C
+ :type sfd_tensor: cute.Tensor
+ :param norm_const_tensor: Normalization constant tensor for quantization
+ :type norm_const_tensor: cute.Tensor
+ :param amax_tensor: Output tensor for absolute maximum value
+ :type amax_tensor: cute.Tensor
+ :param max_active_clusters: Maximum number of active clusters
+ :type max_active_clusters: cutlass.Constexpr
+ :param stream: CUDA stream for asynchronous execution
+ :type stream: cuda.CUstream
+ :param epilogue_op: Optional elementwise lambda function to apply to the output tensor
+ :type epilogue_op: cutlass.Constexpr
+ :raises TypeError: If input data types are incompatible with the MMA instruction.
+ """
+ # Setup static attributes before smem/grid/tma computation
+ self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type
+ self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type
+ self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type
+ self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type
+ self.d_dtype: Type[cutlass.Numeric] = d_tensor.element_type
+ self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode()
+ self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode()
+ self.c_layout = utils.LayoutEnum.from_tensor(c_tensor)
+ self.d_layout = utils.LayoutEnum.from_tensor(d_tensor)
+
+ # Check if input data types are compatible with MMA instruction
+ if cutlass.const_expr(self.a_dtype != self.b_dtype):
+ raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
+
+ # Setup attributes that dependent on gemm inputs
+ self._setup_attributes()
+
+ # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout
+ # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL)
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size)
+ sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout)
+
+ # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size)
+ sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout)
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+
+ # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release}
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
+
+ # Setup TMA load for A
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
+ a_op,
+ a_tensor,
+ a_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # Setup TMA load for B
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
+ b_op,
+ b_tensor,
+ b_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # Setup TMA load for SFA
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
+ sfa_op,
+ sfa_tensor,
+ sfa_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ # Setup TMA load for SFB
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_op,
+ sfb_tensor,
+ sfb_smem_layout,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb,
+ self.cluster_layout_sfb_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ # {$nv-internal-release begin}
+ # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF))
+ # logical blocks for SFB when cta_tile_shape_n=192.
+ # {$nv-internal-release end}
+ if cutlass.const_expr(self.cta_tile_shape_mnk_c[1] == 192):
+ x = tma_tensor_sfb.stride[0][1]
+ y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
+
+ new_shape = (
+ (tma_tensor_sfb.shape[0][0], ((2, 2), y)),
+ tma_tensor_sfb.shape[1],
+ tma_tensor_sfb.shape[2],
+ )
+ # Use right multiplication for ScaledBasis (3 * x instead of x * 3)
+ x_times_3 = 3 * x
+ new_stride = (
+ (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
+ tma_tensor_sfb.stride[1],
+ tma_tensor_sfb.stride[2],
+ )
+ tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride)
+ tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout)
+
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
+
+ # Setup TMA store for C
+ epi_c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ c_tensor,
+ epi_c_smem_layout,
+ self.epi_tile,
+ )
+ epi_d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d_tensor,
+ epi_d_smem_layout,
+ self.epi_tile,
+ )
+
+ # Compute grid size
+ self.tile_sched_params, grid = self._compute_grid(
+ c_tensor,
+ self.cta_tile_shape_mnk_c,
+ self.cluster_shape_mn,
+ max_active_clusters,
+ )
+
+ self.buffer_align_bytes = 1024
+
+ self.generate_sfd = sfd_tensor is not None and norm_const_tensor is not None
+ if cutlass.const_expr(self.generate_sfd):
+ sfd_layout = blockscaled_utils.tile_atom_to_shape_SF(c_tensor.shape, self.sf_vec_size)
+ sfd_tensor = cute.make_tensor(sfd_tensor.iterator, sfd_layout)
+
+ self.generate_amax = amax_tensor is not None
+
+ # Define shared storage for kernel
+ @cute.struct
+ class SharedStorage:
+ ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
+ ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
+ acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
+ acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
+ tmem_dealloc_mbar_ptr: cutlass.Int64
+ tmem_holding_buf: cutlass.Int32
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sC: cute.struct.Align[
+ cute.struct.MemRange[
+ self.c_dtype,
+ cute.cosize(self.c_smem_layout_staged.outer),
+ ],
+ self.buffer_align_bytes,
+ ]
+ sD: cute.struct.Align[
+ cute.struct.MemRange[
+ self.d_dtype,
+ cute.cosize(self.d_smem_layout_staged.outer),
+ ],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sA: cute.struct.Align[
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sB: cute.struct.Align[
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sSFA: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sSFB: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ # Amax reduction shared memory (one FP32 per epilogue warp)
+ # Use smaller alignment for amax since it's only 16 bytes
+ sAmax: cute.struct.Align[
+ cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps],
+ self.buffer_align_bytes,
+ ]
+
+ self.shared_storage = SharedStorage
+
+ # Launch the kernel synchronously
+ self.kernel(
+ tiled_mma,
+ tiled_mma_sfb,
+ tma_atom_a,
+ tma_tensor_a,
+ tma_atom_b,
+ tma_tensor_b,
+ tma_atom_sfa,
+ tma_tensor_sfa,
+ tma_atom_sfb,
+ tma_tensor_sfb,
+ tma_atom_c,
+ tma_tensor_c,
+ tma_atom_d,
+ tma_tensor_d,
+ prob_tensor,
+ amax_tensor,
+ sfd_tensor,
+ norm_const_tensor,
+ self.cluster_layout_vmnk,
+ self.cluster_layout_sfb_vmnk,
+ self.a_smem_layout_staged,
+ self.b_smem_layout_staged,
+ self.sfa_smem_layout_staged,
+ self.sfb_smem_layout_staged,
+ self.c_smem_layout_staged,
+ self.d_smem_layout_staged,
+ self.epi_tile,
+ self.tile_sched_params,
+ epilogue_op,
+ alpha,
+ ).launch(
+ grid=grid,
+ block=[self.threads_per_cta, 1, 1],
+ cluster=(*self.cluster_shape_mn, 1),
+ stream=stream,
+ )
+ return
+
+ # GPU device kernel
+ @cute.kernel
+ def kernel(
+ self,
+ tiled_mma: cute.TiledMma,
+ tiled_mma_sfb: cute.TiledMma,
+ tma_atom_a: cute.CopyAtom,
+ mA_mkl: cute.Tensor,
+ tma_atom_b: cute.CopyAtom,
+ mB_nkl: cute.Tensor,
+ tma_atom_sfa: cute.CopyAtom,
+ mSFA_mkl: cute.Tensor,
+ tma_atom_sfb: cute.CopyAtom,
+ mSFB_nkl: cute.Tensor,
+ tma_atom_c: cute.CopyAtom,
+ mC_mnl: cute.Tensor,
+ tma_atom_d: cute.CopyAtom,
+ mD_mnl: cute.Tensor,
+ mProb_mnl: cute.Tensor,
+ mAmax_tensor: Optional[cute.Tensor],
+ mSFD_mnl: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ cluster_layout_vmnk: cute.Layout,
+ cluster_layout_sfb_vmnk: cute.Layout,
+ a_smem_layout_staged: cute.ComposedLayout,
+ b_smem_layout_staged: cute.ComposedLayout,
+ sfa_smem_layout_staged: cute.Layout,
+ sfb_smem_layout_staged: cute.Layout,
+ c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
+ epi_tile: cute.Tile,
+ tile_sched_params: utils.PersistentTileSchedulerParams,
+ epilogue_op: cutlass.Constexpr,
+ alpha: cutlass.Float32,
+ ):
+ """
+ GPU device kernel performing the Persistent batched GEMM computation.
+ """
+ warp_idx = cute.arch.warp_idx()
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
+
+ #
+ # Prefetch tma desc
+ #
+ if warp_idx == self.tma_warp_id:
+ cpasync.prefetch_descriptor(tma_atom_a)
+ cpasync.prefetch_descriptor(tma_atom_b)
+ cpasync.prefetch_descriptor(tma_atom_sfa)
+ cpasync.prefetch_descriptor(tma_atom_sfb)
+ cpasync.prefetch_descriptor(tma_atom_c)
+ cpasync.prefetch_descriptor(tma_atom_d)
+
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
+
+ #
+ # Setup cta/thread coordinates
+ #
+ # Coords inside cluster
+ bidx, bidy, bidz = cute.arch.block_idx()
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
+ is_leader_cta = mma_tile_coord_v == 0
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
+ # Coord inside cta
+ tidx, _, _ = cute.arch.thread_idx()
+
+ #
+ # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier
+ #
+ smem = utils.SmemAllocator()
+ storage = smem.allocate(self.shared_storage)
+
+ # Initialize mainloop ab_pipeline (barrier) and states
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
+ barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_ab_stage,
+ producer_group=ab_pipeline_producer_group,
+ consumer_group=ab_pipeline_consumer_group,
+ tx_count=self.num_tma_load_bytes,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ defer_sync=True,
+ )
+
+ # Initialize acc_pipeline (barrier) and states
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_acc_stage,
+ producer_group=acc_pipeline_producer_group,
+ consumer_group=acc_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ defer_sync=True,
+ )
+
+ # Tensor memory dealloc barrier init
+ tmem = utils.TmemAllocator(
+ storage.tmem_holding_buf,
+ barrier_for_retrieve=self.tmem_alloc_barrier,
+ allocator_warp_id=self.epilog_warp_id[0],
+ is_two_cta=use_2cta_instrs,
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
+ )
+
+ # Cluster arrive after barrier init
+ pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
+
+ #
+ # Setup smem tensor A/B/SFA/SFB/C
+ #
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
+ # (MMA, MMA_M, MMA_K, STAGE)
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sC = storage.sC.get_tensor(
+ c_smem_layout_staged.outer,
+ swizzle=c_smem_layout_staged.inner,
+ dtype=self.c_dtype,
+ )
+ # (EPI_TILE_M, EPI_TILE_N, STAGE)
+ sD = storage.sD.get_tensor(
+ d_smem_layout_staged.outer,
+ swizzle=d_smem_layout_staged.inner,
+ dtype=self.d_dtype,
+ )
+
+ # Shared memory for amax reduction (one FP32 per epilogue warp)
+ # Simple 1D layout. The allocation always here if no amax is generated,
+ # as the overhead is minimal and we want to keep the code simple.
+ amax_layout = cute.make_layout((self.num_epilog_warps,))
+ sAmax = storage.sAmax.get_tensor(amax_layout)
+
+ #
+ # Compute multicast mask for A/B/SFA/SFB buffer full
+ #
+ a_full_mcast_mask = None
+ b_full_mcast_mask = None
+ sfa_full_mcast_mask = None
+ sfb_full_mcast_mask = None
+ if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
+ a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1)
+ sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1)
+
+ #
+ # Local_tile partition global tensors
+ #
+ # (bM, bK, RestM, RestK, RestL)
+ gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ # (bN, bK, RestN, RestK, RestL)
+ gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
+ # (bM, bK, RestM, RestK, RestL)
+ gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ # (bN, bK, RestN, RestK, RestL)
+ gSFB_nkl = cute.local_tile(
+ mSFB_nkl,
+ cute.slice_(self.mma_tiler_sfb, (0, None, None)),
+ (None, None, None),
+ )
+ # (bM, bN, RestM, RestN, RestL)
+ gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None))
+ # (bM, bN, RestM, RestN, RestL)
+ gD_mnl = cute.local_tile(
+ mD_mnl,
+ cute.slice_(self.mma_tiler_c, (None, None, 0)),
+ (None, None, None),
+ )
+ k_tile_cnt = cute.size(gA_mkl, mode=[3])
+
+ #
+ # Partition global tensor for TiledMMA_A/B/C
+ #
+ thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
+ thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ # (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
+ tCgA = thr_mma.partition_A(gA_mkl)
+ # (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
+ tCgB = thr_mma.partition_B(gB_nkl)
+ # (MMA, MMA_M, MMA_K, RestM, RestK, RestL)
+ tCgSFA = thr_mma.partition_A(gSFA_mkl)
+ # (MMA, MMA_N, MMA_K, RestN, RestK, RestL)
+ tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
+ # (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
+ tCgC = thr_mma.partition_C(gC_mnl)
+ # (MMA, MMA_M, MMA_N, RestM, RestN, RestL)
+ tCgD = thr_mma.partition_C(gD_mnl)
+
+ #
+ # Partition global/shared tensor for TMA load A/B
+ #
+ # TMA load A partition_S/D
+ a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestM, RestK, RestL)
+ tAsA, tAgA = cpasync.tma_partition(
+ tma_atom_a,
+ block_in_cluster_coord_vmnk[2],
+ a_cta_layout,
+ cute.group_modes(sA, 0, 3),
+ cute.group_modes(tCgA, 0, 3),
+ )
+ # TMA load B partition_S/D
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestN, RestK, RestL)
+ tBsB, tBgB = cpasync.tma_partition(
+ tma_atom_b,
+ block_in_cluster_coord_vmnk[1],
+ b_cta_layout,
+ cute.group_modes(sB, 0, 3),
+ cute.group_modes(tCgB, 0, 3),
+ )
+
+ # TMA load SFA partition_S/D
+ sfa_cta_layout = a_cta_layout
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestM, RestK, RestL)
+ tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition(
+ tma_atom_sfa,
+ block_in_cluster_coord_vmnk[2],
+ sfa_cta_layout,
+ cute.group_modes(sSFA, 0, 3),
+ cute.group_modes(tCgSFA, 0, 3),
+ )
+ tAsSFA = cute.filter_zeros(tAsSFA)
+ tAgSFA = cute.filter_zeros(tAgSFA)
+
+ # TMA load SFB partition_S/D
+ sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
+ # ((atom_v, rest_v), STAGE)
+ # ((atom_v, rest_v), RestN, RestK, RestL)
+ tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition(
+ tma_atom_sfb,
+ block_in_cluster_coord_sfb_vmnk[1],
+ sfb_cta_layout,
+ cute.group_modes(sSFB, 0, 3),
+ cute.group_modes(tCgSFB, 0, 3),
+ )
+ tBsSFB = cute.filter_zeros(tBsSFB)
+ tBgSFB = cute.filter_zeros(tBgSFB)
+
+ #
+ # Partition shared/tensor memory tensor for TiledMMA_A/B/C
+ #
+ # (MMA, MMA_M, MMA_K, STAGE)
+ tCrA = tiled_mma.make_fragment_A(sA)
+ # (MMA, MMA_N, MMA_K, STAGE)
+ tCrB = tiled_mma.make_fragment_B(sB)
+ # (MMA, MMA_M, MMA_N)
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
+
+ #
+ # Cluster wait before tensor memory alloc
+ #
+ pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
+
+ #
+ # Specialized TMA load warp
+ #
+ if warp_idx == self.tma_warp_id:
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ #
+ # Slice to per mma tile index
+ #
+ # ((atom_v, rest_v), RestK)
+ tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
+ # ((atom_v, rest_v), RestK)
+ tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])]
+
+ # ((atom_v, rest_v), RestK)
+ tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])]
+
+ # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release}
+ slice_n = mma_tile_coord_mnl[1]
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ slice_n = mma_tile_coord_mnl[1] // 2
+ # ((atom_v, rest_v), RestK)
+ tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])]
+
+ # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt
+ ab_producer_state.reset_count()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ #
+ # Tma load loop
+ #
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ # Conditionally wait for AB buffer empty
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
+
+ # TMA load A/B/SFA/SFB
+ cute.copy(
+ tma_atom_a,
+ tAgA_slice[(None, ab_producer_state.count)],
+ tAsA[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=a_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_b,
+ tBgB_slice[(None, ab_producer_state.count)],
+ tBsB[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=b_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_sfa,
+ tAgSFA_slice[(None, ab_producer_state.count)],
+ tAsSFA[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=sfa_full_mcast_mask,
+ )
+ cute.copy(
+ tma_atom_sfb,
+ tBgSFB_slice[(None, ab_producer_state.count)],
+ tBsSFB[(None, ab_producer_state.index)],
+ tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state),
+ mcast_mask=sfb_full_mcast_mask,
+ )
+
+ # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1
+ ab_producer_state.advance()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Wait A/B buffer empty
+ #
+ ab_pipeline.producer_tail(ab_producer_state)
+
+ #
+ # Specialized MMA warp
+ #
+ if warp_idx == self.mma_warp_id:
+ #
+ # Bar sync for retrieve tensor memory ptr from shared mem
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor
+ #
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ # Make accumulator tmem tensor
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ # Make SFA tmem tensor
+ sfa_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base),
+ dtype=self.sf_dtype,
+ )
+ # (MMA, MMA_M, MMA_K)
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
+
+ # Make SFB tmem tensor
+ sfb_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA),
+ dtype=self.sf_dtype,
+ )
+ # (MMA, MMA_N, MMA_K)
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
+ #
+ # Partition for S2T copy of SFA/SFB
+ #
+ (
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t,
+ tCtSFA_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
+ (
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t,
+ tCtSFB_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
+
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
+ acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ # Set tensor memory buffer for current tile
+ # (MMA, MMA_M, MMA_N)
+ tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
+
+ # Peek (try_wait) AB buffer full for k_tile = 0
+ ab_consumer_state.reset_count()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ #
+ # Wait for accumulator buffer empty
+ #
+ if is_leader_cta:
+ acc_pipeline.producer_acquire(acc_producer_state)
+
+ # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or cta_tile_shape_n=64 # {$nv-internal-release}
+ tCtSFB_mma = tCtSFB
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB)
+ offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+ elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ # Move in increments of 64 columns of SFB
+ offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+
+ #
+ # Reset the ACCUMULATE field for each tile
+ #
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
+
+ #
+ # Mma mainloop
+ #
+ for k_tile in range(k_tile_cnt):
+ if is_leader_cta:
+ # Conditionally wait for AB buffer full
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
+
+ # Copy SFA/SFB from smem to tmem
+ s2t_stage_coord = (
+ None,
+ None,
+ None,
+ None,
+ ab_consumer_state.index,
+ )
+ tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
+ tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
+ cute.copy(
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t_staged,
+ tCtSFA_compact_s2t,
+ )
+ cute.copy(
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t_staged,
+ tCtSFB_compact_s2t,
+ )
+
+ # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB
+ num_kblocks = cute.size(tCrA, mode=[2])
+ for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
+ kblock_coord = (
+ None,
+ None,
+ kblock_idx,
+ ab_consumer_state.index,
+ )
+
+ # Set SFA/SFB tensor to tiled_mma
+ sf_kblock_coord = (None, None, kblock_idx)
+ tiled_mma.set(
+ tcgen05.Field.SFA,
+ tCtSFA[sf_kblock_coord].iterator,
+ )
+ tiled_mma.set(
+ tcgen05.Field.SFB,
+ tCtSFB_mma[sf_kblock_coord].iterator,
+ )
+
+ cute.gemm(
+ tiled_mma,
+ tCtAcc,
+ tCrA[kblock_coord],
+ tCrB[kblock_coord],
+ tCtAcc,
+ )
+
+ # Enable accumulate on tCtAcc after first kblock
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
+
+ # Async arrive AB buffer empty
+ ab_pipeline.consumer_release(ab_consumer_state)
+
+ # Peek (try_wait) AB buffer full for k_tile = k_tile + 1
+ ab_consumer_state.advance()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt:
+ if is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ #
+ # Async arrive accumulator buffer full
+ #
+ if is_leader_cta:
+ acc_pipeline.producer_commit(acc_producer_state)
+ acc_producer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Wait for accumulator buffer empty
+ #
+ acc_pipeline.producer_tail(acc_producer_state)
+ #
+ # Specialized epilogue warps
+ #
+ if warp_idx < self.mma_warp_id:
+ #
+ # Alloc tensor memory buffer
+ #
+ tmem.allocate(self.num_tmem_alloc_cols)
+
+ #
+ # Bar sync for retrieve tensor memory ptr from shared memory
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr and make accumulator tensor
+ #
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ #
+ # Partition for epilogue
+ #
+ epi_tidx = tidx
+ (
+ tiled_copy_t2r,
+ tTR_tAcc_base,
+ tTR_rAcc,
+ ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs)
+
+ tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
+ tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC)
+ (
+ bSG_sC,
+ bSG_gC_mnl,
+ ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC)
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ _, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD)
+ (
+ bSG_sD,
+ bSG_gD_mnl,
+ ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD, epi_tile, sD)
+
+ #
+ # Persistent tile scheduling loop
+ #
+ tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim())
+ work_tile = tile_sched.initial_work_tile_info()
+
+ acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
+
+ # Threads/warps participating in tma store pipeline
+ c_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ 32 * len(self.epilog_warp_id),
+ )
+ c_pipeline = pipeline.PipelineTmaStore.create(
+ num_stages=self.num_c_stage,
+ producer_group=c_producer_group,
+ )
+ d_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ 32 * len(self.epilog_warp_id),
+ )
+ d_pipeline = pipeline.PipelineTmaStore.create(
+ num_stages=self.num_d_stage,
+ producer_group=d_producer_group,
+ )
+
+ while work_tile.is_valid_tile:
+ # Get tile coord from tile scheduler
+ cur_tile_coord = work_tile.tile_idx
+ mma_tile_coord_mnl = (
+ cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
+ cur_tile_coord[1],
+ cur_tile_coord[2],
+ )
+
+ #
+ # Slice to per mma tile index
+ #
+ # ((ATOM_V, REST_V), EPI_M, EPI_N)
+ bSG_gC = bSG_gC_mnl[
+ (
+ None,
+ None,
+ None,
+ *mma_tile_coord_mnl,
+ )
+ ]
+ bSG_gD = bSG_gD_mnl[
+ (
+ None,
+ None,
+ None,
+ *mma_tile_coord_mnl,
+ )
+ ]
+ # Set tensor memory buffer for current tile
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
+
+ # Initialize thread-local amax accumulator for this tile
+ # Use 0.0 as initial value since we're computing absolute maximum
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = cutlass.Float32(0.0)
+
+ #
+ # Wait for accumulator buffer full
+ #
+ acc_pipeline.consumer_wait(acc_consumer_state)
+
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+ bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
+
+ #
+ # Store accumulator to global memory in subtiles
+ #
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8))
+ num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
+ #
+ # Get PROB
+ # Note, it always assumes T2R_M/EPI_M is 1, otherwise it will break the result.
+ #
+ mPosition = cur_tile_coord[0] * self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape) + tidx
+ mProb = mProb_mnl[mPosition, 0, cur_tile_coord[2]]
+ for subtile_idx in cutlass.range(0, subtile_cnt, 1):
+ #
+ # Load accumulator from tensor memory buffer to register
+ #
+ tTR_tAcc_subtile = tTR_tAcc[(None, None, None, subtile_idx)]
+ cute.copy(tiled_copy_t2r, tTR_tAcc_subtile, tTR_rAcc)
+
+ #
+ # Convert to C type
+ #
+ acc_vec = tiled_copy_r2s.retile(tTR_rAcc)
+
+ #
+ # Apply alpha
+ #
+ acc_vec_ = acc_vec.load()
+ acc_vec.store(acc_vec_ * alpha)
+
+ #
+ # Store C to shared memory for bprop
+ #
+ c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
+ # Convert to d_dtype before storing
+ # Load, convert type, and store back to temporary register tensor
+ # tTR_rD_up = cute.make_rmem_tensor(tTR_tAcc_mn_up.shape, self.d_dtype)
+ # tTR_rD_gate = cute.make_rmem_tensor(tTR_tAcc_mn_gate.shape, self.d_dtype)
+ tRS_rC.store(acc_vec.load().to(self.c_dtype))
+ cute.copy(
+ tiled_copy_r2s,
+ tRS_rC[(None, None, 0)],
+ tRS_sC[(None, None, 0, c_buffer)],
+ )
+
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ #
+ # TMA store C to global memory
+ #
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(
+ tma_atom_c,
+ bSG_sC[(None, c_buffer)], # ((8192, 1), (1, 4)), ((1, 0), (0, 8192))
+ bSG_gC[(None, subtile_idx)], # (((64, 128), 1), (1, 4)) : (((1@0,1@1),0),(0,64@0))
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ c_pipeline.producer_commit()
+ c_pipeline.producer_acquire()
+
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ #
+ # Generate amax
+ #
+ if cutlass.const_expr(self.generate_amax):
+ acc_values = acc_vec.load()
+ acc_values = epilogue_op(acc_values)
+ acc_values = acc_values * mProb
+ acc_vec.store(acc_values)
+
+ # Apply element-wise absolute value using math.absf (supports vectors)
+ abs_acc_values_ir = cutlass._mlir.dialects.math.absf(acc_values.ir_value()) # operand (positional)
+ abs_acc_values = type(acc_values)(abs_acc_values_ir, acc_values.shape, acc_values.dtype)
+ subtile_amax = abs_acc_values.reduce(
+ cute.ReductionOp.MAX,
+ cutlass.Float32(0.0),
+ 0, # Use 0.0 as init for abs values
+ )
+ thread_tile_amax = cute.arch.fmax(thread_tile_amax, subtile_amax)
+
+ #
+ # Generate sfd
+ #
+ if cutlass.const_expr(self.generate_sfd):
+ cute.printf("SFD not implemented\n")
+ else:
+ #
+ # Convert to D type directly
+ #
+ acc_vec = tiled_copy_r2s.retile(acc_vec).load()
+ tRS_rD.store(acc_vec.to(self.d_dtype))
+
+ #
+ # Store D to shared memory
+ #
+ d_buffer = (num_prev_subtiles + subtile_idx) % self.num_d_stage
+ cute.copy(
+ tiled_copy_r2s,
+ tRS_rD,
+ tRS_sD[(None, None, None, d_buffer)],
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ #
+ # TMA store D to global memory
+ #
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(
+ tma_atom_d,
+ bSG_sD[(None, d_buffer)],
+ bSG_gD[(None, subtile_idx)],
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA store
+ d_pipeline.producer_commit()
+ d_pipeline.producer_acquire()
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ # Perform amax reduction after all subtiles are processed
+ if cutlass.const_expr(self.generate_amax):
+ # Warp-level reduction using wrapper function
+ warp_amax = cute.arch.warp_redux_sync(
+ value=thread_tile_amax,
+ kind="fmax",
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ # Each epilogue warp's lane 0 writes warp amax to shared memory
+ if cute.arch.lane_idx() == 0:
+ sAmax[warp_idx] = cutlass.Float32(warp_amax)
+
+ # Ensure all epilogue warps complete their writes before block reduction
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ # Block-level reduction: only first epilogue warp's lane 0 handles this
+ if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0:
+ block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum
+ for i in cutlass.range(self.num_epilog_warps):
+ warp_amax_val = sAmax[i]
+ block_amax = cute.arch.fmax(block_amax, warp_amax_val)
+
+ # Global atomic max (accumulates across all tiles for final tensor amax)
+ # Since we compute absolute values, all values are non-negative
+ # Use wrapper function for atomic max operation
+ _ = cute.arch.atomic_max_float32(ptr=mAmax_tensor.iterator.llvm_ptr, value=block_amax)
+
+ #
+ # Async arrive accumulator buffer empty
+ #
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_sched.advance_to_next_work()
+ work_tile = tile_sched.get_current_work()
+
+ #
+ # Dealloc the tensor memory buffer
+ #
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier.arrive_and_wait()
+ tmem.free(acc_tmem_ptr)
+ #
+ # Wait for C store complete
+ #
+ c_pipeline.producer_tail()
+ d_pipeline.producer_tail()
+
+ def mainloop_s2t_copy_and_partition(
+ self,
+ sSF: cute.Tensor,
+ tSF: cute.Tensor,
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination).
+
+ :param sSF: The scale factor tensor in smem
+ :type sSF: cute.Tensor
+ :param tSF: The scale factor tensor in tmem
+ :type tSF: cute.Tensor
+
+ :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where:
+ - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t)
+ - tCsSF_compact_s2t: The partitioned scale factor tensor in smem
+ - tSF_compact_s2t: The partitioned scale factor tensor in tmem
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
+ """
+ # (MMA, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact = cute.filter_zeros(sSF)
+ # (MMA, MMA_MN, MMA_K)
+ tCtSF_compact = cute.filter_zeros(tSF)
+
+ # Make S2T CopyAtom and tiledCopy
+ copy_atom_s2t = cute.make_copy_atom(
+ tcgen05.Cp4x32x128bOp(self.cta_group),
+ self.sf_dtype,
+ )
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
+
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
+
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
+
+ def epilog_tmem_copy_and_partition(
+ self,
+ tidx: cutlass.Int32,
+ tAcc: cute.Tensor,
+ gD_mnl: cute.Tensor,
+ epi_tile: cute.Tile,
+ use_2cta_instrs: Union[cutlass.Boolean, bool],
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
+
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param tAcc: The accumulator tensor to be copied and partitioned
+ :type tAcc: cute.Tensor
+ :param gD_mnl: The global tensor C
+ :type gD_mnl: cute.Tensor
+ :param epi_tile: The epilogue tiler
+ :type epi_tile: cute.Tile
+ :param use_2cta_instrs: Whether use_2cta_instrs is enabled
+ :type use_2cta_instrs: bool
+
+ :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where:
+ - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
+ - tTR_tAcc: The partitioned accumulator tensor
+ - tTR_rAcc: The partitioned accumulator tensor for acc up
+ - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]
+ """
+ # Make tiledCopy for tensor memory load
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
+ self.cta_tile_shape_mnk,
+ self.c_layout,
+ self.c_dtype,
+ self.acc_dtype,
+ epi_tile,
+ use_2cta_instrs,
+ )
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
+ tAcc_epi = cute.flat_divide(
+ tAcc[((None, None), 0, 0, None)],
+ epi_tile,
+ )
+ # (EPI_TILE_M, EPI_TILE_N)
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
+
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ tTR_gD = thr_copy_t2r.partition_D(gD_mnl_epi)
+ # (T2R, T2R_M, T2R_N)
+ tTR_rAcc = cute.make_rmem_tensor(tTR_gD[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
+
+ def epilog_smem_copy_and_partition(
+ self,
+ tiled_copy_t2r: cute.TiledCopy,
+ tTR_rC: cute.Tensor,
+ tidx: cutlass.Int32,
+ sC: cute.Tensor,
+ ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
+ """
+ Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination).
+
+ :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r)
+ :type tiled_copy_t2r: cute.TiledCopy
+ :param tTR_rC: The partitioned accumulator tensor
+ :type tTR_rC: cute.Tensor
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param sC: The shared memory tensor to be copied and partitioned
+ :type sC: cute.Tensor
+ :type sepi: cute.Tensor
+
+ :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where:
+ - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s)
+ - tRS_rC: The partitioned tensor C (register source)
+ - tRS_sC: The partitioned tensor C (smem destination)
+ :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]
+ """
+ copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r)
+ tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
+ # (R2S, R2S_M, R2S_N, PIPE_D)
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
+ tRS_sC = thr_copy_r2s.partition_D(sC)
+ # (R2S, R2S_M, R2S_N)
+ tRS_rC = tiled_copy_r2s.retile(tTR_rC)
+ return tiled_copy_r2s, tRS_rC, tRS_sC
+
+ def epilog_gmem_copy_and_partition(
+ self,
+ tidx: cutlass.Int32,
+ atom: Union[cute.CopyAtom, cute.TiledCopy],
+ gC_mnl: cute.Tensor,
+ epi_tile: cute.Tile,
+ sC: cute.Tensor,
+ ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]:
+ """Make tiledCopy for global memory store, then use it to:
+ partition shared memory (source) and global memory (destination) for TMA store version.
+
+ :param tidx: The thread index in epilogue warp groups
+ :type tidx: cutlass.Int32
+ :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version
+ :type atom: cute.CopyAtom or cute.TiledCopy
+ :param gC_mnl: The global tensor C
+ :type gC_mnl: cute.Tensor
+ :param epi_tile: The epilogue tiler
+ :type epi_tile: cute.Tile
+ :param sC: The shared memory tensor to be copied and partitioned
+ :type sC: cute.Tensor
+
+ :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where:
+ - tma_atom_c: The TMA copy atom
+ - bSG_sC: The partitioned shared memory tensor C
+ - bSG_gC: The partitioned global tensor C
+ :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]
+ """
+ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL)
+ gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ # ((ATOM_V, REST_V), EPI_M, EPI_N)
+ # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL)
+ bSG_sC, bSG_gC = cpasync.tma_partition(
+ atom,
+ 0,
+ cute.make_layout(1),
+ cute.group_modes(sC, 0, 2),
+ cute.group_modes(gC_epi, 0, 2),
+ )
+ return bSG_sC, bSG_gC
+
+ @staticmethod
+ def _compute_stages(
+ tiled_mma: cute.TiledMma,
+ mma_tiler_mnk: Tuple[int, int, int],
+ a_dtype: Type[cutlass.Numeric],
+ b_dtype: Type[cutlass.Numeric],
+ epi_tile: cute.Tile,
+ c_dtype: Type[cutlass.Numeric],
+ c_layout: utils.LayoutEnum,
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ d_layout: utils.LayoutEnum,
+ smem_capacity: int,
+ occupancy: int,
+ ) -> Tuple[int, int, int]:
+ """Computes the number of stages for A/B/C operands based on heuristics.
+
+ :param tiled_mma: The tiled MMA object defining the core computation.
+ :type tiled_mma: cute.TiledMma
+ :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler.
+ :type mma_tiler_mnk: tuple[int, int, int]
+ :param a_dtype: Data type of operand A.
+ :type a_dtype: type[cutlass.Numeric]
+ :param b_dtype: Data type of operand B.
+ :type b_dtype: type[cutlass.Numeric]
+ :param epi_tile: The epilogue tile shape.
+ :type epi_tile: cute.Tile
+ :param c_dtype: Data type of operand C (output).
+ :type c_dtype: type[cutlass.Numeric]
+ :param c_layout: Layout enum of operand C.
+ :type c_layout: utils.LayoutEnum
+ :param sf_dtype: Data type of Scale factor.
+ :type sf_dtype: type[cutlass.Numeric]
+ :param sf_vec_size: Scale factor vector size.
+ :type sf_vec_size: int
+ :param d_dtype: Data type of operand D.
+ :type d_dtype: type[cutlass.Numeric]
+ :param d_layout: Layout enum of operand D.
+ :type d_layout: utils.LayoutEnum
+ :param smem_capacity: Total available shared memory capacity in bytes.
+ :type smem_capacity: int
+ :param occupancy: Target number of CTAs per SM (occupancy).
+ :type occupancy: int
+
+ :return: A tuple containing the computed number of stages for:
+ (ACC stages, A/B operand stages, C stages)
+ :rtype: tuple[int, int, int]
+ """
+ # ACC stages
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
+
+ # Default C stages
+ num_c_stage = 2 # mma_tiler_mnk[1] // cute.cosize(epi_tile[1])
+ num_d_stage = num_c_stage
+
+ # Calculate smem layout and size for one stage of A, B, SFA, SFB and C
+ a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ mma_tiler_mnk,
+ a_dtype,
+ 1, # a tmp 1 stage is provided
+ )
+ b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ mma_tiler_mnk,
+ b_dtype,
+ 1, # a tmp 1 stage is provided
+ )
+ sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ mma_tiler_mnk,
+ sf_vec_size,
+ 1, # a tmp 1 stage is provided
+ )
+ sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ mma_tiler_mnk,
+ sf_vec_size,
+ 1, # a tmp 1 stage is provided
+ )
+
+ c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
+ c_dtype,
+ c_layout,
+ epi_tile,
+ 1,
+ )
+
+ d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(
+ d_dtype,
+ d_layout,
+ epi_tile,
+ 1,
+ )
+
+ ab_bytes_per_stage = (
+ cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
+ + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
+ )
+ mbar_helpers_bytes = 1024
+ c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
+ d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
+ amax_bytes = Sm100BlockScaledPersistentDenseGemmKernel.get_amax_smem_size()
+ epi_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage * num_d_stage + amax_bytes
+
+ # Calculate A/B/SFA/SFB stages:
+ # Start with total smem per CTA (capacity / occupancy)
+ # Subtract reserved bytes and initial C stages bytes
+ # Divide remaining by bytes needed per A/B/SFA/SFB stage
+ num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)) // ab_bytes_per_stage
+
+ return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage
+
+ @staticmethod
+ def _compute_grid(
+ c: cute.Tensor,
+ cta_tile_shape_mnk: Tuple[int, int, int],
+ cluster_shape_mn: Tuple[int, int],
+ max_active_clusters: cutlass.Constexpr,
+ ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
+ """Use persistent tile scheduler to compute the grid size for the output tensor C.
+
+ :param c: The output tensor C
+ :type c: cute.Tensor
+ :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile.
+ :type cta_tile_shape_mnk: tuple[int, int, int]
+ :param cluster_shape_mn: Shape of each cluster in M, N dimensions.
+ :type cluster_shape_mn: tuple[int, int]
+ :param max_active_clusters: Maximum number of active clusters.
+ :type max_active_clusters: cutlass.Constexpr
+
+ :return: A tuple containing:
+ - tile_sched_params: Parameters for the persistent tile scheduler.
+ - grid: Grid shape for kernel launch.
+ :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
+ """
+ c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0))
+ gc = cute.zipped_divide(c, tiler=c_shape)
+ num_ctas_mnl = gc[(0, (None, None, None))].shape
+ cluster_shape_mnl = (*cluster_shape_mn, 1)
+
+ tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
+ grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters)
+
+ return tile_sched_params, grid
+
+ @staticmethod
+ def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float:
+ """
+ Calculates the reciprocal of the maximum absolute value for a given data type.
+
+ :param dtype: Data type
+ :type dtype: Type[cutlass.Numeric]
+
+ :return: An float representing the reciprocal of the maximum absolute value
+ :rtype: float
+ """
+ if dtype == cutlass.Float4E2M1FN:
+ return 1 / 6.0
+ if dtype == cutlass.Float8E4M3FN:
+ return 1 / 448.0
+ if dtype == cutlass.Float8E5M2:
+ return 1 / 128.0
+ return 1.0
+
+ @staticmethod
+ def is_valid_dtypes_and_scale_factor_vec_size(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ ) -> bool:
+ """
+ Check if the dtypes and sf_vec_size are valid combinations
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param sf_dtype: The data type of the scale factor
+ :type sf_dtype: Type[cutlass.Numeric]
+ :param sf_vec_size: The vector size of the scale factor
+ :type sf_vec_size: int
+ :param d_dtype: The data type of the output tensor
+ :type d_dtype: Type[cutlass.Numeric]
+
+ :return: True if the dtypes and sf_vec_size are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ # Check valid ab_dtype
+ if ab_dtype not in {
+ cutlass.Float4E2M1FN,
+ cutlass.Float8E5M2,
+ cutlass.Float8E4M3FN,
+ }:
+ is_valid = False
+
+ if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
+ # Check valid sf_vec_size
+ if sf_vec_size not in {16, 32}:
+ is_valid = False
+ # Check valid sf_dtype
+ if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}:
+ is_valid = False
+ # Check valid sf_dtype and sf_vec_size combinations
+ if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32:
+ is_valid = False
+ if sf_dtype == cutlass.Float8E8M0FNU and sf_vec_size == 16:
+ is_valid = False
+
+ # Check valid c_dtype
+ if d_dtype not in {
+ cutlass.Float32,
+ cutlass.Float16,
+ cutlass.BFloat16,
+ cutlass.Float8E5M2,
+ cutlass.Float8E4M3FN,
+ cutlass.Float4E2M1FN, # {$nv-internal-release}
+ }:
+ is_valid = False
+
+ return is_valid
+
+ @staticmethod
+ def is_valid_layouts(
+ ab_dtype: Type[cutlass.Numeric],
+ c_dtype: Type[cutlass.Numeric],
+ a_major: str,
+ b_major: str,
+ c_major: str,
+ ) -> bool:
+ """
+ Check if layouts and dtypes are valid combinations
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param c_dtype: The data type of the output tensor
+ :type c_dtype: Type[cutlass.Numeric]
+ :param a_major: The major dimension of the A tensor
+ :type a_major: str
+ :param b_major: The major dimension of the B tensor
+ :type b_major: str
+ :param c_major: The major dimension of the C tensor
+ :type c_major: str
+
+ :return: True if the layouts are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ # {$nv-internal-release begin}
+ if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"):
+ is_valid = False
+ # TODO: Currently we don't support m major output for Float4E2M1FN
+ if c_dtype is cutlass.Float4E2M1FN and c_major == "m":
+ is_valid = False
+ # {$nv-internal-release end}
+
+ return is_valid
+
+ @staticmethod
+ def is_valid_mma_tiler_and_cluster_shape(
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ ) -> bool:
+ """
+ Check if the mma tiler and cluster shape are valid
+
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
+ :type cluster_shape_mn: Tuple[int, int]
+
+ :return: True if the mma tiler and cluster shape are valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+ # Skip invalid mma tile shape
+ if mma_tiler_mn[0] not in [128, 256]:
+ is_valid = False
+ if mma_tiler_mn[1] not in [64, 128, 192, 256]:
+ is_valid = False
+ # Skip illegal cluster shape
+ if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0:
+ is_valid = False
+ # Skip invalid cluster shape
+ is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0
+ if (
+ cluster_shape_mn[0] * cluster_shape_mn[1] > 16
+ or cluster_shape_mn[0] <= 0
+ or cluster_shape_mn[1] <= 0
+ # Special cluster shape check for scale factor multicasts.
+ # Due to limited size of scale factors, we can't multicast among more than 4 CTAs.
+ or cluster_shape_mn[0] > 4
+ or cluster_shape_mn[1] > 4
+ or not is_power_of_2(cluster_shape_mn[0])
+ or not is_power_of_2(cluster_shape_mn[1])
+ ):
+ is_valid = False
+ return is_valid
+
+ @staticmethod
+ def is_valid_tensor_alignment(
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ ab_dtype: Type[cutlass.Numeric],
+ c_dtype: Type[cutlass.Numeric],
+ a_major: str,
+ b_major: str,
+ c_major: str,
+ ) -> bool:
+ """
+ Check if the tensor alignment is valid
+
+ :param m: The number of rows in the A tensor
+ :type m: int
+ :param n: The number of columns in the B tensor
+ :type n: int
+ :param k: The number of columns in the A tensor
+ :type k: int
+ :param l: The number of columns in the C tensor
+ :type l: int
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param c_dtype: The data type of the output tensor
+ :type c_dtype: Type[cutlass.Numeric]
+ :param a_major: The major axis of the A tensor
+ :type a_major: str
+ :param b_major: The major axis of the B tensor
+ :type b_major: str
+ :param c_major: The major axis of the C tensor
+ :type c_major: str
+
+ :return: True if the problem shape is valid, False otherwise
+ :rtype: bool
+ """
+ is_valid = True
+
+ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
+ major_mode_idx = 0 if is_mode0_major else 1
+ num_major_elements = tensor_shape[major_mode_idx]
+ num_contiguous_elements = 16 * 8 // dtype.width
+ return num_major_elements % num_contiguous_elements == 0
+
+ if (
+ not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l))
+ or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l))
+ or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l))
+ ):
+ is_valid = False
+ return is_valid
+
+ @staticmethod
+ def can_implement(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ d_dtype: Type[cutlass.Numeric],
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ a_major: str,
+ b_major: str,
+ d_major: str,
+ ) -> bool:
+ """
+ Check if the gemm can be implemented
+
+ :param ab_dtype: The data type of the A and B operands
+ :type ab_dtype: Type[cutlass.Numeric]
+ :param sf_dtype: The data type of the scale factor tensor
+ :type sf_dtype: Type[cutlass.Numeric]
+ :param sf_vec_size: The vector size
+ :type sf_vec_size: int
+ :param d_dtype: The data type of the output tensor
+ :type d_dtype: Type[cutlass.Numeric]
+ :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler
+ :type mma_tiler_mn: Tuple[int, int]
+ :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
+ :type cluster_shape_mn: Tuple[int, int]
+ :param m: The number of rows in the A tensor
+ :type m: int
+ :param n: The number of columns in the B tensor
+ :type n: int
+ :param k: The number of columns in the A tensor
+ :type k: int
+ :param l: The number of columns in the C tensor
+ :type l: int
+ :param a_major: The major axis of the A tensor
+ :type a_major: str
+ :param b_major: The major axis of the B tensor
+ :type b_major: str
+ :param d_major: The major axis of the C tensor
+ :type d_major: str
+
+ :return: True if the gemm can be implemented, False otherwise
+ :rtype: bool
+ """
+ can_implement = True
+ # Skip unsupported types
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, d_dtype):
+ can_implement = False
+ # Skip unsupported layouts
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(ab_dtype, d_dtype, a_major, b_major, d_major):
+ can_implement = False
+ # Skip invalid mma tile shape and cluster shape
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn):
+ can_implement = False
+ # Skip illegal problem shape for load/store alignment
+ if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major):
+ can_implement = False
+ return can_implement
+
+ @staticmethod
+ def get_amax_smem_size():
+ # Note: 4 is hardcoded for num_epilog_warps
+ return 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,)))
diff --git a/python/cudnn/gemm_swiglu/api.py b/python/cudnn/gemm_swiglu/api.py
index 270c026a..2daeaa74 100644
--- a/python/cudnn/gemm_swiglu/api.py
+++ b/python/cudnn/gemm_swiglu/api.py
@@ -69,7 +69,7 @@ def __init__(
):
super().__init__()
- self._logger.warning("GemmSwigluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self.a_desc = self._make_tensor_desc(sample_a, name="sample_a")
@@ -94,7 +94,7 @@ def __init__(
self.vector_f32 = vector_f32
self.ab12_stages = ab12_stages
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
# Kernel selection
if self.sfa_desc is None and self.sfb_desc is None and self.amax_desc is None and self.sfc_desc is None and self.norm_const_desc is None:
diff --git a/python/cudnn/grouped_gemm/__init__.py b/python/cudnn/grouped_gemm/__init__.py
index a19f6c2f..818a1ab9 100644
--- a/python/cudnn/grouped_gemm/__init__.py
+++ b/python/cudnn/grouped_gemm/__init__.py
@@ -16,11 +16,26 @@
grouped_gemm_quant_wrapper_sm100,
)
+from .grouped_gemm_srelu.api import (
+ GroupedGemmSreluSm100,
+ grouped_gemm_srelu_wrapper_sm100,
+)
+
+from .grouped_gemm_dsrelu.api import (
+ GroupedGemmDsreluSm100,
+ grouped_gemm_dsrelu_wrapper_sm100,
+)
+
from .grouped_gemm_glu.api import (
GroupedGemmGluSm100,
grouped_gemm_glu_wrapper_sm100,
)
+from .grouped_gemm_glu_hadamard.api import (
+ GroupedGemmGluHadamardSm100,
+ grouped_gemm_glu_hadamard_wrapper_sm100,
+)
+
from .grouped_gemm_dglu.api import (
GroupedGemmDgluSm100,
grouped_gemm_dglu_wrapper_sm100,
@@ -38,8 +53,14 @@
"grouped_gemm_dswiglu_wrapper_sm100",
"GroupedGemmQuantSm100",
"grouped_gemm_quant_wrapper_sm100",
+ "GroupedGemmSreluSm100",
+ "grouped_gemm_srelu_wrapper_sm100",
+ "GroupedGemmDsreluSm100",
+ "grouped_gemm_dsrelu_wrapper_sm100",
"GroupedGemmGluSm100",
"grouped_gemm_glu_wrapper_sm100",
+ "GroupedGemmGluHadamardSm100",
+ "grouped_gemm_glu_hadamard_wrapper_sm100",
"GroupedGemmDgluSm100",
"grouped_gemm_dglu_wrapper_sm100",
"GroupedGemmWgradSm100",
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
index 843b3412..d7a19666 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py
@@ -174,7 +174,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("GroupedGemmDgluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
# ---- Weight mode auto-detection ----
@@ -260,7 +260,7 @@ def __init__(
self._kernel = BlockScaledMoEGroupedGemmDgluDbiasKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._workspace = None
@@ -676,7 +676,7 @@ def compile(self) -> None:
def _compile_dense(self, gemm_dglu, max_active_clusters, fake_stream) -> None:
"""Compile for dense (contiguous) weight mode."""
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
fake_workspace_ptr = cute.runtime.nullptr(
dtype=cutlass.Uint8,
@@ -1446,7 +1446,7 @@ def dynamic_m_tensor_signature(
stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
return static_shape_suffix, stride_signature, tensor.dtype
- use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
if is_dense:
cache_key = (
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py
new file mode 100644
index 00000000..e1513779
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .api import (
+ GroupedGemmDsreluSm100,
+ grouped_gemm_dsrelu_wrapper_sm100,
+)
+
+__all__ = [
+ "GroupedGemmDsreluSm100",
+ "grouped_gemm_dsrelu_wrapper_sm100",
+]
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py
new file mode 100644
index 00000000..2226a627
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/api.py
@@ -0,0 +1,1581 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""
+Unified API for Grouped GEMM dSReLU Backward Kernel (SM100+)
+
+This module provides a single API class that supports both contiguous (dense)
+and discrete weight modes for block-scaled grouped GEMM with dSReLU activation
+gradient in MoE (Mixture of Experts) workloads.
+
+Dense mode
+ All expert weights are packed contiguously in a 3-D tensor (N, K, L).
+ Callers supply ``sample_b`` and ``sample_sfb``.
+
+Discrete mode
+ Each expert has its own memory allocation. Callers supply
+ ``num_experts``, ``b_shape``, ``b_dtype``, and per-expert pointer arrays
+ at execution time.
+"""
+
+from .moe_blockscaled_grouped_gemm_dsrelu_quant import (
+ BlockScaledMoEGroupedGemmQuantBwdKernel,
+ EpilogueType,
+)
+from ..moe_utils import MoEWeightMode
+from cuda.bindings import driver as cuda
+import logging
+import os
+import torch
+from typing import Tuple, Optional
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+from cutlass.cute.runtime import from_dlpack, make_fake_stream
+
+from cudnn.datatypes import _convert_to_cutlass_data_type
+from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2
+
+
+def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype == torch.uint8:
+ cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1)
+ cute_tensor.element_type = cutlass.Float4E2M1FN
+ return cute_tensor
+ return tensor
+
+
+class GroupedGemmDsreluSm100(APIBase):
+ """Unified API for grouped GEMM dSReLU backward operation on SM100+ GPUs.
+
+ This kernel performs block-scaled grouped GEMM with dSReLU activation
+ gradient (dSReLU), designed for MoE workloads. It supports
+ both dense (contiguous) and discrete (per-expert pointer) weight layouts
+ through the ``BlockScaledMoEGroupedGemmQuantBwdKernel``.
+
+ Weight mode is auto-detected from the constructor arguments:
+
+ - **Dense**: provide ``sample_b`` and ``sample_sfb``.
+ - **Discrete**: provide ``num_experts``, ``b_shape``, and ``b_dtype``.
+
+ Example::
+
+ # Dense mode
+ api = GroupedGemmDsreluSm100(
+ sample_a=a, sample_c=c,
+ sample_d_row=d_row, sample_d_col=d_col,
+ sample_sfa=sfa, sample_padded_offsets=offsets,
+ sample_alpha=alpha,
+ sample_prob=prob, sample_dprob=dprob,
+ sample_b=b, sample_sfb=sfb,
+ )
+
+ # Discrete mode
+ api = GroupedGemmDsreluSm100(
+ sample_a=a, sample_c=c,
+ sample_d_row=d_row, sample_d_col=d_col,
+ sample_sfa=sfa, sample_padded_offsets=offsets,
+ sample_alpha=alpha,
+ sample_prob=prob, sample_dprob=dprob,
+ num_experts=8, b_shape=(n, k), b_dtype=torch.uint8,
+ )
+
+ api.check_support()
+ api.compile()
+ api.execute(...)
+ """
+
+ def __init__(
+ self,
+ sample_a: torch.Tensor,
+ # Dense mode (contiguous) -- provide these. sample_dbias is optional:
+ sample_b: Optional[torch.Tensor] = None,
+ sample_c: Optional[torch.Tensor] = None,
+ sample_d_row: Optional[torch.Tensor] = None,
+ sample_d_col: Optional[torch.Tensor] = None,
+ sample_sfa: Optional[torch.Tensor] = None,
+ sample_sfb: Optional[torch.Tensor] = None,
+ sample_padded_offsets: Optional[torch.Tensor] = None,
+ sample_alpha: Optional[torch.Tensor] = None,
+ sample_prob: Optional[torch.Tensor] = None,
+ sample_dprob: Optional[torch.Tensor] = None,
+ sample_dbias: Optional[torch.Tensor] = None,
+ # Discrete mode -- provide these instead:
+ num_experts: Optional[int] = None,
+ b_shape: Optional[Tuple[int, ...]] = None,
+ b_dtype: Optional[torch.dtype] = None,
+ # Optional quantization output arguments
+ sample_sfd_row: Optional[torch.Tensor] = None,
+ sample_sfd_col: Optional[torch.Tensor] = None,
+ sample_amax: Optional[torch.Tensor] = None,
+ sample_norm_const: Optional[torch.Tensor] = None,
+ # Configuration
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ discrete_col_sfd: bool = False,
+ b_major: str = "k",
+ use_dynamic_sched: bool = False,
+ ):
+ """Initialize the GroupedGemmDsreluSm100 API.
+
+ :param sample_a: Sample A tensor (valid_m, k, 1)
+ :param sample_c: Sample C tensor -- forward activations (valid_m, n, 1)
+ :param sample_d_row: Sample D row output tensor (valid_m, n, 1)
+ :param sample_d_col: Sample D col output tensor (valid_m, n, 1)
+ :param sample_sfa: Sample scale factor A tensor
+ :param sample_padded_offsets: End offset for each expert after padding
+ :param sample_alpha: Per-group alpha scaling factors
+ :param sample_prob: Per-row probability tensor (valid_m, 1, 1)
+ :param sample_dprob: Gradient of probability tensor (valid_m, 1, 1), must be zero-initialized
+ :param sample_b: (Dense) Sample B tensor (n, k, l)
+ :param sample_sfb: (Dense) Sample scale factor B tensor
+ :param sample_dbias: Optional dbias output tensor (expert_cnt, n, 1)
+ :param num_experts: (Discrete) Number of experts
+ :param b_shape: (Discrete) Shape of a single expert B tensor, e.g. (n, k)
+ :param b_dtype: (Discrete) Data type of B tensors
+ :param sample_sfd_row: Optional row scale factor for D
+ :param sample_sfd_col: Optional column scale factor for D
+ :param sample_amax: Optional amax tensor for quantization, shape (expert_cnt, 1)
+ :param sample_norm_const: Optional normalization constant
+ :param acc_dtype: Accumulator data type
+ :param mma_tiler_mn: MMA tiler shape (M, N)
+ :param cluster_shape_mn: Cluster shape (M, N)
+ :param sf_vec_size: Scale factor vector size
+ :param vector_f32: Use vectorized f32 operations
+ :param m_aligned: Alignment for group M dimension
+ :param discrete_col_sfd: Generate discrete col-major scale factor tensor
+ :param b_major: Major dimension for B tensor, one of "k" or "n"
+ :param use_dynamic_sched: Enable dynamic tile scheduling for load balancing
+ """
+ super().__init__()
+
+ self._warn_experimental_api()
+ self._logger.debug("Entering __init__")
+
+ # ---- Weight mode auto-detection ----
+ if sample_b is not None and num_experts is None:
+ self.weight_mode = MoEWeightMode.DENSE
+ if sample_sfb is None:
+ raise ValueError("sample_sfb is required when sample_b is provided (dense mode)")
+ elif num_experts is not None and sample_b is None:
+ self.weight_mode = MoEWeightMode.DISCRETE
+ if b_shape is None or b_dtype is None:
+ raise ValueError("b_shape and b_dtype are required in discrete mode")
+ else:
+ raise ValueError("Provide either (sample_b, sample_sfb) for dense mode " "or (num_experts, b_shape, b_dtype) for discrete mode, but not both.")
+
+ self._sample_a_tensor = sample_a
+ self._sample_b_tensor = sample_b
+
+ # ---- Common tensor descriptors ----
+ self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False)
+ self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
+ self.d_row_desc = self._make_tensor_desc(sample_d_row, name="sample_d_row")
+ self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col")
+ self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa")
+ self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets")
+ self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha")
+ self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
+ self.dprob_desc = self._make_tensor_desc(sample_dprob, name="sample_dprob")
+ self.dbias_desc = self._make_tensor_desc(sample_dbias, name="sample_dbias")
+
+ self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row")
+ self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col")
+ self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax")
+ self.norm_const_desc = self._unpad_tensor_to_ndim(
+ self._make_tensor_desc(sample_norm_const, name="sample_norm_const"),
+ 1,
+ "norm_const",
+ )
+
+ # ---- Mode-specific state ----
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False)
+ self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
+ self.expert_cnt = self.padded_offsets_desc.shape[0]
+ else:
+ self._value_error_if(num_experts == 0, "num_experts must be > 0")
+ self.expert_cnt = num_experts
+ self.b_shape = b_shape
+ self.b_dtype = b_dtype
+ self.b_major = b_major
+ self._value_error_if(
+ self.padded_offsets_desc.shape[0] != self.expert_cnt,
+ f"padded_offsets length ({self.padded_offsets_desc.shape[0]}) " f"must equal num_experts ({self.expert_cnt})",
+ )
+
+ # ---- Configuration ----
+ self.acc_dtype = acc_dtype
+ self.mma_tiler_mn = mma_tiler_mn
+ self.use_2cta_instrs = mma_tiler_mn[0] == 256
+ if cluster_shape_mn is None:
+ self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1)
+ else:
+ self.cluster_shape_mn = cluster_shape_mn
+ self.sf_vec_size = sf_vec_size
+ self.vector_f32 = vector_f32
+ self.m_aligned = m_aligned
+ self.discrete_col_sfd = discrete_col_sfd
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self.b_major = b_major # stored for both modes
+
+ self.use_dynamic_sched = use_dynamic_sched
+
+ self._interpret_uint8_as_fp4x2 = True
+ self._has_dbias = self.dbias_desc is not None
+ self._kernel = BlockScaledMoEGroupedGemmQuantBwdKernel
+
+ self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+
+ self._workspace = None
+
+ self._logger.debug("__init__ completed")
+
+ def check_support(self) -> bool:
+ """Check if the kernel configuration is supported.
+
+ :return: True if supported, raises exception otherwise
+ """
+ self._logger.debug("Entering check_support")
+
+ # ---- SFD group validation ----
+ all_none = all(x is None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc])
+ all_provided = all(x is not None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc])
+ self._value_error_if(
+ not (all_none or all_provided),
+ "sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None",
+ )
+ self._user_requested_sfd = all_provided
+
+ # ---- Shapes and strides ----
+ self._logger.debug("Checking tensor shapes and strides")
+ tensor_m, k, _one = self._tensor_shape(self.a_desc, name="sample_a")
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ n, _, l = self._tensor_shape(self.b_desc, name="sample_b")
+ else:
+ # Discrete: extract n, k from b_shape
+ if len(self.b_shape) == 2:
+ n, b_k = self.b_shape
+ else:
+ n, b_k, _ = self.b_shape
+ self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})")
+ l = self.expert_cnt # for shape checks that use l
+
+ n_out = n
+
+ self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A")
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_tensor_shape(self.b_desc, (n, k, l), "B")
+ self._check_tensor_shape(self.c_desc, (tensor_m, n_out, 1), "C")
+ self._check_tensor_shape(self.d_row_desc, (tensor_m, n_out, 1), "D_row")
+ self._check_tensor_shape(self.d_col_desc, (tensor_m, n_out, 1), "D_col")
+
+ rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_k, 1), "SFA")
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB")
+
+ rest_n_out = ceil_div(ceil_div(n_out, self.sf_vec_size), 4)
+ self._check_tensor_shape(
+ self.sfd_row_desc,
+ (32, 4, ceil_div(tensor_m, 128), 4, rest_n_out, 1),
+ "SFD_row",
+ )
+ rest_m = ceil_div(ceil_div(tensor_m, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfd_col_desc, (32, 4, ceil_div(n_out, 128), 4, rest_m, 1), "SFD_col")
+
+ self._check_tensor_shape(self.alpha_desc, (self.expert_cnt,), "alpha")
+ self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob")
+ self._check_tensor_shape(self.dprob_desc, (tensor_m, 1, 1), "dprob")
+ self._check_tensor_shape(self.dbias_desc, (self.expert_cnt, n_out, 1), "dbias")
+ self._check_tensor_shape(self.amax_desc, (self.expert_cnt, 1), "amax")
+ self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const")
+ self._check_tensor_shape(self.padded_offsets_desc, (self.expert_cnt,), "padded_offsets")
+
+ # Strides
+ _ = self._check_tensor_stride(
+ self.a_desc,
+ stride=[(k, 1, tensor_m * k)],
+ extra_error_msg="A must have k-major layout",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ if self._is_fp8(self.a_desc):
+ _ = self._check_tensor_stride(
+ self.b_desc,
+ stride=[(k, 1, n * k), (1, n, n * k)],
+ extra_error_msg="For fp8 ab_dtype, B must have k- or n-major layout",
+ )
+ else:
+ _ = self._check_tensor_stride(
+ self.b_desc,
+ stride=[(k, 1, n * k)],
+ extra_error_msg="For fp4 ab_dtype, B must have k-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.c_desc,
+ stride=[(n_out, 1, tensor_m * n_out)],
+ extra_error_msg="C must have n-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.d_row_desc,
+ stride=[(n_out, 1, tensor_m * n_out)],
+ extra_error_msg="D_row must have n-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.d_col_desc,
+ stride=[(n_out, 1, tensor_m * n_out)],
+ extra_error_msg="D_col must have n-major layout",
+ )
+
+ # ---- Data types ----
+ self._logger.debug("Checking data types")
+ self.ab_dtype = self._check_dtype(
+ self.a_desc,
+ dtype=[
+ torch.float4_e2m1fn_x2,
+ torch.uint8,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
+ ],
+ name="A/B",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_dtype(
+ self.b_desc,
+ dtype=self.ab_dtype,
+ name="B",
+ extra_error_msg="B must have the same dtype as A",
+ )
+ else:
+ self._value_error_if(
+ self.b_dtype != self.ab_dtype,
+ f"b_dtype ({self.b_dtype}) must match A dtype ({self.ab_dtype})",
+ )
+
+ self.sf_dtype = self._check_dtype(
+ self.sfa_desc,
+ dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn],
+ name="SFA/SFB/SFD",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_dtype(
+ self.sfb_desc,
+ dtype=self.sf_dtype,
+ name="SFB",
+ extra_error_msg="SFB must have the same dtype as SFA",
+ )
+ self._check_dtype(
+ self.sfd_row_desc,
+ dtype=self.sf_dtype,
+ name="SFD_row",
+ extra_error_msg="SFD_row must have the same dtype as SFA",
+ )
+ self._check_dtype(
+ self.sfd_col_desc,
+ dtype=self.sf_dtype,
+ name="SFD_col",
+ extra_error_msg="SFD_col must have the same dtype as SFA",
+ )
+
+ self._value_error_if(
+ self.sf_vec_size not in [16, 32],
+ f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}",
+ )
+ self._value_error_if(
+ self.sf_dtype in [torch.float8_e4m3fn] and self.sf_vec_size == 32,
+ f"sf_dtype {self.sf_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported",
+ )
+ self._value_error_if(
+ self._is_fp8(self.ab_dtype) and self.sf_vec_size == 16,
+ f"ab_dtype {self.ab_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported",
+ )
+
+ self._check_dtype(
+ self.acc_dtype,
+ dtype=torch.float32,
+ name="Accumulator",
+ extra_error_msg="Accumulator must be float32",
+ )
+ self._check_dtype(
+ self.prob_desc,
+ dtype=torch.float32,
+ name="Prob",
+ extra_error_msg="Prob must be float32",
+ )
+ self._check_dtype(
+ self.dprob_desc,
+ dtype=torch.float32,
+ name="Dprob",
+ extra_error_msg="Dprob must be float32",
+ )
+ self._check_dtype(
+ self.dbias_desc,
+ dtype=torch.bfloat16,
+ name="Dbias",
+ extra_error_msg="dbias must be bfloat16",
+ )
+ self.c_dtype = self._check_dtype(
+ self.c_desc,
+ dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="C",
+ )
+ if self._is_fp8(self.c_dtype) and self.vector_f32:
+ raise ValueError("Invalid configuration: fp8 c_dtype and vector_f32 is not supported. " "Please use vector_f32=False or c_dtype=bfloat16 instead")
+
+ if self._is_fp4x2(self.ab_dtype):
+ self.d_dtype = self._check_dtype(
+ self.d_row_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32],
+ name="D_row",
+ extra_error_msg="D_row must be fp16, bf16, or float32 when ab_dtype is fp4",
+ )
+ elif self._is_fp8(self.ab_dtype):
+ self.d_dtype = self._check_dtype(
+ self.d_row_desc,
+ dtype=[
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ ],
+ name="D_row",
+ extra_error_msg="D_row must be fp8 dtype when ab_dtype is fp8",
+ )
+ else:
+ raise NotImplementedError(f"Invalid ab_dtype: {self.ab_dtype}, expected fp4 or fp8")
+ self._check_dtype(
+ self.d_col_desc,
+ dtype=self.d_dtype,
+ name="D_col",
+ extra_error_msg="D_col must have the same dtype as D_row",
+ )
+
+ # ---- SFD generation logic ----
+ kernel_generate_sfd = self._is_fp8(self.ab_dtype) and self.sf_dtype == torch.float8_e8m0fnu and self._is_fp8(self.d_dtype)
+ self._value_error_if(
+ kernel_generate_sfd and not self._user_requested_sfd,
+ "sfd_row, sfd_col, and norm_const are required for FP8 input/FP8 output with sf_dtype=torch.float8_e8m0fnu",
+ )
+ if not kernel_generate_sfd and self._user_requested_sfd:
+ self._logger.warning(
+ "sfd_row/sfd_col/norm_const were provided, but this configuration does not generate SFD outputs; " "the tensors will be ignored by the kernel",
+ )
+ self.generate_sfd = kernel_generate_sfd
+ if self.discrete_col_sfd and not self.generate_sfd:
+ self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored")
+ self.discrete_col_sfd = False
+
+ # ---- Activation function validation ----
+ # ---- Discrete-mode-specific validation ----
+ if self.weight_mode == MoEWeightMode.DISCRETE:
+ self._value_error_if(
+ self.b_major not in ["k", "n"],
+ f"b_major must be 'k' or 'n', got {self.b_major}",
+ )
+ self._value_error_if(
+ self._is_fp4x2(self.ab_dtype) and self.b_major != "k",
+ "b_major must be 'k' when ab_dtype is fp4",
+ )
+
+ # ---- MMA tile / cluster shape ----
+ self._logger.debug("Checking MMA tile shape and cluster shape")
+ self._value_error_if(
+ not self.use_2cta_instrs and self.mma_tiler_mn[0] != 128,
+ f"MMA tiler M must be 128 when use_2cta_instrs=False, got {self.mma_tiler_mn[0]}",
+ )
+ self._value_error_if(
+ self.use_2cta_instrs and self.mma_tiler_mn[0] != 256,
+ f"MMA tiler M must be 256 when use_2cta_instrs=True, got {self.mma_tiler_mn[0]}",
+ )
+ self._value_error_if(
+ self.mma_tiler_mn[1] != 256,
+ f"MMA tiler N must be 256, got {self.mma_tiler_mn[1]}",
+ )
+ self._value_error_if(
+ self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0,
+ f"cluster_shape_mn[0] must be divisible by 2 when use_2cta_instrs=True, got {self.cluster_shape_mn[0]}",
+ )
+ self._value_error_if(
+ not (
+ self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16
+ and self.cluster_shape_mn[0] > 0
+ and self.cluster_shape_mn[1] > 0
+ and self.cluster_shape_mn[0] <= 4
+ and self.cluster_shape_mn[1] <= 4
+ and is_power_of_2(self.cluster_shape_mn[0])
+ and is_power_of_2(self.cluster_shape_mn[1])
+ ),
+ f"Invalid cluster shape: expected values to be powers of 2 and product <= 16, got {self.cluster_shape_mn}",
+ )
+ cluster_tiler_m = (self.cluster_shape_mn[0] // (2 if self.use_2cta_instrs else 1)) * self.mma_tiler_mn[0]
+ self._value_error_if(
+ cluster_tiler_m not in [128, 256],
+ f"Invalid cluster tiler shape: expected cluster_tiler_m in {{128, 256}}, got {cluster_tiler_m}",
+ )
+ self._value_error_if(
+ self.m_aligned % self.mma_tiler_mn[0] != 0,
+ f"m_aligned must be divisible by mma_tiler_mn[0], got {self.m_aligned} % {self.mma_tiler_mn[0]} != 0",
+ )
+ self._value_error_if(
+ self.m_aligned != BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE,
+ f"m_aligned must be {BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE} (FIX_PAD_SIZE), got {self.m_aligned}",
+ )
+
+ # ---- Tensor alignment ----
+ self._logger.debug("Checking tensor alignment")
+
+ def check_contiguous_16B_alignment(dtype, stride_order, tensor_shape):
+ is_mode0_major = stride_order == (0, 1, 2)
+ major_mode_idx = 0 if is_mode0_major else 1
+ num_major_elements = tensor_shape[major_mode_idx]
+ num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width)
+ return num_major_elements % num_contiguous_elements == 0
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ b_stride_order_for_check = self.b_desc.stride_order
+ b_shape_for_check = (n, k, l)
+ else:
+ b_stride_order_for_check = (0, 1, 2) if self.b_major == "n" else (1, 0, 2)
+ b_shape_for_check = (n, k, 1)
+
+ self._value_error_if(
+ not (
+ check_contiguous_16B_alignment(self.ab_dtype, self.a_desc.stride_order, (tensor_m, k, l))
+ and check_contiguous_16B_alignment(self.ab_dtype, b_stride_order_for_check, b_shape_for_check)
+ and check_contiguous_16B_alignment(self.d_dtype, self.d_row_desc.stride_order, (tensor_m, n_out, 1))
+ ),
+ "Invalid tensor alignment: tensors must be 16B aligned",
+ )
+
+ # ---- Expert count limit ----
+ self._value_error_if(
+ self.expert_cnt > 1024,
+ f"expert_cnt must be <= 1024, got {self.expert_cnt}",
+ )
+
+ # ---- Disabled configurations ----
+ self._not_implemented_error_if(
+ self.dbias_desc is None and self._is_fp4x2(self.ab_dtype) and self.sf_vec_size == 16 and self.d_dtype == torch.float32,
+ "Invalid configuration: fp4 ab_dtype, sf_vec_size 16, d_dtype float32 is not supported. " "Please use sf_vec_size 32 or d_dtype bf16 instead",
+ )
+
+ # ---- SM100+ check ----
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA is not available")
+ device = torch.cuda.current_device()
+ major, minor = torch.cuda.get_device_capability(device)
+ compute_capability = major * 10 + minor
+ if compute_capability < 100:
+ raise RuntimeError(f"GroupedGemmDsrelu requires SM100+ compute capability, " f"but found SM{compute_capability} on device {device}")
+
+ self._is_supported = True
+ self._logger.debug("check_support completed successfully")
+ return True
+
+ def compile(self) -> None:
+ """Compile the kernel."""
+ self._logger.debug("Entering compile")
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ self._logger.debug("Kernel already compiled; skipping recompilation")
+ return
+ if self.a_desc.shape[0] == 0:
+ self._logger.debug("sample valid_m is zero, skipping kernel compilation")
+ return
+
+ gemm_dsrelu = self._kernel(
+ sf_vec_size=self.sf_vec_size,
+ acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype),
+ use_2cta_instrs=self.use_2cta_instrs,
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ vectorized_f32=self.vector_f32,
+ generate_sfd=self.generate_sfd,
+ discrete_col_sfd=self.discrete_col_sfd,
+ expert_cnt=self.expert_cnt,
+ weight_mode=self.weight_mode,
+ use_dynamic_sched=self.use_dynamic_sched,
+ epilogue_type=EpilogueType.DSRELU.value,
+ generate_dbias=self._has_dbias,
+ )
+
+ hardware_info = cutlass.utils.HardwareInfo()
+ max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1])
+ max_active_clusters -= self.num_cluster_overlap_margin
+ self._value_error_if(
+ max_active_clusters <= 0,
+ "max_active_clusters must be > 0 after applying overlap margin; reduce CUDNNFE_CLUSTER_OVERLAP_MARGIN",
+ )
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+
+ self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+
+ workspace_bytes = gemm_dsrelu.get_workspace_bytes()
+ self._workspace = torch.empty(max(workspace_bytes, 1), dtype=torch.uint8, device="cuda")
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._compile_dense(gemm_dsrelu, max_active_clusters, fake_stream)
+ else:
+ self._compile_discrete(gemm_dsrelu, max_active_clusters, fake_stream)
+
+ self._logger.debug("Kernel compiled successfully")
+
+ def _compile_dense(self, gemm_dsrelu, max_active_clusters, fake_stream) -> None:
+ """Compile for dense (contiguous) weight mode."""
+ self._logger.debug("Compiling grouped_gemm_dsrelu kernel")
+ use_full_dynamic = self._use_full_dynamic_mnkl
+
+ fake_workspace_ptr = cute.runtime.nullptr(
+ dtype=cutlass.Uint8,
+ assumed_align=128,
+ )
+
+ if not use_full_dynamic:
+ valid_m = cute.sym_int(divisibility=256)
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride_order=self.a_desc.stride_order,
+ )
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride_order=self.c_desc.stride_order,
+ )
+ d_row_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_row_desc.dtype,
+ shape=(valid_m, *self.d_row_desc.shape[1:]),
+ stride_order=self.d_row_desc.stride_order,
+ )
+ d_col_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, *self.d_col_desc.shape[1:]),
+ stride_order=self.d_col_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1),
+ stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
+ )
+
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+
+ prob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, 1, 1),
+ stride_order=self.prob_desc.stride_order,
+ )
+ dprob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, 1, 1),
+ stride_order=self.dprob_desc.stride_order,
+ )
+
+ sfd_row_fake = None
+ sfd_col_fake = None
+ if self.sfd_row_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1),
+ stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m),
+ )
+ if self.sfd_col_desc is not None:
+ rest_m = cute.sym_int(divisibility=1)
+ stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1),
+ stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n),
+ )
+ else:
+ valid_m = cute.sym_int(divisibility=256)
+ n_sym = cute.sym_int()
+ n_out_sym = cute.sym_int()
+ k_sym = cute.sym_int()
+ l_sym = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, k_sym, 1),
+ stride_order=self.a_desc.stride_order,
+ dynamic_mode=self.a_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ b_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.b_desc.dtype,
+ shape=(n_sym, k_sym, l_sym),
+ stride_order=self.b_desc.stride_order,
+ dynamic_mode=self.b_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, n_out_sym, 1),
+ stride_order=self.c_desc.stride_order,
+ dynamic_mode=self.c_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
+ )
+
+ d_row_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_row_desc.dtype,
+ shape=(valid_m, n_out_sym, 1),
+ stride_order=self.d_row_desc.stride_order,
+ dynamic_mode=self.d_row_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_row_desc.dtype) else 16,
+ )
+
+ d_col_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, n_out_sym, 1),
+ stride_order=self.d_col_desc.stride_order,
+ dynamic_mode=self.d_col_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ rest_k = cute.sym_int()
+ stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_shape[4] = rest_k
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[2] = stride_rest_k
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ )
+
+ tensor_n_128 = cute.sym_int()
+ stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfb_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfb_desc.dtype,
+ shape=(32, 4, tensor_n_128, 4, rest_k, l_sym),
+ stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k),
+ )
+
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride=self.prob_desc.stride,
+ )
+ dprob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, *self.dprob_desc.shape[1:]),
+ stride=self.dprob_desc.stride,
+ )
+
+ sfd_row_fake = None
+ sfd_col_fake = None
+ if self.sfd_row_desc is not None:
+ rest_n_out = cute.sym_int()
+ stride_sfd_rest_n_out = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_rest_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, rest_n_out, 1),
+ stride=(16, 4, stride_sfd_rest_n_out, 1, 512, stride_sfd_rest_tensor_m_128),
+ )
+ if self.sfd_col_desc is not None:
+ tensor_n_out_128 = cute.sym_int()
+ rest_m_dyn = cute.sym_int()
+ stride_sfd_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_n_out = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, tensor_n_out_128, 4, rest_m_dyn, 1),
+ stride=(16, 4, stride_sfd_rest_m, 1, 512, stride_sfd_n_out),
+ )
+
+ dbias_fake = self._make_fake_cute_tensor_from_desc(self.dbias_desc, assumed_align=16)
+
+ _compiled_kernel = cute.compile(
+ gemm_dsrelu,
+ a=_reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake,
+ b=_reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake,
+ sfb=sfb_cute_fake,
+ n=cutlass.Int32(0),
+ k=cutlass.Int32(0),
+ b_stride_size=cutlass.Int64(0),
+ b_major_mode=OperandMajorMode.K,
+ workspace_ptr=fake_workspace_ptr,
+ c=c_cute_fake,
+ d=d_row_cute_fake,
+ d_col=d_col_cute_fake,
+ sfa=sfa_cute_fake,
+ sfd_row_tensor=sfd_row_fake,
+ sfd_col_tensor=sfd_col_fake,
+ amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16),
+ norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16),
+ padded_offsets=self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16),
+ alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16),
+ prob=prob_cute_fake,
+ dprob=dprob_cute_fake,
+ dbias_tensor=dbias_fake,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ options="--enable-tvm-ffi",
+ )
+
+ cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_row_tensor: torch.Tensor,
+ d_col_tensor: Optional[torch.Tensor],
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ sfd_row_tensor: Optional[torch.Tensor],
+ sfd_col_tensor: Optional[torch.Tensor],
+ amax_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ dprob_tensor: torch.Tensor,
+ dbias_tensor: Optional[torch.Tensor],
+ stream: cuda.CUstream,
+ ) -> None:
+ norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const")
+ _compiled_kernel(
+ _reinterpret_raw_grouped_fp4_tensor(a_tensor),
+ _reinterpret_raw_grouped_fp4_tensor(b_tensor),
+ sfb_tensor,
+ cutlass.Int32(0),
+ cutlass.Int32(0),
+ cutlass.Int64(0),
+ cached_workspace_ptr,
+ c_tensor,
+ d_row_tensor,
+ d_col_tensor,
+ sfa_tensor,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ amax_tensor,
+ norm_const_tensor,
+ padded_offsets,
+ alpha_tensor,
+ prob_tensor,
+ dprob_tensor,
+ dbias_tensor,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def _compile_discrete(self, gemm_dsrelu, max_active_clusters, fake_stream) -> None:
+ """Compile for discrete (per-expert pointer) weight mode."""
+ if len(self.b_shape) == 2:
+ n, k = self.b_shape
+ else:
+ n, k, _ = self.b_shape
+
+ b_major_mode = OperandMajorMode.K if self.b_major == "k" else OperandMajorMode.MN
+ if self.b_major == "k":
+ b_stride_size = k
+ else:
+ b_stride_size = n
+
+ ab_cutlass_dtype = _convert_to_cutlass_data_type(self.a_desc.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2)
+ align = 32 if ab_cutlass_dtype.width == 4 else 16
+
+ valid_m = cute.sym_int(divisibility=256)
+ a_tensor = self._make_fake_cute_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride=(self.a_desc.stride[0], *self.a_desc.stride[1:]),
+ assumed_align=align,
+ )
+ c_tensor = self._make_fake_cute_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride=(self.c_desc.stride[0], *self.c_desc.stride[1:]),
+ )
+ d_row_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.d_row_desc.dtype,
+ shape=(valid_m, *self.d_row_desc.shape[1:]),
+ stride_order=self.d_row_desc.stride_order,
+ )
+ d_col_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, *self.d_col_desc.shape[1:]),
+ stride_order=self.d_col_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ assumed_align=16,
+ )
+ sfd_row_tensor = None
+ if self.sfd_row_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1),
+ stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m),
+ assumed_align=16,
+ )
+ sfd_col_tensor = None
+ if self.sfd_col_desc is not None:
+ rest_m = cute.sym_int(divisibility=1)
+ stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1),
+ stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n),
+ assumed_align=16,
+ )
+ amax_tensor = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16)
+ norm_const_tensor_cute = self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16)
+ padded_offsets_tensor = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16)
+ alpha_tensor = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16)
+ prob_tensor = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride=self.prob_desc.stride,
+ assumed_align=16,
+ )
+ dprob_tensor = self._make_fake_cute_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, *self.dprob_desc.shape[1:]),
+ stride=self.dprob_desc.stride,
+ assumed_align=16,
+ )
+ dbias_tensor = self._make_fake_cute_tensor_from_desc(self.dbias_desc, assumed_align=16)
+
+ b_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda")
+ sfb_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda")
+ b_ptrs_cute = from_dlpack(b_ptrs_placeholder, assumed_align=8).iterator
+ sfb_ptrs_cute = from_dlpack(sfb_ptrs_placeholder, assumed_align=8).iterator
+
+ workspace_ptr_cute = from_dlpack(self._workspace, assumed_align=128).iterator
+
+ self._logger.debug("Compiling discrete grouped GEMM dSReLU kernel")
+ _compiled_kernel = cute.compile(
+ gemm_dsrelu,
+ a_tensor,
+ b_ptrs_cute,
+ sfb_ptrs_cute,
+ cutlass.Int32(n),
+ cutlass.Int32(k),
+ cutlass.Int64(b_stride_size),
+ b_major_mode,
+ workspace_ptr_cute,
+ c_tensor,
+ d_row_tensor,
+ d_col_tensor,
+ sfa_tensor,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ amax_tensor,
+ norm_const_tensor_cute,
+ padded_offsets_tensor,
+ alpha_tensor,
+ prob_tensor,
+ dprob_tensor,
+ dbias_tensor,
+ max_active_clusters,
+ fake_stream,
+ options="--enable-tvm-ffi",
+ )
+
+ self._n = n
+ self._k = k
+ self._b_stride_size = b_stride_size
+
+ cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator
+ cached_n = cutlass.Int32(self._n)
+ cached_k = cutlass.Int32(self._k)
+ cached_b_stride = cutlass.Int64(self._b_stride_size)
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_ptrs_device: torch.Tensor,
+ sfb_ptrs_device: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_row_tensor: torch.Tensor,
+ d_col_tensor: Optional[torch.Tensor],
+ sfa_tensor: torch.Tensor,
+ sfd_row_tensor: Optional[torch.Tensor],
+ sfd_col_tensor: Optional[torch.Tensor],
+ amax_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ dprob_tensor: torch.Tensor,
+ dbias_tensor: Optional[torch.Tensor],
+ stream: cuda.CUstream,
+ ) -> None:
+ norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const")
+ b_ptrs_addr = int(b_ptrs_device.data_ptr())
+ sfb_ptrs_addr = int(sfb_ptrs_device.data_ptr())
+
+ _compiled_kernel(
+ a_tensor,
+ b_ptrs_addr,
+ sfb_ptrs_addr,
+ cached_n,
+ cached_k,
+ cached_b_stride,
+ cached_workspace_ptr,
+ c_tensor,
+ d_row_tensor,
+ d_col_tensor,
+ sfa_tensor,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ amax_tensor,
+ norm_const_tensor,
+ padded_offsets,
+ alpha_tensor,
+ prob_tensor,
+ dprob_tensor,
+ dbias_tensor,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ a_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_row_tensor: torch.Tensor,
+ d_col_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ dprob_tensor: torch.Tensor,
+ # Dense mode:
+ b_tensor: Optional[torch.Tensor] = None,
+ sfb_tensor: Optional[torch.Tensor] = None,
+ dbias_tensor: Optional[torch.Tensor] = None,
+ # Discrete mode:
+ b_ptrs: Optional[torch.Tensor] = None,
+ sfb_ptrs: Optional[torch.Tensor] = None,
+ # Optional:
+ sfd_row_tensor: Optional[torch.Tensor] = None,
+ sfd_col_tensor: Optional[torch.Tensor] = None,
+ amax_tensor: Optional[torch.Tensor] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ """Execute the compiled kernel.
+
+ For dense mode, supply ``b_tensor`` and ``sfb_tensor``.
+ For discrete mode, supply ``b_ptrs`` and ``sfb_ptrs``.
+
+ :param a_tensor: Input A tensor (gradient input)
+ :param c_tensor: Forward activations input
+ :param d_row_tensor: Output D row tensor
+ :param d_col_tensor: Output D column tensor
+ :param sfa_tensor: Scale factor A
+ :param padded_offsets: End offset per expert after padding
+ :param alpha_tensor: Per-group alpha scaling factors
+ :param prob_tensor: Per-row probability (from forward)
+ :param dprob_tensor: Gradient of probability (output, must be zero-initialized)
+ :param b_tensor: (Dense) Input B tensor (weights)
+ :param sfb_tensor: (Dense) Scale factor B
+ :param dbias_tensor: Optional dbias output tensor.
+ :param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers
+ :param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers
+ :param sfd_row_tensor: Optional row scale factor D
+ :param sfd_col_tensor: Optional column scale factor D
+ :param amax_tensor: Optional amax tensor
+ :param norm_const_tensor: Optional normalization constant
+ :param current_stream: CUDA stream
+ """
+ self._logger.debug("Entering execute")
+ current_stream = self._get_default_stream(current_stream)
+
+ if a_tensor.shape[0] == 0:
+ self._logger.debug("execute: valid_m is zero, skipping kernel execution")
+ return
+ self._runtime_error_if(
+ self._compiled_kernel is None,
+ "Kernel not compiled; call compile() first",
+ )
+
+ self._logger.debug("Executing grouped GEMM dSReLU kernel")
+ if self._has_dbias:
+ self._value_error_if(
+ dbias_tensor is None,
+ "dbias_tensor is required when GroupedGemmDsreluSm100 is configured with sample_dbias",
+ )
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ stream=current_stream,
+ )
+ else:
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_ptrs_device=b_ptrs,
+ sfb_ptrs_device=sfb_ptrs,
+ c_tensor=c_tensor,
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ stream=current_stream,
+ )
+
+ self._logger.debug("Execute completed")
+
+
+_logger = logging.getLogger(__name__)
+_cache_of_GroupedGemmDsreluSm100Objects = {}
+
+
+def grouped_gemm_dsrelu_wrapper_sm100(
+ a_tensor: torch.Tensor,
+ b_tensor: Optional[torch.Tensor] = None,
+ c_tensor: Optional[torch.Tensor] = None,
+ sfa_tensor: Optional[torch.Tensor] = None,
+ sfb_tensor: Optional[torch.Tensor] = None,
+ padded_offsets: Optional[torch.Tensor] = None,
+ alpha_tensor: Optional[torch.Tensor] = None,
+ prob_tensor: Optional[torch.Tensor] = None,
+ dprob_tensor: Optional[torch.Tensor] = None,
+ # generate_dbias is optional in both modes:
+ generate_dbias: bool = False,
+ # Discrete mode:
+ b_ptrs: Optional[torch.Tensor] = None,
+ sfb_ptrs: Optional[torch.Tensor] = None,
+ n: Optional[int] = None,
+ b_dtype: Optional[torch.dtype] = None,
+ b_major: str = "k",
+ # Common:
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ acc_dtype: torch.dtype = torch.float32,
+ d_dtype: torch.dtype = torch.bfloat16,
+ cd_major: str = "n",
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ discrete_col_sfd: bool = False,
+ use_dynamic_sched: bool = False,
+ current_stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ """Convenience wrapper for grouped GEMM dSReLU backward operation.
+
+ Auto-detects dense vs. discrete mode based on which weight arguments
+ are provided.
+
+ Dense mode: provide ``b_tensor`` and ``sfb_tensor``.
+ Discrete mode: provide ``b_ptrs``, ``sfb_ptrs``, ``n``, and ``b_dtype``.
+
+ Compiled kernels are cached for reuse when called with the same configuration.
+
+ Args:
+ a_tensor: Input A tensor (valid_m, k, 1) -- gradient input
+ c_tensor: Forward activations input (valid_m, n_out, 1)
+ sfa_tensor: Scale factor A
+ padded_offsets: End offset per expert after padding
+ alpha_tensor: Per-group alpha scaling
+ prob_tensor: Per-row probability (from forward)
+ dprob_tensor: Gradient of probability (output, must be zero-initialized)
+ b_tensor: (Dense) Weight B tensor (n, k, l)
+ sfb_tensor: (Dense) Scale factor B
+ generate_dbias: Optional flag to allocate and return dbias output
+ b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers
+ sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers
+ n: (Discrete) B weight N dimension
+ b_dtype: (Discrete) B weight data type
+ b_major: (Discrete) B tensor major dimension ("k" or "n")
+ norm_const_tensor: Optional normalization constant
+ acc_dtype: Accumulator data type
+ d_dtype: Output D tensor data type
+ cd_major: CD major dimension (only "n" supported)
+ mma_tiler_mn: MMA tiler shape
+ cluster_shape_mn: Cluster shape
+ sf_vec_size: Scale factor vector size
+ vector_f32: Use vectorized f32
+ m_aligned: M alignment (must be 256)
+ discrete_col_sfd: Generate discrete col-major scale factor tensor
+ use_dynamic_sched: Enable dynamic tile scheduling for load balancing
+ current_stream: CUDA stream
+
+ Returns:
+ TupleDict with keys: d_row_tensor, d_col_tensor, dprob_tensor,
+ dbias_tensor, amax_tensor, sfd_row_tensor, sfd_col_tensor
+ """
+ from cudnn.discrete_grouped_gemm.discrete_kernel_utils import _require_pointer_tensor
+
+ is_dense = b_tensor is not None
+ is_discrete = b_ptrs is not None
+
+ if is_dense and is_discrete:
+ raise ValueError("Provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs), not both")
+ if not is_dense and not is_discrete:
+ raise ValueError("Must provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs)")
+
+ valid_m, k_physical, _ = a_tensor.shape
+
+ if is_dense:
+ weight_mode = MoEWeightMode.DENSE
+ n_weight, _, l = b_tensor.shape
+ else:
+ weight_mode = MoEWeightMode.DISCRETE
+ _require_pointer_tensor(b_ptrs, "b_ptrs")
+ num_experts = b_ptrs.shape[0]
+ _require_pointer_tensor(sfb_ptrs, "sfb_ptrs", num_experts)
+ if n is None or b_dtype is None:
+ raise ValueError("n and b_dtype are required for discrete mode")
+ n_weight = n
+ k_logical = k_physical * 2 if b_dtype in (torch.float4_e2m1fn_x2, torch.uint8) else k_physical
+ b_shape = (n_weight, k_logical)
+ l = num_experts
+
+ n_out = n_weight
+
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Creating output tensors")
+
+ if cd_major == "n":
+ d_row_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ d_col_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ else:
+ raise ValueError(f"cd_major must be 'n', got {cd_major}")
+
+ sfd_row_tensor = None
+ sfd_col_tensor = None
+ amax_tensor = None
+ dbias_tensor = None
+
+ if dprob_tensor is None:
+ dprob_tensor = torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device)
+
+ if a_tensor.dtype in [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ ] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]:
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Detected fp8 config, constructing sfd tensors")
+
+ sf_dtype = sfa_tensor.dtype
+ mma_permute_order = (3, 4, 1, 5, 2, 0)
+
+ sf_k_row = ceil_div(n_out, sf_vec_size)
+ mma_shape_row = (1, ceil_div(valid_m, 128), ceil_div(sf_k_row, 4), 32, 4, 4)
+ sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+
+ sf_k_col = ceil_div(valid_m, sf_vec_size)
+ mma_shape_col = (1, ceil_div(n_out, 128), ceil_div(sf_k_col, 4), 32, 4, 4)
+ sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Constructing amax_tensor")
+ amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+ if generate_dbias:
+ dbias_tensor = torch.zeros((l, n_out, 1), dtype=torch.bfloat16, device=a_tensor.device)
+
+ if valid_m == 0:
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: valid_m is zero, skipping kernel execution")
+ return TupleDict(
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ amax_tensor=amax_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ )
+
+ # ---- Build cache key ----
+ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
+ return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
+
+ def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype
+
+ def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return None, stride_order(tensor), tensor.dtype
+
+ def dynamic_m_tensor_signature(
+ tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = ()
+ ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
+ return static_shape_suffix, stride_signature, tensor.dtype
+
+ use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+
+ if is_dense:
+ cache_key = (
+ weight_mode,
+ use_full_dynamic,
+ a_tensor.shape[1:] if not use_full_dynamic else None,
+ b_tensor.shape[2] if use_full_dynamic else tuple(b_tensor.shape),
+ c_tensor.shape[1:] if not use_full_dynamic else None,
+ a_tensor.dtype,
+ b_tensor.dtype,
+ c_tensor.dtype,
+ stride_order(a_tensor),
+ stride_order(b_tensor),
+ stride_order(c_tensor),
+ *(
+ dynamic_tensor_signature(sfa_tensor)
+ if use_full_dynamic
+ else dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,))
+ ),
+ *tensor_signature(alpha_tensor),
+ *(dynamic_m_tensor_signature(prob_tensor, (1, 1)) if not use_full_dynamic else dynamic_tensor_signature(prob_tensor)),
+ *(dynamic_m_tensor_signature(dprob_tensor, (1, 1)) if not use_full_dynamic else dynamic_tensor_signature(dprob_tensor)),
+ *(dynamic_tensor_signature(dbias_tensor) if use_full_dynamic else tensor_signature(dbias_tensor)),
+ *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)),
+ norm_const_tensor.shape if norm_const_tensor is not None else None,
+ norm_const_tensor.stride() if norm_const_tensor is not None else None,
+ norm_const_tensor.dtype if norm_const_tensor is not None else None,
+ tuple(padded_offsets.shape),
+ tuple(padded_offsets.stride()),
+ padded_offsets.dtype,
+ acc_dtype,
+ d_dtype,
+ cd_major,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ sf_vec_size,
+ vector_f32,
+ m_aligned,
+ discrete_col_sfd,
+ use_dynamic_sched,
+ )
+ else:
+ cache_key = (
+ weight_mode,
+ *dynamic_m_tensor_signature(a_tensor, tuple(a_tensor.shape[1:]), dynamic_stride_dims=(2,)),
+ b_shape,
+ b_dtype,
+ *dynamic_m_tensor_signature(c_tensor, tuple(c_tensor.shape[1:]), dynamic_stride_dims=(2,)),
+ *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)),
+ *tensor_signature(alpha_tensor),
+ *dynamic_m_tensor_signature(prob_tensor, (1, 1)),
+ *dynamic_m_tensor_signature(dprob_tensor, (1, 1)),
+ *tensor_signature(dbias_tensor),
+ *tensor_signature(norm_const_tensor),
+ tuple(b_ptrs.shape),
+ tuple(b_ptrs.stride()),
+ b_ptrs.dtype,
+ tuple(sfb_ptrs.shape),
+ tuple(sfb_ptrs.stride()),
+ sfb_ptrs.dtype,
+ tuple(padded_offsets.shape),
+ tuple(padded_offsets.stride()),
+ padded_offsets.dtype,
+ acc_dtype,
+ d_dtype,
+ cd_major,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ sf_vec_size,
+ vector_f32,
+ m_aligned,
+ discrete_col_sfd,
+ use_dynamic_sched,
+ b_major,
+ num_experts,
+ )
+
+ # ---- Cache lookup or create + compile ----
+ if cache_key in _cache_of_GroupedGemmDsreluSm100Objects:
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Using cached object")
+ api = _cache_of_GroupedGemmDsreluSm100Objects[cache_key]
+ else:
+ _logger.debug("grouped_gemm_dsrelu_wrapper_sm100: Creating new object")
+ if is_dense:
+ api = GroupedGemmDsreluSm100(
+ sample_a=a_tensor,
+ sample_c=c_tensor,
+ sample_d_row=d_row_tensor,
+ sample_d_col=d_col_tensor,
+ sample_sfa=sfa_tensor,
+ sample_padded_offsets=padded_offsets,
+ sample_alpha=alpha_tensor,
+ sample_prob=prob_tensor,
+ sample_dprob=dprob_tensor,
+ sample_dbias=dbias_tensor,
+ sample_b=b_tensor,
+ sample_sfb=sfb_tensor,
+ sample_sfd_row=sfd_row_tensor,
+ sample_sfd_col=sfd_col_tensor,
+ sample_amax=amax_tensor,
+ sample_norm_const=norm_const_tensor,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ m_aligned=m_aligned,
+ discrete_col_sfd=discrete_col_sfd,
+ use_dynamic_sched=use_dynamic_sched,
+ )
+ else:
+ api = GroupedGemmDsreluSm100(
+ sample_a=a_tensor,
+ sample_c=c_tensor,
+ sample_d_row=d_row_tensor,
+ sample_d_col=d_col_tensor,
+ sample_sfa=sfa_tensor,
+ sample_padded_offsets=padded_offsets,
+ sample_alpha=alpha_tensor,
+ sample_prob=prob_tensor,
+ sample_dprob=dprob_tensor,
+ sample_dbias=dbias_tensor,
+ num_experts=num_experts,
+ b_shape=b_shape,
+ b_dtype=b_dtype,
+ sample_sfd_row=sfd_row_tensor,
+ sample_sfd_col=sfd_col_tensor,
+ sample_amax=amax_tensor,
+ sample_norm_const=norm_const_tensor,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ m_aligned=m_aligned,
+ discrete_col_sfd=discrete_col_sfd,
+ b_major=b_major,
+ use_dynamic_sched=use_dynamic_sched,
+ )
+
+ if not api.check_support():
+ raise RuntimeError("Unsupported configuration")
+ api.compile()
+ _cache_of_GroupedGemmDsreluSm100Objects[cache_key] = api
+
+ # ---- Execute ----
+ if is_dense:
+ api.execute(
+ a_tensor=a_tensor,
+ c_tensor=c_tensor,
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ b_tensor=b_tensor,
+ sfb_tensor=sfb_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ current_stream=current_stream,
+ )
+ else:
+ api.execute(
+ a_tensor=a_tensor,
+ c_tensor=c_tensor,
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ b_ptrs=b_ptrs,
+ sfb_ptrs=sfb_ptrs,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ current_stream=current_stream,
+ )
+
+ return TupleDict(
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ dprob_tensor=dprob_tensor,
+ dbias_tensor=dbias_tensor,
+ amax_tensor=amax_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ )
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py
new file mode 100644
index 00000000..d6db56a7
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_dsrelu/moe_blockscaled_grouped_gemm_dsrelu_quant.py
@@ -0,0 +1,2249 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+MoE Block-Scaled Grouped GEMM Backward Kernel with DSRELU Support.
+
+Supports:
+ - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler)
+ - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout
+ - FP8/FP4 output quantization with row/column scale factors (SFD)
+ - Optional routing-probability (prob) fusion
+ - DSRELU backward epilogue: dA = relu(acc)·C·2·prob; dprob = Σ_n(relu(acc)²·C)
+ - NONE epilogue for testing (identity, no SReLU gate)
+
+EpilogueType.NONE:
+ out[m,n] = alpha · (SFA·A ★ SFB·B)[m,n] · prob[m]
+
+EpilogueType.DSRELU (backward through srelu):
+ out[m,n] = relu(alpha · acc[m,n]) · C[m,n] · 2 · prob[m]
+ dprob[m] += Σ_n( relu(alpha · acc[m,n])² · C[m,n] )
+
+C is the upstream gradient (same shape as forward SReLU output D).
+"""
+
+from enum import Enum
+from typing import Type, Tuple, Union, Optional
+
+import cuda.bindings.driver as cuda
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu import cpasync, tcgen05
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+import cutlass.utils as utils
+import cutlass.pipeline as pipeline
+import cutlass.utils.blackwell_helpers as sm100_utils
+import cutlass.utils.blockscaled_layout as blockscaled_utils
+from cutlass._mlir.dialects.nvvm import ReduxKind
+from cutlass._mlir.dialects import llvm
+from cutlass.cute.typing import Float32, Int32, AddressSpace
+
+
+def atomic_add_bf16x2(ptr, val_fp32_lo, val_fp32_hi, *, loc=None, ip=None):
+ """Packed BF16x2 atomic reduction to global memory (same impl as dglu_dbias ref)."""
+ lo_ir = val_fp32_lo.ir_value(loc=loc, ip=ip)
+ hi_ir = val_fp32_hi.ir_value(loc=loc, ip=ip)
+ llvm.inline_asm(
+ None,
+ [ptr, hi_ir, lo_ir],
+ "{ .reg .b32 packed; cvt.rn.bf16x2.f32 packed, $1, $2; red.global.add.noftz.bf16x2 [$0], packed; }",
+ "l,f,f",
+ has_side_effects=True,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ )
+
+
+from ..moe_persistent_scheduler import (
+ MoEPersistentTileScheduler,
+ MoESchedulerParams,
+ MoEWorkTileInfo,
+)
+from ..moe_utils import (
+ compute_expert_token_range,
+ MoEWeightMode,
+ TensormapWorkspace,
+ store_tma_desc,
+)
+from ..moe_sched_extension import (
+ DiscreteWeightScaledGemmSchedExtension,
+ ContiguousAndConsistentGroupedGemmSchedExtension,
+)
+from ..moe_kernel_helpers import (
+ fmin,
+ fmax,
+ warp_redux_sync,
+ atomic_add_float32,
+ atomic_max_float32,
+ compute_stages,
+ compute_grid,
+ can_implement,
+ amax_reduction_per_thread,
+ epilog_gmem_copy_and_partition,
+ get_dtype_rcp_limits,
+)
+
+
+class EpilogueType(Enum):
+ NONE = 0
+ DSRELU = 1
+
+
+class BlockScaledMoEGroupedGemmQuantBwdKernel:
+ """Block-scaled grouped GEMM backward kernel with MoE tile scheduling and DSRELU.
+
+ Computes the backward pass through the SReLU epilogue:
+ out[m,n] = relu(alpha * acc[m,n]) * C[m,n] * 2 * prob[m] (DSRELU)
+ dprob[m] += sum_n( relu(alpha * acc[m,n])^2 * C[m,n] ) (DSRELU)
+ or identity (NONE epilogue):
+ out[m,n] = alpha * acc[m,n] * prob[m]
+
+ Supports both dense and discrete weight layouts, static and dynamic
+ scheduling, and quantized output with row/column scale factors.
+
+ :param sf_vec_size: Scale-factor vector size (16 or 32).
+ :param acc_dtype: Accumulator data type (Float32).
+ :param use_2cta_instrs: Use 2-CTA MMA instructions.
+ :param mma_tiler_mn: MMA tile shape (M, N).
+ :param cluster_shape_mn: Cluster shape (M, N).
+ :param vectorized_f32: Use packed FP32 arithmetic.
+ :param generate_sfd: Generate output scale factors.
+ :param discrete_col_sfd: Use discrete column SFD layout.
+ :param expert_cnt: Number of experts.
+ :param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``.
+ :param use_dynamic_sched: Enable dynamic tile scheduling.
+ :param epilogue_type: Epilogue type (``EpilogueType.NONE`` or ``EpilogueType.DSRELU``).
+ """
+
+ FIX_PAD_SIZE = 256
+
+ @staticmethod
+ def can_implement(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ d_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ a_major: str,
+ b_major: str,
+ cd_major: str,
+ m_aligned: int,
+ ) -> bool:
+ return can_implement(
+ ab_dtype,
+ sf_dtype,
+ sf_vec_size,
+ acc_dtype,
+ d_dtype,
+ use_2cta_instrs,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ m,
+ n,
+ k,
+ l,
+ a_major,
+ b_major,
+ cd_major,
+ m_aligned,
+ fix_pad_size=BlockScaledMoEGroupedGemmQuantBwdKernel.FIX_PAD_SIZE,
+ )
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vectorized_f32: bool,
+ generate_sfd: bool,
+ discrete_col_sfd: bool,
+ expert_cnt: int,
+ weight_mode: MoEWeightMode = MoEWeightMode.DENSE,
+ use_dynamic_sched: bool = False,
+ epilogue_type: int = EpilogueType.NONE.value,
+ generate_dbias: bool = False,
+ ):
+ mma_tile_m = mma_tiler_mn[0]
+ if self.FIX_PAD_SIZE % mma_tile_m != 0:
+ raise ValueError(
+ f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}). " f"Supported mma_tiler_mn[0] values: 128, 256."
+ )
+ if expert_cnt > 1024:
+ raise ValueError("Expert count > 1024 is not supported.")
+ if not isinstance(weight_mode, MoEWeightMode):
+ raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}")
+
+ self.sf_vec_size = sf_vec_size
+ self.expert_cnt = expert_cnt
+ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
+ self.use_2cta_instrs = use_2cta_instrs
+ self.cluster_shape_mn = cluster_shape_mn
+ self.mma_tiler = (*mma_tiler_mn, 1)
+
+ self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
+
+ self.occupancy = 1
+ self.epilog_warp_id = (0, 1, 2, 3)
+ self.mma_warp_id = 4
+ self.tma_warp_id = 5
+ self.epilog_load_tma_id = 6 # new warp: TMA-loads C subtiles into sC
+ self.sched_warp_id = 7 # shifted from 6 in forward kernel
+ self.threads_per_warp = 32
+ all_warps = [
+ *self.epilog_warp_id,
+ self.mma_warp_id,
+ self.tma_warp_id,
+ self.epilog_load_tma_id,
+ self.sched_warp_id,
+ ]
+ warps_wo_sched = [
+ *self.epilog_warp_id,
+ self.mma_warp_id,
+ self.tma_warp_id,
+ self.epilog_load_tma_id,
+ ]
+ self.threads_per_cta = self.threads_per_warp * len(all_warps)
+ self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched)
+
+ self.cta_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=1,
+ num_threads=self.threads_per_cta,
+ )
+ self.epilog_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=2,
+ num_threads=32 * len(self.epilog_warp_id),
+ )
+ self.tmem_alloc_barrier = pipeline.NamedBarrier(
+ barrier_id=3,
+ num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
+ )
+ self.sched_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=4,
+ num_threads=self.threads_per_warp,
+ )
+ self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
+ SM100_TMEM_CAPACITY_COLUMNS = 512
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
+
+ self.vectorized_f32 = vectorized_f32
+ self.generate_sfd = generate_sfd
+ self.discrete_col_sfd = discrete_col_sfd
+ self.generate_dbias = generate_dbias
+ self.dbias_cross_warp_reduce = generate_dbias
+
+ self.weight_mode = weight_mode
+ self.use_dynamic_sched = use_dynamic_sched
+
+ self.epilogue_use_functor = False
+ self.epilogue_type = epilogue_type
+
+ self.num_epilog_warps = len(self.epilog_warp_id)
+
+ # ------------------------------------------------------------------
+ # _setup_attributes
+ # ------------------------------------------------------------------
+
+ def _setup_attributes(self):
+ """Configure MMA / tile / stage / SMEM layouts from GEMM inputs."""
+
+ self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1])
+ self.mma_inst_shape_mn_sfb = (
+ self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
+ cute.round_up(self.mma_inst_shape_mn[1], 128),
+ )
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+
+ mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
+ mma_inst_tile_k = 4
+ self.mma_tiler = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.mma_tiler_sfb = (
+ self.mma_inst_shape_mn_sfb[0],
+ self.mma_inst_shape_mn_sfb[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+
+ self.cta_tile_shape_mnk = (
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler[1],
+ self.mma_tiler[2],
+ )
+ self.cta_tile_shape_mnk_sfb = (
+ self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_sfb[1],
+ self.mma_tiler_sfb[2],
+ )
+
+ self.mma_tiler_d = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk_d = (
+ self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_d[1],
+ self.mma_tiler_d[2],
+ )
+
+ self.cluster_layout_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma.thr_id.shape,),
+ )
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma_sfb.thr_id.shape,),
+ )
+
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
+
+ self.epi_tile = (128, 32)
+
+ (
+ self.num_acc_stage,
+ self.num_ab_stage,
+ self.num_c_stage,
+ self.num_d_stage,
+ self.num_tile_stage,
+ ) = self._compute_stages(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.b_dtype,
+ self.epi_tile,
+ self.c_dtype,
+ self.c_layout,
+ self.d_dtype,
+ self.d_layout,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.num_smem_capacity,
+ self.occupancy,
+ self.generate_sfd,
+ self.generate_dbias,
+ )
+
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.num_ab_stage,
+ )
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ self.mma_tiler,
+ self.b_dtype,
+ self.num_ab_stage,
+ )
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.c_dtype,
+ self.c_layout,
+ self.epi_tile,
+ self.num_c_stage,
+ )
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.d_dtype,
+ self.d_layout,
+ self.epi_tile,
+ self.num_d_stage,
+ )
+
+ self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256
+ self.epilogue_prefetch_more = False
+
+ sf_atom_mn = 32
+ self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
+ self.num_accumulator_tmem_cols = (
+ self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
+ )
+
+ self.epi_tile_n_required = cute.size(self.epi_tile[1])
+ self.iter_acc_early_release_in_epilogue = (self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1
+
+ # ------------------------------------------------------------------
+ # _compute_stages
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _compute_stages(
+ tiled_mma,
+ mma_tiler_mnk,
+ a_dtype,
+ b_dtype,
+ epi_tile,
+ c_dtype,
+ c_layout,
+ d_dtype,
+ d_layout,
+ sf_dtype,
+ sf_vec_size,
+ num_smem_capacity,
+ occupancy,
+ generate_sfd,
+ generate_dbias=False,
+ ):
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
+ # Always 2 c stages for c_pipeline double-buffering
+ num_c_stage = 2
+ num_d_stage = 2 if generate_sfd else 1
+ num_tile_stage = 2
+
+ a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1)
+ b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1)
+ sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
+ d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
+
+ ab_bytes_per_stage = (
+ cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
+ + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
+ )
+ mbar_helpers_bytes = 1024
+ sinfo_bytes = 4 * 4 * num_tile_stage
+ c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
+ c_bytes = c_bytes_per_stage * num_c_stage
+ d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
+ d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1)
+ amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0
+ # dBias SMEM transpose buffer: 128 M rows × epi_tile_n N cols × FP32
+ # (same formula as SharedStorage.sDbias field). Must be subtracted from
+ # the AB-stage budget or num_ab_stage gets over-allocated and the module
+ # fails to serialize.
+ dbias_bytes = 128 * cute.size(epi_tile[1]) * 4 if generate_dbias else 0
+
+ epi_bytes = c_bytes + d_bytes + amax_bytes + dbias_bytes
+ num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage
+
+ return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage
+
+ # ------------------------------------------------------------------
+ # Workspace helpers
+ # ------------------------------------------------------------------
+
+ def get_desc_workspace_bytes(self) -> int:
+ if self.weight_mode == MoEWeightMode.DISCRETE:
+ from ..moe_utils import DiscreteWeightTensormapConstructor
+
+ return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt)
+ return 0
+
+ def get_workspace_bytes(self) -> int:
+ desc_workspace_bytes = self.get_desc_workspace_bytes()
+ dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0
+ return desc_workspace_bytes + dynamic_sched_bytes
+
+ @cute.jit
+ def _get_sched_counter_ptr(self, workspace_ptr):
+ counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes()
+ return cute.make_ptr(
+ cutlass.Int32,
+ counter_addr,
+ AddressSpace.gmem,
+ assumed_align=4,
+ )
+
+ # ------------------------------------------------------------------
+ # helper_kernel: pre-main-kernel initialization
+ # ------------------------------------------------------------------
+
+ @cute.kernel
+ def helper_kernel(
+ self,
+ ptrs_b: cute.Pointer,
+ ptrs_sfb: cute.Pointer,
+ n: Int32,
+ k: Int32,
+ b_stride_size: cutlass.Int64,
+ b_major_mode: cutlass.Constexpr,
+ workspace_ptr,
+ tiled_mma_arg: cute.TiledMma,
+ tiled_mma_sfb_arg: cute.TiledMma,
+ b_smem_layout_arg,
+ sfb_smem_layout_arg,
+ cluster_layout_vmnk_shape_arg: cutlass.Constexpr,
+ cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr,
+ ):
+ """Pre-main-kernel initialization (discrete TMA desc + dynamic sched counter)."""
+ expert_idx = cute.arch.block_idx()[0]
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+ sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+
+ b_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,))
+ )
+ sfb_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,))
+ )
+
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ stride_n = b_stride_size
+ stride_k = c1_64
+ else:
+ stride_n = c1_64
+ stride_k = b_stride_size
+
+ b_ptr_val = b_ptr_tensor[expert_idx]
+ b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem)
+ b_tensor_i = cute.make_tensor(
+ b_ptr,
+ cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)),
+ )
+ tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ b_tma_op_arg,
+ b_tensor_i,
+ b_smem_layout_arg,
+ self.mma_tiler,
+ tiled_mma_arg,
+ cluster_layout_vmnk_shape_arg,
+ )
+ workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx))
+
+ sfb_ptr_val = sfb_ptr_tensor[expert_idx]
+ sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb_tensor_i = cute.make_tensor(sfb_ptr, sfb_layout)
+ tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_tma_op_arg,
+ sfb_tensor_i,
+ sfb_smem_layout_arg,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb_arg,
+ cluster_layout_sfb_vmnk_shape_arg,
+ internal_type=cutlass.Uint64,
+ )
+ store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx))
+
+ if cutlass.const_expr(self.use_dynamic_sched):
+ if expert_idx == cutlass.Int32(0):
+ sched_counter = cute.make_tensor(
+ self._get_sched_counter_ptr(workspace_ptr),
+ cute.make_layout(1),
+ )
+ sched_counter[0] = cutlass.Int32(0)
+
+ # ------------------------------------------------------------------
+ # __call__
+ # ------------------------------------------------------------------
+
+ @cute.jit
+ def __call__(
+ self,
+ a: cute.Tensor,
+ b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[]
+ sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[]
+ n: Int32, # Ignored for dense mode
+ k: Int32, # Ignored for dense mode
+ b_stride_size: cutlass.Int64, # Ignored for dense mode
+ b_major_mode: cutlass.Constexpr, # Ignored for dense mode
+ workspace_ptr,
+ c: cute.Tensor, # INPUT: upstream gradient (forward output), shape (M,N,L)
+ d: cute.Tensor, # OUTPUT: dA gradient
+ d_col: Optional[cute.Tensor],
+ sfa: cute.Tensor,
+ sfd_row_tensor: Optional[cute.Tensor],
+ sfd_col_tensor: Optional[cute.Tensor],
+ amax_tensor: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ prob: cute.Tensor,
+ dprob: Optional[cute.Tensor], # OUTPUT: dL/d(prob), shape (M,1,1), Float32
+ dbias_tensor: Optional[cute.Tensor], # OUTPUT: dL/d(bias), shape (L,N), BF16 (accumulated via atomic)
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ ):
+ """Execute the backward GEMM.
+
+ Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L).
+ Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[]
+ arrays of per-expert base addresses.
+
+ ``c`` is the upstream gradient tensor (same shape as forward output D),
+ loaded G2S by the specialized epilog_load_tma warp and consumed by epilog warps.
+
+ ``dprob`` (optional) receives per-token routing probability gradients,
+ accumulated via atomic FP32 add.
+ """
+ self.a_dtype: Type[cutlass.Numeric] = a.element_type
+ self.b_dtype: Type[cutlass.Numeric] = a.element_type
+ self.c_dtype: Type[cutlass.Numeric] = c.element_type
+ self.d_dtype: Type[cutlass.Numeric] = d.element_type
+ self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
+ self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
+ self.c_layout = utils.LayoutEnum.from_tensor(c)
+ self.d_layout = utils.LayoutEnum.from_tensor(d)
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
+ else:
+ self.b_major_mode = b_major_mode
+
+ if cutlass.const_expr(self.a_dtype != self.b_dtype):
+ raise TypeError(f"A/B dtype must match: {self.a_dtype} != {self.b_dtype}")
+
+ self._setup_attributes()
+
+ # ---- SFA layout ----
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size)
+ sfa = cute.make_tensor(sfa.iterator, sfa_layout)
+
+ # ---- B / SFB setup (mode-dependent) ----
+ b_from_call_arg = b
+ sfb_from_call_arg = sfb
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size)
+ sfb = cute.make_tensor(sfb.iterator, sfb_layout)
+ else:
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ b_template_stride = (b_stride_size, c1_64, c0)
+ else:
+ b_template_stride = (c1_64, b_stride_size, c0)
+ b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride)
+ b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16)
+ b = cute.make_tensor(b_ptr_typed, b_template_layout)
+
+ sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout)
+
+ # ---- SFD setup ----
+ self.generate_sfd = sfd_row_tensor is not None and norm_const_tensor is not None
+ if cutlass.const_expr(self.generate_sfd == False):
+ self.discrete_col_sfd = False
+ if cutlass.const_expr(self.generate_sfd):
+ sfd_row_layout = blockscaled_utils.tile_atom_to_shape_SF(d.shape, self.sf_vec_size)
+ sfd_row_tensor = cute.make_tensor(sfd_row_tensor.iterator, sfd_row_layout)
+ sfd_col_layout = cute.tile_to_shape(
+ blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout,
+ d.shape,
+ (1, 2, 3),
+ )
+ if cutlass.const_expr(self.discrete_col_sfd):
+ sfd_col_layout = sfd_row_layout
+ sfd_col_tensor = cute.make_tensor(sfd_col_tensor.iterator, sfd_col_layout)
+
+ self.generate_amax = amax_tensor is not None
+ # self.generate_dbias was set in __init__ (needed at _compute_stages time);
+ # assert consistency with the dbias_tensor passed here.
+ assert self.generate_dbias == (dbias_tensor is not None), "dbias_tensor presence must match generate_dbias set at construction"
+
+ # ---- TMA atoms ----
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
+
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
+ a_op,
+ a,
+ a_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
+ b_op,
+ b,
+ b_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
+ sfa_op,
+ sfa,
+ sfa_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_op,
+ sfb,
+ sfb_smem_layout,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb,
+ self.cluster_layout_sfb_vmnk.shape,
+ internal_type=cutlass.Uint64,
+ )
+
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ x = tma_tensor_sfb.stride[0][1]
+ y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
+ new_shape = (
+ (tma_tensor_sfb.shape[0][0], ((2, 2), y)),
+ tma_tensor_sfb.shape[1],
+ tma_tensor_sfb.shape[2],
+ )
+ x_times_3 = 3 * x
+ new_stride = (
+ (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
+ tma_tensor_sfb.stride[1],
+ tma_tensor_sfb.stride[2],
+ )
+ tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride)
+ tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout)
+
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
+
+ # C is an INPUT (upstream gradient): use G2S TMA atom
+ c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
+ self.tma_c_load_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout)
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileG2SOp(),
+ c,
+ c_smem_layout,
+ self.epi_tile,
+ )
+
+ # D is the output (dA gradient): S2G TMA atom
+ d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d,
+ d_smem_layout,
+ self.epi_tile,
+ )
+ tma_atom_d_col, tma_tensor_d_col = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d_col,
+ d_smem_layout,
+ self.epi_tile,
+ )
+
+ # ---- Helper kernel ----
+ _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched)
+ if cutlass.const_expr(_need_helper):
+ _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1
+ _helper_args = (
+ b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0),
+ b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode,
+ workspace_ptr,
+ tiled_mma,
+ tiled_mma_sfb,
+ b_smem_layout,
+ sfb_smem_layout,
+ self.cluster_layout_vmnk.shape,
+ self.cluster_layout_sfb_vmnk.shape,
+ )
+ self.helper_kernel(*_helper_args).launch(
+ grid=(_helper_grid_x, 1, 1),
+ block=(1, 1, 1),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+
+ # ---- Grid computation via MoE scheduler ----
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ b_n, b_k, b_l = cute.shape(b)
+ sched_expert_shape = (self.expert_cnt, b_n, b_k)
+ else:
+ sched_expert_shape = (self.expert_cnt, n, k)
+
+ sched_params = MoESchedulerParams(
+ scenario="2Dx3D",
+ expert_shape=sched_expert_shape,
+ cta_tile_shape_mnk=self.cta_tile_shape_mnk,
+ cluster_shape_mn=self.cluster_shape_mn,
+ use_dynamic_sched=self.use_dynamic_sched,
+ )
+ self.sched_params, grid = compute_grid(sched_params, max_active_clusters, self.use_2cta_instrs)
+
+ self.buffer_align_bytes = 1024
+
+ # ---- Shared storage ----
+ sD_col_size = cute.cosize(self.d_smem_layout_staged.outer) if self.generate_sfd else 0
+ SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched)
+
+ @cute.struct
+ class SharedStorage:
+ ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
+ acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
+ c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage * 2]
+ scheduler: SchedulerStorage
+ tmem_dealloc_mbar_ptr: cutlass.Int64
+ tmem_holding_buf: cutlass.Int32
+ # sC: staging buffer for loading C (upstream gradient) from global memory
+ sC: cute.struct.Align[
+ cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sD: cute.struct.Align[
+ cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sD_col: cute.struct.Align[
+ cute.struct.MemRange[self.d_dtype, sD_col_size],
+ self.buffer_align_bytes,
+ ]
+ sA: cute.struct.Align[
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sB: cute.struct.Align[
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sSFA: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ sSFB: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ sAmax: cute.struct.Align[
+ cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps],
+ 4,
+ ]
+ # dBias SMEM transpose buffer: 128 × epi_tile_n FP32 (single-vec dA, vs
+ # reference's 128 × epi_tile_n × 2 that held both d1+d2 for DGLU).
+ sDbias: cute.struct.Align[
+ cute.struct.MemRange[
+ cutlass.Float32,
+ 128 * self.epi_tile[1] if self.generate_dbias else 1,
+ ],
+ 128 if self.generate_dbias else 4,
+ ]
+
+ self.shared_storage = SharedStorage
+
+ # ---- Launch ----
+ self.kernel(
+ tiled_mma,
+ tiled_mma_sfb,
+ tma_atom_a,
+ tma_tensor_a,
+ tma_atom_b,
+ tma_tensor_b,
+ tma_atom_sfa,
+ tma_tensor_sfa,
+ tma_atom_sfb,
+ tma_tensor_sfb,
+ tma_atom_c,
+ tma_tensor_c,
+ tma_atom_d,
+ tma_tensor_d,
+ tma_atom_d_col,
+ tma_tensor_d_col,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ norm_const_tensor,
+ amax_tensor,
+ padded_offsets,
+ alpha,
+ prob,
+ dprob,
+ dbias_tensor,
+ workspace_ptr,
+ self.cluster_layout_vmnk,
+ self.cluster_layout_sfb_vmnk,
+ self.a_smem_layout_staged,
+ self.b_smem_layout_staged,
+ self.sfa_smem_layout_staged,
+ self.sfb_smem_layout_staged,
+ self.c_smem_layout_staged,
+ self.d_smem_layout_staged,
+ self.epi_tile,
+ self.sched_params,
+ ).launch(
+ grid=grid,
+ block=[self.threads_per_cta, 1, 1],
+ cluster=(*self.cluster_shape_mn, 1),
+ max_number_threads=[self.threads_per_cta, 1, 1],
+ smem=self.shared_storage.size_in_bytes(),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+ return
+
+ # ------------------------------------------------------------------
+ # Helper methods
+ # ------------------------------------------------------------------
+
+ def mainloop_s2t_copy_and_partition(self, sSF, tSF):
+ tCsSF_compact = cute.filter_zeros(sSF)
+ tCtSF_compact = cute.filter_zeros(tSF)
+ copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
+
+ @cute.jit
+ def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem):
+ warp_amax = warp_redux_sync(
+ value=amax_fp32,
+ kind=ReduxKind.MAX,
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ if cute.arch.lane_idx() == 0:
+ amax_smem[warp_idx] = cutlass.Float32(warp_amax)
+ self.epilog_sync_barrier.arrive_and_wait()
+ if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0:
+ block_amax = cutlass.Float32(0.0)
+ for i in cutlass.range(self.num_epilog_warps):
+ warp_amax_val = amax_smem[i]
+ block_amax = cute.arch.fmax(block_amax, warp_amax_val)
+ _ = atomic_max_float32(ptr=amax_gmem, value=block_amax)
+
+ @cute.jit
+ def dbias_reduction(
+ self,
+ dA_vec,
+ warp_idx,
+ sDbias,
+ dbias_gmem_2d,
+ expert_idx,
+ n_base,
+ dbias_n_total,
+ ) -> None:
+ """Sum dA across M within this subtile and atomic-add to dbias[expert, n].
+
+ Adapted from moe_blockscaled_grouped_gemm_dglu_dbias.py's dbias_reduction,
+ which handled two interleaved vectors (d1, d2). Here we have a single
+ vector dA, so 16 lanes handle the 32 N columns (2 cols each via bf16x2).
+ Lanes 16-31 remain idle in the atomic phase but still contribute to the
+ SMEM write below.
+
+ SMEM layout: sDbias[n, lane_idx, warp_local] with shape (epi_n, 32, num_warps),
+ holding fp32 values. Each thread writes its dA_vec (epi_n values) at its
+ (lane, warp) slot — M is spread across (lane_idx × num_warps).
+ """
+ epi_n = self.epi_tile[1]
+ lane_idx = cute.arch.lane_idx()
+ warp_local = warp_idx - self.epilog_warp_id[0]
+
+ for n in cutlass.range(epi_n, unroll_full=True):
+ sDbias[(n, lane_idx, warp_local)] = dA_vec[n]
+
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ # 16 active lanes, each handles 2 consecutive N columns (col_a, col_b)
+ # → 16 lanes × 2 cols = 32 cols = epi_n.
+ col_a = 2 * lane_idx if lane_idx < 16 else 0
+ col_b = col_a + 1
+
+ copy_128bit_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128)
+ # Each warp owns a contiguous slice of sDbias of size epi_n * 32
+ warp_base_ptr = sDbias.iterator + warp_local * epi_n * 32
+ swizzle_a = ((col_a >> 1) & 0x7) << 2
+ swizzle_b = ((col_b >> 1) & 0x7) << 2
+
+ sum_a = cutlass.Float32(0.0)
+ sum_b = cutlass.Float32(0.0)
+ rDst_a = cute.make_rmem_tensor(cute.make_layout((4,)), cutlass.Float32)
+ rDst_b = cute.make_rmem_tensor(cute.make_layout((4,)), cutlass.Float32)
+ # 8 groups × 4 M = 32 M rows per warp
+ for g in cutlass.range(8, unroll_full=True):
+ m_base = g * 4
+ sw_offset_a = col_a * 32 + (m_base ^ swizzle_a)
+ sSrc_a = cute.make_tensor(warp_base_ptr + sw_offset_a, cute.make_layout((4,)))
+ cute.copy_atom_call(copy_128bit_atom, sSrc_a, rDst_a)
+
+ sw_offset_b = col_b * 32 + (m_base ^ swizzle_b)
+ sSrc_b = cute.make_tensor(warp_base_ptr + sw_offset_b, cute.make_layout((4,)))
+ cute.copy_atom_call(copy_128bit_atom, sSrc_b, rDst_b)
+
+ for i in cutlass.range(4, unroll_full=True):
+ sum_a = sum_a + rDst_a[i]
+ sum_b = sum_b + rDst_b[i]
+
+ n_offset = n_base + 2 * lane_idx # lanes 0..15 cover N offsets 0,2,...,30
+
+ if cutlass.const_expr(self.dbias_cross_warp_reduce):
+ # Cross-warp reduction: each warp writes its partial (sum_a, sum_b) to
+ # a per-warp slot in SMEM, then warp 0 sums across all warps.
+ reduce_base = sDbias.iterator
+ copy_64bit_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=64)
+
+ self.epilog_sync_barrier.arrive_and_wait()
+ if lane_idx < 16:
+ rSrc_partial = cute.make_rmem_tensor(cute.make_layout((2,)), cutlass.Float32)
+ rSrc_partial[0] = sum_a
+ rSrc_partial[1] = sum_b
+ sDst_partial = cute.make_tensor(reduce_base + warp_local * 32 + lane_idx * 2, cute.make_layout((2,)))
+ cute.copy_atom_call(copy_64bit_atom, rSrc_partial, sDst_partial)
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ if warp_idx == self.epilog_warp_id[0] and lane_idx < 16:
+ cta_sum_a = cutlass.Float32(0.0)
+ cta_sum_b = cutlass.Float32(0.0)
+ rDst_w = cute.make_rmem_tensor(cute.make_layout((2,)), cutlass.Float32)
+ for w in cutlass.range(self.num_epilog_warps):
+ sSrc_w = cute.make_tensor(reduce_base + w * 32 + lane_idx * 2, cute.make_layout((2,)))
+ cute.copy_atom_call(copy_64bit_atom, sSrc_w, rDst_w)
+ cta_sum_a = cta_sum_a + rDst_w[0]
+ cta_sum_b = cta_sum_b + rDst_w[1]
+ if n_offset < dbias_n_total:
+ gmem_ptr = dbias_gmem_2d[(expert_idx, n_offset, None)].iterator.llvm_ptr
+ atomic_add_bf16x2(gmem_ptr, cta_sum_a, cta_sum_b)
+ else:
+ if lane_idx < 16 and n_offset < dbias_n_total:
+ gmem_ptr = dbias_gmem_2d[(expert_idx, n_offset, None)].iterator.llvm_ptr
+ atomic_add_bf16x2(gmem_ptr, sum_a, sum_b)
+
+ @cute.jit
+ def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD):
+ tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size))
+ acc_frg = tTR_rAcc_frg.load()
+ abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value())
+ abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype)
+ pvscale_f32x4 = cute.make_rmem_tensor(4, cutlass.Float32)
+ sfd_f8x4 = cute.make_rmem_tensor(4, self.sf_dtype)
+ tmp_f32 = abs_acc_frg[None, 0].reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) * rcp_limit * norm_const
+ if tile_idx == 0:
+ pvscale[0] = tmp_f32
+ elif tile_idx == 1:
+ pvscale[1] = tmp_f32
+ elif tile_idx == 2:
+ pvscale[2] = tmp_f32
+ elif tile_idx == 3:
+ pvscale[3] = tmp_f32
+ pvscale_f32x4[0] = tmp_f32
+ sfd_f8x4.store(pvscale_f32x4.load().to(self.sf_dtype))
+ pvscale_f32x4.store(sfd_f8x4.load().to(cutlass.Float32))
+ qpvscale_up = pvscale_f32x4[0]
+ fp32_max = cutlass.Float32(3.40282346638528859812e38)
+ acc_scale = norm_const * cute.arch.rcp_approx(qpvscale_up)
+ acc_scale = fmin(acc_scale, fp32_max, nan=True)
+ if cutlass.const_expr(self.vectorized_f32):
+ vec = tTR_rAcc_frg[None, 0]
+ for ei in cutlass.range_constexpr(0, self.sf_vec_size, 2):
+ vec[ei], vec[ei + 1] = cute.arch.mul_packed_f32x2(
+ (vec[ei], vec[ei + 1]),
+ (acc_scale, acc_scale),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ vec = tTR_rAcc_frg[None, 0]
+ for ei in cutlass.range_constexpr(self.sf_vec_size):
+ vec[ei] = vec[ei] * acc_scale
+ acc_vec = tiled_copy_r2s.retile(src).load()
+ tRSrD.store(acc_vec.to(self.d_dtype))
+
+ @cute.jit
+ def quant_sfd_col(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD):
+ tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size))
+ acc_frg = tTR_rAcc_frg.load()
+ abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value())
+ acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype)
+ tmp_f32 = cutlass.Float32(0.0)
+ for vi in cutlass.range_constexpr(acc_frg.shape[0]):
+ max_value_original = (
+ cutlass.Float32(
+ warp_redux_sync(
+ value=acc_frg[vi, 0],
+ kind=ReduxKind.MAX,
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ )
+ * rcp_limit
+ * norm_const
+ )
+ max_value_vec = cute.full(4, max_value_original, dtype=cutlass.Float32)
+ max_value_vec_f8 = max_value_vec.to(cutlass.Float8E8M0FNU)
+ max_value_vec_f32_chunked = max_value_vec_f8.to(cutlass.Float32)
+ max_value = max_value_vec_f32_chunked[0]
+ tidx = cute.arch.thread_idx()[0]
+ if tidx % 32 == vi:
+ tmp_f32 = max_value
+ acc_scale_col = cutlass.Float32(0.0)
+ if max_value_vec_f32_chunked[0] == 0.000000:
+ acc_scale_col = cutlass.Float32(0.0)
+ else:
+ acc_scale_col = norm_const * cute.arch.rcp_approx(max_value_vec_f32_chunked[0])
+ fp32_max = cutlass.Float32(3.40282346638528859812e38)
+ acc_scale_col = fmin(acc_scale_col, fp32_max)
+ tTR_rAcc_frg[vi] = tTR_rAcc_frg[vi] * acc_scale_col
+ pvscale[None, None, tile_idx][0] = tmp_f32
+ acc_vec = tiled_copy_r2s.retile(src).load()
+ tRSrD.store(acc_vec.to(self.d_dtype))
+
+ @cute.jit
+ def tile_info_to_mn_idx(self, tile_info: cute.Tensor):
+ m_idx = tile_info[1] * cute.size(self.cta_tile_shape_mnk[0])
+ n_idx = tile_info[2] * cute.size(self.cta_tile_shape_mnk[1])
+ return m_idx, n_idx
+
+ @cute.jit
+ def create_and_partition_new_SFDCol(self, tile_info, mSFDCol_mnl, padded_offsets):
+ m_idx, n_idx = self.tile_info_to_mn_idx(tile_info)
+ expert_idx = tile_info[0]
+ cumsum_tokens, tokens_this_group = compute_expert_token_range(padded_offsets, expert_idx)
+ n_total = cute.size(mSFDCol_mnl.shape[1])
+
+ sf_tile_idx_begin = cumsum_tokens // cute.size(mSFDCol_mnl.shape[0][0])
+ mSFDCol_mnl_new_ptr = mSFDCol_mnl[(None, sf_tile_idx_begin), None, 0].iterator
+
+ sfd_col_quant_layout = cute.tile_to_shape(
+ blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout,
+ (tokens_this_group, n_total, mSFDCol_mnl.shape[2]),
+ (1, 2, 3),
+ )
+ regPerSubtile = 4
+ sfd_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile))
+ mSFDCol_mnl_new = cute.make_tensor(mSFDCol_mnl_new_ptr, sfd_col_quant_layout)
+ gSFDCol_mnl_new = cute.local_tile(mSFDCol_mnl_new, sfd_tile, (None, None, None))
+
+ thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
+ val_layout = cute.make_ordered_layout((1,), order=(0,))
+ copy_atom_sfd_col_quant = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gSFDCol_mnl_new.element_type,
+ num_bits_per_copy=8,
+ )
+ tiled_copy_sfd_col_quant = cute.make_tiled_copy_tv(
+ copy_atom_sfd_col_quant,
+ thr_layout,
+ val_layout,
+ )
+ tidx = cute.arch.thread_idx()[0]
+ thr_copy_sfd_col_quant = tiled_copy_sfd_col_quant.get_slice(tidx)
+ tCgSFDCol_mnl = thr_copy_sfd_col_quant.partition_D(cute.filter_zeros(gSFDCol_mnl_new))
+ tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl)
+ return tCgSFDCol_mnl
+
+ def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs):
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
+ self.cta_tile_shape_mnk,
+ self.d_layout,
+ self.d_dtype,
+ self.acc_dtype,
+ epi_tile,
+ use_2cta_instrs,
+ )
+ tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
+ gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi)
+ tTR_rAcc = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
+
+ def epilog_smem_copy_and_partition_load(self, tiled_copy_t2r, tTR_rC, tidx, sC):
+ """Partition sC (smem) for S2R copy — used by epilog warps consuming the c_pipeline."""
+ copy_atom_s2r = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype)
+ tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r)
+ thr_copy_s2r = tiled_copy_s2r.get_slice(tidx)
+ tRS_sC = thr_copy_s2r.partition_D(sC)
+ tRS_rC = tiled_copy_s2r.retile(tTR_rC)
+ return tiled_copy_s2r, tRS_rC, tRS_sC
+
+ def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rD, tidx, sD):
+ copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r)
+ tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
+ tRS_sD = thr_copy_r2s.partition_D(sD)
+ tRS_rD = tiled_copy_r2s.retile(tTR_rD)
+ return tiled_copy_r2s, tRS_rD, tRS_sD
+
+ # ------------------------------------------------------------------
+ # GPU device kernel
+ # ------------------------------------------------------------------
+
+ @cute.kernel
+ def kernel(
+ self,
+ tiled_mma: cute.TiledMma,
+ tiled_mma_sfb: cute.TiledMma,
+ tma_atom_a: cute.CopyAtom,
+ mA_mkl: cute.Tensor,
+ tma_atom_b: cute.CopyAtom,
+ mB_nkl: cute.Tensor,
+ tma_atom_sfa: cute.CopyAtom,
+ mSFA_mkl: cute.Tensor,
+ tma_atom_sfb: cute.CopyAtom,
+ mSFB_nkl: cute.Tensor,
+ tma_atom_c: cute.CopyAtom, # G2S atom for loading upstream gradient C
+ mC_mnl: cute.Tensor, # upstream gradient tensor (input)
+ tma_atom_d: cute.CopyAtom,
+ mD_mnl: cute.Tensor, # dA output tensor
+ tma_atom_d_col: cute.CopyAtom,
+ mD_col_mnl: cute.Tensor,
+ mSFDRow_mnl: Optional[cute.Tensor],
+ mSFDCol_mnl: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ mAmax_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ prob: cute.Tensor,
+ dprob: Optional[cute.Tensor], # dL/d(prob) output, shape (M,1,1), Float32
+ mDbias_tensor: Optional[cute.Tensor], # dL/d(bias) output, shape (L,N), BF16
+ workspace_ptr,
+ cluster_layout_vmnk: cute.Layout,
+ cluster_layout_sfb_vmnk: cute.Layout,
+ a_smem_layout_staged: cute.ComposedLayout,
+ b_smem_layout_staged: cute.ComposedLayout,
+ sfa_smem_layout_staged: cute.Layout,
+ sfb_smem_layout_staged: cute.Layout,
+ c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ epi_tile: cute.Tile,
+ sched_params: MoESchedulerParams,
+ ):
+ """GPU device kernel for persistent MoE grouped GEMM backward with DSRELU."""
+ warp_idx = cute.arch.warp_idx()
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
+ lane_idx = cute.arch.lane_idx()
+
+ if warp_idx == self.tma_warp_id:
+ cpasync.prefetch_descriptor(tma_atom_a)
+ cpasync.prefetch_descriptor(tma_atom_sfa)
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ cpasync.prefetch_descriptor(tma_atom_b)
+ cpasync.prefetch_descriptor(tma_atom_sfb)
+ cpasync.prefetch_descriptor(tma_atom_d)
+ if cutlass.const_expr(self.generate_sfd):
+ cpasync.prefetch_descriptor(tma_atom_d_col)
+
+ if warp_idx == self.epilog_load_tma_id:
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ cpasync.prefetch_descriptor(tma_atom_c)
+
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
+ total_token = padded_offsets[self.expert_cnt - 1]
+
+ bidx, bidy, bidz = cute.arch.block_idx()
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
+ is_leader_cta = mma_tile_coord_v == 0
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
+ tidx, _, _ = cute.arch.thread_idx()
+
+ smem = utils.SmemAllocator()
+ storage = smem.allocate(self.shared_storage)
+ sched_storage = storage.scheduler
+
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
+ barrier_storage=storage.ab_mbar_ptr.data_ptr(),
+ num_stages=self.num_ab_stage,
+ producer_group=ab_pipeline_producer_group,
+ consumer_group=ab_pipeline_consumer_group,
+ tx_count=self.num_tma_load_bytes,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.acc_mbar_ptr.data_ptr(),
+ num_stages=self.num_acc_stage,
+ producer_group=acc_pipeline_producer_group,
+ consumer_group=acc_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ # C pipeline: epilog_load_tma warp produces, epilog warps consume
+ c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ c_consumer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ len(self.epilog_warp_id),
+ )
+ c_pipeline = pipeline.PipelineTmaAsync.create(
+ barrier_storage=storage.c_full_mbar_ptr.data_ptr(),
+ num_stages=self.num_c_stage,
+ producer_group=c_producer_group,
+ consumer_group=c_consumer_group,
+ tx_count=self.tma_c_load_bytes,
+ )
+
+ tile_info_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp * 1)
+ tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_wo_sched)
+ tile_info_pipeline = pipeline.PipelineAsync.create(
+ barrier_storage=sched_storage.tile_info_mbar.data_ptr(),
+ num_stages=self.num_tile_stage,
+ producer_group=tile_info_pipeline_producer_group,
+ consumer_group=tile_info_pipeline_consumer_group,
+ )
+
+ scheduler = MoEPersistentTileScheduler.create(
+ sched_params,
+ padded_offsets,
+ cute.arch.block_idx(),
+ cute.arch.grid_dim(),
+ counter_ptr=self._get_sched_counter_ptr(workspace_ptr),
+ sched_storage=sched_storage,
+ )
+ scheduler.internal_init()
+
+ tmem = utils.TmemAllocator(
+ storage.tmem_holding_buf,
+ barrier_for_retrieve=self.tmem_alloc_barrier,
+ allocator_warp_id=self.epilog_warp_id[0],
+ is_two_cta=use_2cta_instrs,
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
+ )
+
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_arrive_relaxed()
+
+ sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner)
+ sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
+ sD_col = sD
+ if cutlass.const_expr(self.generate_sfd):
+ sD_col = storage.sD_col.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
+ amax_layout = cute.make_layout((self.num_epilog_warps,))
+ sAmax = storage.sAmax.get_tensor(amax_layout)
+ sDbias = None
+ if cutlass.const_expr(self.generate_dbias):
+ sDbias = storage.sDbias.get_tensor(
+ cute.make_layout(
+ (self.epi_tile[1], 32, self.num_epilog_warps),
+ stride=(32, 1, self.epi_tile[1] * 32),
+ )
+ )
+ info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4))
+ sInfo = sched_storage.sInfo.get_tensor(info_layout)
+
+ a_full_mcast_mask = None
+ b_full_mcast_mask = None
+ sfa_full_mcast_mask = None
+ sfb_full_mcast_mask = None
+ if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
+ a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1)
+ sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1)
+
+ thr_mma_common = tiled_mma.get_slice(0)
+ tCsA_common = thr_mma_common.partition_A(sA)
+ tCsB_common = thr_mma_common.partition_B(sB)
+ tCsA_common = cute.filter_zeros(tCsA_common)
+ tCsB_common = cute.filter_zeros(tCsB_common)
+
+ tCrA = tiled_mma.make_fragment_A(sA)
+ tCrB = tiled_mma.make_fragment_B(sB)
+
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
+ if cutlass.const_expr(self.overlapping_accum):
+ num_acc_stage_overlapped = 2
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped))
+ tCtAcc_fake = cute.make_tensor(
+ tCtAcc_fake.iterator,
+ cute.make_layout(
+ tCtAcc_fake.shape,
+ stride=(
+ tCtAcc_fake.stride[0],
+ tCtAcc_fake.stride[1],
+ tCtAcc_fake.stride[2],
+ (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
+ ),
+ ),
+ )
+ else:
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
+
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_wait()
+ else:
+ self.cta_sync_barrier.arrive_and_wait()
+
+ if total_token <= 0:
+ cute.arch.nvvm.exit()
+
+ # ==============================================================
+ # Scheduler warp (MoE Persistent Tile Scheduler)
+ # ==============================================================
+ if warp_idx == self.sched_warp_id:
+ work_tile_info = scheduler.initial_work_tile_info()
+ tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage)
+ while work_tile_info.is_valid_tile:
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx
+ sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx
+ sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx
+ sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ work_tile_info = scheduler.advance_to_next_work()
+
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1)
+ sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ tile_info_pipeline.producer_tail(tile_info_producer_state)
+
+ # ==============================================================
+ # DMA / TMA load warp (A, B, SFA, SFB)
+ # ==============================================================
+ if warp_idx == self.tma_warp_id:
+ ext = self._make_extension(workspace_ptr)
+ ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ k_tile_cnt = work_tile_info.k_tile_cnt
+ ext.update_expert_info(padded_offsets, work_tile_info.expert_idx)
+
+ real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info)
+ real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info)
+ real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info)
+ real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info)
+
+ gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
+ gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None))
+
+ thr_mma_dma = tiled_mma.get_slice(mma_tile_coord_v)
+ thr_mma_sfb_dma = tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ tCgA = thr_mma_dma.partition_A(gA_mkl)
+ tCgB = thr_mma_dma.partition_B(gB_nkl)
+ tCgSFA = thr_mma_dma.partition_A(gSFA_mkl)
+ tCgSFB = thr_mma_sfb_dma.partition_B(gSFB_nkl)
+
+ a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
+ tAsA, tAgA = cpasync.tma_partition(
+ tma_atom_a,
+ block_in_cluster_coord_vmnk[2],
+ a_cta_layout,
+ cute.group_modes(sA, 0, 3),
+ cute.group_modes(tCgA, 0, 3),
+ )
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
+ tBsB, tBgB = cpasync.tma_partition(
+ tma_atom_b,
+ block_in_cluster_coord_vmnk[1],
+ b_cta_layout,
+ cute.group_modes(sB, 0, 3),
+ cute.group_modes(tCgB, 0, 3),
+ )
+ sfa_cta_layout = a_cta_layout
+ tAsSFA, tAgSFA = cpasync.tma_partition(
+ tma_atom_sfa,
+ block_in_cluster_coord_vmnk[2],
+ sfa_cta_layout,
+ cute.group_modes(sSFA, 0, 3),
+ cute.group_modes(tCgSFA, 0, 3),
+ )
+ tAsSFA = cute.filter_zeros(tAsSFA)
+ tAgSFA = cute.filter_zeros(tAgSFA)
+ sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
+ tBsSFB, tBgSFB = cpasync.tma_partition(
+ tma_atom_sfb,
+ block_in_cluster_coord_sfb_vmnk[1],
+ sfb_cta_layout,
+ cute.group_modes(sSFB, 0, 3),
+ cute.group_modes(tCgSFB, 0, 3),
+ )
+ tBsSFB = cute.filter_zeros(tBsSFB)
+ tBgSFB = cute.filter_zeros(tBgSFB)
+
+ mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape)
+ mma_tile_coord_n = work_tile_info.tile_n_idx
+ tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)]
+ tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)]
+ tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)]
+ slice_n = mma_tile_coord_n
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ slice_n = mma_tile_coord_n // 2
+ tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)]
+
+ ab_producer_state.reset_count()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
+ tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
+ tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)]
+ tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)]
+ tAsA_pipe = tAsA[(None, ab_producer_state.index)]
+ tBsB_pipe = tBsB[(None, ab_producer_state.index)]
+ tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)]
+ tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)]
+
+ tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state)
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
+
+ cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask)
+ cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b)
+ cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask)
+ cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb)
+
+ ab_producer_state.advance()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+ ab_pipeline.producer_tail(ab_producer_state)
+
+ # ==============================================================
+ # MMA warp
+ # ==============================================================
+ if warp_idx == self.mma_warp_id:
+ tmem.wait_for_alloc()
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ sfa_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols,
+ dtype=self.sf_dtype,
+ )
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
+
+ sfb_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
+
+ (
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t,
+ tCtSFA_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
+ (
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t,
+ tCtSFB_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
+
+ ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
+ acd_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ k_tile_cnt = tile_info[3]
+
+ ab_consumer_state.reset_count()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ acd_producer_state.reset_count()
+ peek_acc_empty_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state)
+
+ mma_tile_coord_mnl = (
+ tile_info[1] // cute.size(tiled_mma.thr_id.shape),
+ tile_info[2],
+ tile_info[0],
+ )
+
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acd_producer_state.phase ^ 1
+ else:
+ acc_stage_index = acd_producer_state.index
+
+ tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
+
+ tCtSFB_mma = tCtSFB
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+ elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+
+ if is_leader_cta:
+ acc_pipeline.producer_acquire(acd_producer_state, peek_acc_empty_status)
+
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ if is_leader_cta:
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
+
+ s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
+ cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t[s2t_stage_coord], tCtSFA_compact_s2t)
+ cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t[s2t_stage_coord], tCtSFB_compact_s2t)
+
+ num_kblocks = cute.size(tCrA, mode=[2])
+ ab_consumer_state_next = ab_consumer_state.clone()
+ ab_consumer_state_next.advance()
+ if ab_consumer_state_next.count < k_tile_cnt:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next)
+
+ for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
+ kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
+ sf_kblock_coord = (None, None, kblock_idx)
+ tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
+ tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator)
+ cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
+
+ ab_pipeline.consumer_release(ab_consumer_state)
+ ab_consumer_state = ab_consumer_state_next
+
+ if is_leader_cta:
+ acc_pipeline.producer_commit(acd_producer_state)
+
+ acd_producer_state.advance()
+ if acd_producer_state.count < k_tile_cnt:
+ if is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ acc_pipeline.producer_tail(acd_producer_state)
+
+ # ==============================================================
+ # Epilog load TMA warp: streams upstream gradient C into sC
+ # Must always consume from tile_info_pipeline (it is part of the
+ # 224-thread consumer group), but only loads C when DSRELU.
+ # ==============================================================
+ if warp_idx == self.epilog_load_tma_id:
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ c_ext = self._make_extension(workspace_ptr)
+ c_pipeline_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_c_stage)
+ # Toggled per-tile to mirror the epilog consumer's
+ # reverse_subtile = (acc_consumer_state.phase == 0). Starts True so
+ # tile 0 (consumer phase=0) gets C in reverse order.
+ c_is_reverse = cutlass.Boolean(True)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ c_work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ c_ext.update_expert_info(padded_offsets, c_work_tile_info.expert_idx)
+
+ real_c, _ = c_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, c_work_tile_info)
+
+ thr_mma_c = tiled_mma.get_slice(mma_tile_coord_v)
+ gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
+ tCgC_loop = thr_mma_c.partition_C(gC_mnl_loop)
+
+ _, bGS_sC, bGS_gC_partitioned = epilog_gmem_copy_and_partition(
+ tidx,
+ tma_atom_c,
+ tCgC_loop,
+ epi_tile,
+ sC,
+ )
+
+ epi_mma_tile_coord = (
+ c_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
+ c_work_tile_info.tile_n_idx,
+ 0,
+ )
+ bGS_gC = bGS_gC_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bGS_gC = cute.group_modes(bGS_gC, 1, cute.rank(bGS_gC))
+ subtile_cnt = cute.size(bGS_gC.shape, mode=[1])
+
+ # Toggle per tile to mirror consumer's phase-based reverse_subtile.
+ c_reverse_this_tile = cutlass.Boolean(False)
+ if cutlass.const_expr(self.overlapping_accum):
+ c_reverse_this_tile = c_is_reverse
+ c_is_reverse = not c_is_reverse
+ for subtile_idx in cutlass.range(subtile_cnt, unroll=1):
+ real_c_subtile_idx = subtile_idx
+ if cutlass.const_expr(self.overlapping_accum):
+ if c_reverse_this_tile:
+ real_c_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx
+ c_pipeline.producer_acquire(c_pipeline_producer_state)
+ cute.copy(
+ tma_atom_c,
+ bGS_gC[(None, real_c_subtile_idx)],
+ bGS_sC[(None, c_pipeline_producer_state.index)],
+ tma_bar_ptr=c_pipeline.producer_get_barrier(c_pipeline_producer_state),
+ )
+ c_pipeline_producer_state.advance()
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ c_pipeline.producer_tail(c_pipeline_producer_state)
+
+ # ==============================================================
+ # Epilogue warps: compute dA (and dprob for DSRELU)
+ # ==============================================================
+ if warp_idx < self.mma_warp_id:
+ tmem.allocate(self.num_tmem_alloc_cols)
+ tmem.wait_for_alloc()
+ tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
+
+ epi_tidx = tidx
+ thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v)
+
+ gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape)
+
+ tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
+ epi_tidx,
+ tCtAcc_base,
+ tCgD_shape,
+ epi_tile,
+ use_2cta_instrs,
+ )
+
+ # Register buffer for reading C from sC (c_pipeline consumer)
+ tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
+ tiled_copy_s2r, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition_load(
+ tiled_copy_t2r,
+ tTR_rC,
+ epi_tidx,
+ sC,
+ )
+
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
+ tiled_copy_t2r,
+ tTR_rD,
+ epi_tidx,
+ sD,
+ )
+
+ if cutlass.const_expr(self.generate_sfd):
+ tTR_rD_col = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ tiled_copy_r2s, tRS_rD_col, tRS_sD_col = self.epilog_smem_copy_and_partition(
+ tiled_copy_t2r,
+ tTR_rD_col,
+ epi_tidx,
+ sD_col,
+ )
+ norm_const = norm_const_tensor[0]
+ regPerSubtile = 4
+ sfd_row_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile))
+ gSFDRow_mnl = cute.local_tile(mSFDRow_mnl, sfd_row_tile, (None, None, None))
+ thr_copy_t2r_local = tiled_copy_t2r.get_slice(tidx)
+ tCgSFDRow_mnl = thr_copy_t2r_local.partition_D(gSFDRow_mnl)
+ tCgSFDRow_mnl = cute.filter_zeros(tCgSFDRow_mnl)
+ tCrSFDRow = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype)
+ tCrSFDRow_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32)
+ d_rcp_limits = get_dtype_rcp_limits(self.d_dtype)
+
+ sfd_col_tile = sfd_row_tile
+ gSFDCol_mnl = cute.local_tile(mSFDCol_mnl, sfd_col_tile, (None, None, None))
+ thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
+ val_layout = cute.make_ordered_layout((1,), order=(0,))
+ copy_atom_sfd_col = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gSFDCol_mnl.element_type, num_bits_per_copy=8)
+ tiled_copy_sfd_col = cute.make_tiled_copy_tv(copy_atom_sfd_col, thr_layout, val_layout)
+ thr_copy_sfd_col = tiled_copy_sfd_col.get_slice(tidx)
+ tCgSFDCol_mnl = thr_copy_sfd_col.partition_D(cute.filter_zeros(gSFDCol_mnl))
+ tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl)
+ tCrSFDCol = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].shape, self.sf_dtype)
+ tCrSFDCol_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32)
+
+ epi_ext = self._make_extension(workspace_ptr)
+
+ acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
+ c_pipeline_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_c_stage)
+ d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id))
+ d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group)
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ num_prev_subtiles = cutlass.Int32(0)
+ while is_valid_tile:
+ epi_work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ expert_idx = epi_work_tile_info.expert_idx
+ epi_ext.update_expert_info(padded_offsets, expert_idx)
+
+ alpha_val = alpha[expert_idx]
+
+ real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info)
+
+ thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v)
+
+ gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop)
+ _, bSG_sD, bSG_gD_partitioned = epilog_gmem_copy_and_partition(
+ epi_tidx,
+ tma_atom_d,
+ tCgD_loop,
+ epi_tile,
+ sD,
+ )
+
+ real_d_col = real_d
+ if cutlass.const_expr(self.generate_sfd):
+ real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info)
+
+ gD_col_mnl_loop = gD_mnl_loop
+ tCgD_col_loop = tCgD_loop
+ if cutlass.const_expr(self.generate_sfd):
+ gD_col_mnl_loop = cute.local_tile(real_d_col, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_col_loop = thr_mma_epi_loop.partition_C(gD_col_mnl_loop)
+ _, bSG_sD_col, bSG_gD_col_partitioned = epilog_gmem_copy_and_partition(
+ epi_tidx,
+ tma_atom_d_col,
+ tCgD_col_loop,
+ epi_tile,
+ sD_col,
+ )
+
+ epi_mma_tile_coord = (
+ epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
+ epi_work_tile_info.tile_n_idx,
+ 0,
+ )
+ bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
+ bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col))
+
+ if cutlass.const_expr(self.generate_sfd):
+ tCgSFDRow_mn = tCgSFDRow_mnl[(None, None, None, None, None, 0)]
+ tCgSFDCol_mnl_new = tCgSFDCol_mnl
+ if cutlass.const_expr(self.discrete_col_sfd):
+ tCgSFDCol_mnl_new = self.create_and_partition_new_SFDCol(tile_info, mSFDCol_mnl, padded_offsets)
+ tCgSFDCol_mn = tCgSFDCol_mnl_new[(None, None, None, None, None, 0)]
+
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = cutlass.Float32(0.0)
+
+ mPosition = epi_work_tile_info.tile_m_idx * self.cta_tile_shape_mnk[0] + tidx
+ real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info)
+ mProb = real_prob[mPosition, 0, 0]
+
+ # Accumulator for dprob: summed over N subtiles, written once per tile
+ dProbVal = cutlass.Float32(0.0)
+
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acc_consumer_state.phase
+ reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
+ else:
+ acc_stage_index = acc_consumer_state.index
+
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+
+ acc_pipeline.consumer_wait(acc_consumer_state)
+
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
+
+ for subtile_idx in cutlass.range(0, subtile_cnt, 1, unroll=1):
+ real_subtile_idx = subtile_idx
+ if cutlass.const_expr(self.overlapping_accum):
+ if reverse_subtile:
+ real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx
+
+ if cutlass.const_expr(self.overlapping_accum):
+ if subtile_idx == self.iter_acc_early_release_in_epilogue:
+ cute.arch.fence_view_async_tmem_load()
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)]
+ cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
+
+ # Apply alpha scaling
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2(
+ (tTR_rAcc[i], tTR_rAcc[i + 1]),
+ (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val)
+
+ acc_vec = tTR_rAcc.load()
+
+ # Consume one C subtile from c_pipeline
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ c_pipeline.consumer_wait(c_pipeline_consumer_state)
+ cute.copy(
+ tiled_copy_s2r,
+ tRS_sC[(None, None, None, c_pipeline_consumer_state.index)],
+ tRS_rC,
+ )
+ cute.arch.fence_proxy("async.shared", space="cta")
+ c_pipeline.consumer_release(c_pipeline_consumer_state)
+ c_pipeline_consumer_state.advance()
+ c_vec = tiled_copy_s2r.retile(tRS_rC)
+
+ tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype)
+
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ # DSRELU backward: dA = relu(acc) * C * 2 * prob
+ acc_relu = cute.where(acc_vec > 0, acc_vec, cute.full_like(acc_vec, 0))
+ tRelu = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype)
+ tRelu.store(acc_relu)
+ probx2 = 2 * mProb
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tRelu[i], tRelu[i + 1]),
+ (c_vec[i].to(self.acc_dtype), c_vec[i + 1].to(self.acc_dtype)),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (cutlass.Float32(probx2), cutlass.Float32(probx2)),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tCompute[i] = tRelu[i] * c_vec[i].to(self.acc_dtype) * cutlass.Float32(probx2)
+
+ # Accumulate dprob: dprob[m] += sum_n(relu(acc)^2 * C)
+ if cutlass.const_expr(dprob is not None):
+ tDprob = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype)
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tDprob[i], tDprob[i + 1] = cute.arch.mul_packed_f32x2(
+ (tRelu[i], tRelu[i + 1]),
+ (tRelu[i], tRelu[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ tDprob[i], tDprob[i + 1] = cute.arch.mul_packed_f32x2(
+ (tDprob[i], tDprob[i + 1]),
+ (c_vec[i].to(self.acc_dtype), c_vec[i + 1].to(self.acc_dtype)),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tDprob[i] = tRelu[i] * tRelu[i] * c_vec[i].to(self.acc_dtype)
+ dProbVal = dProbVal + tDprob.load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
+ else:
+ # NONE epilogue: dA = acc * prob (identity, no SReLU gate)
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (acc_vec[i], acc_vec[i + 1]),
+ (mProb, mProb),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tCompute[i] = acc_vec[i] * mProb
+
+ # dbias reduction: sum dA across M for each N, atomic-add to dbias[expert, n]
+ if cutlass.const_expr(self.generate_dbias):
+ dA_vec = tCompute.load()
+ n_base = epi_work_tile_info.tile_n_idx * self.mma_tiler[1] + real_subtile_idx * self.epi_tile[1]
+ dbias_n_total = cute.size(mDbias_tensor, mode=[1])
+ self.dbias_reduction(
+ dA_vec,
+ warp_idx,
+ sDbias,
+ mDbias_tensor,
+ expert_idx,
+ n_base,
+ dbias_n_total,
+ )
+
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = amax_reduction_per_thread(tCompute, thread_tile_amax)
+
+ if cutlass.const_expr(self.generate_sfd):
+ tCompute_col = cute.make_rmem_tensor(tCompute.layout, tCompute.element_type)
+ tCompute_col.store(tCompute.load())
+ self.quant_sfd_row(
+ real_subtile_idx % 4,
+ tiled_copy_r2s,
+ tCompute,
+ tCrSFDRow_pvscale,
+ norm_const,
+ d_rcp_limits,
+ tRS_rD,
+ )
+ self.quant_sfd_col(
+ real_subtile_idx % 4,
+ tiled_copy_r2s,
+ tCompute_col,
+ tCrSFDCol_pvscale,
+ norm_const,
+ d_rcp_limits,
+ tRS_rD_col,
+ )
+ global_sfd_m = epi_work_tile_info.tile_m_idx + epi_ext.token_offset // self.cta_tile_shape_mnk[0]
+ if cutlass.const_expr(self.mma_tiler[1] == 256):
+ sfd_n = epi_work_tile_info.tile_n_idx * 2 + (real_subtile_idx >> 2)
+ else:
+ sfd_n = epi_work_tile_info.tile_n_idx
+ sfd_row_idx_mn = (global_sfd_m, sfd_n)
+ sfd_col_idx_mn = sfd_row_idx_mn
+ if cutlass.const_expr(self.discrete_col_sfd):
+ sfd_col_idx_mn = (
+ epi_work_tile_info.tile_m_idx,
+ sfd_n,
+ )
+ tCgSFDRow = tCgSFDRow_mn[(None, None, None, *sfd_row_idx_mn)]
+ tCgSFDCol = tCgSFDCol_mn[(None, None, None, *sfd_col_idx_mn)]
+ if subtile_idx == 3 or subtile_idx == 7:
+ if sfd_row_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDRow_mnl.layout, mode=[1])):
+ tCrSFDRow.store(tCrSFDRow_pvscale.load().to(self.sf_dtype))
+ cute.autovec_copy(tCrSFDRow, tCgSFDRow)
+ if sfd_col_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDCol_mnl.layout, mode=[1])):
+ tCrSFDCol.store(tCrSFDCol_pvscale.load().to(self.sf_dtype))
+ cute.autovec_copy(tCrSFDCol, tCgSFDCol)
+ else:
+ acc_vec = tiled_copy_r2s.retile(tCompute).load()
+ tRS_rD.store(acc_vec.to(self.d_dtype))
+
+ d_buffer = num_prev_subtiles % self.num_d_stage
+ num_prev_subtiles = num_prev_subtiles + 1
+ cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
+ if cutlass.const_expr(self.generate_sfd):
+ cute.copy(tiled_copy_r2s, tRS_rD_col, tRS_sD_col[(None, None, None, d_buffer)])
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(tma_atom_d, bSG_sD[(None, d_buffer)], bSG_gD[(None, real_subtile_idx)])
+ if cutlass.const_expr(self.generate_sfd):
+ cute.copy(tma_atom_d_col, bSG_sD_col[(None, d_buffer)], bSG_gD_col[(None, real_subtile_idx)])
+ d_pipeline.producer_commit()
+ d_pipeline.producer_acquire()
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ if cutlass.const_expr(not self.overlapping_accum):
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ # Write accumulated dprob gradient atomically
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.DSRELU.value):
+ if cutlass.const_expr(dprob is not None):
+ real_dprob, _ = epi_ext.get_gmem_tensor("dprob", dprob, padded_offsets, epi_work_tile_info)
+ _ = atomic_add_float32(
+ ptr=real_dprob[(mPosition, None, None)].iterator.llvm_ptr,
+ value=dProbVal,
+ )
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ if cutlass.const_expr(self.generate_amax):
+ gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr
+ self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax)
+
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier.arrive_and_wait()
+ tmem.free(tmem_ptr)
+ d_pipeline.producer_tail()
+
+ # ------------------------------------------------------------------
+ # Internal: create extension based on weight_mode
+ # ------------------------------------------------------------------
+
+ @cute.jit
+ def _make_extension(self, workspace_ptr):
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ return DiscreteWeightScaledGemmSchedExtension(
+ tensormap_ctor=desc_workspace,
+ sf_vec_size=self.sf_vec_size,
+ )
+ else:
+ return ContiguousAndConsistentGroupedGemmSchedExtension(
+ sf_vec_size=self.sf_vec_size,
+ )
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py
index 5da7b868..f6b51822 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py
@@ -36,9 +36,7 @@
from .grouped_gemm_dswiglu_quant import (
BlockScaledContiguousGroupedGemmKernel,
)
-from ..utils import logical_shape_fp4x2_aware
from cuda.bindings import driver as cuda
-import os
import torch
from typing import Tuple, Optional
@@ -130,7 +128,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("GroupedGemmDswigluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
# Store sample tensor descriptors
@@ -182,7 +180,7 @@ def __init__(
self._interpret_uint8_as_fp4x2 = True
self._kernel = BlockScaledContiguousGroupedGemmKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._logger.debug(f"__init__ completed")
def check_support(self) -> bool:
@@ -374,9 +372,6 @@ def compile(self) -> None:
if self._compiled_kernel is not None:
self._logger.debug("Kernel already compiled; skipping recompilation")
return
- if self.a_desc.shape[0] == 0:
- self._logger.debug("sample valid_m is zero, skipping kernel compilation")
- return
gemm_dswiglu = self._kernel(
sf_vec_size=self.sf_vec_size,
@@ -400,8 +395,8 @@ def compile(self) -> None:
fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
self._logger.debug("Compiling grouped_gemm_dswiglu kernel")
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
- if not use_full_dynamic: # only mark the m dimension as dynamic
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+ if not use_full_dynamic:
valid_m = cute.sym_int(divisibility=256)
a_cute_fake = self._make_fake_cute_compact_tensor(
@@ -425,27 +420,26 @@ def compile(self) -> None:
shape=(valid_m, *self.d_col_desc.shape[1:]),
stride_order=self.d_col_desc.stride_order,
)
-
- tensor_m_128 = cute.sym_int()
- stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
- stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
- sfa_cute_fake = self._make_fake_cute_tensor(
- dtype=self.sfa_desc.dtype,
- shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1),
- stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
- )
-
- sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
prob_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.prob_desc.dtype,
shape=(valid_m, 1, 1),
stride_order=self.prob_desc.stride_order,
)
- d_prob_fake = self._make_fake_cute_compact_tensor(
+ dprob_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.dprob_desc.dtype,
shape=(valid_m, 1, 1),
stride_order=self.dprob_desc.stride_order,
)
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1),
+ stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
+ )
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+
sfd_row_fake = None
sfd_col_fake = None
if self.sfd_row_desc is not None:
@@ -466,7 +460,6 @@ def compile(self) -> None:
)
else:
valid_m = cute.sym_int(divisibility=256)
- n = cute.sym_int()
n_2 = cute.sym_int()
k = cute.sym_int()
l = cute.sym_int()
@@ -480,12 +473,11 @@ def compile(self) -> None:
)
b_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.b_desc.dtype,
- shape=(n, k, l),
+ shape=(cute.sym_int(), k, l),
stride_order=self.b_desc.stride_order,
dynamic_mode=self.b_desc.stride_order[0],
divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
)
-
c_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.c_desc.dtype,
shape=(valid_m, n_2, 1),
@@ -493,7 +485,6 @@ def compile(self) -> None:
dynamic_mode=self.c_desc.stride_order[0],
divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
)
-
d_row_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.d_row_desc.dtype,
shape=(valid_m, n_2, 1),
@@ -501,7 +492,6 @@ def compile(self) -> None:
dynamic_mode=self.d_row_desc.stride_order[0],
divisibility=8 if self._is_f16(self.d_row_desc.dtype) else 16,
)
-
d_col_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.d_col_desc.dtype,
shape=(valid_m, n_2, 1),
@@ -509,8 +499,17 @@ def compile(self) -> None:
dynamic_mode=self.d_col_desc.stride_order[0],
divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16,
)
+ prob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, 1, 1),
+ stride_order=self.prob_desc.stride_order,
+ )
+ dprob_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.dprob_desc.dtype,
+ shape=(valid_m, 1, 1),
+ stride_order=self.dprob_desc.stride_order,
+ )
- # 32, 4, tensor_m // 128, 4, rest_k, 1)
tensor_m_128 = cute.sym_int()
rest_k = cute.sym_int()
stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
@@ -520,7 +519,6 @@ def compile(self) -> None:
shape=(32, 4, tensor_m_128, 4, rest_k, 1),
stride=(16, 4, stride_rest_k, 1, 512, stride_tensor_m_128),
)
-
tensor_n_128 = cute.sym_int()
stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4)
@@ -530,18 +528,6 @@ def compile(self) -> None:
stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k),
)
- prob_cute_fake = self._make_fake_cute_compact_tensor(
- dtype=self.prob_desc.dtype,
- shape=(valid_m, 1, 1),
- stride_order=self.prob_desc.stride_order,
- )
-
- d_prob_fake = self._make_fake_cute_compact_tensor(
- dtype=self.dprob_desc.dtype,
- shape=(valid_m, 1, 1),
- stride_order=self.dprob_desc.stride_order,
- )
-
sfd_row_fake = None
sfd_col_fake = None
if self.sfd_row_desc is not None:
@@ -581,7 +567,7 @@ def compile(self) -> None:
alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16),
beta=self._make_fake_cute_tensor_from_desc(self.beta_desc, assumed_align=16),
prob=prob_cute_fake,
- dprob=d_prob_fake,
+ dprob=dprob_cute_fake,
max_active_clusters=max_active_clusters,
epilogue_op=self.epilogue_op,
stream=fake_stream,
@@ -675,9 +661,6 @@ def execute(
self._logger.debug("Entering execute")
current_stream = self._get_default_stream(current_stream)
- if a_tensor.shape[0] == 0:
- self._logger.debug("execute: valid_m is zero, skipping kernel execution")
- return
if self._compiled_kernel is None:
raise RuntimeError("Kernel not compiled; call compile() first")
self._logger.debug("Executing grouped_gemm_dswiglu kernel")
@@ -705,6 +688,7 @@ def execute(
import logging
+import os
_logger = logging.getLogger(__name__)
_cache_of_GroupedGemmDswigluSm100Objects = {}
@@ -718,7 +702,7 @@ def grouped_gemm_dswiglu_wrapper_sm100(
sfb_tensor: torch.Tensor,
padded_offsets: torch.Tensor,
alpha_tensor: torch.Tensor,
- beta_tensor: torch.Tensor,
+ beta_tensor: Optional[torch.Tensor],
prob_tensor: torch.Tensor,
norm_const_tensor: Optional[torch.Tensor] = None,
acc_dtype: torch.dtype = torch.float32,
@@ -732,6 +716,8 @@ def grouped_gemm_dswiglu_wrapper_sm100(
discrete_col_sfd: bool = False,
epilogue_op: Optional[str] = None,
current_stream: Optional[cuda.CUstream] = None,
+ dprob_tensor_buf: Optional[torch.Tensor] = None,
+ amax_tensor_buf: Optional[torch.Tensor] = None,
) -> TupleDict:
"""Convenience wrapper for grouped GEMM dSwiGLU backward operation.
@@ -771,67 +757,13 @@ def grouped_gemm_dswiglu_wrapper_sm100(
- **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D
- **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D
"""
+ valid_m = a_tensor.shape[0]
+ n, _, l = b_tensor.shape
- valid_m, _, _ = logical_shape_fp4x2_aware(a_tensor)
- n, _, l = logical_shape_fp4x2_aware(b_tensor)
-
- _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Creating output tensors d_row_tensor, d_col_tensor, dprob_tensor")
-
- if cd_major == "n":
- d_row_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device)
- d_col_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device)
- dprob_tensor = torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device)
- else:
+ if cd_major != "n":
raise ValueError(f"cd_major must be 'n', got {cd_major}")
- sfd_row_tensor = None
- sfd_col_tensor = None
- amax_tensor = None
-
- if a_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]:
- _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor")
-
- sf_dtype = sfa_tensor.dtype
- mma_permute_order = (3, 4, 1, 5, 2, 0)
-
- sf_k_row = ceil_div(n * 2, sf_vec_size)
- mma_shape_row = (
- 1,
- ceil_div(valid_m, 128),
- ceil_div(sf_k_row, 4),
- 32,
- 4,
- 4,
- )
- sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
-
- sf_k_col = ceil_div(valid_m, sf_vec_size)
- mma_shape_col = (
- 1,
- ceil_div(n * 2, 128),
- ceil_div(sf_k_col, 4),
- 32,
- 4,
- 4,
- )
- sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
-
- if d_dtype in [torch.bfloat16, torch.float16]:
- _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor")
- amax_tensor = torch.full((l, 2, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
-
- if valid_m == 0:
- _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: valid_m is zero, skipping kernel execution")
- return TupleDict(
- d_row_tensor=d_row_tensor,
- d_col_tensor=d_col_tensor,
- dprob_tensor=dprob_tensor,
- amax_tensor=amax_tensor,
- sfd_row_tensor=sfd_row_tensor,
- sfd_col_tensor=sfd_col_tensor,
- )
-
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
@@ -847,12 +779,14 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
stride_order(a_tensor),
stride_order(b_tensor),
stride_order(c_tensor),
- norm_const_tensor.shape if norm_const_tensor is not None else None,
- norm_const_tensor.stride() if norm_const_tensor is not None else None,
- norm_const_tensor.dtype if norm_const_tensor is not None else None,
+ sfa_tensor.dtype,
+ sfb_tensor.dtype,
padded_offsets.shape if not use_full_dynamic else None,
padded_offsets.stride() if not use_full_dynamic else None,
padded_offsets.dtype,
+ norm_const_tensor.shape if norm_const_tensor is not None else None,
+ norm_const_tensor.stride() if norm_const_tensor is not None else None,
+ norm_const_tensor.dtype if norm_const_tensor is not None else None,
acc_dtype,
d_dtype,
cd_major,
@@ -865,32 +799,69 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
epilogue_op,
)
- if cache_key in _cache_of_GroupedGemmDswigluSm100Objects:
- _logger.debug("group_gemm_dswiglu_wrapper_sm100: Using previously cached GroupedGemmDswigluSm100 object")
- grouped_gemm_dswiglu = _cache_of_GroupedGemmDswigluSm100Objects[cache_key]
- grouped_gemm_dswiglu.execute(
- a_tensor=a_tensor,
- b_tensor=b_tensor,
- c_tensor=c_tensor,
+ # Allocate M-dependent output tensors fresh every call (M varies across MoE steps).
+ # Only M-independent tensors (amax, beta) are cached to avoid repeated allocation.
+ _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Allocating M-dependent output tensors")
+ d_row_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device)
+ d_col_tensor = torch.empty_strided((valid_m, n * 2, 1), (n * 2, 1, valid_m * n * 2), dtype=d_dtype, device=a_tensor.device)
+ dprob_tensor = dprob_tensor_buf.zero_() if dprob_tensor_buf is not None else torch.zeros((valid_m, 1, 1), dtype=torch.float32, device=a_tensor.device)
+
+ if valid_m == 0:
+ amax_tensor = None
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ amax_tensor = torch.full((l, 2, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+ _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: valid_m is zero, skipping kernel execution")
+ return TupleDict(
d_row_tensor=d_row_tensor,
d_col_tensor=d_col_tensor,
- sfa_tensor=sfa_tensor,
- sfb_tensor=sfb_tensor,
- padded_offsets=padded_offsets,
- alpha_tensor=alpha_tensor,
- beta_tensor=beta_tensor,
- prob_tensor=prob_tensor,
dprob_tensor=dprob_tensor,
- sfd_row_tensor=sfd_row_tensor,
- sfd_col_tensor=sfd_col_tensor,
amax_tensor=amax_tensor,
- norm_const_tensor=norm_const_tensor,
- current_stream=current_stream,
+ sfd_row_tensor=None,
+ sfd_col_tensor=None,
)
+
+ sfd_row_tensor = None
+ sfd_col_tensor = None
+ if a_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]:
+ _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor")
+ sf_dtype = sfa_tensor.dtype
+ mma_permute_order = (3, 4, 1, 5, 2, 0)
+ sf_k_row = ceil_div(n * 2, sf_vec_size)
+ mma_shape_row = (1, ceil_div(valid_m, 128), ceil_div(sf_k_row, 4), 32, 4, 4)
+ sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+ sf_k_col = ceil_div(valid_m, sf_vec_size)
+ mma_shape_col = (1, ceil_div(n * 2, 128), ceil_div(sf_k_col, 4), 32, 4, 4)
+ sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+
+ if cache_key in _cache_of_GroupedGemmDswigluSm100Objects:
+ _logger.debug("group_gemm_dswiglu_wrapper_sm100: Using previously cached GroupedGemmDswigluSm100 object")
+ grouped_gemm_dswiglu, cached_amax_tensor, cached_beta_tensor = _cache_of_GroupedGemmDswigluSm100Objects[cache_key]
+ amax_tensor = amax_tensor_buf if amax_tensor_buf is not None else cached_amax_tensor
+ if beta_tensor is not None:
+ effective_beta = beta_tensor
+ elif cached_beta_tensor is not None:
+ effective_beta = cached_beta_tensor
+ else:
+ # Fallback: cache was populated without beta caching (non-NVFP4 path),
+ # but caller now passes None (NVFP4 path). Create ones tensor on-the-fly.
+ effective_beta = torch.ones(l, dtype=torch.float32, device=a_tensor.device)
else:
_logger.debug(
"group_gemm_dswiglu_wrapper_sm100: No previously cached GroupedGemmDswigluSm100 object found, creating new GroupedGemmDswigluSm100 object"
)
+
+ # For NVFP4 (beta_tensor=None): create and cache a ones tensor — avoids FillFunctor on every step.
+ # For non-NVFP4 (beta_tensor provided): use caller's value directly; don't cache (it changes each step).
+ cached_beta_tensor = torch.ones(l, dtype=torch.float32, device=a_tensor.device) if beta_tensor is None else None
+ effective_beta = cached_beta_tensor if beta_tensor is None else beta_tensor
+
+ cached_amax_tensor = None
+ amax_tensor = None
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ _logger.debug("grouped_gemm_dswiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor")
+ cached_amax_tensor = torch.empty((l, 2, 1), dtype=torch.float32, device=a_tensor.device)
+ amax_tensor = amax_tensor_buf if amax_tensor_buf is not None else cached_amax_tensor
+
grouped_gemm_dswiglu = GroupedGemmDswigluSm100(
sample_a=a_tensor,
sample_b=b_tensor,
@@ -901,10 +872,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
sample_sfb=sfb_tensor,
sample_padded_offsets=padded_offsets,
sample_alpha=alpha_tensor,
- sample_beta=beta_tensor,
+ sample_beta=effective_beta,
sample_prob=prob_tensor,
sample_dprob=dprob_tensor,
- sample_amax=amax_tensor,
+ sample_amax=cached_amax_tensor,
sample_sfd_row=sfd_row_tensor,
sample_sfd_col=sfd_col_tensor,
sample_norm_const=norm_const_tensor,
@@ -920,26 +891,27 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
assert grouped_gemm_dswiglu.check_support(), "Unsupported configuration"
grouped_gemm_dswiglu.compile()
- grouped_gemm_dswiglu.execute(
- a_tensor=a_tensor,
- b_tensor=b_tensor,
- c_tensor=c_tensor,
- d_row_tensor=d_row_tensor,
- d_col_tensor=d_col_tensor,
- sfa_tensor=sfa_tensor,
- sfb_tensor=sfb_tensor,
- padded_offsets=padded_offsets,
- alpha_tensor=alpha_tensor,
- beta_tensor=beta_tensor,
- prob_tensor=prob_tensor,
- dprob_tensor=dprob_tensor,
- sfd_row_tensor=sfd_row_tensor,
- sfd_col_tensor=sfd_col_tensor,
- amax_tensor=amax_tensor,
- norm_const_tensor=norm_const_tensor,
- current_stream=current_stream,
- )
- _cache_of_GroupedGemmDswigluSm100Objects[cache_key] = grouped_gemm_dswiglu
+ _cache_of_GroupedGemmDswigluSm100Objects[cache_key] = (grouped_gemm_dswiglu, cached_amax_tensor, cached_beta_tensor)
+
+ grouped_gemm_dswiglu.execute(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_row_tensor=d_row_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ beta_tensor=effective_beta,
+ prob_tensor=prob_tensor,
+ dprob_tensor=dprob_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ current_stream=current_stream,
+ )
return TupleDict(
d_row_tensor=d_row_tensor,
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
index b60eac72..cecaee3a 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu/api.py
@@ -165,7 +165,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("GroupedGemmGluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
# ---- Weight mode auto-detection ----
@@ -239,7 +239,7 @@ def __init__(
self._kernel = BlockScaledMoEGroupedGemmGluBiasKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._workspace = None
@@ -648,7 +648,7 @@ def compile(self) -> None:
def _compile_dense(self, gemm_glu, max_active_clusters, fake_stream) -> None:
"""Compile for dense (contiguous) weight mode."""
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
fake_workspace_ptr = cute.runtime.nullptr(
dtype=cutlass.Uint8,
@@ -1385,7 +1385,7 @@ def dynamic_m_tensor_signature(
stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
return static_shape_suffix, stride_signature, tensor.dtype
- use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
if is_dense:
cache_key = (
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py b/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
index 7c2f2647..72b728f4 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py
@@ -2313,6 +2313,19 @@ def kernel(
expert_idx = tile_info[0]
gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)]
+
+ # For dynamic MNKL, cuteDSL drops the 128bit alignment requirement
+ # but we know that during runtime the alignment is always 16 bytes.
+ gBias_tile = cute.make_tensor(
+ cute.make_ptr(
+ gBias_tile.element_type,
+ gBias_tile.iterator.toint(),
+ AddressSpace.gmem,
+ assumed_align=16,
+ ),
+ gBias_tile.layout,
+ )
+
tBs_gBias = thr_bias_g2s.partition_S(gBias_tile)
# Predicate: check if this thread's chunk is within N
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py
new file mode 100644
index 00000000..2ef06aa0
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .api import (
+ GroupedGemmGluHadamardSm100,
+ grouped_gemm_glu_hadamard_wrapper_sm100,
+)
+
+__all__ = [
+ "GroupedGemmGluHadamardSm100",
+ "grouped_gemm_glu_hadamard_wrapper_sm100",
+]
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py
new file mode 100644
index 00000000..73539492
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/api.py
@@ -0,0 +1,540 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""FE API for grouped GEMM GLU + Hadamard forward fusion."""
+
+import logging
+import os
+from typing import Optional, Tuple
+
+from cuda.bindings import driver as cuda
+import cutlass
+import cutlass.cute as cute
+import torch
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+from cutlass.cute.runtime import from_dlpack, make_fake_stream
+
+from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2
+from cudnn.datatypes import _convert_to_cutlass_data_type
+
+from ..moe_utils import MoEWeightMode
+from .hadamard_utils import HADAMARD_SIZE, hadamard_matrix
+from .moe_blockscaled_grouped_gemm_glu_hadamard import BlockScaledMoEGroupedGemmGluHadamardKernel
+
+
+def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype == torch.uint8:
+ cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1)
+ cute_tensor.element_type = cutlass.Float4E2M1FN
+ return cute_tensor
+ return tensor
+
+
+class GroupedGemmGluHadamardSm100(APIBase):
+ """Dense grouped GEMM GLU forward kernel with fused Hadamard transform."""
+
+ def __init__(
+ self,
+ sample_a: torch.Tensor,
+ sample_b: torch.Tensor,
+ sample_c: torch.Tensor,
+ sample_d: torch.Tensor,
+ sample_sfa: torch.Tensor,
+ sample_sfb: torch.Tensor,
+ sample_padded_offsets: torch.Tensor,
+ sample_alpha: torch.Tensor,
+ sample_prob: torch.Tensor,
+ sample_amax: Optional[torch.Tensor] = None,
+ sample_bias: Optional[torch.Tensor] = None,
+ sample_hadamard: Optional[torch.Tensor] = None,
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ act_func: str = "swiglu",
+ use_dynamic_sched: bool = False,
+ ):
+ super().__init__()
+
+ self._warn_experimental_api()
+ self._interpret_uint8_as_fp4x2 = True
+ self._sample_a_tensor = sample_a
+ self._sample_b_tensor = sample_b
+
+ self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False)
+ self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False)
+ self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
+ self.d_desc = self._make_tensor_desc(sample_d, name="sample_d")
+ self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa")
+ self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
+ self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets")
+ self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha")
+ self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
+ self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias")
+ self.expert_cnt = self.padded_offsets_desc.shape[0]
+ if sample_amax is None:
+ sample_amax = torch.empty((self.expert_cnt, 1), dtype=torch.float32, device=sample_a.device)
+ self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax")
+ if sample_hadamard is None:
+ self.hadamard_tensor = self._make_hadamard_tensor(sample_a.device)
+ else:
+ self.hadamard_tensor = self._normalize_hadamard_tensor(
+ sample_hadamard,
+ device=sample_a.device,
+ name="sample_hadamard",
+ )
+
+ self.acc_dtype = acc_dtype
+ self.mma_tiler_mn = mma_tiler_mn
+ self.use_2cta_instrs = mma_tiler_mn[0] == 256
+ self.cluster_shape_mn = cluster_shape_mn if cluster_shape_mn is not None else ((2, 1) if self.use_2cta_instrs else (1, 1))
+ self.sf_vec_size = sf_vec_size
+ self.vector_f32 = vector_f32
+ self.m_aligned = m_aligned
+ self.act_func = act_func
+ self.use_dynamic_sched = use_dynamic_sched
+ self.weight_mode = MoEWeightMode.DENSE
+ self._kernel = BlockScaledMoEGroupedGemmGluHadamardKernel
+ self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
+ self._workspace = None
+
+ @staticmethod
+ def _make_hadamard_tensor(device: torch.device) -> torch.Tensor:
+ return hadamard_matrix(HADAMARD_SIZE, dtype=torch.bfloat16, device=device).t().contiguous()
+
+ @classmethod
+ def _normalize_hadamard_tensor(
+ cls,
+ hadamard_tensor: torch.Tensor,
+ *,
+ device: torch.device,
+ name: str,
+ ) -> torch.Tensor:
+ expected_shape = (HADAMARD_SIZE, HADAMARD_SIZE)
+ if tuple(hadamard_tensor.shape) != expected_shape:
+ raise ValueError(f"{name} tensor shape mismatch: expected {expected_shape}, got {tuple(hadamard_tensor.shape)}")
+ if hadamard_tensor.dtype != torch.bfloat16 or hadamard_tensor.device != device:
+ hadamard_tensor = hadamard_tensor.to(device=device, dtype=torch.bfloat16)
+ if not hadamard_tensor.is_contiguous():
+ hadamard_tensor = hadamard_tensor.contiguous()
+ return hadamard_tensor
+
+ def check_support(self) -> bool:
+ tensor_m, k, _ = self._tensor_shape(self.a_desc, name="sample_a")
+ n, _, l = self._tensor_shape(self.b_desc, name="sample_b")
+ _, n_c, _ = self._tensor_shape(self.c_desc, name="sample_c")
+ _, n_d, _ = self._tensor_shape(self.d_desc, name="sample_d")
+
+ self._value_error_if(l != self.expert_cnt, f"B L dimension ({l}) must match expert_cnt ({self.expert_cnt})")
+ self._value_error_if(n % 64 != 0, f"N must be divisible by 64, got {n}")
+ self._value_error_if((n // 2) % HADAMARD_SIZE != 0, f"N/2 must be divisible by {HADAMARD_SIZE}, got {n // 2}")
+
+ self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A")
+ self._check_tensor_shape(self.b_desc, (n, k, l), "B")
+ self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C")
+ self._check_tensor_shape(self.d_desc, (tensor_m, n // 2, 1), "D")
+ self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, ceil_div(ceil_div(k, self.sf_vec_size), 4), 1), "SFA")
+ self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, ceil_div(ceil_div(k, self.sf_vec_size), 4), l), "SFB")
+ self._check_tensor_shape(self.padded_offsets_desc, (l,), "padded_offsets")
+ self._check_tensor_shape(self.alpha_desc, (l,), "alpha")
+ self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob")
+ self._check_tensor_shape(self.bias_desc, (n, l), "bias")
+ self._check_tensor_shape(self.amax_desc, (l, 1), "amax")
+ self._check_tensor_shape(self.hadamard_tensor, (HADAMARD_SIZE, HADAMARD_SIZE), "hadamard")
+
+ self._check_tensor_stride(self.a_desc, stride=[(k, 1, tensor_m * k)], name="A", extra_error_msg="A must have k-major layout")
+ self._check_tensor_stride(self.b_desc, stride=[(k, 1, n * k)], name="B", extra_error_msg="B must have k-major layout")
+ self._check_tensor_stride(self.c_desc, stride=[(n_c, 1, tensor_m * n_c)], name="C", extra_error_msg="C must have n-major layout")
+ self._check_tensor_stride(self.d_desc, stride=[(n_d, 1, tensor_m * n_d)], name="D", extra_error_msg="D must have n-major layout")
+ self._check_tensor_stride(self.bias_desc, stride=[(1, n)], name="bias")
+
+ self.ab_dtype = self._check_dtype(
+ self.a_desc,
+ dtype=[torch.float4_e2m1fn_x2, torch.uint8],
+ name="A",
+ )
+ self._check_dtype(self.b_desc, dtype=self.ab_dtype, name="B", extra_error_msg="B must match A dtype")
+ self.sf_dtype = self._check_dtype(self.sfa_desc, dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], name="SFA")
+ self._check_dtype(self.sfb_desc, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must match SFA dtype")
+ self.c_dtype = self._check_dtype(self.c_desc, dtype=[torch.float16, torch.bfloat16], name="C")
+ self.d_dtype = self._check_dtype(self.d_desc, dtype=[torch.float16, torch.bfloat16], name="D")
+ self._check_dtype(self.alpha_desc, dtype=torch.float32, name="alpha")
+ self._check_dtype(self.prob_desc, dtype=torch.float32, name="prob")
+ self._check_dtype(self.bias_desc, dtype=[torch.float16, torch.bfloat16, torch.float32], name="bias")
+ self._check_dtype(self.amax_desc, dtype=torch.float32, name="amax")
+ self._check_dtype(self.hadamard_tensor, dtype=torch.bfloat16, name="hadamard")
+ self._check_dtype(self.acc_dtype, dtype=torch.float32, name="acc_dtype")
+
+ self._value_error_if(self.sf_vec_size not in [16, 32], f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}")
+ self._value_error_if(self.act_func not in ["swiglu", "geglu"], f"act_func must be 'swiglu' or 'geglu', got {self.act_func}")
+ self._value_error_if(
+ not self.use_2cta_instrs or self.mma_tiler_mn != (256, 256), f"Hadamard fusion requires mma_tiler_mn=(256, 256), got {self.mma_tiler_mn}"
+ )
+ self._value_error_if(self.cluster_shape_mn[0] % 2 != 0, f"cluster_shape_mn[0] must be divisible by 2, got {self.cluster_shape_mn[0]}")
+ self._value_error_if(
+ not (
+ self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16
+ and self.cluster_shape_mn[0] > 0
+ and self.cluster_shape_mn[1] > 0
+ and self.cluster_shape_mn[0] <= 4
+ and self.cluster_shape_mn[1] <= 4
+ and is_power_of_2(self.cluster_shape_mn[0])
+ and is_power_of_2(self.cluster_shape_mn[1])
+ ),
+ f"Invalid cluster shape: {self.cluster_shape_mn}",
+ )
+ self._value_error_if(
+ self.m_aligned != BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE,
+ f"m_aligned must be {BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE}, got {self.m_aligned}",
+ )
+ self._value_error_if(self.expert_cnt > 1024, f"expert_cnt must be <= 1024, got {self.expert_cnt}")
+
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA is not available")
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
+ compute_capability = major * 10 + minor
+ if compute_capability < 100:
+ raise RuntimeError(f"GroupedGemmGluHadamardSm100 requires SM100+, found SM{compute_capability}")
+
+ if not self._kernel.can_implement(
+ _convert_to_cutlass_data_type(self.ab_dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2),
+ _convert_to_cutlass_data_type(self.sf_dtype),
+ self.sf_vec_size,
+ _convert_to_cutlass_data_type(self.acc_dtype),
+ _convert_to_cutlass_data_type(self.d_dtype),
+ self.use_2cta_instrs,
+ self.mma_tiler_mn,
+ self.cluster_shape_mn,
+ self.m_aligned,
+ n,
+ k,
+ l,
+ "k",
+ "k",
+ "n",
+ self.m_aligned,
+ ):
+ raise RuntimeError("Unsupported grouped GEMM GLU hadamard configuration")
+
+ self._is_supported = True
+ return True
+
+ def compile(self) -> None:
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ return
+ if self.a_desc.shape[0] == 0:
+ return
+
+ kernel = self._kernel(
+ sf_vec_size=self.sf_vec_size,
+ acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype),
+ use_2cta_instrs=self.use_2cta_instrs,
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ vectorized_f32=self.vector_f32,
+ expert_cnt=self.expert_cnt,
+ weight_mode=MoEWeightMode.DENSE,
+ use_dynamic_sched=self.use_dynamic_sched,
+ act_func=self.act_func,
+ enable_bias=self.bias_desc is not None,
+ )
+
+ hardware_info = cutlass.utils.HardwareInfo()
+ max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1])
+ max_active_clusters -= self.num_cluster_overlap_margin
+ self._value_error_if(max_active_clusters <= 0, "max_active_clusters must be > 0 after overlap margin")
+ self._workspace = torch.empty(max(kernel.get_workspace_bytes(), 1), dtype=torch.uint8, device="cuda")
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+ fake_workspace_ptr = cute.runtime.nullptr(dtype=cutlass.Uint8, assumed_align=128)
+
+ valid_m = cute.sym_int(divisibility=self.m_aligned)
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, self.a_desc.shape[1], 1),
+ stride_order=self.a_desc.stride_order,
+ dynamic_mode=self.a_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, self.c_desc.shape[1], 1),
+ stride_order=self.c_desc.stride_order,
+ dynamic_mode=self.c_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.c_desc) else 16,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, self.d_desc.shape[1], 1),
+ stride_order=self.d_desc.stride_order,
+ dynamic_mode=self.d_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_desc) else 16,
+ )
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, 1, 1),
+ stride=self.prob_desc.stride,
+ )
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfa_desc.shape[4], 1),
+ stride=(16, 4, self.sfa_desc.stride[2], 1, 512, stride_tensor_m_128),
+ )
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+ alpha_cute_fake = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16)
+ padded_offsets_cute_fake = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16)
+ amax_cute_fake = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16)
+ bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16)
+ hadamard_cute_fake = self._make_fake_cute_tensor_like(self.hadamard_tensor, assumed_align=16, name="sample_hadamard")
+ cached_linear_offset = cutlass.Float32(1.0 if self.act_func == "geglu" else 0.0)
+
+ compiled_kernel = cute.compile(
+ kernel,
+ _reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake,
+ _reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake,
+ sfa_cute_fake,
+ sfb_cute_fake,
+ cutlass.Int32(0),
+ cutlass.Int32(0),
+ cutlass.Int64(0),
+ OperandMajorMode.K,
+ fake_workspace_ptr,
+ c_cute_fake,
+ d_cute_fake,
+ amax_cute_fake,
+ padded_offsets_cute_fake,
+ alpha_cute_fake,
+ prob_cute_fake,
+ hadamard_cute_fake,
+ bias_cute_fake,
+ max_active_clusters,
+ fake_stream,
+ cached_linear_offset,
+ options="--enable-tvm-ffi",
+ )
+
+ cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ hadamard_tensor: torch.Tensor,
+ amax_tensor: Optional[torch.Tensor],
+ bias_tensor: Optional[torch.Tensor],
+ stream: cuda.CUstream,
+ ) -> None:
+ compiled_kernel(
+ a_tensor,
+ b_tensor,
+ sfa_tensor,
+ sfb_tensor,
+ cutlass.Int32(0),
+ cutlass.Int32(0),
+ cutlass.Int64(0),
+ cached_workspace_ptr,
+ c_tensor,
+ d_tensor,
+ amax_tensor,
+ padded_offsets,
+ alpha_tensor,
+ prob_tensor,
+ hadamard_tensor,
+ bias_tensor,
+ stream,
+ cached_linear_offset,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ hadamard_tensor: Optional[torch.Tensor] = None,
+ amax_tensor: Optional[torch.Tensor] = None,
+ bias_tensor: Optional[torch.Tensor] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ self._ensure_support_checked()
+ if self._compiled_kernel is None:
+ raise RuntimeError("Kernel has not been compiled")
+ if a_tensor.shape[0] == 0:
+ return
+ if current_stream is None:
+ current_stream = cuda.CUstream(torch.cuda.current_stream(a_tensor.device).cuda_stream)
+ if hadamard_tensor is None:
+ hadamard_tensor = self.hadamard_tensor
+ else:
+ hadamard_tensor = self._normalize_hadamard_tensor(
+ hadamard_tensor,
+ device=a_tensor.device,
+ name="hadamard",
+ )
+
+ self._compiled_kernel(
+ _reinterpret_raw_grouped_fp4_tensor(a_tensor),
+ _reinterpret_raw_grouped_fp4_tensor(b_tensor),
+ c_tensor,
+ d_tensor,
+ sfa_tensor,
+ sfb_tensor,
+ padded_offsets,
+ alpha_tensor,
+ prob_tensor,
+ hadamard_tensor,
+ amax_tensor,
+ bias_tensor,
+ current_stream,
+ )
+
+
+_logger = logging.getLogger(__name__)
+_cache_of_GroupedGemmGluHadamardSm100Objects = {}
+
+
+def grouped_gemm_glu_hadamard_wrapper_sm100(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: torch.Tensor,
+ bias_tensor: Optional[torch.Tensor] = None,
+ acc_dtype: torch.dtype = torch.float32,
+ c_dtype: torch.dtype = torch.bfloat16,
+ d_dtype: torch.dtype = torch.bfloat16,
+ cd_major: str = "n",
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ act_func: str = "swiglu",
+ use_dynamic_sched: bool = False,
+ current_stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ """High-level wrapper for grouped GEMM GLU + Hadamard forward fusion."""
+
+ valid_m = a_tensor.shape[0]
+ n_full, _, l = b_tensor.shape
+ n_out = n_full // 2
+
+ if cd_major != "n":
+ raise ValueError(f"cd_major must be 'n', got {cd_major}")
+
+ c_tensor = torch.empty_strided((valid_m, n_full, 1), (n_full, 1, valid_m * n_full), dtype=c_dtype, device=a_tensor.device)
+ d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ amax_tensor = None
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+
+ if valid_m == 0:
+ return TupleDict(c_tensor=c_tensor, d_tensor=d_tensor, amax_tensor=amax_tensor)
+
+ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
+ return tuple(i for i, _ in sorted(enumerate(tensor.stride()), key=lambda item: item[1]))
+
+ def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype
+
+ def dynamic_m_tensor_signature(
+ tensor: Optional[torch.Tensor], static_shape_suffix: Optional[Tuple[int, ...]], dynamic_stride_dims: Tuple[int, ...] = ()
+ ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ stride_signature = tuple(None if idx in dynamic_stride_dims else value for idx, value in enumerate(tensor.stride()))
+ return static_shape_suffix, stride_signature, tensor.dtype
+
+ cache_key = (
+ act_func,
+ a_tensor.shape[1:],
+ tuple(b_tensor.shape),
+ c_tensor.shape[1:],
+ a_tensor.dtype,
+ b_tensor.dtype,
+ c_tensor.dtype,
+ d_tensor.dtype,
+ stride_order(a_tensor),
+ stride_order(b_tensor),
+ stride_order(c_tensor),
+ *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1), dynamic_stride_dims=(5,)),
+ *tensor_signature(sfb_tensor),
+ *tensor_signature(alpha_tensor),
+ *dynamic_m_tensor_signature(prob_tensor, (1, 1)),
+ *tensor_signature(bias_tensor),
+ *tensor_signature(padded_offsets),
+ acc_dtype,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ sf_vec_size,
+ vector_f32,
+ m_aligned,
+ use_dynamic_sched,
+ )
+
+ if cache_key in _cache_of_GroupedGemmGluHadamardSm100Objects:
+ api = _cache_of_GroupedGemmGluHadamardSm100Objects[cache_key]
+ else:
+ api = GroupedGemmGluHadamardSm100(
+ sample_a=a_tensor,
+ sample_b=b_tensor,
+ sample_c=c_tensor,
+ sample_d=d_tensor,
+ sample_sfa=sfa_tensor,
+ sample_sfb=sfb_tensor,
+ sample_padded_offsets=padded_offsets,
+ sample_alpha=alpha_tensor,
+ sample_prob=prob_tensor,
+ sample_amax=amax_tensor,
+ sample_bias=bias_tensor,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ m_aligned=m_aligned,
+ act_func=act_func,
+ use_dynamic_sched=use_dynamic_sched,
+ )
+ api.check_support()
+ api.compile()
+ _cache_of_GroupedGemmGluHadamardSm100Objects[cache_key] = api
+
+ api.execute(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ amax_tensor=amax_tensor,
+ bias_tensor=bias_tensor,
+ current_stream=current_stream,
+ )
+ return TupleDict(c_tensor=c_tensor, d_tensor=d_tensor, amax_tensor=amax_tensor)
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py
new file mode 100644
index 00000000..205ef21d
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/hadamard_utils.py
@@ -0,0 +1,140 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""Local Hadamard helpers for the grouped GEMM GLU hadamard kernel."""
+
+import cutlass
+import cutlass.cute as cute
+import cutlass.pipeline as pipeline
+import cutlass.utils.blackwell_helpers as sm100_utils
+from cutlass.cute.nvgpu import tcgen05
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode, OperandSource
+import torch
+
+HADAMARD_SIZE = 16
+TMEM_ROW_STRIDE = 1 << 16
+M_PER_CLUSTER = 256
+
+
+@cute.jit
+def hadamard_setup(g_hadamard, s_hadamard, tidx):
+ tiled_hmma = sm100_utils.make_trivial_tiled_mma(
+ cutlass.BFloat16,
+ OperandMajorMode.K,
+ OperandMajorMode.K,
+ cutlass.Float32,
+ tcgen05.CtaGroup.TWO,
+ (M_PER_CLUSTER, HADAMARD_SIZE),
+ OperandSource.TMEM,
+ )
+ s_hadamard[tidx] = g_hadamard[tidx if tidx < 64 else (tidx ^ 8)]
+ return tiled_hmma
+
+
+@cute.jit
+def hadamard_compute(tiled_hmma, tmem_a_ptr, tmem_acc_ptr, s_hadamard, epi_tile, tidx, pipeline_producer):
+ n = epi_tile[1]
+ mma_tiler = (M_PER_CLUSTER, HADAMARD_SIZE, HADAMARD_SIZE)
+ cta_rank = cute.arch.block_in_cluster_idx()[0]
+ thr = tiled_hmma.get_slice(cta_rank)
+ b_layout = sm100_utils.make_smem_layout_b(tiled_hmma, mma_tiler, cutlass.BFloat16, 1)
+ s_bh = s_hadamard.get_tensor(b_layout.outer, swizzle=b_layout.inner)
+ t_bs_b = tiled_hmma.make_fragment_B(s_bh)
+
+ t_a_subtile = cute.make_tensor(0, cute.make_layout((M_PER_CLUSTER, n), stride=(n, 1)))
+ t_ht_a_frg = thr.partition_A(t_a_subtile)
+ t_ht_a = cute.make_tensor(
+ cute.recast_ptr(tmem_a_ptr, dtype=cutlass.BFloat16),
+ tiled_hmma.make_fragment_A(t_ht_a_frg.layout).layout,
+ )
+
+ t_ht_c_frg = thr.partition_C(t_a_subtile)
+ t_ht_c = cute.make_tensor(
+ tmem_acc_ptr,
+ tiled_hmma.make_fragment_C(t_ht_c_frg.layout).layout,
+ )
+
+ if cta_rank == 0 and tidx < 32:
+ hadamard_empty = pipeline_producer.acquire_and_advance()
+ for i in cutlass.range_constexpr(cute.size(t_ht_c.shape, mode=[2]), unroll_full=True):
+ cute.gemm(
+ tiled_hmma,
+ cute.append_ones(t_ht_c[None, None, i], up_to_rank=3),
+ cute.append_ones(t_ht_a[(None, None, i)], up_to_rank=3),
+ cute.append_ones(t_bs_b[(None, None, 0, 0)], up_to_rank=3),
+ cute.append_ones(t_ht_c[None, None, i], up_to_rank=3),
+ )
+ hadamard_empty.commit()
+
+
+@cute.jit
+def hadamard_in(rmem_src: cute.Tensor, cols: cutlass.Constexpr, tmem_ptr, tidx):
+ tmem_ptr = cute.make_ptr(
+ cutlass.Float32,
+ tmem_ptr.toint(),
+ cute.AddressSpace.tmem,
+ assumed_align=8,
+ )
+ tmem_tensor = cute.make_tensor(
+ tmem_ptr,
+ cute.make_layout((128, cols), stride=(TMEM_ROW_STRIDE, 1)),
+ )
+
+ if cutlass.const_expr(cols == 32):
+ st_atom = cute.make_copy_atom(tcgen05.St32x32bOp(tcgen05.Repetition.x32), cutlass.Float32)
+ else:
+ st_atom = cute.make_copy_atom(tcgen05.St32x32bOp(tcgen05.Repetition.x16), cutlass.Float32)
+ tiled_st = tcgen05.make_tmem_copy(st_atom, tmem_tensor)
+ thr_st = tiled_st.get_slice(tidx)
+ t_d_st = thr_st.partition_D(tmem_tensor)
+ cute.copy(thr_st, cute.recast_tensor(rmem_src, cutlass.Float32), t_d_st)
+
+
+@cute.jit
+def hadamard_out(rmem_dst: cute.Tensor, cols: cutlass.Constexpr, tmem_ptr, tidx):
+ tmem_ptr = cute.make_ptr(
+ cutlass.Float32,
+ tmem_ptr.toint(),
+ cute.AddressSpace.tmem,
+ assumed_align=8,
+ )
+ tmem_tensor = cute.make_tensor(
+ tmem_ptr,
+ cute.make_layout((128, cols), stride=(TMEM_ROW_STRIDE, 1)),
+ )
+
+ if cutlass.const_expr(cols == 32):
+ ld_atom = cute.make_copy_atom(tcgen05.Ld32x32bOp(tcgen05.Repetition.x32), cutlass.Float32)
+ else:
+ ld_atom = cute.make_copy_atom(tcgen05.Ld32x32bOp(tcgen05.Repetition.x16), cutlass.Float32)
+ tiled_ld = tcgen05.make_tmem_copy(ld_atom, tmem_tensor)
+ thr_ld = tiled_ld.get_slice(tidx)
+ t_d_ld_ = thr_ld.partition_D(tmem_tensor)
+ t_d_ld = cute.make_tensor(
+ cute.make_ptr(
+ cutlass.Float32,
+ t_d_ld_.iterator.toint(),
+ cute.AddressSpace.tmem,
+ assumed_align=8,
+ ),
+ t_d_ld_.layout,
+ )
+ cute.copy(thr_ld, t_d_ld, cute.recast_tensor(rmem_dst, cutlass.Float32))
+
+
+def hadamard_matrix(n, dtype=None, device=None):
+ if dtype is None:
+ dtype = torch.float32
+ if n < 1:
+ raise ValueError("n must be a positive integer")
+ if n & (n - 1):
+ raise ValueError("n must be a power of 2")
+
+ kwargs = {"dtype": dtype}
+ if device is not None:
+ kwargs["device"] = device
+ matrix = torch.tensor([[1]], **kwargs)
+ base = torch.tensor([[1, 1], [1, -1]], **kwargs)
+ while matrix.shape[0] < n:
+ matrix = torch.kron(matrix, base)
+ return matrix
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py
new file mode 100644
index 00000000..2cfc81e1
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_glu_hadamard/moe_blockscaled_grouped_gemm_glu_hadamard.py
@@ -0,0 +1,2513 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""
+MoE Block-Scaled Grouped GEMM Kernel with GLU (SwiGLU/GeGLU) + Hadamard Transform Fusion.
+
+Supports:
+ - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler)
+ - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout
+ - BF16/F16 D output with Hadamard transform (pingpong epilogue)
+ - Optional C output (pre-activation GLU output)
+ - AMAX reduction for calibration
+ - GLU activation fusion (SwiGLU / GeGLU)
+
+Warp assignment (8 epilogue warps, pingpong):
+ warps 0-3 : ACT warps — TMEM→reg, alpha scale, GLU activation, C store, hadamard_in
+ warps 4-7 : RHT store warps — hadamard_compute, D store
+ warp 8 : MMA warp
+ warp 9 : TMA load warp
+ warp 10 : Scheduler warp (MoEPersistentTileScheduler)
+ warp 11 : Bias load warp (optional)
+
+sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt)
+ Validity: tile_info[0] >= 0 (expert_idx == -1 signals end)
+"""
+
+from typing import Type, Tuple, Union, Optional
+
+import cuda.bindings.driver as cuda
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu import cpasync, tcgen05
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+import cutlass.utils as utils
+import cutlass.pipeline as pipeline
+import cutlass.utils.blackwell_helpers as sm100_utils
+import cutlass.utils.blockscaled_layout as blockscaled_utils
+from cutlass._mlir.dialects.nvvm import ReduxKind
+from cutlass.cute.typing import Float32, Int32, AddressSpace
+from ..moe_persistent_scheduler import (
+ MoEPersistentTileScheduler,
+ MoESchedulerParams,
+ MoEWorkTileInfo,
+)
+from ..moe_utils import (
+ compute_expert_token_range,
+ MoEWeightMode,
+ TensormapWorkspace,
+ store_tma_desc,
+)
+from .hadamard_utils import (
+ hadamard_setup,
+ hadamard_compute,
+ hadamard_in,
+ hadamard_out,
+ HADAMARD_SIZE,
+)
+from ..moe_sched_extension import (
+ DiscreteWeightScaledGemmSchedExtension,
+ ContiguousAndConsistentGroupedGemmSchedExtension,
+)
+from ..moe_kernel_helpers import (
+ fmin,
+ fmax,
+ warp_redux_sync,
+ atomic_max_float32,
+ silu_f32,
+ silu_f32_geglu_scaled,
+ compute_grid,
+ get_dtype_rcp_limits,
+ can_implement,
+ amax_reduction_per_thread,
+)
+
+
+class BlockScaledMoEGroupedGemmGluHadamardKernel:
+ """Block-scaled MoE grouped GEMM with GLU activation and Hadamard transform fusion.
+
+ Always uses pingpong epilogue (8 epilogue warps: 4 ACT + 4 RHT-store).
+ D output is BF16 or F16 only (no FP8/FP4, no SFD).
+
+ :param sf_vec_size: Scalefactor vector size.
+ :param mma_tiler_mn: Shape of MMA tile (M, N).
+ :param cluster_shape_mn: Cluster dimensions (M, N).
+ :param expert_cnt: Number of experts (compile-time constant).
+ :param weight_mode: Dense or Discrete weight layout.
+ :param use_dynamic_sched: Use dynamic tile scheduling.
+ :param act_func: Activation function ('swiglu' or 'geglu').
+ :param enable_bias: Enable bias addition.
+ """
+
+ FIX_PAD_SIZE = 256
+
+ @staticmethod
+ def can_implement(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ d_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ a_major: str,
+ b_major: str,
+ cd_major: str,
+ m_aligned: int,
+ ) -> bool:
+ # speical requirements for hadamard fusion
+ if not use_2cta_instrs or mma_tiler_mn[0] != 256 or mma_tiler_mn[1] != 256:
+ return False
+ return can_implement(
+ ab_dtype,
+ sf_dtype,
+ sf_vec_size,
+ acc_dtype,
+ d_dtype,
+ use_2cta_instrs,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ m,
+ n,
+ k,
+ l,
+ a_major,
+ b_major,
+ cd_major,
+ m_aligned,
+ fix_pad_size=BlockScaledMoEGroupedGemmGluHadamardKernel.FIX_PAD_SIZE,
+ )
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vectorized_f32: bool,
+ expert_cnt: int,
+ weight_mode: MoEWeightMode = MoEWeightMode.DISCRETE,
+ use_dynamic_sched: bool = False,
+ act_func: str = "swiglu",
+ enable_bias: bool = False,
+ ):
+ mma_tile_m = mma_tiler_mn[0]
+ if self.FIX_PAD_SIZE % mma_tile_m != 0:
+ raise ValueError(f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}).")
+ if expert_cnt > 1024:
+ raise ValueError("Expert count > 1024 is not supported.")
+ if not isinstance(weight_mode, MoEWeightMode):
+ raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}")
+
+ self.sf_vec_size = sf_vec_size
+ self.expert_cnt = expert_cnt
+ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
+ self.use_2cta_instrs = use_2cta_instrs
+ self.cluster_shape_mn = cluster_shape_mn
+ self.mma_tiler = (*mma_tiler_mn, 1)
+ self.weight_mode = weight_mode
+ self.use_dynamic_sched = use_dynamic_sched
+ self.enable_bias = enable_bias
+
+ # Always use pingpong epilogue for Hadamard
+ self.epilogue_pingpong = True
+ # Always delay TMA store acquire sync for Hadamard
+ self.delay_tma_store_acquire_sync = True
+
+ self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
+
+ self.occupancy = 1
+ self.threads_per_warp = 32
+
+ # Warp assignments: 8 epilogue warps (4 ACT + 4 RHT-store)
+ self.epilog_warp_id = (0, 1, 2, 3, 4, 5, 6, 7)
+ self.epilog_act_warp_id = (0, 1, 2, 3)
+ self.epilog_rht_store_warp_id = (4, 5, 6, 7)
+ self.mma_warp_id = 8
+ self.tma_warp_id = 9
+ self.sched_warp_id = 10
+ self.bias_load_warp_id = 11 if enable_bias else None
+
+ self.epilogue_warp_group_size = len(self.epilog_act_warp_id) # = 4
+
+ all_warps = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id, self.sched_warp_id]
+ warps_wo_sched = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id]
+ if enable_bias:
+ all_warps.append(self.bias_load_warp_id)
+ warps_wo_sched.append(self.bias_load_warp_id)
+ self.threads_per_cta = self.threads_per_warp * len(all_warps)
+ self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched)
+
+ # Named barriers
+ self.cta_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=1,
+ num_threads=self.threads_per_cta,
+ )
+ self.epilog_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=2,
+ num_threads=32 * len(self.epilog_warp_id),
+ )
+ self.tmem_alloc_barrier = pipeline.NamedBarrier(
+ barrier_id=3,
+ num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
+ )
+ self.sched_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=4,
+ num_threads=self.threads_per_warp,
+ )
+ # Pingpong barriers (group 0 = ACT warps, group 1 = RHT store warps)
+ self.epilog_sync_barrier_group0 = pipeline.NamedBarrier(
+ barrier_id=5,
+ num_threads=32 * self.epilogue_warp_group_size,
+ )
+ self.epilog_sync_barrier_group1 = pipeline.NamedBarrier(
+ barrier_id=6,
+ num_threads=32 * self.epilogue_warp_group_size,
+ )
+
+ self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
+ SM100_TMEM_CAPACITY_COLUMNS = 512
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
+
+ self.vectorized_f32 = vectorized_f32
+
+ # Amax: only RHT store warps (4) do reduction
+ self.num_epilog_warps = len(self.epilog_rht_store_warp_id) # = 4
+
+ self.act_func = act_func
+ if act_func not in ["swiglu", "geglu"]:
+ raise ValueError(f"Invalid activation function: {act_func}")
+
+ def _setup_attributes(self):
+ """Set up configurations dependent on GEMM inputs (called inside __call__)."""
+
+ self.mma_inst_shape_mn = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ )
+ self.mma_inst_shape_mn_sfb = (
+ self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
+ cute.round_up(self.mma_inst_shape_mn[1], 128),
+ )
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+
+ mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
+ mma_inst_tile_k = 4
+ self.mma_tiler = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+
+ self.mma_tiler_sfb = (
+ self.mma_inst_shape_mn_sfb[0],
+ self.mma_inst_shape_mn_sfb[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+
+ self.cta_tile_shape_mnk = (
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler[1],
+ self.mma_tiler[2],
+ )
+ self.cta_tile_shape_mnk_sfb = (
+ self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_sfb[1],
+ self.mma_tiler_sfb[2],
+ )
+
+ self.mma_tiler_d = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1] // 2,
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk_d = (
+ self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_d[1],
+ self.mma_tiler_d[2],
+ )
+
+ self.cluster_layout_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma.thr_id.shape,),
+ )
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma_sfb.thr_id.shape,),
+ )
+
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
+
+ self.epi_tile = (128, 32)
+ self.epi_tile_cnt = (
+ self.cta_tile_shape_mnk_d[0] // self.epi_tile[0],
+ self.cta_tile_shape_mnk_d[1] // self.epi_tile[1],
+ )
+ self.epi_tile_c = (128, 64)
+
+ (
+ self.num_acc_stage,
+ self.num_ab_stage,
+ self.num_c_stage,
+ self.num_d_stage,
+ self.num_tile_stage,
+ self.num_bias_stage,
+ self.num_pingpong_stage,
+ ) = self._compute_stages(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.b_dtype,
+ self.epi_tile,
+ self.epi_tile_c,
+ self.c_dtype,
+ self.c_layout,
+ self.d_dtype,
+ self.d_layout,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.num_smem_capacity,
+ self.occupancy,
+ self.bias_dtype if self.enable_bias else None,
+ )
+
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.num_ab_stage,
+ )
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ self.mma_tiler,
+ self.b_dtype,
+ self.num_ab_stage,
+ )
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.c_dtype,
+ self.c_layout,
+ self.epi_tile_c,
+ self.num_c_stage,
+ )
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.d_dtype,
+ self.d_layout,
+ self.epi_tile,
+ self.num_d_stage,
+ )
+
+ if self.enable_bias:
+ self.bias_smem_layout_staged = cute.make_layout(
+ (self.mma_tiler[1], self.num_bias_stage),
+ stride=(1, self.mma_tiler[1]),
+ )
+ else:
+ self.bias_smem_layout_staged = cute.make_layout((1, 1))
+
+ self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256
+
+ sf_atom_mn = 32
+ self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
+ self.num_accumulator_tmem_cols = (
+ self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
+ )
+
+ self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1])
+ self.iter_acc_early_release_in_epilogue = ((self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1) * 2
+
+ def get_desc_workspace_bytes(self) -> int:
+ """Return descriptor workspace size in bytes."""
+ if self.weight_mode == MoEWeightMode.DISCRETE:
+ from ..moe_utils import DiscreteWeightTensormapConstructor
+
+ return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt)
+ return 0
+
+ def get_workspace_bytes(self) -> int:
+ """Return total workspace size in bytes."""
+ desc_workspace_bytes = self.get_desc_workspace_bytes()
+ dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0
+ return desc_workspace_bytes + dynamic_sched_bytes
+
+ @cute.jit
+ def _get_sched_counter_ptr(self, workspace_ptr):
+ counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes()
+ return cute.make_ptr(
+ cutlass.Int32,
+ counter_addr,
+ AddressSpace.gmem,
+ assumed_align=4,
+ )
+
+ @cute.kernel
+ def helper_kernel(
+ self,
+ ptrs_b: cute.Pointer,
+ ptrs_sfb: cute.Pointer,
+ n: Int32,
+ k: Int32,
+ b_stride_size: cutlass.Int64,
+ b_major_mode: cutlass.Constexpr,
+ workspace_ptr,
+ tiled_mma_arg: cute.TiledMma,
+ tiled_mma_sfb_arg: cute.TiledMma,
+ b_smem_layout_arg,
+ sfb_smem_layout_arg,
+ cluster_layout_vmnk_shape_arg: cutlass.Constexpr,
+ cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr,
+ ):
+ """Pre-main-kernel: build per-expert TMA descriptors (discrete mode) and/or reset sched counter."""
+ expert_idx = cute.arch.block_idx()[0]
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+ sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+
+ # Read per-expert base addresses from the pointer arrays
+ b_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8),
+ cute.make_layout((self.expert_cnt,)),
+ )
+ sfb_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8),
+ cute.make_layout((self.expert_cnt,)),
+ )
+
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ stride_n = b_stride_size
+ stride_k = c1_64
+ else:
+ stride_n = c1_64
+ stride_k = b_stride_size
+
+ b_ptr_val = b_ptr_tensor[expert_idx]
+ b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem)
+ b_expert = cute.make_tensor(
+ b_ptr,
+ cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)),
+ )
+ tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ b_tma_op_arg,
+ b_expert,
+ b_smem_layout_arg,
+ self.mma_tiler,
+ tiled_mma_arg,
+ cluster_layout_vmnk_shape_arg,
+ )
+
+ workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx))
+
+ sfb_ptr_val = sfb_ptr_tensor[expert_idx]
+ sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb_expert = cute.make_tensor(sfb_ptr, sfb_layout)
+ tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_tma_op_arg,
+ sfb_expert,
+ sfb_smem_layout_arg,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb_arg,
+ cluster_layout_sfb_vmnk_shape_arg,
+ internal_type=cutlass.Uint64,
+ )
+ store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx))
+
+ if cutlass.const_expr(self.use_dynamic_sched):
+ if expert_idx == cutlass.Int32(0):
+ sched_counter = cute.make_tensor(
+ self._get_sched_counter_ptr(workspace_ptr),
+ cute.make_layout(1),
+ )
+ sched_counter[0] = cutlass.Int32(0)
+
+ @cute.jit
+ def __call__(
+ self,
+ a: cute.Tensor,
+ b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[]
+ sfa: cute.Tensor,
+ sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[]
+ n: Int32, # Ignored for dense mode
+ k: Int32, # Ignored for dense mode
+ b_stride_size: cutlass.Int64, # Ignored for dense mode
+ b_major_mode: cutlass.Constexpr, # Ignored for dense mode
+ workspace_ptr,
+ c: cute.Tensor,
+ d: cute.Tensor,
+ amax_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ prob: cute.Tensor,
+ hadamard_tensor: Optional[cute.Tensor],
+ bias: Optional[cute.Tensor],
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ epilogue_op: cutlass.Constexpr = lambda x: x,
+ linear_offset: cutlass.Float32 = 0.0,
+ ):
+ """Execute the MoE GEMM + GLU + Hadamard kernel.
+
+ Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L).
+ Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[]
+ arrays of per-expert base addresses.
+ """
+ self.a_dtype: Type[cutlass.Numeric] = a.element_type
+ self.b_dtype: Type[cutlass.Numeric] = a.element_type
+ self.c_dtype: Type[cutlass.Numeric] = c.element_type
+ self.d_dtype: Type[cutlass.Numeric] = d.element_type
+ self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
+ self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16
+ self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
+ self.c_layout = utils.LayoutEnum.from_tensor(c)
+ self.d_layout = utils.LayoutEnum.from_tensor(d)
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
+ else:
+ self.b_major_mode = b_major_mode
+
+ if cutlass.const_expr(self.a_dtype != self.b_dtype):
+ raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}")
+
+ self._setup_attributes()
+
+ # ---- B / SFB setup (mode-dependent) ----
+ b_from_call_arg = b
+ sfb_from_call_arg = sfb
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size)
+ sfb = cute.make_tensor(sfb.iterator, sfb_layout)
+ else:
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ b_template_stride = (b_stride_size, c1_64, c0)
+ else:
+ b_template_stride = (c1_64, b_stride_size, c0)
+ b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride)
+ b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16)
+ b = cute.make_tensor(b_ptr_typed, b_template_layout)
+
+ sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout)
+
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size)
+ sfa = cute.make_tensor(sfa.iterator, sfa_layout)
+
+ self.generate_amax = amax_tensor is not None
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
+
+ # TMA load A
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
+ a_op,
+ a,
+ a_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # TMA load B
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
+ b_op,
+ b,
+ b_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ # TMA load SFA
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
+ sfa_op,
+ sfa,
+ sfa_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ # TMA load SFB
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_op,
+ sfb,
+ sfb_smem_layout,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb,
+ self.cluster_layout_sfb_vmnk.shape,
+ internal_type=cutlass.Uint64,
+ )
+
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ x = tma_tensor_sfb.stride[0][1]
+ y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
+ new_shape = (
+ (tma_tensor_sfb.shape[0][0], ((2, 2), y)),
+ tma_tensor_sfb.shape[1],
+ tma_tensor_sfb.shape[2],
+ )
+ x_times_3 = 3 * x
+ new_stride = (
+ (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
+ tma_tensor_sfb.stride[1],
+ tma_tensor_sfb.stride[2],
+ )
+ tma_tensor_sfb = cute.make_tensor(
+ tma_tensor_sfb.iterator,
+ cute.make_layout(new_shape, stride=new_stride),
+ )
+
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
+
+ # TMA store C
+ c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ c,
+ c_smem_layout,
+ self.epi_tile_c,
+ )
+
+ # TMA store D
+ d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d,
+ d_smem_layout,
+ self.epi_tile,
+ )
+
+ # ---- Helper kernel (discrete TMA desc init + dynamic sched counter reset) ----
+ _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched)
+ if cutlass.const_expr(_need_helper):
+ _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1
+ _helper_args = (
+ b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0),
+ b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode,
+ workspace_ptr,
+ tiled_mma,
+ tiled_mma_sfb,
+ b_smem_layout,
+ sfb_smem_layout,
+ self.cluster_layout_vmnk.shape,
+ self.cluster_layout_sfb_vmnk.shape,
+ )
+ self.helper_kernel(*_helper_args).launch(
+ grid=(_helper_grid_x, 1, 1),
+ block=(1, 1, 1),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+
+ # ---- Grid computation via MoE scheduler ----
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ b_n, b_k, b_l = cute.shape(b)
+ sched_expert_shape = (self.expert_cnt, b_n, b_k)
+ else:
+ sched_expert_shape = (self.expert_cnt, n, k)
+
+ sched_params = MoESchedulerParams(
+ scenario="2Dx3D",
+ expert_shape=sched_expert_shape,
+ cta_tile_shape_mnk=self.cta_tile_shape_mnk,
+ cluster_shape_mn=self.cluster_shape_mn,
+ use_dynamic_sched=self.use_dynamic_sched,
+ )
+ self.sched_params, grid = compute_grid(
+ sched_params,
+ max_active_clusters,
+ self.use_2cta_instrs,
+ )
+
+ self.buffer_align_bytes = 1024
+
+ # ---- Shared storage ----
+ SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched)
+
+ @cute.struct
+ class SharedStorage:
+ ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
+ acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
+ hadamard_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2]
+ hadamard_prerequisite_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2]
+ pingpong_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_pingpong_stage * 2]
+ if cutlass.const_expr(self.enable_bias):
+ bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2]
+ scheduler: SchedulerStorage
+ tmem_dealloc_mbar_ptr: cutlass.Int64
+ tmem_holding_buf: cutlass.Int32
+ sC: cute.struct.Align[
+ cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sD: cute.struct.Align[
+ cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sA: cute.struct.Align[
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sB: cute.struct.Align[
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sSFA: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ sSFB: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ if cutlass.const_expr(self.enable_bias):
+ sBias: cute.struct.Align[
+ cute.struct.MemRange[self.bias_dtype, cute.cosize(self.bias_smem_layout_staged)],
+ 16,
+ ]
+ sHadamard: cute.struct.Align[cute.struct.MemRange[cutlass.BFloat16, HADAMARD_SIZE * HADAMARD_SIZE], 16]
+ sBH: cute.struct.Align[cute.struct.MemRange[cutlass.BFloat16, HADAMARD_SIZE * HADAMARD_SIZE], 16]
+ sAmax: cute.struct.Align[
+ cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps],
+ 1,
+ ]
+
+ self.shared_storage = SharedStorage
+
+ # Launch main kernel
+ self.kernel(
+ tiled_mma,
+ tiled_mma_sfb,
+ tma_atom_a,
+ tma_tensor_a,
+ tma_atom_b,
+ tma_tensor_b,
+ tma_atom_sfa,
+ tma_tensor_sfa,
+ tma_atom_sfb,
+ tma_tensor_sfb,
+ tma_atom_c,
+ tma_tensor_c,
+ tma_atom_d,
+ tma_tensor_d,
+ amax_tensor,
+ padded_offsets,
+ alpha,
+ bias,
+ prob,
+ hadamard_tensor,
+ workspace_ptr,
+ self.cluster_layout_vmnk,
+ self.cluster_layout_sfb_vmnk,
+ self.a_smem_layout_staged,
+ self.b_smem_layout_staged,
+ self.sfa_smem_layout_staged,
+ self.sfb_smem_layout_staged,
+ self.c_smem_layout_staged,
+ self.d_smem_layout_staged,
+ self.bias_smem_layout_staged,
+ self.epi_tile,
+ self.sched_params,
+ epilogue_op,
+ linear_offset,
+ ).launch(
+ grid=grid,
+ block=[self.threads_per_cta, 1, 1],
+ cluster=(*self.cluster_shape_mn, 1),
+ max_number_threads=[self.threads_per_cta, 1, 1],
+ smem=self.shared_storage.size_in_bytes(),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+ return
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ @cute.jit
+ def _make_extension(self, workspace_ptr):
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ return DiscreteWeightScaledGemmSchedExtension(
+ tensormap_ctor=desc_workspace,
+ sf_vec_size=self.sf_vec_size,
+ )
+ else:
+ return ContiguousAndConsistentGroupedGemmSchedExtension(
+ sf_vec_size=self.sf_vec_size,
+ )
+
+ def mainloop_s2t_copy_and_partition(self, sSF, tSF):
+ tCsSF_compact = cute.filter_zeros(sSF)
+ tCtSF_compact = cute.filter_zeros(tSF)
+ copy_atom_s2t = cute.make_copy_atom(
+ tcgen05.Cp4x32x128bOp(self.cta_group),
+ self.sf_dtype,
+ )
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
+
+ @cute.jit
+ def amax_reduction_per_thread(self, vec_fp32, amax_fp32):
+ vec_fp32_ssa = vec_fp32.load()
+ import cutlass._mlir.dialects.math as _math
+
+ abs_acc_values_ir = _math.absf(vec_fp32_ssa.ir_value())
+ abs_acc_values = type(vec_fp32_ssa)(abs_acc_values_ir, vec_fp32_ssa.shape, vec_fp32_ssa.dtype)
+ subtile_amax = abs_acc_values.reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0)
+ return cute.arch.fmax(amax_fp32, subtile_amax)
+
+ @cute.jit
+ def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem):
+ warp_amax = warp_redux_sync(value=amax_fp32, kind=ReduxKind.MAX, mask_and_clamp=0xFFFFFFFF, nan=True)
+ if cute.arch.lane_idx() == 0:
+ amax_smem[warp_idx & 0x3] = cutlass.Float32(warp_amax)
+ self.epilog_sync_barrier_group1.arrive_and_wait()
+ if warp_idx == self.epilog_rht_store_warp_id[0] and cute.arch.lane_idx() == 0:
+ block_amax = cutlass.Float32(0.0)
+ for i in cutlass.range(self.num_epilog_warps):
+ block_amax = cute.arch.fmax(block_amax, amax_smem[i])
+ _ = atomic_max_float32(ptr=amax_gmem, value=block_amax)
+
+ @cute.jit
+ def store_c(
+ self,
+ tiled_copy_r2s,
+ tma_atom_c,
+ warp_idx,
+ tTR_rAcc,
+ tTR_rAcc_up,
+ tRS_rC,
+ tRS_sC,
+ bSG_gC,
+ bSG_sC,
+ c_pipeline,
+ prev_subtile_idx,
+ real_subtile_idx,
+ ):
+ c_buffer = prev_subtile_idx % self.num_c_stage
+ tRS_rC.store(tTR_rAcc.load().to(self.c_dtype))
+ cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)])
+ tRS_rC.store(tTR_rAcc_up.load().to(self.c_dtype))
+ cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 1, c_buffer)])
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier_group0.arrive_and_wait()
+ if warp_idx == self.epilog_act_warp_id[0]:
+ cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)])
+ c_pipeline.producer_commit()
+ if not cutlass.const_expr(self.delay_tma_store_acquire_sync):
+ c_pipeline.producer_acquire()
+ if not cutlass.const_expr(self.delay_tma_store_acquire_sync):
+ self.epilog_sync_barrier_group0.arrive_and_wait()
+
+ @cute.jit
+ def geglu_act(self, tCompute, acc_vec_up, acc_vec_gate, mProb, linear_offset=1.0):
+ if cutlass.const_expr(self.vectorized_f32):
+ LOG2_E = cutlass.Float32(1.4426950408889634)
+ for i in cutlass.range_constexpr(0, cute.size(tCompute), 2):
+ scaled_gate_0, scaled_gate_1 = cute.arch.mul_packed_f32x2(
+ (acc_vec_gate[i], acc_vec_gate[i + 1]),
+ (1.702, 1.702),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute_log2e = cute.arch.mul_packed_f32x2(
+ (scaled_gate_0, scaled_gate_1),
+ (-LOG2_E, -LOG2_E),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.add_packed_f32x2(
+ (cute.math.exp2(tCompute_log2e[0], fastmath=True), cute.math.exp2(tCompute_log2e[1], fastmath=True)),
+ (1.0, 1.0),
+ )
+ tCompute[i] = cute.arch.rcp_approx(tCompute[i])
+ tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1])
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (acc_vec_gate[i], acc_vec_gate[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ up0, up1 = cute.arch.add_packed_f32x2(
+ (linear_offset, linear_offset),
+ (acc_vec_up[i], acc_vec_up[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (up0, up1),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (mProb, mProb),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tCompute)):
+ tCompute[i] = (acc_vec_up[i] + linear_offset) * silu_f32_geglu_scaled(acc_vec_gate[i], fastmath=True)
+ tCompute[i] = tCompute[i] * mProb
+
+ @cute.jit
+ def swiglu_act(self, tCompute, acc_vec_up, acc_vec_gate, mProb):
+ if cutlass.const_expr(self.vectorized_f32):
+ LOG2_E = cutlass.Float32(1.4426950408889634)
+ for i in cutlass.range_constexpr(0, cute.size(tCompute), 2):
+ tCompute_log2e = cute.arch.mul_packed_f32x2(
+ (acc_vec_gate[i], acc_vec_gate[i + 1]),
+ (-LOG2_E, -LOG2_E),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.add_packed_f32x2(
+ (cute.math.exp2(tCompute_log2e[0], fastmath=True), cute.math.exp2(tCompute_log2e[1], fastmath=True)),
+ (1.0, 1.0),
+ )
+ tCompute[i] = cute.arch.rcp_approx(tCompute[i])
+ tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1])
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (acc_vec_gate[i], acc_vec_gate[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (acc_vec_up[i], acc_vec_up[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (tCompute[i], tCompute[i + 1]),
+ (mProb, mProb),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tCompute)):
+ tCompute[i] = acc_vec_up[i] * silu_f32(acc_vec_gate[i], fastmath=True)
+ tCompute[i] = tCompute[i] * mProb
+
+ @cute.jit
+ def query_hadamard_tmem_a_ptr(self, tile_idx, reverse_subtile, tmem_ptr):
+ hadamard_tmem_offset = 256 + (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2)
+ if reverse_subtile:
+ hadamard_tmem_offset = 256 - self.num_sf_tmem_cols - HADAMARD_SIZE - (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2)
+ return cute.recast_ptr(tmem_ptr + hadamard_tmem_offset, dtype=self.acc_dtype)
+
+ @cute.jit
+ def query_hadamard_tmem_acc_ptr(self, tile_idx, reverse_subtile, tmem_ptr):
+ hadamard_tmem_offset = 256 + (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) + HADAMARD_SIZE
+ if reverse_subtile:
+ hadamard_tmem_offset = 256 - self.num_sf_tmem_cols - HADAMARD_SIZE - (HADAMARD_SIZE + self.epi_tile[1]) * (tile_idx // 2) - self.epi_tile[1]
+ return cute.recast_ptr(tmem_ptr + hadamard_tmem_offset, dtype=self.acc_dtype)
+
+ def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs):
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
+ self.cta_tile_shape_mnk,
+ self.d_layout,
+ self.d_dtype,
+ self.acc_dtype,
+ epi_tile,
+ use_2cta_instrs,
+ )
+ tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
+ gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi)
+ tTR_rAcc_gate = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ tTR_rAcc_up = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc_gate, tTR_rAcc_up
+
+ def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rC, tidx, sD):
+ copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r)
+ tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
+ tRS_sD = thr_copy_r2s.partition_D(sD)
+ tRS_rD = tiled_copy_r2s.retile(tTR_rC)
+ return tiled_copy_r2s, tRS_rD, tRS_sD
+
+ def epilog_gmem_copy_and_partition(self, tidx, atom, gD_mnl, epi_tile, sD):
+ gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ tma_atom_d = atom
+ sD_for_tma_partition = cute.group_modes(sD, 0, 2)
+ gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2)
+ bSG_sD, bSG_gD = cpasync.tma_partition(
+ tma_atom_d,
+ 0,
+ cute.make_layout(1),
+ sD_for_tma_partition,
+ gD_for_tma_partition,
+ )
+ return tma_atom_d, bSG_sD, bSG_gD
+
+ @staticmethod
+ def _compute_stages(
+ tiled_mma,
+ mma_tiler_mnk,
+ a_dtype,
+ b_dtype,
+ epi_tile,
+ epi_tile_c,
+ c_dtype,
+ c_layout,
+ d_dtype,
+ d_layout,
+ sf_dtype,
+ sf_vec_size,
+ num_smem_capacity,
+ occupancy,
+ bias_dtype,
+ ):
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
+ num_c_stage = 1
+ num_d_stage = 1
+ num_tile_stage = 2
+ num_pingpong_stage = mma_tiler_mnk[1] // epi_tile_c[1]
+
+ a_smem_layout_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1)
+ b_smem_layout_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1)
+ sfa_smem_layout_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ sfb_smem_layout_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ c_smem_layout_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile_c, 1)
+ d_smem_layout_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
+
+ ab_bytes_per_stage = (
+ cute.size_in_bytes(a_dtype, a_smem_layout_one)
+ + cute.size_in_bytes(b_dtype, b_smem_layout_one)
+ + cute.size_in_bytes(sf_dtype, sfa_smem_layout_one)
+ + cute.size_in_bytes(sf_dtype, sfb_smem_layout_one)
+ )
+ mbar_helpers_bytes = 1024
+ # sInfo is in SchedulerStorage, not here, so use 4-int sInfo
+ sinfo_bytes = 4 * 4 * num_tile_stage
+ c_bytes = cute.size_in_bytes(c_dtype, c_smem_layout_one) * num_c_stage
+ d_bytes = cute.size_in_bytes(d_dtype, d_smem_layout_one) * num_d_stage
+ hadamard_bytes = (cutlass.BFloat16.width // 8) * HADAMARD_SIZE * HADAMARD_SIZE if d_dtype == cutlass.BFloat16 else 0
+ # sAmax for 4 RHT store warps
+ amax_bytes = 4 * 4 # 4 Float32 values
+
+ if bias_dtype is not None:
+ num_bias_stage = 2
+ bias_bytes = mma_tiler_mnk[1] * num_bias_stage * (bias_dtype.width // 8)
+ else:
+ num_bias_stage = 0
+ bias_bytes = 0
+
+ epi_bytes = c_bytes + d_bytes + hadamard_bytes + amax_bytes + bias_bytes
+
+ num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage
+
+ return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage, num_pingpong_stage
+
+ # GPU device kernel
+ @cute.kernel
+ def kernel(
+ self,
+ tiled_mma: cute.TiledMma,
+ tiled_mma_sfb: cute.TiledMma,
+ tma_atom_a: cute.CopyAtom,
+ mA_mkl: cute.Tensor,
+ tma_atom_b: cute.CopyAtom,
+ mB_nkl: cute.Tensor,
+ tma_atom_sfa: cute.CopyAtom,
+ mSFA_mkl: cute.Tensor,
+ tma_atom_sfb: cute.CopyAtom,
+ mSFB_nkl: cute.Tensor,
+ tma_atom_c: cute.CopyAtom,
+ mC_mnl: cute.Tensor,
+ tma_atom_d: cute.CopyAtom,
+ mD_mnl: cute.Tensor,
+ mAmax_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ mBias_nl: Optional[cute.Tensor],
+ prob: cute.Tensor,
+ hadamard_tensor: Optional[cute.Tensor],
+ workspace_ptr,
+ cluster_layout_vmnk: cute.Layout,
+ cluster_layout_sfb_vmnk: cute.Layout,
+ a_smem_layout_staged: cute.ComposedLayout,
+ b_smem_layout_staged: cute.ComposedLayout,
+ sfa_smem_layout_staged: cute.Layout,
+ sfb_smem_layout_staged: cute.Layout,
+ c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ bias_smem_layout_staged: cute.Layout,
+ epi_tile: cute.Tile,
+ sched_params: MoESchedulerParams,
+ epilogue_op: cutlass.Constexpr,
+ linear_offset: cutlass.Float32 = 0.0,
+ ):
+ """GPU device kernel: MoE persistent GEMM + GLU + Hadamard (pingpong epilogue)."""
+ warp_idx = cute.arch.warp_idx()
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
+ lane_idx = cute.arch.lane_idx()
+ cta_rank = cute.arch.block_in_cluster_idx()[0]
+
+ total_token = padded_offsets[self.expert_cnt - 1]
+
+ # Prefetch TMA descriptors
+ if warp_idx == self.tma_warp_id:
+ cpasync.prefetch_descriptor(tma_atom_a)
+ cpasync.prefetch_descriptor(tma_atom_sfa)
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ cpasync.prefetch_descriptor(tma_atom_b)
+ cpasync.prefetch_descriptor(tma_atom_sfb)
+ cpasync.prefetch_descriptor(tma_atom_c)
+ cpasync.prefetch_descriptor(tma_atom_d)
+
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
+
+ # CTA coordinates
+ bidx, bidy, bidz = cute.arch.block_idx()
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
+ is_leader_cta = mma_tile_coord_v == 0
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
+ tidx, _, _ = cute.arch.thread_idx()
+
+ # Shared memory allocation
+ smem = utils.SmemAllocator()
+ storage = smem.allocate(self.shared_storage)
+ sched_storage = storage.scheduler
+
+ # AB pipeline
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
+ barrier_storage=storage.ab_mbar_ptr.data_ptr(),
+ num_stages=self.num_ab_stage,
+ producer_group=ab_pipeline_producer_group,
+ consumer_group=ab_pipeline_consumer_group,
+ tx_count=self.num_tma_load_bytes,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ # ACC pipeline
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_acc_consumer_threads = len(self.epilog_act_warp_id) * (2 if use_2cta_instrs else 1)
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.acc_mbar_ptr.data_ptr(),
+ num_stages=self.num_acc_stage,
+ producer_group=acc_pipeline_producer_group,
+ consumer_group=acc_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ # Hadamard SMEM tensor
+ hadamard_b_layout_staged = cute.make_layout((16, 8), stride=(1, 16))
+ sHadamard = storage.sHadamard.get_tensor(hadamard_b_layout_staged)
+ hadamard_h_iter = hadamard_tensor.iterator + (cutlass.Int32(128) if cta_rank == 1 else cutlass.Int32(0))
+ hadamard_tensor_local = cute.make_tensor(
+ cute.make_ptr(
+ cutlass.BFloat16,
+ hadamard_h_iter.toint(),
+ hadamard_h_iter.memspace,
+ assumed_align=16,
+ ),
+ hadamard_b_layout_staged,
+ )
+
+ # Hadamard UMMA pipeline (1 stage)
+ hadamard_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_hadamard_consumer_threads = 128 * (2 if use_2cta_instrs else 1)
+ hadamard_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_consumer_threads)
+ hadamard_producer, hadamard_consumer = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.hadamard_mbar_ptr.data_ptr(),
+ num_stages=1,
+ producer_group=hadamard_pipeline_producer_group,
+ consumer_group=hadamard_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ ).make_participants()
+
+ # Hadamard prerequisite pipeline (cross-CTA sync, 1 stage)
+ num_hadamard_prerequisite_threads = 128
+ hadamard_prerequisite_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_prerequisite_threads)
+ hadamard_prerequisite_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_hadamard_prerequisite_threads)
+ peer_rank = cutlass.Int32(1) - cta_rank
+ hadamard_prerequisite_producer, hadamard_prerequisite_consumer = pipeline.PipelineAsync.create(
+ barrier_storage=storage.hadamard_prerequisite_mbar_ptr.data_ptr(),
+ num_stages=1,
+ producer_group=hadamard_prerequisite_pipeline_producer_group,
+ consumer_group=hadamard_prerequisite_pipeline_consumer_group,
+ producer_mask=peer_rank,
+ defer_sync=True,
+ ).make_participants()
+
+ # Pingpong pipeline
+ pingpong_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilog_act_warp_id) * self.threads_per_warp)
+ pingpong_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilog_rht_store_warp_id) * self.threads_per_warp)
+ pingpong_pipeline = pipeline.PipelineAsync.create(
+ barrier_storage=storage.pingpong_mbar_ptr.data_ptr(),
+ num_stages=self.num_pingpong_stage,
+ producer_group=pingpong_producer_group,
+ consumer_group=pingpong_consumer_group,
+ )
+
+ # Tile info pipeline (uses SchedulerStorage's barrier)
+ tile_info_pipeline_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp * 1,
+ )
+ tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_wo_sched,
+ )
+ tile_info_pipeline = pipeline.PipelineAsync.create(
+ barrier_storage=sched_storage.tile_info_mbar.data_ptr(),
+ num_stages=self.num_tile_stage,
+ producer_group=tile_info_pipeline_producer_group,
+ consumer_group=tile_info_pipeline_consumer_group,
+ )
+
+ # MoE persistent tile scheduler
+ scheduler = MoEPersistentTileScheduler.create(
+ sched_params,
+ padded_offsets,
+ cute.arch.block_idx(),
+ cute.arch.grid_dim(),
+ counter_ptr=self._get_sched_counter_ptr(workspace_ptr),
+ sched_storage=sched_storage,
+ )
+ scheduler.internal_init()
+
+ # Bias pipeline
+ if cutlass.const_expr(self.enable_bias):
+ bias_pipeline_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp,
+ )
+ bias_pipeline_consumer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp * len(self.epilog_act_warp_id),
+ )
+ bias_pipeline = pipeline.PipelineCpAsync.create(
+ barrier_storage=storage.bias_mbar_ptr.data_ptr(),
+ num_stages=self.num_bias_stage,
+ producer_group=bias_pipeline_producer_group,
+ consumer_group=bias_pipeline_consumer_group,
+ )
+ sBias = storage.sBias.get_tensor(bias_smem_layout_staged)
+ gBias_nl = cute.local_tile(mBias_nl, cute.slice_(self.mma_tiler[:2], (0, None)), (None, None))
+
+ # TMEM allocator
+ tmem = utils.TmemAllocator(
+ storage.tmem_holding_buf,
+ barrier_for_retrieve=self.tmem_alloc_barrier,
+ allocator_warp_id=self.epilog_act_warp_id[0],
+ is_two_cta=use_2cta_instrs,
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
+ )
+
+ # Cluster arrive after barrier init
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_arrive_relaxed()
+
+ # SMEM tensors
+ sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner)
+ sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
+ amax_layout = cute.make_layout((self.num_epilog_warps,))
+ sAmax = storage.sAmax.get_tensor(amax_layout)
+
+ # sInfo from SchedulerStorage
+ info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4))
+ sInfo = sched_storage.sInfo.get_tensor(info_layout)
+
+ # Multicast masks
+ a_full_mcast_mask = None
+ b_full_mcast_mask = None
+ sfa_full_mcast_mask = None
+ sfb_full_mcast_mask = None
+ if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
+ a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1)
+ sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1)
+
+ # MMA fragments
+ tCrA = tiled_mma.make_fragment_A(sA)
+ tCrB = tiled_mma.make_fragment_B(sB)
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
+ if cutlass.const_expr(self.overlapping_accum):
+ num_acc_stage_overlapped = 2
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped))
+ tCtAcc_fake = cute.make_tensor(
+ tCtAcc_fake.iterator,
+ cute.make_layout(
+ tCtAcc_fake.shape,
+ stride=(
+ tCtAcc_fake.stride[0],
+ tCtAcc_fake.stride[1],
+ tCtAcc_fake.stride[2],
+ (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
+ ),
+ ),
+ )
+ else:
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
+
+ # Cluster wait / CTA sync
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_wait()
+ else:
+ self.cta_sync_barrier.arrive_and_wait()
+
+ if total_token <= 0:
+ cute.arch.nvvm.exit()
+
+ # ---------------------------------------------------------------
+ # Specialized Scheduler warp (MoEPersistentTileScheduler)
+ # ---------------------------------------------------------------
+ if warp_idx == self.sched_warp_id:
+ work_tile_info = scheduler.initial_work_tile_info()
+ tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage)
+
+ while work_tile_info.is_valid_tile:
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx
+ sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx
+ sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx
+ sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ work_tile_info = scheduler.advance_to_next_work()
+
+ # Send invalid signal: expert_idx = -1
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1)
+ sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ tile_info_pipeline.producer_tail(tile_info_producer_state)
+
+ # ---------------------------------------------------------------
+ # Specialized TMA load warp
+ # ---------------------------------------------------------------
+ if warp_idx == self.tma_warp_id:
+ ext = self._make_extension(workspace_ptr)
+
+ ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ k_tile_cnt = work_tile_info.k_tile_cnt
+ ext.update_expert_info(padded_offsets, work_tile_info.expert_idx)
+
+ real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info)
+ real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info)
+ real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info)
+ real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info)
+
+ gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
+ gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None))
+
+ thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
+ thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ tCgA = thr_mma.partition_A(gA_mkl)
+ tCgB = thr_mma.partition_B(gB_nkl)
+ tCgSFA = thr_mma.partition_A(gSFA_mkl)
+ tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl)
+
+ a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
+ tAsA, tAgA = cpasync.tma_partition(
+ tma_atom_a,
+ block_in_cluster_coord_vmnk[2],
+ a_cta_layout,
+ cute.group_modes(sA, 0, 3),
+ cute.group_modes(tCgA, 0, 3),
+ )
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
+ tBsB, tBgB = cpasync.tma_partition(
+ tma_atom_b,
+ block_in_cluster_coord_vmnk[1],
+ b_cta_layout,
+ cute.group_modes(sB, 0, 3),
+ cute.group_modes(tCgB, 0, 3),
+ )
+ sfa_cta_layout = a_cta_layout
+ tAsSFA, tAgSFA = cpasync.tma_partition(
+ tma_atom_sfa,
+ block_in_cluster_coord_vmnk[2],
+ sfa_cta_layout,
+ cute.group_modes(sSFA, 0, 3),
+ cute.group_modes(tCgSFA, 0, 3),
+ )
+ tAsSFA = cute.filter_zeros(tAsSFA)
+ tAgSFA = cute.filter_zeros(tAgSFA)
+ sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
+ tBsSFB, tBgSFB = cpasync.tma_partition(
+ tma_atom_sfb,
+ block_in_cluster_coord_sfb_vmnk[1],
+ sfb_cta_layout,
+ cute.group_modes(sSFB, 0, 3),
+ cute.group_modes(tCgSFB, 0, 3),
+ )
+ tBsSFB = cute.filter_zeros(tBsSFB)
+ tBgSFB = cute.filter_zeros(tBgSFB)
+
+ mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape)
+ mma_tile_coord_n = work_tile_info.tile_n_idx
+ tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)]
+ tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)]
+ tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)]
+ slice_n = mma_tile_coord_n
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ slice_n = mma_tile_coord_n // 2
+ tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)]
+
+ ab_producer_state.reset_count()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
+ tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
+ tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)]
+ tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)]
+ tAsA_pipe = tAsA[(None, ab_producer_state.index)]
+ tBsB_pipe = tBsB[(None, ab_producer_state.index)]
+ tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)]
+ tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)]
+ tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state)
+
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
+ ab_producer_state_next = ab_producer_state.clone()
+ ab_producer_state_next.advance()
+ if ab_producer_state_next.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state_next)
+ else:
+ peek_ab_empty_status = cutlass.Boolean(1)
+
+ cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask)
+ cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b)
+ cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask)
+ cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb)
+
+ ab_producer_state.advance()
+
+ # Advance to next tile
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ ab_pipeline.producer_tail(ab_producer_state)
+
+ # ---------------------------------------------------------------
+ # Specialized MMA warp
+ # ---------------------------------------------------------------
+ if warp_idx == self.mma_warp_id:
+ tmem.wait_for_alloc()
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ sfa_tmem_ptr = cute.recast_ptr(acc_tmem_ptr + self.num_accumulator_tmem_cols, dtype=self.sf_dtype)
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
+
+ sfb_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
+
+ tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
+ tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
+
+ ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
+ acd_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ k_tile_cnt = tile_info[3]
+ ab_consumer_state.reset_count()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ acd_producer_state.reset_count()
+ peek_acc_empty_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state)
+
+ mma_tile_coord_mnl = (
+ tile_info[1] // cute.size(tiled_mma.thr_id.shape),
+ tile_info[2],
+ tile_info[0],
+ )
+
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acd_producer_state.phase ^ 1
+ else:
+ acc_stage_index = acd_producer_state.index
+
+ tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
+ tCtSFB_mma = tCtSFB
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+ elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+
+ if is_leader_cta:
+ acc_pipeline.producer_acquire(acd_producer_state, peek_acc_empty_status)
+
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ if is_leader_cta:
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
+ s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
+ tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
+ tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
+ cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t)
+ cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
+
+ num_kblocks = cute.size(tCrA, mode=[2])
+ ab_consumer_state_next = ab_consumer_state.clone()
+ ab_consumer_state_next.advance()
+ if ab_consumer_state_next.count < k_tile_cnt:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next)
+
+ for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
+ kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
+ sf_kblock_coord = (None, None, kblock_idx)
+ tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
+ tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator)
+ cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
+
+ ab_pipeline.consumer_release(ab_consumer_state)
+ ab_consumer_state = ab_consumer_state_next
+
+ if is_leader_cta:
+ acc_pipeline.producer_commit(acd_producer_state)
+
+ acd_producer_state.advance()
+ if acd_producer_state.count < k_tile_cnt:
+ if is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ acc_pipeline.producer_tail(acd_producer_state)
+
+ # ---------------------------------------------------------------
+ # Specialized bias load warp
+ # ---------------------------------------------------------------
+ if cutlass.const_expr(self.enable_bias):
+ if warp_idx == self.bias_load_warp_id and total_token > 0:
+ bias_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_bias_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ bias_elems_per_thread = 128 // self.bias_dtype.width
+ bias_g2s_atom = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ self.bias_dtype,
+ num_bits_per_copy=128,
+ )
+ bias_g2s_tiled = cute.make_tiled_copy_tv(
+ bias_g2s_atom,
+ cute.make_layout((self.threads_per_warp,)),
+ cute.make_layout((bias_elems_per_thread,)),
+ )
+ thr_bias_g2s = bias_g2s_tiled.get_slice(cute.arch.lane_idx())
+ tBs_sBias = thr_bias_g2s.partition_D(sBias)
+
+ bias_n_total = mBias_nl.shape[0]
+ tBpBias = cute.make_rmem_tensor(cute.make_layout((1,)), cutlass.Boolean)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ bias_producer_state.reset_count()
+ mma_n_coord = tile_info[2]
+ expert_idx = tile_info[0]
+ gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)]
+ tBs_gBias = thr_bias_g2s.partition_S(gBias_tile)
+ tBpBias[0] = mma_n_coord * self.mma_tiler[1] + cute.arch.lane_idx() * bias_elems_per_thread < bias_n_total
+ bias_pipeline.producer_acquire(bias_producer_state)
+ cute.copy(
+ bias_g2s_tiled,
+ tBs_gBias[(None, 0)],
+ tBs_sBias[(None, 0, bias_producer_state.index)],
+ pred=tBpBias,
+ )
+ bias_pipeline.producer_commit(bias_producer_state)
+ bias_producer_state.advance()
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ bias_pipeline.producer_tail(bias_producer_state)
+
+ # ---------------------------------------------------------------
+ # Specialized ACT epilogue warps (0-3): TMEM→regs, alpha, GLU activation,
+ # C store, hadamard_in
+ # ---------------------------------------------------------------
+ if warp_idx < self.epilog_rht_store_warp_id[0] and total_token > 0:
+ epi_tidx = tidx
+
+ #
+ # Alloc tensor memory buffer
+ #
+ tmem.allocate(self.num_tmem_alloc_cols)
+
+ #
+ # Bar sync for retrieve tensor memory ptr from shared memory
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr and make accumulator tensor
+ #
+ tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ # (MMA, MMA_M, MMA_N, STAGE)
+ tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
+
+ #
+ # Partition for epilogue (shape-only via mD_mnl for invariant setup)
+ #
+ thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v)
+ gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape)
+
+ (
+ tiled_copy_t2r,
+ tTR_tAcc_base,
+ tTR_rAcc_gate,
+ tTR_rAcc_up,
+ ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD_shape, epi_tile, use_2cta_instrs)
+
+ tTR_rC = cute.make_rmem_tensor(tTR_rAcc_gate.shape, self.c_dtype)
+ tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC)
+
+ #
+ # Create per-expert extension (for C/prob tensors inside tile loop)
+ #
+ epi_ext = self._make_extension(workspace_ptr)
+
+ #
+ # Persistent tile scheduling state
+ #
+ acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
+
+ #
+ # Pingpong producer state
+ #
+ pingpong_act_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_pingpong_stage)
+
+ # Threads/warps participating in TMA store pipeline for C
+ c_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp * len(self.epilog_act_warp_id),
+ )
+ c_pipeline = pipeline.PipelineTmaStore.create(
+ num_stages=self.num_c_stage,
+ producer_group=c_producer_group,
+ )
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_bias_stage)
+ bias_s2r_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.bias_dtype, num_bits_per_copy=128)
+ tTR_rBias_gate = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype)
+ tTR_rBias_up = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype)
+
+ # Get the first tile info
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ num_prev_subtiles = cutlass.Int32(0)
+ while is_valid_tile:
+ # sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt)
+ epi_work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ mma_tile_coord_mnl = (
+ epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
+ epi_work_tile_info.tile_n_idx,
+ cutlass.Int32(0),
+ )
+
+ expert_idx = epi_work_tile_info.expert_idx
+ alpha_val = alpha[expert_idx]
+ epi_ext.update_expert_info(padded_offsets, epi_work_tile_info.expert_idx)
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_consumer_state.reset_count()
+ bias_pipeline.consumer_wait(bias_consumer_state)
+ sBias_stage = sBias[(None, bias_consumer_state.index)]
+ sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(2 * self.epi_tile[1]))
+
+ #
+ # Get per-expert C tensor inside tile loop
+ #
+ real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info)
+ gC_mnl = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
+ thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v)
+ tCgC = thr_mma_epi_loop.partition_C(gC_mnl)
+ _, bSG_sC, bSG_gC_partitioned = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, self.epi_tile_c, sC)
+ bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)]
+ bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
+
+ #
+ # Get per-expert prob tensor inside tile loop
+ #
+ real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info)
+ mPosition = (
+ (epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape)) * self.mma_tiler[0]
+ + mma_tile_coord_v * (self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape))
+ + tidx
+ )
+ mProb = real_prob[mPosition, 0, 0]
+
+ #
+ # Get accumulator stage index
+ #
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acc_consumer_state.phase
+ reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
+ else:
+ acc_stage_index = acc_consumer_state.index
+
+ # Set tensor memory buffer for current tile
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, STAGE)
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
+
+ #
+ # Wait for accumulator buffer full
+ #
+ acc_pipeline.consumer_wait(acc_consumer_state)
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+
+ #
+ # Store accumulator to global memory in subtiles
+ #
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
+ for subtile_idx in cutlass.range(0, subtile_cnt, 2, unroll=1):
+ real_subtile_idx = subtile_idx // 2
+ if cutlass.const_expr(self.overlapping_accum):
+ if reverse_subtile:
+ real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx // 2
+
+ #
+ # Load accumulator from tensor memory buffer to register
+ #
+ tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2)]
+ tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)]
+
+ cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate)
+ cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up)
+
+ #
+ # Async arrive accumulator buffer empty earlier when overlapping_accum is enabled
+ #
+ if cutlass.const_expr(self.overlapping_accum):
+ if subtile_idx == self.iter_acc_early_release_in_epilogue:
+ cute.arch.fence_view_async_tmem_load()
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ #
+ # Notify pingpong consumer for subtile > 0
+ #
+ if subtile_idx != 0:
+ pingpong_pipeline.producer_commit(pingpong_act_producer_state)
+ pingpong_act_producer_state.advance()
+
+ #
+ # Apply alpha (+ bias when enabled)
+ #
+ if cutlass.const_expr(self.enable_bias):
+ sBias_pair = sBias_subtiles[(None, real_subtile_idx)]
+ sBias_sub = cute.flat_divide(sBias_pair, cute.make_layout(self.epi_tile[1]))
+ cute.copy(bias_s2r_atom, sBias_sub[(None, 0)], tTR_rBias_gate)
+ bias_vec_gate = tTR_rBias_gate.load()
+ cute.copy(bias_s2r_atom, sBias_sub[(None, 1)], tTR_rBias_up)
+ bias_vec_up = tTR_rBias_up.load()
+
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_gate), 2):
+ bias_gate_f32_0 = bias_vec_gate[i].to(cutlass.Float32)
+ bias_gate_f32_1 = bias_vec_gate[i + 1].to(cutlass.Float32)
+ bias_up_f32_0 = bias_vec_up[i].to(cutlass.Float32)
+ bias_up_f32_1 = bias_vec_up[i + 1].to(cutlass.Float32)
+ tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1] = cute.arch.fma_packed_f32x2(
+ (tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1]),
+ (
+ cutlass.Float32(alpha_val),
+ cutlass.Float32(alpha_val),
+ ),
+ (bias_gate_f32_0, bias_gate_f32_1),
+ rnd="rn",
+ ftz=False,
+ )
+ tTR_rAcc_up[i], tTR_rAcc_up[i + 1] = cute.arch.fma_packed_f32x2(
+ (tTR_rAcc_up[i], tTR_rAcc_up[i + 1]),
+ (
+ cutlass.Float32(alpha_val),
+ cutlass.Float32(alpha_val),
+ ),
+ (bias_up_f32_0, bias_up_f32_1),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc_gate)):
+ tTR_rAcc_gate[i] = tTR_rAcc_gate[i] * cutlass.Float32(alpha_val) + bias_vec_gate[i].to(cutlass.Float32)
+ tTR_rAcc_up[i] = tTR_rAcc_up[i] * cutlass.Float32(alpha_val) + bias_vec_up[i].to(cutlass.Float32)
+
+ if subtile_idx == subtile_cnt - 2:
+ bias_pipeline.consumer_release(bias_consumer_state)
+ bias_consumer_state.advance()
+ else:
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_gate), 2):
+ tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1] = cute.arch.mul_packed_f32x2(
+ (tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1]),
+ (
+ cutlass.Float32(alpha_val),
+ cutlass.Float32(alpha_val),
+ ),
+ rnd="rn",
+ ftz=False,
+ )
+ tTR_rAcc_up[i], tTR_rAcc_up[i + 1] = cute.arch.mul_packed_f32x2(
+ (tTR_rAcc_up[i], tTR_rAcc_up[i + 1]),
+ (
+ cutlass.Float32(alpha_val),
+ cutlass.Float32(alpha_val),
+ ),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc_gate)):
+ tTR_rAcc_gate[i] = tTR_rAcc_gate[i] * cutlass.Float32(alpha_val)
+ tTR_rAcc_up[i] = tTR_rAcc_up[i] * cutlass.Float32(alpha_val)
+
+ #
+ # Store gate+up to C tensor (pre-activation for residual)
+ #
+ self.store_c(
+ tiled_copy_r2s,
+ tma_atom_c,
+ warp_idx,
+ tTR_rAcc_gate,
+ tTR_rAcc_up,
+ tRS_rC,
+ tRS_sC,
+ bSG_gC,
+ bSG_sC,
+ c_pipeline,
+ num_prev_subtiles,
+ real_subtile_idx,
+ )
+ num_prev_subtiles = num_prev_subtiles + 1
+
+ #
+ # GeGLU clamp before C store
+ #
+ if cutlass.const_expr(self.act_func == "geglu"):
+ geglu_max_val = cutlass.Float32(7.0)
+ geglu_min_val = cutlass.Float32(-7.0)
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)):
+ tTR_rAcc_gate[i] = fmin(tTR_rAcc_gate[i], geglu_max_val)
+ tTR_rAcc_up[i] = fmin(tTR_rAcc_up[i], geglu_max_val)
+ tTR_rAcc_up[i] = fmax(tTR_rAcc_up[i], geglu_min_val)
+
+ acc_vec_gate = tTR_rAcc_gate.load()
+ acc_vec_up = tTR_rAcc_up.load()
+
+ #
+ # Compute GLU activation (SwiGLU or GeGLU)
+ #
+ tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype)
+ if cutlass.const_expr(self.act_func == "geglu"):
+ self.geglu_act(tCompute, acc_vec_up, acc_vec_gate, mProb, linear_offset)
+ elif cutlass.const_expr(self.act_func == "swiglu"):
+ self.swiglu_act(tCompute, acc_vec_up, acc_vec_gate, mProb)
+
+ #
+ # Convert to BF16 and write to TMEM for Hadamard
+ #
+ tCompute_hadamard = cute.make_rmem_tensor(tCompute.layout, cutlass.BFloat16)
+ tCompute_hadamard.store(tCompute.load().to(tCompute_hadamard.element_type))
+
+ #
+ # Pingpong producer acquire for current subtile
+ #
+ pingpong_pipeline.producer_acquire(pingpong_act_producer_state)
+ hadamard_in(
+ tCompute_hadamard,
+ HADAMARD_SIZE,
+ self.query_hadamard_tmem_a_ptr(subtile_idx, reverse_subtile, tmem_ptr),
+ epi_tidx,
+ )
+
+ #
+ # Delayed TMA store acquire + group sync (always enabled)
+ #
+ if cutlass.const_expr(self.delay_tma_store_acquire_sync):
+ if warp_idx == self.epilog_act_warp_id[0]:
+ c_pipeline.producer_acquire()
+ self.epilog_sync_barrier_group0.arrive_and_wait()
+
+ #
+ # Pingpong producer commit last subtile
+ #
+ pingpong_pipeline.producer_commit(pingpong_act_producer_state)
+ pingpong_act_producer_state.advance()
+
+ #
+ # Full epilogue barrier (ACT + RHT must both arrive)
+ #
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ #
+ # Async arrive accumulator buffer empty
+ #
+ if cutlass.const_expr(not self.overlapping_accum):
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ #
+ # Advance to next tile
+ #
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ #
+ # Dealloc the tensor memory buffer
+ #
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier_group0.arrive_and_wait()
+ tmem.free(tmem_ptr)
+ #
+ # Wait for C store / pingpong complete
+ #
+ c_pipeline.producer_tail()
+ pingpong_pipeline.producer_tail(pingpong_act_producer_state)
+
+ # ---------------------------------------------------------------
+ # Specialized RHT store warps (4-7): hadamard_compute + D store
+ # ---------------------------------------------------------------
+ if warp_idx < self.mma_warp_id and warp_idx >= self.epilog_rht_store_warp_id[0] and total_token > 0:
+ epi_tidx = tidx % 128
+
+ #
+ # Alloc tensor memory buffer
+ #
+ tmem.allocate(self.num_tmem_alloc_cols)
+
+ #
+ # Bar sync for retrieve tensor memory ptr from shared memory
+ #
+ tmem.wait_for_alloc()
+
+ #
+ # Retrieving tensor memory ptr
+ #
+ tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+
+ #
+ # Hadamard setup (loads Hadamard matrix from GMEM to SMEM)
+ #
+ tiled_hmma = hadamard_setup(hadamard_tensor_local, sHadamard, epi_tidx)
+
+ #
+ # Partition for epilogue (shape-only via mD_mnl)
+ #
+ tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
+ thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v)
+ gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape)
+
+ (
+ tiled_copy_t2r,
+ tTR_tAcc_base,
+ tTR_rAcc_gate,
+ tTR_rAcc_up,
+ ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD_shape, epi_tile, use_2cta_instrs)
+
+ #
+ # Pingpong consumer state
+ #
+ pingpong_rht_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_pingpong_stage)
+
+ #
+ # D register and smem partition
+ #
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc_gate.shape, self.d_dtype)
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD)
+
+ #
+ # Create per-expert extension (for D tensor inside tile loop)
+ #
+ epi_ext = self._make_extension(workspace_ptr)
+
+ # Threads/warps participating in TMA store pipeline for D
+ d_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp * len(self.epilog_rht_store_warp_id),
+ )
+ d_pipeline = pipeline.PipelineTmaStore.create(
+ num_stages=self.num_d_stage,
+ producer_group=d_producer_group,
+ )
+
+ # Get the first tile info
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ #
+ # Accumulator stage index (for overlapping_accum)
+ #
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = 0
+ reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
+ else:
+ acc_stage_index = 0
+
+ num_prev_subtiles = cutlass.Int32(0)
+ while is_valid_tile:
+ # sInfo format: (expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt)
+ epi_work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ mma_tile_coord_mnl = (
+ epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
+ epi_work_tile_info.tile_n_idx,
+ cutlass.Int32(0),
+ )
+ expert_idx = epi_work_tile_info.expert_idx
+ epi_ext.update_expert_info(padded_offsets, epi_work_tile_info.expert_idx)
+
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = cutlass.Float32(0.0)
+
+ #
+ # Get per-expert D tensor inside tile loop
+ #
+ real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info)
+ gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v)
+ tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop)
+ _, bSG_sD, bSG_gD_partitioned = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD_loop, epi_tile, sD)
+ bSG_gD = bSG_gD_partitioned[(None, None, None, *mma_tile_coord_mnl)]
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
+
+ #
+ # Set tensor memory buffer for current tile
+ # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, STAGE)
+ #
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+
+ #
+ # Store accumulator to global memory in subtiles
+ #
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
+ for subtile_idx in cutlass.range(0, subtile_cnt, 2, unroll=1):
+ real_subtile_idx = subtile_idx // 2
+ if cutlass.const_expr(self.overlapping_accum):
+ if reverse_subtile:
+ real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx // 2
+
+ #
+ # Get Hadamard TMEM pointers for this subtile
+ #
+ hadamard_tmem_a_ptr = self.query_hadamard_tmem_a_ptr(subtile_idx, reverse_subtile, tmem_ptr)
+ hadamard_tmem_acc_ptr = self.query_hadamard_tmem_acc_ptr(subtile_idx, reverse_subtile, tmem_ptr)
+ tCompute = cute.make_rmem_tensor(tTR_rAcc_gate.layout, cutlass.Float32)
+
+ #
+ # Cross-CTA sync (prerequisite for Hadamard: both CTAs must have
+ # written hadamard_in before either starts hadamard_compute)
+ #
+ hadamard_prerequisite_empty = hadamard_prerequisite_producer.acquire_and_advance()
+ hadamard_prerequisite_empty.commit()
+ hadamard_prerequisite_empty = hadamard_prerequisite_consumer.wait_and_advance()
+ hadamard_prerequisite_consumer.release(hadamard_prerequisite_empty)
+
+ #
+ # Wait for pingpong producer (ACT warp) ready
+ #
+ pingpong_pipeline.consumer_wait(pingpong_rht_consumer_state)
+
+ #
+ # Apply Hadamard transform (reads from TMEM a_ptr, writes to TMEM acc_ptr)
+ #
+ hadamard_compute(
+ tiled_hmma,
+ hadamard_tmem_a_ptr,
+ hadamard_tmem_acc_ptr,
+ storage.sHadamard,
+ epi_tile,
+ epi_tidx,
+ hadamard_producer,
+ )
+ hadamard_consumer.wait_and_advance()
+
+ #
+ # Get Hadamard result from TMEM to registers
+ #
+ hadamard_out(tCompute, epi_tile[1], hadamard_tmem_acc_ptr, epi_tidx)
+
+ #
+ # Release pingpong consumer slot
+ #
+ pingpong_pipeline.consumer_release(pingpong_rht_consumer_state)
+ pingpong_rht_consumer_state.advance()
+
+ #
+ # Amax accumulation per subtile
+ #
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = self.amax_reduction_per_thread(tCompute, thread_tile_amax)
+
+ #
+ # Convert to D dtype and store D to shared memory
+ #
+ acc_vec = tiled_copy_r2s.retile(tCompute).load()
+ tRS_rD.store(acc_vec.to(self.d_dtype))
+
+ d_buffer = num_prev_subtiles % self.num_d_stage
+ num_prev_subtiles = num_prev_subtiles + 1
+ cute.copy(
+ tiled_copy_r2s,
+ tRS_rD,
+ tRS_sD[(None, None, None, d_buffer)],
+ )
+ # Fence and barrier to make sure shared memory store is visible to TMA
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier_group1.arrive_and_wait()
+ #
+ # TMA store D to global memory
+ #
+ if warp_idx == self.epilog_rht_store_warp_id[0]:
+ cute.copy(
+ tma_atom_d,
+ bSG_sD[(None, d_buffer)],
+ bSG_gD[(None, real_subtile_idx)],
+ )
+ d_pipeline.producer_commit()
+ d_pipeline.producer_acquire()
+ self.epilog_sync_barrier_group1.arrive_and_wait()
+
+ #
+ # Full epilogue barrier (ACT + RHT must both arrive)
+ #
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ #
+ # Update overlapping_accum stage index for next tile
+ #
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acc_stage_index ^ 1
+ reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
+ else:
+ acc_stage_index = 0
+
+ #
+ # Advance to next tile
+ #
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ #
+ # Amax reduction per tile (across warps and CTAs)
+ #
+ if cutlass.const_expr(self.generate_amax):
+ gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr # First element
+ self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax)
+
+ #
+ # Dealloc the tensor memory buffer
+ #
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier_group1.arrive_and_wait()
+ tmem.free(tmem_ptr)
+ #
+ # Wait for D store complete
+ #
+ d_pipeline.producer_tail()
+
+ # END OF KERNEL
+
+
+class BlockScaledMoEGroupedGemmGluHadamardCompatKernel(BlockScaledMoEGroupedGemmGluHadamardKernel):
+ """Compatibility adapter for the existing grouped GLU frontend wrapper shape."""
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vectorized_f32: bool,
+ generate_sfd: bool,
+ discrete_col_sfd: bool,
+ expert_cnt: int,
+ weight_mode: MoEWeightMode = MoEWeightMode.DISCRETE,
+ use_dynamic_sched: bool = False,
+ act_func: str = "swiglu",
+ enable_bias: bool = False,
+ ):
+ del generate_sfd, discrete_col_sfd
+ super().__init__(
+ sf_vec_size=sf_vec_size,
+ acc_dtype=acc_dtype,
+ use_2cta_instrs=use_2cta_instrs,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ vectorized_f32=vectorized_f32,
+ expert_cnt=expert_cnt,
+ weight_mode=weight_mode,
+ use_dynamic_sched=use_dynamic_sched,
+ act_func=act_func,
+ enable_bias=enable_bias,
+ )
+
+ def __call__(
+ self,
+ a: cute.Tensor,
+ b,
+ sfb,
+ n: Int32,
+ k: Int32,
+ b_stride_size: cutlass.Int64,
+ b_major_mode: cutlass.Constexpr,
+ workspace_ptr,
+ c: cute.Tensor,
+ d: cute.Tensor,
+ d_col: cute.Tensor,
+ sfa: cute.Tensor,
+ sfd_row_tensor: Optional[cute.Tensor],
+ sfd_col_tensor: Optional[cute.Tensor],
+ amax_tensor: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ prob: cute.Tensor,
+ bias: Optional[cute.Tensor],
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ epilogue_op: cutlass.Constexpr = lambda x: x,
+ linear_offset: cutlass.Float32 = 0.0,
+ ):
+ del d_col, sfd_row_tensor, sfd_col_tensor, norm_const_tensor
+ return super().__call__(
+ a=a,
+ b=b,
+ sfa=sfa,
+ sfb=sfb,
+ n=n,
+ k=k,
+ b_stride_size=b_stride_size,
+ b_major_mode=b_major_mode,
+ workspace_ptr=workspace_ptr,
+ c=c,
+ d=d,
+ amax_tensor=amax_tensor,
+ padded_offsets=padded_offsets,
+ alpha=alpha,
+ prob=prob,
+ hadamard_tensor=None,
+ bias=bias,
+ max_active_clusters=max_active_clusters,
+ stream=stream,
+ epilogue_op=epilogue_op,
+ linear_offset=linear_offset,
+ )
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py b/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
index be30580f..f2db2cfa 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_quant/api.py
@@ -75,7 +75,7 @@ def __init__(
sample_padded_offsets: torch.Tensor,
sample_alpha: torch.Tensor,
sample_d: torch.Tensor,
- sample_d_col: torch.Tensor,
+ sample_d_col: Optional[torch.Tensor] = None,
# Dense mode (contiguous) -- provide these:
sample_b: Optional[torch.Tensor] = None,
sample_sfb: Optional[torch.Tensor] = None,
@@ -90,8 +90,6 @@ def __init__(
sample_amax: Optional[torch.Tensor] = None,
sample_norm_const: Optional[torch.Tensor] = None,
sample_prob: Optional[torch.Tensor] = None,
- # Internal: C tensor placeholder (kernel compilation requires it)
- sample_c: Optional[torch.Tensor] = None,
# Configuration
acc_dtype: torch.dtype = torch.float32,
mma_tiler_mn: Tuple[int, int] = (256, 256),
@@ -110,7 +108,7 @@ def __init__(
:param sample_padded_offsets: End offset for each expert after padding, shape (expert_cnt,)
:param sample_alpha: Per-group alpha scaling factors
:param sample_d: Sample D output tensor (valid_m, n, 1)
- :param sample_d_col: Column-quantized D tensor (required for quant kernel)
+ :param sample_d_col: Optional column-quantized D tensor. Required only when SFD outputs are generated.
:param sample_b: (Dense) Sample B tensor (n, k, l)
:param sample_sfb: (Dense) Sample scale factor B tensor
:param sample_bias: Optional bias tensor with shape (n, l) or (n, expert_cnt), stride (1, n).
@@ -123,7 +121,6 @@ def __init__(
:param sample_amax: Optional amax tensor for quantization
:param sample_norm_const: Optional normalization constant
:param sample_prob: Optional probability tensor for gating
- :param sample_c: Internal C tensor placeholder (kernel requires it for dtype inference)
:param acc_dtype: Accumulator data type
:param mma_tiler_mn: MMA tiler shape (M, N)
:param cluster_shape_mn: Cluster shape (M, N)
@@ -136,7 +133,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("GroupedGemmQuantSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
# ---- Weight mode auto-detection ----
@@ -157,7 +154,17 @@ def __init__(
self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets")
self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha")
+ self._has_d_col = sample_d_col is not None
self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col")
+ if self.d_col_desc is None:
+ self.d_col_desc = TensorDesc(
+ dtype=self.d_desc.dtype,
+ shape=self.d_desc.shape,
+ stride=self.d_desc.stride,
+ stride_order=self.d_desc.stride_order,
+ device=self.d_desc.device,
+ name="sample_d_col",
+ )
self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row")
self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col")
self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax")
@@ -169,20 +176,6 @@ def __init__(
self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias")
- # C tensor: required by kernel for dtype inference but never written to (generate_c=False).
- # If not provided, derive from D descriptor with bfloat16 dtype.
- if sample_c is not None:
- self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
- else:
- self.c_desc = TensorDesc(
- dtype=torch.bfloat16,
- shape=self.d_desc.shape,
- stride=self.d_desc.stride,
- stride_order=self.d_desc.stride_order,
- device=self.d_desc.device,
- name="sample_c",
- )
-
if self.weight_mode == MoEWeightMode.DENSE:
self.b_desc = self._make_tensor_desc(sample_b, name="sample_b")
self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
@@ -218,8 +211,9 @@ def __init__(
self._kernel = BlockScaledMoEGroupedGemmQuantKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._workspace = None
+ self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
self._logger.debug("__init__ completed")
def check_support(self) -> bool:
@@ -236,6 +230,10 @@ def check_support(self) -> bool:
"sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None",
)
self.generate_sfd = all_provided
+ self._value_error_if(
+ self.generate_sfd and not self._has_d_col,
+ "sample_d_col is required when SFD outputs are generated",
+ )
if self.discrete_col_sfd and not self.generate_sfd:
self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored")
self.discrete_col_sfd = False
@@ -253,13 +251,11 @@ def check_support(self) -> bool:
self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})")
l = self.expert_cnt
- _, _, _one = self._tensor_shape(self.c_desc, name="sample_c")
_, _, _one = self._tensor_shape(self.d_desc, name="sample_d")
self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A")
if self.weight_mode == MoEWeightMode.DENSE:
self._check_tensor_shape(self.b_desc, (n, k, l), "B")
- self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C")
self._check_tensor_shape(self.d_desc, (tensor_m, n, 1), "D")
self._check_tensor_shape(self.d_col_desc, (tensor_m, n, 1), "D_col")
@@ -302,11 +298,6 @@ def check_support(self) -> bool:
stride=[(k, 1, n * k)],
extra_error_msg="For fp4 ab_dtype, B must have k-major layout",
)
- _ = self._check_tensor_stride(
- self.c_desc,
- stride=[(n, 1, tensor_m * n)],
- extra_error_msg="C must have n-major layout",
- )
_ = self._check_tensor_stride(
self.d_desc,
stride=[(n, 1, tensor_m * n)],
@@ -402,19 +393,6 @@ def check_support(self) -> bool:
name="Accumulator",
extra_error_msg="Accumulator must be float32",
)
- self.c_dtype = self._check_dtype(
- self.c_desc,
- dtype=[
- torch.float32,
- torch.float16,
- torch.bfloat16,
- torch.float8_e4m3fn,
- torch.float8_e5m2,
- torch.float4_e2m1fn_x2,
- ],
- name="C",
- )
-
if self._is_fp4x2(self.ab_dtype):
self.d_dtype = self._check_dtype(
self.d_desc,
@@ -536,11 +514,6 @@ def check_contigous_16B_alignment(dtype, stride_order, tensor_shape):
"Invalid configuration: fp8 ab_dtype and sf_vec_size 32 with mma_tiler_mn[1] == 128 and fp8 d_dtype is not supported. "
"Please use mma_tiler_mn[1] == 256 instead",
)
- self._not_implemented_error_if(
- self._is_fp4x2(self.ab_dtype) and (self.c_dtype not in [torch.float16, torch.bfloat16]),
- f"Invalid configuration: for fp4 ab_dtype, c_dtype must be float16 or bfloat16, got {self.c_dtype}",
- )
-
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
@@ -564,6 +537,8 @@ def compile(self) -> None:
self._logger.debug("sample valid_m is zero, skipping kernel compilation")
return
+ self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+
gemm_quant = self._kernel(
sf_vec_size=self.sf_vec_size,
acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype),
@@ -573,7 +548,6 @@ def compile(self) -> None:
vectorized_f32=self.vector_f32,
generate_sfd=self.generate_sfd,
discrete_col_sfd=self.discrete_col_sfd,
- generate_c=False,
enable_bias=self._has_bias,
expert_cnt=self.expert_cnt,
weight_mode=self.weight_mode,
@@ -607,7 +581,7 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
)
self._logger.debug("Compiling grouped_gemm_quant kernel")
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = self._use_full_dynamic_mnkl
if not use_full_dynamic:
valid_m = cute.sym_int(divisibility=256)
@@ -618,11 +592,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
stride_order=self.a_desc.stride_order,
)
b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
- c_cute_fake = self._make_fake_cute_compact_tensor(
- dtype=self.c_desc.dtype,
- shape=(valid_m, *self.c_desc.shape[1:]),
- stride_order=self.c_desc.stride_order,
- )
d_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.d_desc.dtype,
shape=(valid_m, *self.d_desc.shape[1:]),
@@ -677,7 +646,8 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16)
else:
valid_m = cute.sym_int(divisibility=256)
- n_sym = cute.sym_int()
+ n_sym_divisibility = 128 // _convert_to_cutlass_data_type(self.bias_desc.dtype).width if self.bias_desc is not None else 1
+ n_sym = cute.sym_int(divisibility=n_sym_divisibility)
k_sym = cute.sym_int()
l_sym = cute.sym_int()
@@ -695,13 +665,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
dynamic_mode=self.b_desc.stride_order[0],
divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
)
- c_cute_fake = self._make_fake_cute_compact_tensor(
- dtype=self.c_desc.dtype,
- shape=(valid_m, n_sym, 1),
- stride_order=self.c_desc.stride_order,
- dynamic_mode=self.c_desc.stride_order[0],
- divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
- )
d_cute_fake = self._make_fake_cute_compact_tensor(
dtype=self.d_desc.dtype,
shape=(valid_m, n_sym, 1),
@@ -790,7 +753,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
b_stride_size=cutlass.Int64(0),
b_major_mode=OperandMajorMode.K,
workspace_ptr=fake_workspace_ptr,
- c=c_cute_fake,
d=d_cute_fake,
d_col=d_col_cute_fake,
sfa=sfa_cute_fake,
@@ -812,7 +774,6 @@ def _compile_dense(self, gemm_quant, max_active_clusters, fake_stream) -> None:
def tensor_api(
a_tensor: torch.Tensor,
b_tensor: torch.Tensor,
- c_tensor: torch.Tensor,
d_tensor: torch.Tensor,
d_col_tensor: Optional[torch.Tensor],
sfa_tensor: torch.Tensor,
@@ -836,7 +797,6 @@ def tensor_api(
cutlass.Int32(0),
cutlass.Int64(0),
cached_workspace_ptr,
- c_tensor,
d_tensor,
d_col_tensor,
sfa_tensor,
@@ -873,11 +833,6 @@ def _compile_discrete(self, gemm_quant, max_active_clusters, fake_stream) -> Non
stride_order=self.a_desc.stride_order,
assumed_align=align,
)
- c_tensor = self._make_fake_cute_compact_tensor(
- dtype=self.c_desc.dtype,
- shape=(valid_m, *self.c_desc.shape[1:]),
- stride_order=self.c_desc.stride_order,
- )
d_tensor = self._make_fake_cute_compact_tensor(
dtype=self.d_desc.dtype,
shape=(valid_m, *self.d_desc.shape[1:]),
@@ -942,29 +897,28 @@ def _compile_discrete(self, gemm_quant, max_active_clusters, fake_stream) -> Non
self._logger.debug("Compiling discrete grouped_gemm_quant kernel")
_compiled_kernel = cute.compile(
gemm_quant,
- a_tensor,
- b_ptrs_cute,
- sfb_ptrs_cute,
- cutlass.Int32(n),
- cutlass.Int32(k),
- cutlass.Int64(b_stride_size),
- b_major_mode,
- workspace_ptr_cute,
- c_tensor,
- d_tensor,
- d_col_tensor,
- sfa_tensor,
- sfd_row_tensor,
- sfd_col_tensor,
- amax_tensor,
- norm_const_tensor_cute,
- padded_offsets_tensor,
- alpha_tensor,
- bias_cute_fake,
- prob_tensor,
- max_active_clusters,
- fake_stream,
- lambda x: x,
+ a=a_tensor,
+ b=b_ptrs_cute,
+ sfb=sfb_ptrs_cute,
+ n=cutlass.Int32(n),
+ k=cutlass.Int32(k),
+ b_stride_size=cutlass.Int64(b_stride_size),
+ b_major_mode=b_major_mode,
+ workspace_ptr=workspace_ptr_cute,
+ d=d_tensor,
+ d_col=d_col_tensor,
+ sfa=sfa_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor_cute,
+ padded_offsets=padded_offsets_tensor,
+ alpha=alpha_tensor,
+ bias=bias_cute_fake,
+ prob=prob_tensor,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ epilogue_op=lambda x: x,
options="--enable-tvm-ffi",
)
@@ -977,7 +931,6 @@ def tensor_api(
a_tensor: torch.Tensor,
b_ptrs_device: torch.Tensor,
sfb_ptrs_device: torch.Tensor,
- c_tensor: torch.Tensor,
d_tensor: torch.Tensor,
d_col_tensor: Optional[torch.Tensor],
sfa_tensor: torch.Tensor,
@@ -1002,7 +955,6 @@ def tensor_api(
cached_k,
cached_b_stride,
cached_workspace_ptr,
- c_tensor,
d_tensor,
d_col_tensor,
sfa_tensor,
@@ -1033,7 +985,6 @@ def execute(
# Discrete mode:
b_ptrs: Optional[torch.Tensor] = None,
sfb_ptrs: Optional[torch.Tensor] = None,
- c_tensor: Optional[torch.Tensor] = None,
d_col_tensor: Optional[torch.Tensor] = None,
sfd_row_tensor: Optional[torch.Tensor] = None,
sfd_col_tensor: Optional[torch.Tensor] = None,
@@ -1056,8 +1007,6 @@ def execute(
at construction, ``bias_tensor`` must also be omitted at execute time.
:param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers
:param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers
- :param c_tensor: Optional C tensor placeholder (kernel requires it but never writes to it;
- a minimal dummy is created automatically if not provided)
:param d_col_tensor: Optional column-quantized output
:param sfd_row_tensor: Optional row scale factor D
:param sfd_col_tensor: Optional column scale factor D
@@ -1077,13 +1026,12 @@ def execute(
"Kernel not compiled; call compile() first",
)
- if c_tensor is None:
- c_tensor = torch.empty_strided(
- self.c_desc.shape,
- self.c_desc.stride,
- dtype=self.c_desc.dtype,
- device=d_tensor.device,
+ if d_col_tensor is None:
+ self._value_error_if(
+ self.generate_sfd,
+ "d_col_tensor is required when SFD outputs are generated",
)
+ d_col_tensor = d_tensor
self._value_error_if(
prob_tensor is None,
"prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. "
@@ -1105,7 +1053,6 @@ def execute(
self._compiled_kernel(
a_tensor=a_tensor,
b_tensor=b_tensor,
- c_tensor=c_tensor,
d_tensor=d_tensor,
d_col_tensor=d_col_tensor,
sfa_tensor=sfa_tensor,
@@ -1125,7 +1072,6 @@ def execute(
a_tensor=a_tensor,
b_ptrs_device=b_ptrs,
sfb_ptrs_device=sfb_ptrs,
- c_tensor=c_tensor,
d_tensor=d_tensor,
d_col_tensor=d_col_tensor,
sfa_tensor=sfa_tensor,
@@ -1165,7 +1111,6 @@ def grouped_gemm_quant_wrapper_sm100(
norm_const_tensor: Optional[torch.Tensor] = None,
prob_tensor: Optional[torch.Tensor] = None,
acc_dtype: torch.dtype = torch.float32,
- c_dtype: torch.dtype = torch.bfloat16,
d_dtype: torch.dtype = torch.bfloat16,
cd_major: str = "n",
mma_tiler_mn: Tuple[int, int] = (256, 256),
@@ -1202,7 +1147,6 @@ def grouped_gemm_quant_wrapper_sm100(
prob_tensor: Probability tensor for per-row gating (shape `(valid_m, 1, 1)`).
This argument is required. Pass a tensor of ones when no gating is needed.
acc_dtype: Accumulator data type
- c_dtype: Internal C tensor data type (not user-visible)
d_dtype: Output D tensor data type
cd_major: CD major dimension (only "n"-major layout is supported)
mma_tiler_mn: MMA tiler shape
@@ -1217,7 +1161,7 @@ def grouped_gemm_quant_wrapper_sm100(
TupleDict: A dictionary-like object containing output tensors that can also be unpacked as a tuple.
Dictionary keys (also the unpacking order):
- **d_tensor** (torch.Tensor): Final output tensor
- - **d_col_tensor** (torch.Tensor): Column-wise output tensor
+ - **d_col_tensor** (torch.Tensor or None): Column-wise output tensor for low-precision D output
- **amax_tensor** (torch.Tensor or None): Absolute maximum values (for quantization)
- **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D (FP8 only)
- **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D (FP8 only)
@@ -1264,19 +1208,6 @@ def grouped_gemm_quant_wrapper_sm100(
if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, num_experts):
raise ValueError(f"bias_tensor must have shape {(n_out, num_experts)}, got {tuple(bias_tensor.shape)}")
- _logger.debug("grouped_gemm_quant_wrapper_sm100: Creating output tensors d_tensor, d_col_tensor")
-
- if cd_major == "n":
- c_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=c_dtype, device=a_tensor.device)
- d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
- d_col_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
- else:
- raise ValueError(f"cd_major must be 'n', got {cd_major}")
-
- sfd_row_tensor = None
- sfd_col_tensor = None
- amax_tensor = None
-
is_fp8_input_config = a_tensor.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
@@ -1284,23 +1215,39 @@ def grouped_gemm_quant_wrapper_sm100(
torch.float8_e8m0fnu,
torch.float8_e4m3fn,
]
- is_fp8_output_config = d_dtype in [
+ is_low_precision_output_config = d_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float4_e2m1fn_x2,
]
- if is_fp8_input_config and is_fp8_output_config and norm_const_tensor is None:
+ _logger.debug("grouped_gemm_quant_wrapper_sm100: Creating output tensors")
+
+ if cd_major == "n":
+ d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ d_col_tensor = (
+ torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ if is_low_precision_output_config
+ else None
+ )
+ else:
+ raise ValueError(f"cd_major must be 'n', got {cd_major}")
+
+ sfd_row_tensor = None
+ sfd_col_tensor = None
+ amax_tensor = None
+
+ if is_fp8_input_config and is_low_precision_output_config and norm_const_tensor is None:
raise ValueError(
"norm_const_tensor is required when FP8 inputs are used with FP8 output "
"(a_tensor is FP8 and sfa_tensor is FP8 and d_dtype is FP8). "
"Pass a tensor with shape (1,), e.g. torch.tensor([0.01], dtype=torch.float32, device=a_tensor.device)."
)
- if not is_fp8_output_config:
+ if not is_low_precision_output_config:
norm_const_tensor = None
- if is_fp8_input_config and is_fp8_output_config:
+ if is_fp8_input_config and is_low_precision_output_config:
_logger.debug("grouped_gemm_quant_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor")
sf_dtype = sfa_tensor.dtype
@@ -1369,7 +1316,7 @@ def dynamic_m_tensor_signature(
stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
return static_shape_suffix, stride_signature, tensor.dtype
- use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
if is_dense:
cache_key = (
@@ -1381,9 +1328,8 @@ def dynamic_m_tensor_signature(
b_tensor.dtype,
stride_order(a_tensor),
stride_order(b_tensor),
- c_tensor.shape[1:] if not use_full_dynamic else None,
- stride_order(c_tensor),
- c_tensor.dtype,
+ d_tensor.shape[1:] if not use_full_dynamic else None,
+ stride_order(d_tensor),
*(
dynamic_tensor_signature(sfa_tensor)
if use_full_dynamic
@@ -1398,7 +1344,6 @@ def dynamic_m_tensor_signature(
tuple(padded_offsets.stride()),
padded_offsets.dtype,
acc_dtype,
- c_dtype,
d_dtype,
cd_major,
mma_tiler_mn,
@@ -1417,9 +1362,8 @@ def dynamic_m_tensor_signature(
a_tensor.dtype,
b_shape,
b_dtype,
- c_tensor.shape[1:],
- stride_order(c_tensor),
- c_tensor.dtype,
+ d_tensor.shape[1:],
+ stride_order(d_tensor),
*dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)),
*tensor_signature(bias_tensor),
*tensor_signature(alpha_tensor),
@@ -1435,7 +1379,6 @@ def dynamic_m_tensor_signature(
tuple(padded_offsets.stride()),
padded_offsets.dtype,
acc_dtype,
- c_dtype,
d_dtype,
cd_major,
mma_tiler_mn,
@@ -1470,7 +1413,6 @@ def dynamic_m_tensor_signature(
sample_sfd_col=sfd_col_tensor,
sample_norm_const=norm_const_tensor,
sample_prob=prob_tensor,
- sample_c=c_tensor,
acc_dtype=acc_dtype,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
@@ -1497,7 +1439,6 @@ def dynamic_m_tensor_signature(
sample_sfd_col=sfd_col_tensor,
sample_norm_const=norm_const_tensor,
sample_prob=prob_tensor,
- sample_c=c_tensor,
acc_dtype=acc_dtype,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
@@ -1522,7 +1463,6 @@ def dynamic_m_tensor_signature(
d_tensor=d_tensor,
b_tensor=b_tensor,
sfb_tensor=sfb_tensor,
- c_tensor=c_tensor,
d_col_tensor=d_col_tensor,
sfd_row_tensor=sfd_row_tensor,
sfd_col_tensor=sfd_col_tensor,
@@ -1541,7 +1481,6 @@ def dynamic_m_tensor_signature(
d_tensor=d_tensor,
b_ptrs=b_ptrs,
sfb_ptrs=sfb_ptrs,
- c_tensor=c_tensor,
d_col_tensor=d_col_tensor,
sfd_row_tensor=sfd_row_tensor,
sfd_col_tensor=sfd_col_tensor,
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
index 6619baa8..8773accf 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py
@@ -9,7 +9,6 @@
- Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout
- FP8/FP4 output quantization with row/column scale factors (SFD)
- Optional bias and routing-probability (prob) fusion
- - Optional C output (generate_c)
- AMAX reduction for FP8 calibration
This module contains only the kernel class.
@@ -73,7 +72,6 @@ class BlockScaledMoEGroupedGemmQuantKernel:
:param vectorized_f32: Use packed FP32 arithmetic.
:param generate_sfd: Generate output scale factors.
:param discrete_col_sfd: Use discrete column SFD layout.
- :param generate_c: Generate C output tensor.
:param enable_bias: Fuse bias addition.
:param expert_cnt: Number of experts.
:param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``.
@@ -131,7 +129,6 @@ def __init__(
vectorized_f32: bool,
generate_sfd: bool,
discrete_col_sfd: bool,
- generate_c: bool,
enable_bias: bool,
expert_cnt: int,
weight_mode: MoEWeightMode = MoEWeightMode.DENSE,
@@ -199,7 +196,6 @@ def __init__(
self.vectorized_f32 = vectorized_f32
self.generate_sfd = generate_sfd
self.discrete_col_sfd = discrete_col_sfd
- self.generate_c = generate_c
self.enable_bias = enable_bias
self.weight_mode = weight_mode
@@ -295,7 +291,6 @@ def _setup_attributes(self):
(
self.num_acc_stage,
self.num_ab_stage,
- self.num_c_stage,
self.num_d_stage,
self.num_tile_stage,
self.num_bias_stage,
@@ -305,8 +300,6 @@ def _setup_attributes(self):
self.a_dtype,
self.b_dtype,
self.epi_tile,
- self.c_dtype,
- self.c_layout,
self.d_dtype,
self.d_layout,
self.sf_dtype,
@@ -314,7 +307,6 @@ def _setup_attributes(self):
self.num_smem_capacity,
self.occupancy,
self.generate_sfd,
- self.generate_c,
self.bias_dtype if self.enable_bias else None,
)
@@ -342,12 +334,6 @@ def _setup_attributes(self):
self.sf_vec_size,
self.num_ab_stage,
)
- self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
- self.c_dtype,
- self.c_layout,
- self.epi_tile,
- self.num_c_stage,
- )
self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
self.d_dtype,
self.d_layout,
@@ -388,8 +374,6 @@ def _compute_stages(
a_dtype,
b_dtype,
epi_tile,
- c_dtype,
- c_layout,
d_dtype,
d_layout,
sf_dtype,
@@ -397,11 +381,9 @@ def _compute_stages(
num_smem_capacity,
occupancy,
generate_sfd,
- generate_c,
bias_dtype,
):
num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
- num_c_stage = 2 if generate_sfd else 1
num_d_stage = 2 if generate_sfd else 1
num_tile_stage = 2
@@ -409,7 +391,6 @@ def _compute_stages(
b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1)
sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
- c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
ab_bytes_per_stage = (
@@ -420,8 +401,6 @@ def _compute_stages(
)
mbar_helpers_bytes = 1024
sinfo_bytes = 4 * 4 * num_tile_stage
- c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
- c_bytes = c_bytes_per_stage * num_c_stage
d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1)
amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0
@@ -434,10 +413,10 @@ def _compute_stages(
num_bias_stage = 0
bias_bytes = 0
- epi_bytes = c_bytes + d_bytes + amax_bytes + bias_bytes
+ epi_bytes = d_bytes + amax_bytes + bias_bytes
num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage
- return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage
+ return num_acc_stage, num_ab_stage, num_d_stage, num_tile_stage, num_bias_stage
# ------------------------------------------------------------------
# Workspace helpers
@@ -575,7 +554,6 @@ def __call__(
b_stride_size: cutlass.Int64, # Ignored for dense mode
b_major_mode: cutlass.Constexpr, # Ignored for dense mode
workspace_ptr,
- c: cute.Tensor,
d: cute.Tensor,
d_col: Optional[cute.Tensor],
sfa: cute.Tensor,
@@ -600,11 +578,9 @@ def __call__(
"""
self.a_dtype: Type[cutlass.Numeric] = a.element_type
self.b_dtype: Type[cutlass.Numeric] = a.element_type
- self.c_dtype: Type[cutlass.Numeric] = c.element_type
self.d_dtype: Type[cutlass.Numeric] = d.element_type
self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
- self.c_layout = utils.LayoutEnum.from_tensor(c)
self.d_layout = utils.LayoutEnum.from_tensor(d)
self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16
@@ -754,13 +730,6 @@ def __call__(
sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
- c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
- tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
- cpasync.CopyBulkTensorTileS2GOp(),
- c,
- c_smem_layout,
- self.epi_tile,
- )
d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
@@ -832,10 +801,6 @@ class SharedStorage:
bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
- sC: cute.struct.Align[
- cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)],
- self.buffer_align_bytes,
- ]
sD: cute.struct.Align[
cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
self.buffer_align_bytes,
@@ -884,8 +849,6 @@ class SharedStorage:
tma_tensor_sfa,
tma_atom_sfb,
tma_tensor_sfb,
- tma_atom_c,
- tma_tensor_c,
tma_atom_d,
tma_tensor_d,
tma_atom_d_col,
@@ -905,7 +868,6 @@ class SharedStorage:
self.b_smem_layout_staged,
self.sfa_smem_layout_staged,
self.sfb_smem_layout_staged,
- self.c_smem_layout_staged,
self.d_smem_layout_staged,
self.bias_smem_layout_staged,
self.epi_tile,
@@ -955,32 +917,6 @@ def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_g
block_amax = cute.arch.fmax(block_amax, warp_amax_val)
_ = atomic_max_float32(ptr=amax_gmem, value=block_amax)
- @cute.jit
- def store_c(
- self,
- tiled_copy_r2s,
- tma_atom_c,
- warp_idx,
- tTR_rAcc,
- tRS_rC,
- tRS_sC,
- bSG_gC,
- bSG_sC,
- c_pipeline,
- prev_subtile_idx,
- real_subtile_idx,
- ):
- c_buffer = prev_subtile_idx % self.num_c_stage
- tRS_rC.store(tTR_rAcc.load().to(self.c_dtype))
- cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)])
- cute.arch.fence_proxy("async.shared", space="cta")
- self.epilog_sync_barrier.arrive_and_wait()
- if warp_idx == self.epilog_warp_id[0]:
- cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)])
- c_pipeline.producer_commit()
- c_pipeline.producer_acquire()
- self.epilog_sync_barrier.arrive_and_wait()
-
@cute.jit
def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD):
tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size))
@@ -1147,8 +1083,6 @@ def kernel(
mSFA_mkl: cute.Tensor,
tma_atom_sfb: cute.CopyAtom,
mSFB_nkl: cute.Tensor,
- tma_atom_c: cute.CopyAtom,
- mC_mnl: cute.Tensor,
tma_atom_d: cute.CopyAtom,
mD_mnl: cute.Tensor,
tma_atom_d_col: cute.CopyAtom,
@@ -1168,7 +1102,6 @@ def kernel(
b_smem_layout_staged: cute.ComposedLayout,
sfa_smem_layout_staged: cute.Layout,
sfb_smem_layout_staged: cute.Layout,
- c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
bias_smem_layout_staged: Optional[cute.Layout],
epi_tile: cute.Tile,
@@ -1189,8 +1122,6 @@ def kernel(
cpasync.prefetch_descriptor(tma_atom_d)
if cutlass.const_expr(self.generate_sfd):
cpasync.prefetch_descriptor(tma_atom_d_col)
- if cutlass.const_expr(self.generate_c):
- cpasync.prefetch_descriptor(tma_atom_c)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
total_token = padded_offsets[self.expert_cnt - 1]
@@ -1274,7 +1205,6 @@ def kernel(
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
- sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner)
sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
sD_col = sD
if cutlass.const_expr(self.generate_sfd):
@@ -1744,13 +1674,6 @@ def kernel(
use_2cta_instrs,
)
- tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
- tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
- tiled_copy_t2r,
- tTR_rC,
- epi_tidx,
- sC,
- )
tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
tiled_copy_t2r,
@@ -1793,8 +1716,6 @@ def kernel(
epi_ext = self._make_extension(workspace_ptr)
acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
- c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id))
- c_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_producer_group)
d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id))
d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group)
@@ -1833,7 +1754,6 @@ def kernel(
sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(self.epi_tile[1]))
real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info)
- real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info)
real_d_col = real_d
if cutlass.const_expr(self.generate_sfd):
real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info)
@@ -1850,16 +1770,6 @@ def kernel(
sD,
)
- gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
- tCgC_loop = thr_mma_epi_loop.partition_C(gC_mnl_loop)
- _, bSG_sC, bSG_gC_partitioned = epilog_gmem_copy_and_partition(
- epi_tidx,
- tma_atom_c,
- tCgC_loop,
- epi_tile,
- sC,
- )
-
gD_col_mnl_loop = gD_mnl_loop
tCgD_col_loop = tCgD_loop
if cutlass.const_expr(self.generate_sfd):
@@ -1878,10 +1788,8 @@ def kernel(
epi_work_tile_info.tile_n_idx,
0,
)
- bSG_gC = bSG_gC_partitioned[(None, None, None, *epi_mma_tile_coord)]
bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)]
bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)]
- bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col))
@@ -1968,21 +1876,6 @@ def kernel(
for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val)
- if cutlass.const_expr(self.generate_c):
- self.store_c(
- tiled_copy_r2s,
- tma_atom_c,
- warp_idx,
- tTR_rAcc,
- tRS_rC,
- tRS_sC,
- bSG_gC,
- bSG_sC,
- c_pipeline,
- num_prev_subtiles,
- real_subtile_idx,
- )
-
acc_vec = tTR_rAcc.load()
if cutlass.const_expr(not self.enable_bias):
tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype)
@@ -2089,8 +1982,6 @@ def kernel(
tmem.relinquish_alloc_permit()
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(tmem_ptr)
- if cutlass.const_expr(self.generate_c):
- c_pipeline.producer_tail()
d_pipeline.producer_tail()
# ------------------------------------------------------------------
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py
new file mode 100644
index 00000000..70b2bfe7
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+Grouped GEMM SReLU Kernel Module
+
+This module provides the forward grouped GEMM with SReLU activation
+for MoE (Mixture of Experts) workloads on SM100+ GPUs.
+"""
+
+from .api import (
+ GroupedGemmSreluSm100,
+ grouped_gemm_srelu_wrapper_sm100,
+)
+
+__all__ = [
+ "GroupedGemmSreluSm100",
+ "grouped_gemm_srelu_wrapper_sm100",
+]
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py
new file mode 100644
index 00000000..02f49939
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/api.py
@@ -0,0 +1,1572 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""
+Unified API for Grouped GEMM SReLU Kernel (SM100+)
+
+This module provides a single API class that supports both dense (contiguous)
+and discrete weight modes for grouped block-scaled GEMM with output
+SReLU output quantization in MoE (Mixture of Experts) workloads.
+"""
+
+import os
+from typing import Optional, Tuple
+
+import cutlass
+import cutlass.cute as cute
+import torch
+from cuda.bindings import driver as cuda
+from cutlass.cute.runtime import make_fake_stream
+
+from cudnn.api_base import APIBase, TensorDesc, TupleDict, ceil_div, is_power_of_2
+from cudnn.datatypes import _convert_to_cutlass_data_type
+
+from .moe_blockscaled_grouped_gemm_srelu_quant import (
+ BlockScaledMoEGroupedGemmQuantKernel,
+ EpilogueType,
+)
+from ..moe_utils import MoEWeightMode
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+from cutlass.cute.runtime import from_dlpack
+
+
+def _reinterpret_raw_grouped_fp4_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype == torch.uint8:
+ cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=1)
+ cute_tensor.element_type = cutlass.Float4E2M1FN
+ return cute_tensor
+ return tensor
+
+
+class GroupedGemmSreluSm100(APIBase):
+ """Unified API for grouped GEMM SReLU operation on SM100+ GPUs.
+
+ This kernel performs block-scaled grouped GEMM with output SReLU output quantization
+ (D = srelu(alpha * A @ B)), designed for MoE workloads. It supports both
+ dense (contiguous) and discrete (per-expert pointer) weight layouts
+ through ``BlockScaledMoEGroupedGemmQuantKernel``.
+
+ Weight mode is auto-detected from the constructor arguments:
+
+ - Dense: provide ``sample_b`` and ``sample_sfb``.
+ - Discrete: provide ``num_experts``, ``b_shape``, and ``b_dtype``.
+ """
+
+ def __init__(
+ self,
+ sample_a: torch.Tensor,
+ # Dense mode (contiguous) -- provide these:
+ sample_b: Optional[torch.Tensor] = None,
+ sample_c: Optional[torch.Tensor] = None,
+ sample_d: Optional[torch.Tensor] = None,
+ sample_sfa: Optional[torch.Tensor] = None,
+ sample_sfb: Optional[torch.Tensor] = None,
+ sample_padded_offsets: Optional[torch.Tensor] = None,
+ sample_alpha: Optional[torch.Tensor] = None,
+ sample_d_col: Optional[torch.Tensor] = None,
+ sample_bias: Optional[torch.Tensor] = None,
+ # Discrete mode -- provide these instead:
+ num_experts: Optional[int] = None,
+ b_shape: Optional[Tuple[int, ...]] = None,
+ b_dtype: Optional[torch.dtype] = None,
+ # Optional SReLU output quantization output arguments
+ sample_sfd_row: Optional[torch.Tensor] = None,
+ sample_sfd_col: Optional[torch.Tensor] = None,
+ sample_amax: Optional[torch.Tensor] = None,
+ sample_norm_const: Optional[torch.Tensor] = None,
+ sample_prob: Optional[torch.Tensor] = None,
+ # Configuration
+ acc_dtype: torch.dtype = torch.float32,
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ discrete_col_sfd: bool = False,
+ b_major: str = "k",
+ use_dynamic_sched: bool = False,
+ ):
+ """Initialize the GroupedGemmSreluSm100 API.
+
+ :param sample_a: Sample A tensor (valid_m, k, 1)
+ :param sample_sfa: Sample scale factor A tensor
+ :param sample_padded_offsets: End offset for each expert after padding, shape (expert_cnt,)
+ :param sample_alpha: Per-group alpha scaling factors
+ :param sample_c: Sample C output tensor (valid_m, n, 1) before SReLU
+ :param sample_d: Sample D output tensor (valid_m, n, 1)
+ :param sample_d_col: Optional column-sreluized D tensor. Required only when SFD outputs are generated.
+ :param sample_b: (Dense) Sample B tensor (n, k, l)
+ :param sample_sfb: (Dense) Sample scale factor B tensor
+ :param sample_bias: Optional bias tensor with shape (n, l) or (n, expert_cnt), stride (1, n).
+ Dense mode supports fp16/bfloat16/float32 bias; discrete mode supports fp16/bfloat16 bias.
+ :param num_experts: (Discrete) Number of experts
+ :param b_shape: (Discrete) Shape of a single expert B tensor, e.g. (n, k)
+ :param b_dtype: (Discrete) Data type of B tensors
+ :param sample_sfd_row: Optional row scale factor for D
+ :param sample_sfd_col: Optional column scale factor for D
+ :param sample_amax: Optional amax tensor for SReLU output quantization
+ :param sample_norm_const: Optional normalization constant
+ :param sample_prob: Optional probability tensor for gating
+ :param acc_dtype: Accumulator data type
+ :param mma_tiler_mn: MMA tiler shape (M, N)
+ :param cluster_shape_mn: Cluster shape (M, N)
+ :param sf_vec_size: Scale factor vector size
+ :param vector_f32: Use vectorized f32 operations
+ :param m_aligned: Alignment for group M dimension
+ :param discrete_col_sfd: Enable discrete col-major scale factor tensor
+ :param b_major: Major dimension for B tensor, one of "k" or "n"
+ :param use_dynamic_sched: Enable dynamic tile scheduling for load balancing
+ """
+ super().__init__()
+
+ self._warn_experimental_api()
+ self._logger.debug("Entering __init__")
+
+ # ---- Weight mode auto-detection ----
+ if sample_b is not None and num_experts is None:
+ self.weight_mode = MoEWeightMode.DENSE
+ if sample_sfb is None:
+ raise ValueError("sample_sfb is required when sample_b is provided (dense mode)")
+ elif num_experts is not None and sample_b is None:
+ self.weight_mode = MoEWeightMode.DISCRETE
+ if b_shape is None or b_dtype is None:
+ raise ValueError("b_shape and b_dtype are required in discrete mode")
+ else:
+ raise ValueError("Provide either (sample_b, sample_sfb) for dense mode " "or (num_experts, b_shape, b_dtype) for discrete mode, but not both.")
+
+ self._sample_a_tensor = sample_a
+ self._sample_b_tensor = sample_b
+
+ self.a_desc = self._make_tensor_desc(sample_a, name="sample_a", interpret_uint8_as_fp4x2=False)
+ self.c_desc = self._make_tensor_desc(sample_c, name="sample_c")
+ self.d_desc = self._make_tensor_desc(sample_d, name="sample_d")
+ self.sfa_desc = self._make_tensor_desc(sample_sfa, name="sample_sfa")
+ self.padded_offsets_desc = self._make_tensor_desc(sample_padded_offsets, name="sample_padded_offsets")
+ self.alpha_desc = self._make_tensor_desc(sample_alpha, name="sample_alpha")
+
+ self._has_d_col = sample_d_col is not None
+ self.d_col_desc = self._make_tensor_desc(sample_d_col, name="sample_d_col")
+ if self.d_col_desc is None:
+ self.d_col_desc = TensorDesc(
+ dtype=self.d_desc.dtype,
+ shape=self.d_desc.shape,
+ stride=self.d_desc.stride,
+ stride_order=self.d_desc.stride_order,
+ device=self.d_desc.device,
+ name="sample_d_col",
+ )
+ self.sfd_row_desc = self._make_tensor_desc(sample_sfd_row, name="sample_sfd_row")
+ self.sfd_col_desc = self._make_tensor_desc(sample_sfd_col, name="sample_sfd_col")
+ self.amax_desc = self._make_tensor_desc(sample_amax, name="sample_amax")
+ self.norm_const_desc = self._unpad_tensor_to_ndim(
+ self._make_tensor_desc(sample_norm_const, name="sample_norm_const"),
+ 1,
+ "norm_const",
+ )
+ self.prob_desc = self._make_tensor_desc(sample_prob, name="sample_prob")
+ self.bias_desc = self._make_tensor_desc(sample_bias, name="sample_bias")
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self.b_desc = self._make_tensor_desc(sample_b, name="sample_b", interpret_uint8_as_fp4x2=False)
+ self.sfb_desc = self._make_tensor_desc(sample_sfb, name="sample_sfb")
+ self.expert_cnt = self.padded_offsets_desc.shape[0]
+ else:
+ self._value_error_if(num_experts == 0, "num_experts must be > 0")
+ self.expert_cnt = num_experts
+ self.b_shape = b_shape
+ self.b_dtype = b_dtype
+ self.b_major = b_major
+ self._value_error_if(
+ self.padded_offsets_desc.shape[0] != self.expert_cnt,
+ f"padded_offsets length ({self.padded_offsets_desc.shape[0]}) " f"must equal num_experts ({self.expert_cnt})",
+ )
+
+ self.acc_dtype = acc_dtype
+ self.mma_tiler_mn = mma_tiler_mn
+ self.use_2cta_instrs = mma_tiler_mn[0] == 256
+ if cluster_shape_mn is None:
+ self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1)
+ else:
+ self.cluster_shape_mn = cluster_shape_mn
+ self.sf_vec_size = sf_vec_size
+ self.vector_f32 = vector_f32
+ self.m_aligned = m_aligned
+ self.discrete_col_sfd = discrete_col_sfd
+ self.use_dynamic_sched = use_dynamic_sched
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self.b_major = b_major
+
+ self._interpret_uint8_as_fp4x2 = True
+ self._has_bias = self.bias_desc is not None
+ self._kernel = BlockScaledMoEGroupedGemmQuantKernel
+
+ self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._workspace = None
+ self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+ self._logger.debug("__init__ completed")
+
+ def check_support(self) -> bool:
+ """Check if the kernel configuration is supported.
+
+ :return: True if supported, raises exception otherwise
+ """
+ self._logger.debug("Entering check_support")
+
+ all_none = all(x is None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc])
+ all_provided = all(x is not None for x in [self.sfd_row_desc, self.sfd_col_desc, self.norm_const_desc])
+ self._value_error_if(
+ not (all_none or all_provided),
+ "sfd_row_desc, sfd_col_desc, and norm_const_desc must be all None or all not None",
+ )
+ self.generate_sfd = all_provided
+ self._value_error_if(
+ self.generate_sfd and not self._has_d_col,
+ "sample_d_col is required when SFD outputs are generated",
+ )
+ if self.discrete_col_sfd and not self.generate_sfd:
+ self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored")
+ self.discrete_col_sfd = False
+
+ self._logger.debug("Checking tensor shapes and strides")
+ tensor_m, k, _one = self._tensor_shape(self.a_desc, name="sample_a")
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ n, _, l = self._tensor_shape(self.b_desc, name="sample_b")
+ else:
+ if len(self.b_shape) == 2:
+ n, b_k = self.b_shape
+ else:
+ n, b_k, _ = self.b_shape
+ self._value_error_if(b_k != k, f"B K dimension ({b_k}) must match A K dimension ({k})")
+ l = self.expert_cnt
+
+ _, _, _one = self._tensor_shape(self.d_desc, name="sample_d")
+
+ self._check_tensor_shape(self.a_desc, (tensor_m, k, 1), "A")
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_tensor_shape(self.b_desc, (n, k, l), "B")
+ self._check_tensor_shape(self.c_desc, (tensor_m, n, 1), "C")
+ self._check_tensor_shape(self.d_desc, (tensor_m, n, 1), "D")
+ self._check_tensor_shape(self.d_col_desc, (tensor_m, n, 1), "D_col")
+
+ rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfa_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_k, 1), "SFA")
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_tensor_shape(self.sfb_desc, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB")
+ rest_n = ceil_div(ceil_div(n, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfd_row_desc, (32, 4, ceil_div(tensor_m, 128), 4, rest_n, 1), "SFD_row")
+ rest_m = ceil_div(ceil_div(tensor_m, self.sf_vec_size), 4)
+ self._check_tensor_shape(self.sfd_col_desc, (32, 4, ceil_div(n, 128), 4, rest_m, 1), "SFD_col")
+
+ self._check_tensor_shape(self.alpha_desc, (self.expert_cnt,), "alpha")
+ self._value_error_if(
+ self.prob_desc is None,
+ "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. "
+ "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed.",
+ )
+ self._check_tensor_shape(self.prob_desc, (tensor_m, 1, 1), "prob")
+ self._check_tensor_shape(self.bias_desc, (n, l), "bias")
+ self._check_tensor_shape(self.amax_desc, (self.expert_cnt, 1), "amax")
+ self._check_tensor_shape(self.norm_const_desc, (1,), "norm_const")
+ self._check_tensor_shape(self.padded_offsets_desc, (self.expert_cnt,), "padded_offsets")
+
+ _ = self._check_tensor_stride(
+ self.a_desc,
+ stride=[(k, 1, tensor_m * k)],
+ extra_error_msg="A must have k-major layout",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ if self._is_fp8(self.a_desc):
+ _ = self._check_tensor_stride(
+ self.b_desc,
+ stride=[(k, 1, n * k), (1, n, n * k)],
+ extra_error_msg="For fp8 ab_dtype, B must have k- or n-major layout",
+ )
+ else:
+ _ = self._check_tensor_stride(
+ self.b_desc,
+ stride=[(k, 1, n * k)],
+ extra_error_msg="For fp4 ab_dtype, B must have k-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.c_desc,
+ stride=[(n, 1, tensor_m * n)],
+ extra_error_msg="C must have n-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.d_desc,
+ stride=[(n, 1, tensor_m * n)],
+ extra_error_msg="D must have n-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.d_col_desc,
+ stride=[(n, 1, tensor_m * n)],
+ extra_error_msg="D_col must have n-major layout",
+ )
+ _ = self._check_tensor_stride(
+ self.bias_desc,
+ stride=[(1, n)],
+ )
+
+ self._logger.debug("Checking data types")
+ self.ab_dtype = self._check_dtype(
+ self.a_desc,
+ dtype=[
+ torch.float4_e2m1fn_x2,
+ torch.uint8,
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
+ ],
+ name="A/B",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_dtype(
+ self.b_desc,
+ dtype=self.ab_dtype,
+ name="B",
+ extra_error_msg="B must have the same dtype as A",
+ )
+ self._check_dtype(
+ self.bias_desc,
+ dtype=[torch.bfloat16, torch.float16, torch.float32],
+ name="bias",
+ extra_error_msg="bias must be fp16, bfloat16, or float32",
+ )
+ else:
+ self._value_error_if(
+ self.b_dtype != self.ab_dtype,
+ f"b_dtype ({self.b_dtype}) must match A dtype ({self.ab_dtype})",
+ )
+ self._check_dtype(
+ self.bias_desc,
+ dtype=[torch.bfloat16, torch.float16],
+ name="bias",
+ extra_error_msg="bias must be fp16 or bfloat16 in discrete mode",
+ )
+
+ self.sf_dtype = self._check_dtype(
+ self.sfa_desc,
+ dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn],
+ name="SFA/SFB/SFD_row/SFD_col",
+ )
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._check_dtype(
+ self.sfb_desc,
+ dtype=self.sf_dtype,
+ name="SFB",
+ extra_error_msg="SFB must have the same dtype as SFA",
+ )
+ self._check_dtype(
+ self.sfd_row_desc,
+ dtype=self.sf_dtype,
+ name="SFD_row",
+ extra_error_msg="SFD_row must have the same dtype as SFA",
+ )
+ self._check_dtype(
+ self.sfd_col_desc,
+ dtype=self.sf_dtype,
+ name="SFD_col",
+ extra_error_msg="SFD_col must have the same dtype as SFA",
+ )
+
+ self._value_error_if(
+ self.sf_vec_size not in [16, 32],
+ f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}",
+ )
+ self._value_error_if(
+ self.sf_dtype in [torch.float8_e4m3fn] and self.sf_vec_size == 32,
+ f"sf_dtype {self.sf_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported",
+ )
+ self._value_error_if(
+ self._is_fp8(self.ab_dtype) and self.sf_vec_size == 16,
+ f"ab_dtype {self.ab_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported",
+ )
+
+ self._check_dtype(
+ self.acc_dtype,
+ dtype=torch.float32,
+ name="Accumulator",
+ extra_error_msg="Accumulator must be float32",
+ )
+ self.c_dtype = self._check_dtype(
+ self.c_desc,
+ dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2],
+ name="C",
+ )
+ if self._is_fp4x2(self.ab_dtype):
+ self.d_dtype = self._check_dtype(
+ self.d_desc,
+ dtype=[torch.float16, torch.bfloat16, torch.float32],
+ name="D",
+ extra_error_msg="D must be fp16, bf16, or float32 when ab_dtype is fp4",
+ )
+ else:
+ self.d_dtype = self._check_dtype(
+ self.d_desc,
+ dtype=[
+ torch.float16,
+ torch.bfloat16,
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.float4_e2m1fn_x2,
+ ],
+ name="D",
+ )
+ self._check_dtype(
+ self.d_col_desc,
+ dtype=self.d_dtype,
+ name="D_col",
+ extra_error_msg="D_col must have the same dtype as D",
+ )
+
+ self._not_implemented_error_if(
+ self._is_fp4x2(self.ab_dtype) and self.sf_vec_size == 16 and self.d_dtype == torch.float32,
+ "Invalid configuration: fp4 ab_dtype, sf_vec_size 16, d_dtype float32 is not supported. Please use sf_vec_size 32 or d_dtype bf16 instead",
+ )
+
+ if self.weight_mode == MoEWeightMode.DISCRETE:
+ self._value_error_if(
+ self.b_major not in ["k", "n"],
+ f"b_major must be 'k' or 'n', got {self.b_major}",
+ )
+ self._value_error_if(
+ self._is_fp4x2(self.ab_dtype) and self.b_major != "k",
+ "b_major must be 'k' when ab_dtype is fp4",
+ )
+
+ self._logger.debug("Checking MMA tile shape and cluster shape")
+ self._value_error_if(
+ not self.use_2cta_instrs and self.mma_tiler_mn[0] != 128,
+ f"MMA tiler M must be 128 when use_2cta_instrs=False, got {self.mma_tiler_mn[0]}",
+ )
+ self._value_error_if(
+ self.use_2cta_instrs and self.mma_tiler_mn[0] != 256,
+ f"MMA tiler M must be 256 when use_2cta_instrs=True, got {self.mma_tiler_mn[0]}",
+ )
+ self._value_error_if(
+ self.mma_tiler_mn[1] != 256,
+ f"MMA tiler N must be 256, got {self.mma_tiler_mn[1]}",
+ )
+ self._value_error_if(
+ self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0,
+ f"cluster_shape_mn[0] must be divisible by 2 when use_2cta_instrs=True, got {self.cluster_shape_mn[0]}",
+ )
+ self._value_error_if(
+ not (
+ self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16
+ and self.cluster_shape_mn[0] > 0
+ and self.cluster_shape_mn[1] > 0
+ and self.cluster_shape_mn[0] <= 4
+ and self.cluster_shape_mn[1] <= 4
+ and is_power_of_2(self.cluster_shape_mn[0])
+ and is_power_of_2(self.cluster_shape_mn[1])
+ ),
+ "Invalid cluster shape: expected values to be powers of 2 and " f"cluster_shape_mn[0] * cluster_shape_mn[1] <= 16, got {self.cluster_shape_mn}",
+ )
+ cluster_tiler_m = (self.cluster_shape_mn[0] // (2 if self.use_2cta_instrs else 1)) * self.mma_tiler_mn[0]
+ self._value_error_if(
+ cluster_tiler_m not in [128, 256],
+ f"Invalid cluster tiler shape: expected cluster_tiler_m in {{128, 256}}, got {cluster_tiler_m}",
+ )
+ self._value_error_if(
+ self.m_aligned % self.mma_tiler_mn[0] != 0,
+ f"Invalid m_aligned: expected m_aligned to be divisible by mma_tiler_mn[0], got {self.m_aligned} % {self.mma_tiler_mn[0]} != 0",
+ )
+ self._value_error_if(
+ self.m_aligned != BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE,
+ f"m_aligned must be {BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE} (FIX_PAD_SIZE), got {self.m_aligned}",
+ )
+
+ self._logger.debug("Checking tensor alignment")
+
+ def check_contigous_16B_alignment(dtype, stride_order, tensor_shape):
+ is_mode0_major = stride_order == (0, 1, 2)
+ major_mode_idx = 0 if is_mode0_major else 1
+ num_major_elements = tensor_shape[major_mode_idx]
+ num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width)
+ return num_major_elements % num_contiguous_elements == 0
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ b_stride_order_for_check = self.b_desc.stride_order
+ b_shape_for_check = (n, k, l)
+ else:
+ b_stride_order_for_check = (0, 1, 2) if self.b_major == "n" else (1, 0, 2)
+ b_shape_for_check = (n, k, 1)
+
+ self._value_error_if(
+ not (
+ check_contigous_16B_alignment(self.ab_dtype, self.a_desc.stride_order, (tensor_m, k, l))
+ and check_contigous_16B_alignment(self.ab_dtype, b_stride_order_for_check, b_shape_for_check)
+ and check_contigous_16B_alignment(self.d_dtype, self.d_desc.stride_order, (tensor_m, n, 1))
+ ),
+ "Invalid tensor alignment: tensors must be 16B aligned",
+ )
+
+ self._value_error_if(
+ self.expert_cnt > 1024,
+ f"expert_cnt must be <= 1024, got {self.expert_cnt}",
+ )
+
+ self._not_implemented_error_if(self._has_bias and self.mma_tiler_mn[1] != 256, "Bias fusion currently requires mma_tiler_mn[1] == 256")
+
+ self._not_implemented_error_if(
+ (self._is_fp8(self.ab_dtype)) and (self.mma_tiler_mn[1] == 128) and (self._is_fp8(self.d_dtype)),
+ "Invalid configuration: fp8 ab_dtype and sf_vec_size 32 with mma_tiler_mn[1] == 128 and fp8 d_dtype is not supported. "
+ "Please use mma_tiler_mn[1] == 256 instead",
+ )
+ if not torch.cuda.is_available():
+ raise RuntimeError("CUDA is not available")
+ device = torch.cuda.current_device()
+ major, minor = torch.cuda.get_device_capability(device)
+ compute_capability = major * 10 + minor
+ if compute_capability < 100:
+ raise RuntimeError(f"GroupedGemmSrelu requires SM100+ compute capability, but found SM{compute_capability} on device {device}")
+
+ self._is_supported = True
+ self._logger.debug("check_support completed successfully")
+ return True
+
+ def compile(self) -> None:
+ """Compile the kernel."""
+ self._logger.debug("Entering compile")
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ self._logger.debug("Kernel already compiled; skipping recompilation")
+ return
+ if self.a_desc.shape[0] == 0:
+ self._logger.debug("sample valid_m is zero, skipping kernel compilation")
+ return
+
+ self._use_full_dynamic_mnkl = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+
+ gemm_srelu = self._kernel(
+ sf_vec_size=self.sf_vec_size,
+ acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype),
+ use_2cta_instrs=self.use_2cta_instrs,
+ mma_tiler_mn=self.mma_tiler_mn,
+ cluster_shape_mn=self.cluster_shape_mn,
+ vectorized_f32=self.vector_f32,
+ generate_sfd=self.generate_sfd,
+ discrete_col_sfd=self.discrete_col_sfd,
+ generate_c=True,
+ enable_bias=self._has_bias,
+ expert_cnt=self.expert_cnt,
+ weight_mode=self.weight_mode,
+ use_dynamic_sched=self.use_dynamic_sched,
+ epilogue_type=EpilogueType.SRELU.value,
+ )
+
+ hardware_info = cutlass.utils.HardwareInfo()
+ max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1])
+ max_active_clusters -= self.num_cluster_overlap_margin
+ self._value_error_if(
+ max_active_clusters <= 0,
+ "max_active_clusters must be > 0 after applying overlap margin; reduce CUDNNFE_CLUSTER_OVERLAP_MARGIN",
+ )
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+
+ workspace_bytes = gemm_srelu.get_workspace_bytes()
+ self._workspace = torch.empty(max(workspace_bytes, 1), dtype=torch.uint8, device="cuda")
+
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._compile_dense(gemm_srelu, max_active_clusters, fake_stream)
+ else:
+ self._compile_discrete(gemm_srelu, max_active_clusters, fake_stream)
+
+ self._logger.debug("Kernel compiled successfully")
+
+ def _compile_dense(self, gemm_srelu, max_active_clusters, fake_stream) -> None:
+ """Compile for dense (contiguous) weight mode."""
+ fake_workspace_ptr = cute.runtime.nullptr(
+ dtype=cutlass.Uint8,
+ assumed_align=128,
+ )
+
+ self._logger.debug("Compiling grouped_gemm_srelu kernel")
+ use_full_dynamic = self._use_full_dynamic_mnkl
+
+ if not use_full_dynamic:
+ valid_m = cute.sym_int(divisibility=256)
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride_order=self.a_desc.stride_order,
+ )
+ b_cute_fake = self._make_fake_cute_tensor_from_desc(self.b_desc, assumed_align=16)
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride_order=self.c_desc.stride_order,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, *self.d_desc.shape[1:]),
+ stride_order=self.d_desc.stride_order,
+ )
+ d_col_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, *self.d_col_desc.shape[1:]),
+ stride_order=self.d_col_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ )
+
+ sfb_cute_fake = self._make_fake_cute_tensor_from_desc(self.sfb_desc, assumed_align=16)
+
+ prob_cute_fake = None
+ if self.prob_desc is not None:
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride=self.prob_desc.stride,
+ )
+
+ sfd_row_fake = None
+ sfd_col_fake = None
+ if self.sfd_row_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1),
+ stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m),
+ )
+ if self.sfd_col_desc is not None:
+ rest_m = cute.sym_int(divisibility=1)
+ stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1),
+ stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n),
+ )
+ bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16)
+ else:
+ valid_m = cute.sym_int(divisibility=256)
+ n_sym = cute.sym_int()
+ k_sym = cute.sym_int()
+ l_sym = cute.sym_int()
+
+ a_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, k_sym, 1),
+ stride_order=self.a_desc.stride_order,
+ dynamic_mode=self.a_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ b_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.b_desc.dtype,
+ shape=(n_sym, k_sym, l_sym),
+ stride_order=self.b_desc.stride_order,
+ dynamic_mode=self.b_desc.stride_order[0],
+ divisibility=32 if self._is_fp4x2(self.ab_dtype) else 16,
+ )
+ c_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, n_sym, 1),
+ stride_order=self.c_desc.stride_order,
+ dynamic_mode=self.c_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.c_desc.dtype) else 16,
+ )
+ d_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, n_sym, 1),
+ stride_order=self.d_desc.stride_order,
+ dynamic_mode=self.d_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_desc.dtype) else 16,
+ )
+ d_col_cute_fake = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, n_sym, 1),
+ stride_order=self.d_col_desc.stride_order,
+ dynamic_mode=self.d_col_desc.stride_order[0],
+ divisibility=8 if self._is_f16(self.d_col_desc.dtype) else 16,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ rest_k = cute.sym_int()
+ stride_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_shape[4] = rest_k
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[2] = stride_rest_k
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ )
+
+ tensor_n_128 = cute.sym_int()
+ stride_sfb_rest_k = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfb_tensor_n_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfb_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.sfb_desc.dtype,
+ shape=(32, 4, tensor_n_128, 4, rest_k, l_sym),
+ stride=(16, 4, stride_sfb_tensor_n_128, 1, 512, stride_sfb_rest_k),
+ )
+
+ prob_cute_fake = None
+ if self.prob_desc is not None:
+ prob_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride=self.prob_desc.stride,
+ )
+
+ sfd_row_fake = None
+ sfd_col_fake = None
+ if self.sfd_row_desc is not None:
+ rest_n = cute.sym_int()
+ stride_sfd_rest_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_rest_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, rest_n, 1),
+ stride=(16, 4, stride_sfd_rest_n, 1, 512, stride_sfd_rest_tensor_m_128),
+ )
+ if self.sfd_col_desc is not None:
+ tensor_n_128 = cute.sym_int()
+ rest_m_dyn = cute.sym_int()
+ stride_sfd_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_fake = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, tensor_n_128, 4, rest_m_dyn, 1),
+ stride=(16, 4, stride_sfd_rest_m, 1, 512, stride_sfd_n),
+ )
+
+ bias_cute_fake = None
+ if self.bias_desc is not None:
+ bias_cute_fake = self._make_fake_cute_tensor(
+ dtype=self.bias_desc.dtype,
+ shape=(n_sym, l_sym),
+ stride=(1, n_sym),
+ )
+
+ _compiled_kernel = cute.compile(
+ gemm_srelu,
+ a=_reinterpret_raw_grouped_fp4_tensor(self._sample_a_tensor) if self.a_desc.dtype == torch.uint8 else a_cute_fake,
+ b=_reinterpret_raw_grouped_fp4_tensor(self._sample_b_tensor) if self.b_desc.dtype == torch.uint8 else b_cute_fake,
+ sfb=sfb_cute_fake,
+ n=cutlass.Int32(0),
+ k=cutlass.Int32(0),
+ b_stride_size=cutlass.Int64(0),
+ b_major_mode=OperandMajorMode.K,
+ workspace_ptr=fake_workspace_ptr,
+ c=c_cute_fake,
+ d=d_cute_fake,
+ d_col=d_col_cute_fake,
+ sfa=sfa_cute_fake,
+ sfd_row_tensor=sfd_row_fake,
+ sfd_col_tensor=sfd_col_fake,
+ amax_tensor=self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16),
+ norm_const_tensor=self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16),
+ padded_offsets=self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16),
+ alpha=self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16),
+ bias=bias_cute_fake,
+ prob=prob_cute_fake,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ options="--enable-tvm-ffi",
+ )
+
+ cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ d_col_tensor: Optional[torch.Tensor],
+ sfa_tensor: torch.Tensor,
+ sfb_tensor: torch.Tensor,
+ sfd_row_tensor: Optional[torch.Tensor],
+ sfd_col_tensor: Optional[torch.Tensor],
+ amax_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: Optional[torch.Tensor],
+ bias_tensor: Optional[torch.Tensor],
+ stream: cuda.CUstream,
+ ) -> None:
+ norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const")
+ _compiled_kernel(
+ _reinterpret_raw_grouped_fp4_tensor(a_tensor),
+ _reinterpret_raw_grouped_fp4_tensor(b_tensor),
+ sfb_tensor,
+ cutlass.Int32(0),
+ cutlass.Int32(0),
+ cutlass.Int64(0),
+ cached_workspace_ptr,
+ c_tensor,
+ d_tensor,
+ d_col_tensor,
+ sfa_tensor,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ amax_tensor,
+ norm_const_tensor,
+ padded_offsets,
+ alpha_tensor,
+ bias_tensor,
+ prob_tensor,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def _compile_discrete(self, gemm_srelu, max_active_clusters, fake_stream) -> None:
+ """Compile for discrete (per-expert pointer) weight mode."""
+ if len(self.b_shape) == 2:
+ n, k = self.b_shape
+ else:
+ n, k, _ = self.b_shape
+
+ b_major_mode = OperandMajorMode.K if self.b_major == "k" else OperandMajorMode.MN
+ b_stride_size = k if self.b_major == "k" else n
+
+ ab_cutlass_dtype = _convert_to_cutlass_data_type(self.a_desc.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2)
+ align = 32 if ab_cutlass_dtype.width == 4 else 16
+
+ valid_m = cute.sym_int(divisibility=256)
+ a_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.a_desc.dtype,
+ shape=(valid_m, *self.a_desc.shape[1:]),
+ stride_order=self.a_desc.stride_order,
+ assumed_align=align,
+ )
+ c_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.c_desc.dtype,
+ shape=(valid_m, *self.c_desc.shape[1:]),
+ stride_order=self.c_desc.stride_order,
+ )
+ d_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.d_desc.dtype,
+ shape=(valid_m, *self.d_desc.shape[1:]),
+ stride_order=self.d_desc.stride_order,
+ )
+ d_col_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.d_col_desc.dtype,
+ shape=(valid_m, *self.d_col_desc.shape[1:]),
+ stride_order=self.d_col_desc.stride_order,
+ )
+
+ tensor_m_128 = cute.sym_int()
+ stride_tensor_m_128 = cute.sym_int(divisibility=32 * 4 * 4)
+ sfa_shape = list(self.sfa_desc.shape)
+ sfa_shape[2] = tensor_m_128
+ sfa_stride = list(self.sfa_desc.stride)
+ sfa_stride[5] = stride_tensor_m_128
+ sfa_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfa_desc.dtype,
+ shape=tuple(sfa_shape),
+ stride=tuple(sfa_stride),
+ assumed_align=16,
+ )
+ sfd_row_tensor = None
+ if self.sfd_row_desc is not None:
+ stride_sfd_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_row_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfd_row_desc.dtype,
+ shape=(32, 4, tensor_m_128, 4, self.sfd_row_desc.shape[4], 1),
+ stride=(16, 4, self.sfd_row_desc.stride[2], 1, 512, stride_sfd_m),
+ assumed_align=16,
+ )
+ sfd_col_tensor = None
+ if self.sfd_col_desc is not None:
+ rest_m = cute.sym_int(divisibility=1)
+ stride_sfd_n = cute.sym_int(divisibility=32 * 4 * 4)
+ stride_rest_m = cute.sym_int(divisibility=32 * 4 * 4)
+ sfd_col_tensor = self._make_fake_cute_tensor(
+ dtype=self.sfd_col_desc.dtype,
+ shape=(32, 4, self.sfd_col_desc.shape[2], 4, rest_m, 1),
+ stride=(16, 4, stride_rest_m, 1, 512, stride_sfd_n),
+ assumed_align=16,
+ )
+ amax_tensor = self._make_fake_cute_tensor_from_desc(self.amax_desc, assumed_align=16)
+ norm_const_tensor_cute = self._make_fake_cute_tensor_from_desc(self.norm_const_desc, assumed_align=16)
+ padded_offsets_tensor = self._make_fake_cute_tensor_from_desc(self.padded_offsets_desc, assumed_align=16)
+ alpha_tensor = self._make_fake_cute_tensor_from_desc(self.alpha_desc, assumed_align=16)
+ prob_tensor = self._make_fake_cute_tensor(
+ dtype=self.prob_desc.dtype,
+ shape=(valid_m, *self.prob_desc.shape[1:]),
+ stride=self.prob_desc.stride,
+ assumed_align=16,
+ )
+ bias_cute_fake = self._make_fake_cute_tensor_from_desc(self.bias_desc, assumed_align=16)
+
+ b_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda")
+ sfb_ptrs_placeholder = torch.empty((self.expert_cnt,), dtype=torch.int64, device="cuda")
+ b_ptrs_cute = from_dlpack(b_ptrs_placeholder, assumed_align=8).iterator
+ sfb_ptrs_cute = from_dlpack(sfb_ptrs_placeholder, assumed_align=8).iterator
+ workspace_ptr_cute = from_dlpack(self._workspace, assumed_align=128).iterator
+
+ self._logger.debug("Compiling discrete grouped_gemm_srelu kernel")
+ _compiled_kernel = cute.compile(
+ gemm_srelu,
+ a=a_tensor,
+ b=b_ptrs_cute,
+ sfb=sfb_ptrs_cute,
+ n=cutlass.Int32(n),
+ k=cutlass.Int32(k),
+ b_stride_size=cutlass.Int64(b_stride_size),
+ b_major_mode=b_major_mode,
+ workspace_ptr=workspace_ptr_cute,
+ c=c_tensor,
+ d=d_tensor,
+ d_col=d_col_tensor,
+ sfa=sfa_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor_cute,
+ padded_offsets=padded_offsets_tensor,
+ alpha=alpha_tensor,
+ bias=bias_cute_fake,
+ prob=prob_tensor,
+ max_active_clusters=max_active_clusters,
+ stream=fake_stream,
+ epilogue_op=lambda x: x,
+ options="--enable-tvm-ffi",
+ )
+
+ cached_workspace_ptr = from_dlpack(self._workspace, assumed_align=128).iterator
+ cached_n = cutlass.Int32(n)
+ cached_k = cutlass.Int32(k)
+ cached_b_stride = cutlass.Int64(b_stride_size)
+
+ def tensor_api(
+ a_tensor: torch.Tensor,
+ b_ptrs_device: torch.Tensor,
+ sfb_ptrs_device: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ d_col_tensor: Optional[torch.Tensor],
+ sfa_tensor: torch.Tensor,
+ sfd_row_tensor: Optional[torch.Tensor],
+ sfd_col_tensor: Optional[torch.Tensor],
+ amax_tensor: Optional[torch.Tensor],
+ norm_const_tensor: Optional[torch.Tensor],
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ prob_tensor: Optional[torch.Tensor],
+ bias_tensor: Optional[torch.Tensor],
+ stream: cuda.CUstream,
+ ) -> None:
+ norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const")
+ b_ptrs_addr = int(b_ptrs_device.data_ptr())
+ sfb_ptrs_addr = int(sfb_ptrs_device.data_ptr())
+ _compiled_kernel(
+ a_tensor,
+ b_ptrs_addr,
+ sfb_ptrs_addr,
+ cached_n,
+ cached_k,
+ cached_b_stride,
+ cached_workspace_ptr,
+ c_tensor,
+ d_tensor,
+ d_col_tensor,
+ sfa_tensor,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ amax_tensor,
+ norm_const_tensor,
+ padded_offsets,
+ alpha_tensor,
+ bias_tensor,
+ prob_tensor,
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ a_tensor: torch.Tensor,
+ sfa_tensor: torch.Tensor,
+ padded_offsets: torch.Tensor,
+ alpha_tensor: torch.Tensor,
+ c_tensor: torch.Tensor,
+ d_tensor: torch.Tensor,
+ # Dense mode:
+ b_tensor: Optional[torch.Tensor] = None,
+ sfb_tensor: Optional[torch.Tensor] = None,
+ bias_tensor: Optional[torch.Tensor] = None,
+ # Discrete mode:
+ b_ptrs: Optional[torch.Tensor] = None,
+ sfb_ptrs: Optional[torch.Tensor] = None,
+ d_col_tensor: Optional[torch.Tensor] = None,
+ sfd_row_tensor: Optional[torch.Tensor] = None,
+ sfd_col_tensor: Optional[torch.Tensor] = None,
+ amax_tensor: Optional[torch.Tensor] = None,
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ prob_tensor: Optional[torch.Tensor] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ """Execute the compiled kernel.
+
+ :param a_tensor: Input A tensor
+ :param sfa_tensor: Scale factor A
+ :param padded_offsets: End offset per expert after padding
+ :param alpha_tensor: Per-group scaling factors
+ :param c_tensor: Output C tensor before SReLU
+ :param d_tensor: Output D tensor
+ :param b_tensor: (Dense) Input B tensor (weights)
+ :param sfb_tensor: (Dense) Scale factor B
+ :param bias_tensor: Optional bias tensor with shape (n, l) and stride (1, n).
+ Bias fusion is specialized at compile time: if ``sample_bias`` was omitted
+ at construction, ``bias_tensor`` must also be omitted at execute time.
+ :param b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers
+ :param sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers
+ :param d_col_tensor: Optional column-sreluized output
+ :param sfd_row_tensor: Optional row scale factor D
+ :param sfd_col_tensor: Optional column scale factor D
+ :param amax_tensor: Optional amax tensor
+ :param norm_const_tensor: Optional normalization constant
+ :param prob_tensor: Probability tensor for per-row gating. Required.
+ :param current_stream: CUDA stream
+ """
+ self._logger.debug("Entering execute")
+ current_stream = self._get_default_stream(current_stream)
+
+ if a_tensor.shape[0] == 0:
+ self._logger.debug("execute: valid_m is zero, skipping kernel execution")
+ return
+ self._runtime_error_if(
+ self._compiled_kernel is None,
+ "Kernel not compiled; call compile() first",
+ )
+
+ if d_col_tensor is None:
+ self._value_error_if(
+ self.generate_sfd,
+ "d_col_tensor is required when SFD outputs are generated",
+ )
+ d_col_tensor = d_tensor
+ self._value_error_if(
+ prob_tensor is None,
+ "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. "
+ "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed.",
+ )
+ if self._has_bias:
+ self._value_error_if(
+ bias_tensor is None,
+ "bias_tensor must be provided at execute() when the API was compiled with sample_bias",
+ )
+ else:
+ self._value_error_if(
+ bias_tensor is not None,
+ "bias_tensor must be omitted at execute() when the API was compiled without sample_bias",
+ )
+
+ self._logger.debug("Executing grouped_gemm_srelu kernel")
+ if self.weight_mode == MoEWeightMode.DENSE:
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_tensor=b_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ sfb_tensor=sfb_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ bias_tensor=bias_tensor,
+ stream=current_stream,
+ )
+ else:
+ self._compiled_kernel(
+ a_tensor=a_tensor,
+ b_ptrs_device=b_ptrs,
+ sfb_ptrs_device=sfb_ptrs,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ d_col_tensor=d_col_tensor,
+ sfa_tensor=sfa_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ prob_tensor=prob_tensor,
+ bias_tensor=bias_tensor,
+ stream=current_stream,
+ )
+
+ self._logger.debug("Execute completed")
+
+
+import logging
+
+_logger = logging.getLogger(__name__)
+_cache_of_GroupedGemmSreluSm100Objects = {}
+
+
+def grouped_gemm_srelu_wrapper_sm100(
+ a_tensor: torch.Tensor,
+ b_tensor: Optional[torch.Tensor] = None,
+ sfa_tensor: Optional[torch.Tensor] = None,
+ sfb_tensor: Optional[torch.Tensor] = None,
+ padded_offsets: Optional[torch.Tensor] = None,
+ alpha_tensor: Optional[torch.Tensor] = None,
+ bias_tensor: Optional[torch.Tensor] = None,
+ b_ptrs: Optional[torch.Tensor] = None,
+ sfb_ptrs: Optional[torch.Tensor] = None,
+ n: Optional[int] = None,
+ b_dtype: Optional[torch.dtype] = None,
+ b_major: str = "k",
+ norm_const_tensor: Optional[torch.Tensor] = None,
+ prob_tensor: Optional[torch.Tensor] = None,
+ acc_dtype: torch.dtype = torch.float32,
+ c_dtype: torch.dtype = torch.bfloat16,
+ d_dtype: torch.dtype = torch.bfloat16,
+ cd_major: str = "n",
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
+ cluster_shape_mn: Optional[Tuple[int, int]] = None,
+ sf_vec_size: int = 16,
+ vector_f32: bool = False,
+ m_aligned: int = 256,
+ discrete_col_sfd: bool = False,
+ use_dynamic_sched: bool = False,
+ current_stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ """Convenience wrapper for grouped GEMM SReLU operation.
+
+ This function creates the API, compiles, and executes in one call.
+ Compiled kernels are cached for reuse when called with the same configuration.
+
+ Args:
+ a_tensor: Input A tensor (valid_m, k, 1)
+ sfa_tensor: Scale factor A
+ padded_offsets: End offset per expert after padding (l,)
+ alpha_tensor: Per-group scaling
+ b_tensor: (Dense) Weight B tensor (n, k, l)
+ sfb_tensor: (Dense) Scale factor B
+ bias_tensor: Optional per-expert bias, shape ``(n, l)`` in dense mode or ``(n, num_experts)``
+ in discrete mode, stride ``(1, n)``. Bias fusion requires ``mma_tiler_mn[1] == 256``.
+ b_ptrs: (Discrete) 1-D int64 device tensor of per-expert B data pointers
+ sfb_ptrs: (Discrete) 1-D int64 device tensor of per-expert SFB data pointers
+ n: (Discrete) B weight N dimension
+ b_dtype: (Discrete) B weight data type
+ b_major: (Discrete) B tensor major dimension ("k" or "n")
+ norm_const_tensor: Optional normalization constant. Required when using FP8
+ input configurations (i.e., when a_tensor.dtype is FP8 and sfa_tensor.dtype is FP8).
+ Should be None for FP4/BF16 input configurations.
+ prob_tensor: Probability tensor for per-row gating (shape `(valid_m, 1, 1)`).
+ This argument is required. Pass a tensor of ones when no gating is needed.
+ acc_dtype: Accumulator data type
+ c_dtype: Output C tensor data type
+ d_dtype: Output D tensor data type
+ cd_major: CD major dimension (only "n"-major layout is supported)
+ mma_tiler_mn: MMA tiler shape
+ cluster_shape_mn: Cluster shape
+ sf_vec_size: Scale factor vector size
+ vector_f32: Use vectorized f32
+ m_aligned: M alignment (must be 256)
+ discrete_col_sfd: Enable discrete col-major scale factor tensor
+ current_stream: CUDA stream
+
+ Returns:
+ TupleDict: A dictionary-like object containing output tensors that can also be unpacked as a tuple.
+ Dictionary keys (also the unpacking order):
+ - **c_tensor** (torch.Tensor): Accumulator output tensor
+ - **d_tensor** (torch.Tensor): Final output tensor
+ - **d_col_tensor** (torch.Tensor or None): Column-wise output tensor for low-precision D output
+ - **amax_tensor** (torch.Tensor or None): Absolute maximum values (for SReLU output quantization)
+ - **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D (FP8 only)
+ - **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D (FP8 only)
+
+ Example usage::
+
+ # Dictionary-style access
+ result = grouped_gemm_srelu_wrapper_sm100(...)
+ c = result["c_tensor"]
+ d = result["d_tensor"]
+
+ # Tuple unpacking
+ c, d, d_col, amax, sfd_row, sfd_col = grouped_gemm_srelu_wrapper_sm100(...)
+
+ # Integer indexing
+ c = result[0] # c_tensor
+ d = result[1] # d_tensor
+ """
+ from cudnn.discrete_grouped_gemm.discrete_kernel_utils import _require_pointer_tensor
+
+ is_dense = b_tensor is not None
+ is_discrete = b_ptrs is not None
+
+ if is_dense and is_discrete:
+ raise ValueError("Provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs), not both")
+ if not is_dense and not is_discrete:
+ raise ValueError("Must provide either (b_tensor, sfb_tensor) or (b_ptrs, sfb_ptrs)")
+
+ valid_m, k_physical, _ = a_tensor.shape
+ if is_dense:
+ weight_mode = MoEWeightMode.DENSE
+ n_out, _, l = b_tensor.shape
+ if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, l):
+ raise ValueError(f"bias_tensor must have shape {(n_out, l)}, got {tuple(bias_tensor.shape)}")
+ else:
+ weight_mode = MoEWeightMode.DISCRETE
+ _require_pointer_tensor(b_ptrs, "b_ptrs")
+ num_experts = b_ptrs.shape[0]
+ _require_pointer_tensor(sfb_ptrs, "sfb_ptrs", num_experts)
+ if n is None or b_dtype is None:
+ raise ValueError("n and b_dtype are required for discrete mode")
+ k_logical = k_physical * 2 if b_dtype in (torch.float4_e2m1fn_x2, torch.uint8) else k_physical
+ b_shape = (n, k_logical)
+ n_out = n
+ l = num_experts
+ if bias_tensor is not None and tuple(bias_tensor.shape) != (n_out, num_experts):
+ raise ValueError(f"bias_tensor must have shape {(n_out, num_experts)}, got {tuple(bias_tensor.shape)}")
+
+ is_fp8_input_config = a_tensor.dtype in [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ ] and sfa_tensor.dtype in [
+ torch.float8_e8m0fnu,
+ torch.float8_e4m3fn,
+ ]
+ is_low_precision_output_config = d_dtype in [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.float4_e2m1fn_x2,
+ ]
+
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: Creating output tensors")
+
+ if cd_major == "n":
+ c_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=c_dtype, device=a_tensor.device)
+ d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ d_col_tensor = (
+ torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device)
+ if is_low_precision_output_config
+ else None
+ )
+ else:
+ raise ValueError(f"cd_major must be 'n', got {cd_major}")
+
+ sfd_row_tensor = None
+ sfd_col_tensor = None
+ amax_tensor = None
+
+ if is_fp8_input_config and is_low_precision_output_config and norm_const_tensor is None:
+ raise ValueError(
+ "norm_const_tensor is required when FP8 inputs are used with FP8 output "
+ "(a_tensor is FP8 and sfa_tensor is FP8 and d_dtype is FP8). "
+ "Pass a tensor with shape (1,), e.g. torch.tensor([0.01], dtype=torch.float32, device=a_tensor.device)."
+ )
+
+ if not is_low_precision_output_config:
+ norm_const_tensor = None
+
+ if is_fp8_input_config and is_low_precision_output_config:
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor")
+
+ sf_dtype = sfa_tensor.dtype
+ mma_permute_order = (3, 4, 1, 5, 2, 0)
+
+ sf_k_row = ceil_div(n_out, sf_vec_size)
+ mma_shape_row = (
+ 1,
+ ceil_div(valid_m, 128),
+ ceil_div(sf_k_row, 4),
+ 32,
+ 4,
+ 4,
+ )
+ sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+
+ sf_k_col = ceil_div(valid_m, sf_vec_size)
+ mma_shape_col = (
+ 1,
+ ceil_div(n_out, 128),
+ ceil_div(sf_k_col, 4),
+ 32,
+ 4,
+ 4,
+ )
+ sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
+
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor")
+ amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+
+ if prob_tensor is None:
+ raise ValueError(
+ "prob_tensor is required: the kernel unconditionally multiplies output by per-row gating probability. "
+ "Pass a tensor of ones with shape (valid_m, 1, 1) if no gating is needed."
+ )
+
+ if valid_m == 0:
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: valid_m is zero, skipping kernel execution")
+ return TupleDict(
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ d_col_tensor=d_col_tensor,
+ amax_tensor=amax_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ )
+
+ def tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return tuple(tensor.shape), tuple(tensor.stride()), tensor.dtype
+
+ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
+ return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
+
+ def dynamic_tensor_signature(tensor: Optional[torch.Tensor]) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ return None, stride_order(tensor), tensor.dtype
+
+ def dynamic_m_tensor_signature(
+ tensor: Optional[torch.Tensor], static_shape_suffix: Tuple[int, ...], dynamic_stride_dims: Tuple[int, ...] = ()
+ ) -> Tuple[Optional[Tuple[int, ...]], Optional[Tuple[int, ...]], Optional[torch.dtype]]:
+ if tensor is None:
+ return None, None, None
+ stride_signature = tuple(None if i in dynamic_stride_dims else s for i, s in enumerate(tensor.stride()))
+ return static_shape_suffix, stride_signature, tensor.dtype
+
+ use_full_dynamic = is_dense and os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
+
+ if is_dense:
+ cache_key = (
+ weight_mode,
+ use_full_dynamic,
+ a_tensor.shape[1:] if not use_full_dynamic else None,
+ b_tensor.shape[2] if use_full_dynamic else tuple(b_tensor.shape),
+ a_tensor.dtype,
+ b_tensor.dtype,
+ stride_order(a_tensor),
+ stride_order(b_tensor),
+ c_tensor.shape[1:] if not use_full_dynamic else None,
+ stride_order(c_tensor),
+ c_tensor.dtype,
+ d_tensor.shape[1:] if not use_full_dynamic else None,
+ stride_order(d_tensor),
+ *(
+ dynamic_tensor_signature(sfa_tensor)
+ if use_full_dynamic
+ else dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,))
+ ),
+ *(dynamic_tensor_signature(sfb_tensor) if use_full_dynamic else tensor_signature(sfb_tensor)),
+ *(dynamic_tensor_signature(bias_tensor) if use_full_dynamic else tensor_signature(bias_tensor)),
+ *tensor_signature(alpha_tensor),
+ *tensor_signature(norm_const_tensor),
+ *dynamic_m_tensor_signature(prob_tensor, (1, 1)),
+ tuple(padded_offsets.shape),
+ tuple(padded_offsets.stride()),
+ padded_offsets.dtype,
+ acc_dtype,
+ d_dtype,
+ cd_major,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ sf_vec_size,
+ vector_f32,
+ m_aligned,
+ discrete_col_sfd,
+ use_dynamic_sched,
+ )
+ else:
+ cache_key = (
+ weight_mode,
+ a_tensor.shape[1:],
+ stride_order(a_tensor),
+ a_tensor.dtype,
+ b_shape,
+ b_dtype,
+ c_tensor.shape[1:],
+ stride_order(c_tensor),
+ c_tensor.dtype,
+ d_tensor.shape[1:],
+ stride_order(d_tensor),
+ *dynamic_m_tensor_signature(sfa_tensor, (sfa_tensor.shape[4], 1) if sfa_tensor is not None else None, dynamic_stride_dims=(5,)),
+ *tensor_signature(bias_tensor),
+ *tensor_signature(alpha_tensor),
+ *tensor_signature(norm_const_tensor),
+ *dynamic_m_tensor_signature(prob_tensor, (1, 1)),
+ tuple(b_ptrs.shape),
+ tuple(b_ptrs.stride()),
+ b_ptrs.dtype,
+ tuple(sfb_ptrs.shape),
+ tuple(sfb_ptrs.stride()),
+ sfb_ptrs.dtype,
+ tuple(padded_offsets.shape),
+ tuple(padded_offsets.stride()),
+ padded_offsets.dtype,
+ acc_dtype,
+ d_dtype,
+ cd_major,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ sf_vec_size,
+ vector_f32,
+ m_aligned,
+ discrete_col_sfd,
+ use_dynamic_sched,
+ b_major,
+ num_experts,
+ )
+
+ if cache_key in _cache_of_GroupedGemmSreluSm100Objects:
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: Using previously cached GroupedGemmSreluSm100 object")
+ grouped_gemm_srelu = _cache_of_GroupedGemmSreluSm100Objects[cache_key]
+ else:
+ _logger.debug("grouped_gemm_srelu_wrapper_sm100: No previously cached object found, creating new GroupedGemmSreluSm100 object")
+ if is_dense:
+ grouped_gemm_srelu = GroupedGemmSreluSm100(
+ sample_a=a_tensor,
+ sample_sfa=sfa_tensor,
+ sample_padded_offsets=padded_offsets,
+ sample_alpha=alpha_tensor,
+ sample_c=c_tensor,
+ sample_d=d_tensor,
+ sample_d_col=d_col_tensor,
+ sample_b=b_tensor,
+ sample_sfb=sfb_tensor,
+ sample_bias=bias_tensor,
+ sample_amax=amax_tensor,
+ sample_sfd_row=sfd_row_tensor,
+ sample_sfd_col=sfd_col_tensor,
+ sample_norm_const=norm_const_tensor,
+ sample_prob=prob_tensor,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ m_aligned=m_aligned,
+ discrete_col_sfd=discrete_col_sfd,
+ use_dynamic_sched=use_dynamic_sched,
+ )
+ else:
+ grouped_gemm_srelu = GroupedGemmSreluSm100(
+ sample_a=a_tensor,
+ sample_sfa=sfa_tensor,
+ sample_padded_offsets=padded_offsets,
+ sample_alpha=alpha_tensor,
+ sample_c=c_tensor,
+ sample_d=d_tensor,
+ sample_d_col=d_col_tensor,
+ num_experts=num_experts,
+ b_shape=b_shape,
+ b_dtype=b_dtype,
+ sample_bias=bias_tensor,
+ sample_amax=amax_tensor,
+ sample_sfd_row=sfd_row_tensor,
+ sample_sfd_col=sfd_col_tensor,
+ sample_norm_const=norm_const_tensor,
+ sample_prob=prob_tensor,
+ acc_dtype=acc_dtype,
+ mma_tiler_mn=mma_tiler_mn,
+ cluster_shape_mn=cluster_shape_mn,
+ sf_vec_size=sf_vec_size,
+ vector_f32=vector_f32,
+ m_aligned=m_aligned,
+ discrete_col_sfd=discrete_col_sfd,
+ use_dynamic_sched=use_dynamic_sched,
+ b_major=b_major,
+ )
+
+ assert grouped_gemm_srelu.check_support(), "Unsupported configuration"
+ grouped_gemm_srelu.compile()
+ _cache_of_GroupedGemmSreluSm100Objects[cache_key] = grouped_gemm_srelu
+
+ if is_dense:
+ grouped_gemm_srelu.execute(
+ a_tensor=a_tensor,
+ sfa_tensor=sfa_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ b_tensor=b_tensor,
+ sfb_tensor=sfb_tensor,
+ d_col_tensor=d_col_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ prob_tensor=prob_tensor,
+ bias_tensor=bias_tensor,
+ current_stream=current_stream,
+ )
+ else:
+ grouped_gemm_srelu.execute(
+ a_tensor=a_tensor,
+ sfa_tensor=sfa_tensor,
+ padded_offsets=padded_offsets,
+ alpha_tensor=alpha_tensor,
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ b_ptrs=b_ptrs,
+ sfb_ptrs=sfb_ptrs,
+ d_col_tensor=d_col_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ amax_tensor=amax_tensor,
+ norm_const_tensor=norm_const_tensor,
+ prob_tensor=prob_tensor,
+ bias_tensor=bias_tensor,
+ current_stream=current_stream,
+ )
+
+ return TupleDict(
+ c_tensor=c_tensor,
+ d_tensor=d_tensor,
+ d_col_tensor=d_col_tensor,
+ amax_tensor=amax_tensor,
+ sfd_row_tensor=sfd_row_tensor,
+ sfd_col_tensor=sfd_col_tensor,
+ )
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py
new file mode 100644
index 00000000..e1c3245b
--- /dev/null
+++ b/python/cudnn/grouped_gemm/grouped_gemm_srelu/moe_blockscaled_grouped_gemm_srelu_quant.py
@@ -0,0 +1,2139 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""
+MoE Block-Scaled Grouped GEMM Kernel with Quantization and SReLU Support.
+
+Supports:
+ - Static / Dynamic persistent tile scheduling (MoEPersistentTileScheduler)
+ - Dense (contiguous 3-D B) / Discrete (per-expert pointer array B) weight layout
+ - FP8/FP4 output quantization with row/column scale factors (SFD)
+ - Optional bias and routing-probability (prob) fusion
+ - Optional C output (generate_c)
+ - AMAX reduction for FP8 calibration
+ - SReLU epilogue activation fusion (max(x,0)^2)
+
+This module contains only the kernel class.
+MoE scheduler components live in moe_persistent_scheduler.py / moe_sched_extension.py / moe_utils.py.
+"""
+
+from enum import Enum
+from typing import Type, Tuple, Union, Optional
+
+import cuda.bindings.driver as cuda
+
+import cutlass
+import cutlass.cute as cute
+from cutlass.cute.nvgpu import cpasync, tcgen05
+from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
+import cutlass.utils as utils
+import cutlass.pipeline as pipeline
+import cutlass.utils.blackwell_helpers as sm100_utils
+import cutlass.utils.blockscaled_layout as blockscaled_utils
+from cutlass._mlir.dialects.nvvm import ReduxKind
+from cutlass.cute.typing import Float32, Int32, AddressSpace
+from ..moe_persistent_scheduler import (
+ MoEPersistentTileScheduler,
+ MoESchedulerParams,
+ MoEWorkTileInfo,
+)
+from ..moe_utils import (
+ compute_expert_token_range,
+ MoEWeightMode,
+ TensormapWorkspace,
+ store_tma_desc,
+)
+from ..moe_sched_extension import (
+ DiscreteWeightScaledGemmSchedExtension,
+ ContiguousAndConsistentGroupedGemmSchedExtension,
+)
+from ..moe_kernel_helpers import (
+ fmin,
+ fmax,
+ warp_redux_sync,
+ atomic_max_float32,
+ compute_stages,
+ compute_grid,
+ can_implement,
+ amax_reduction_per_thread,
+ epilog_gmem_copy_and_partition,
+ get_dtype_rcp_limits,
+)
+
+
+class EpilogueType(Enum):
+ NONE = 0
+ SRELU = 1
+
+
+class BlockScaledMoEGroupedGemmQuantKernel:
+ """Block-scaled grouped GEMM kernel with MoE tile scheduling and quantization.
+
+ Supports both dense and discrete weight layouts, static and dynamic
+ scheduling, and quantized output with row/column scale factors.
+
+ :param sf_vec_size: Scale-factor vector size (16 or 32).
+ :param acc_dtype: Accumulator data type (Float32).
+ :param use_2cta_instrs: Use 2-CTA MMA instructions.
+ :param mma_tiler_mn: MMA tile shape (M, N).
+ :param cluster_shape_mn: Cluster shape (M, N).
+ :param vectorized_f32: Use packed FP32 arithmetic.
+ :param generate_sfd: Generate output scale factors.
+ :param discrete_col_sfd: Use discrete column SFD layout.
+ :param generate_c: Generate C output tensor.
+ :param enable_bias: Fuse bias addition.
+ :param expert_cnt: Number of experts.
+ :param weight_mode: ``MoEWeightMode.DENSE`` or ``MoEWeightMode.DISCRETE``.
+ :param use_dynamic_sched: Enable dynamic tile scheduling.
+ :param epilogue_type: Epilogue activation type (``EpilogueType.NONE`` or ``EpilogueType.SRELU``).
+ """
+
+ FIX_PAD_SIZE = 256
+
+ @staticmethod
+ def can_implement(
+ ab_dtype: Type[cutlass.Numeric],
+ sf_dtype: Type[cutlass.Numeric],
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ d_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ m: int,
+ n: int,
+ k: int,
+ l: int,
+ a_major: str,
+ b_major: str,
+ cd_major: str,
+ m_aligned: int,
+ ) -> bool:
+ return can_implement(
+ ab_dtype,
+ sf_dtype,
+ sf_vec_size,
+ acc_dtype,
+ d_dtype,
+ use_2cta_instrs,
+ mma_tiler_mn,
+ cluster_shape_mn,
+ m,
+ n,
+ k,
+ l,
+ a_major,
+ b_major,
+ cd_major,
+ m_aligned,
+ fix_pad_size=BlockScaledMoEGroupedGemmQuantKernel.FIX_PAD_SIZE,
+ )
+
+ def __init__(
+ self,
+ sf_vec_size: int,
+ acc_dtype: Type[cutlass.Numeric],
+ use_2cta_instrs: bool,
+ mma_tiler_mn: Tuple[int, int],
+ cluster_shape_mn: Tuple[int, int],
+ vectorized_f32: bool,
+ generate_sfd: bool,
+ discrete_col_sfd: bool,
+ generate_c: bool,
+ enable_bias: bool,
+ expert_cnt: int,
+ weight_mode: MoEWeightMode = MoEWeightMode.DENSE,
+ use_dynamic_sched: bool = False,
+ epilogue_type: int = EpilogueType.NONE.value,
+ ):
+ mma_tile_m = mma_tiler_mn[0]
+ if self.FIX_PAD_SIZE % mma_tile_m != 0:
+ raise ValueError(
+ f"FIX_PAD_SIZE ({self.FIX_PAD_SIZE}) must be divisible by " f"mma_tiler_mn[0] ({mma_tile_m}). " f"Supported mma_tiler_mn[0] values: 128, 256."
+ )
+ if expert_cnt > 1024:
+ raise ValueError("Expert count > 1024 is not supported.")
+ if not isinstance(weight_mode, MoEWeightMode):
+ raise TypeError(f"weight_mode must be a MoEWeightMode, got {type(weight_mode)}")
+
+ self.sf_vec_size = sf_vec_size
+ self.expert_cnt = expert_cnt
+ self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
+ self.use_2cta_instrs = use_2cta_instrs
+ self.cluster_shape_mn = cluster_shape_mn
+ self.mma_tiler = (*mma_tiler_mn, 1)
+
+ self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
+
+ self.occupancy = 1
+ self.epilog_warp_id = (0, 1, 2, 3)
+ self.mma_warp_id = 4
+ self.tma_warp_id = 5
+ self.sched_warp_id = 6
+ self.bias_load_warp_id = 7 if enable_bias else None
+ self.threads_per_warp = 32
+ all_warps = [
+ *self.epilog_warp_id,
+ self.mma_warp_id,
+ self.tma_warp_id,
+ self.sched_warp_id,
+ ]
+ warps_wo_sched = [*self.epilog_warp_id, self.mma_warp_id, self.tma_warp_id]
+ if enable_bias:
+ all_warps.append(self.bias_load_warp_id)
+ warps_wo_sched.append(self.bias_load_warp_id)
+
+ self.threads_per_cta = self.threads_per_warp * len(all_warps)
+ self.threads_wo_sched = self.threads_per_warp * len(warps_wo_sched)
+
+ self.cta_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=1,
+ num_threads=self.threads_per_cta,
+ )
+ self.epilog_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=2,
+ num_threads=32 * len(self.epilog_warp_id),
+ )
+ self.tmem_alloc_barrier = pipeline.NamedBarrier(
+ barrier_id=3,
+ num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
+ )
+ self.sched_sync_barrier = pipeline.NamedBarrier(
+ barrier_id=4,
+ num_threads=self.threads_per_warp,
+ )
+ self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
+ SM100_TMEM_CAPACITY_COLUMNS = 512
+ self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
+
+ self.vectorized_f32 = vectorized_f32
+ self.generate_sfd = generate_sfd
+ self.discrete_col_sfd = discrete_col_sfd
+ self.generate_c = generate_c
+ self.enable_bias = enable_bias
+
+ self.weight_mode = weight_mode
+ self.use_dynamic_sched = use_dynamic_sched
+
+ self.epilogue_use_functor = False
+ self.epilogue_type = epilogue_type
+
+ self.num_epilog_warps = len(self.epilog_warp_id)
+
+ # ------------------------------------------------------------------
+ # _setup_attributes
+ # ------------------------------------------------------------------
+
+ def _setup_attributes(self):
+ """Configure MMA / tile / stage / SMEM layouts from GEMM inputs."""
+
+ self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1])
+ self.mma_inst_shape_mn_sfb = (
+ self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
+ cute.round_up(self.mma_inst_shape_mn[1], 128),
+ )
+
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+
+ mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
+ mma_inst_tile_k = 4
+ self.mma_tiler = (
+ self.mma_tiler[0],
+ self.mma_tiler[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.mma_tiler_sfb = (
+ self.mma_inst_shape_mn_sfb[0],
+ self.mma_inst_shape_mn_sfb[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+
+ self.cta_tile_shape_mnk = (
+ self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler[1],
+ self.mma_tiler[2],
+ )
+ self.cta_tile_shape_mnk_sfb = (
+ self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_sfb[1],
+ self.mma_tiler_sfb[2],
+ )
+
+ self.mma_tiler_d = (
+ self.mma_inst_shape_mn[0],
+ self.mma_inst_shape_mn[1],
+ mma_inst_shape_k * mma_inst_tile_k,
+ )
+ self.cta_tile_shape_mnk_d = (
+ self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape),
+ self.mma_tiler_d[1],
+ self.mma_tiler_d[2],
+ )
+
+ self.cluster_layout_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma.thr_id.shape,),
+ )
+ self.cluster_layout_sfb_vmnk = cute.tiled_divide(
+ cute.make_layout((*self.cluster_shape_mn, 1)),
+ (tiled_mma_sfb.thr_id.shape,),
+ )
+
+ self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
+ self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
+ self.is_a_mcast = self.num_mcast_ctas_a > 1
+ self.is_b_mcast = self.num_mcast_ctas_b > 1
+
+ self.epi_tile = (128, 32)
+
+ (
+ self.num_acc_stage,
+ self.num_ab_stage,
+ self.num_c_stage,
+ self.num_d_stage,
+ self.num_tile_stage,
+ self.num_bias_stage,
+ ) = self._compute_stages(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.b_dtype,
+ self.epi_tile,
+ self.c_dtype,
+ self.c_layout,
+ self.d_dtype,
+ self.d_layout,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.num_smem_capacity,
+ self.occupancy,
+ self.generate_sfd,
+ self.generate_c,
+ self.bias_dtype if self.enable_bias else None,
+ )
+
+ self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
+ tiled_mma,
+ self.mma_tiler,
+ self.a_dtype,
+ self.num_ab_stage,
+ )
+ self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
+ tiled_mma,
+ self.mma_tiler,
+ self.b_dtype,
+ self.num_ab_stage,
+ )
+ self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ self.num_ab_stage,
+ )
+ self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.c_dtype,
+ self.c_layout,
+ self.epi_tile,
+ self.num_c_stage,
+ )
+ self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi(
+ self.d_dtype,
+ self.d_layout,
+ self.epi_tile,
+ self.num_d_stage,
+ )
+
+ if self.enable_bias:
+ self.bias_smem_layout_staged = cute.make_layout(
+ (self.mma_tiler[1], self.num_bias_stage),
+ stride=(1, self.mma_tiler[1]),
+ )
+ else:
+ self.bias_smem_layout_staged = cute.make_layout((1, 1))
+
+ self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256
+
+ sf_atom_mn = 32
+ self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
+ self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
+ self.num_accumulator_tmem_cols = (
+ self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
+ )
+
+ self.epi_tile_n_required = cute.size(self.epi_tile[1])
+ self.iter_acc_early_release_in_epilogue = (self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1
+
+ # ------------------------------------------------------------------
+ # _compute_stages (with bias support)
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _compute_stages(
+ tiled_mma,
+ mma_tiler_mnk,
+ a_dtype,
+ b_dtype,
+ epi_tile,
+ c_dtype,
+ c_layout,
+ d_dtype,
+ d_layout,
+ sf_dtype,
+ sf_vec_size,
+ num_smem_capacity,
+ occupancy,
+ generate_sfd,
+ generate_c,
+ bias_dtype,
+ ):
+ num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2
+ num_c_stage = 2 if generate_sfd else 1
+ num_d_stage = 2 if generate_sfd else 1
+ num_tile_stage = 2
+
+ a_smem_layout_stage_one = sm100_utils.make_smem_layout_a(tiled_mma, mma_tiler_mnk, a_dtype, 1)
+ b_smem_layout_staged_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler_mnk, b_dtype, 1)
+ sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb(tiled_mma, mma_tiler_mnk, sf_vec_size, 1)
+ c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1)
+ d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1)
+
+ ab_bytes_per_stage = (
+ cute.size_in_bytes(a_dtype, a_smem_layout_stage_one)
+ + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one)
+ + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one)
+ )
+ mbar_helpers_bytes = 1024
+ sinfo_bytes = 4 * 4 * num_tile_stage
+ c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one)
+ c_bytes = c_bytes_per_stage * num_c_stage
+ d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
+ d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1)
+ amax_bytes = 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) if d_dtype == cutlass.BFloat16 else 0
+
+ if bias_dtype is not None:
+ num_bias_stage = 2
+ bias_epi_tile_n = mma_tiler_mnk[1]
+ bias_bytes = bias_epi_tile_n * num_bias_stage * (bias_dtype.width // 8)
+ else:
+ num_bias_stage = 0
+ bias_bytes = 0
+
+ epi_bytes = c_bytes + d_bytes + amax_bytes + bias_bytes
+ num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage
+
+ # cute.printf("num_acc_stage: %d, num_ab_stage: %d, num_c_stage: %d, num_d_stage: %d, num_tile_stage: %d, num_bias_stage: %d\n", num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage)
+ return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage, num_bias_stage
+
+ # ------------------------------------------------------------------
+ # Workspace helpers
+ # ------------------------------------------------------------------
+
+ def get_desc_workspace_bytes(self) -> int:
+ if self.weight_mode == MoEWeightMode.DISCRETE:
+ from ..moe_utils import DiscreteWeightTensormapConstructor
+
+ return DiscreteWeightTensormapConstructor.get_workspace_size(self.expert_cnt)
+ return 0
+
+ def get_workspace_bytes(self) -> int:
+ desc_workspace_bytes = self.get_desc_workspace_bytes()
+ dynamic_sched_bytes = 4 if self.use_dynamic_sched else 0
+ return desc_workspace_bytes + dynamic_sched_bytes
+
+ @cute.jit
+ def _get_sched_counter_ptr(self, workspace_ptr):
+ counter_addr = workspace_ptr.toint() + self.get_desc_workspace_bytes()
+ return cute.make_ptr(
+ cutlass.Int32,
+ counter_addr,
+ AddressSpace.gmem,
+ assumed_align=4,
+ )
+
+ # ------------------------------------------------------------------
+ # helper_kernel: pre-main-kernel initialization
+ # - discrete weight: build per-expert B/SFB TMA descriptors
+ # - dynamic sched: reset the atomic tile counter
+ # ------------------------------------------------------------------
+
+ @cute.kernel
+ def helper_kernel(
+ self,
+ # Discrete-only params (unused in dense mode, but must be present for signature)
+ ptrs_b: cute.Pointer,
+ ptrs_sfb: cute.Pointer,
+ n: Int32,
+ k: Int32,
+ b_stride_size: cutlass.Int64,
+ b_major_mode: cutlass.Constexpr,
+ workspace_ptr,
+ tiled_mma_arg: cute.TiledMma,
+ tiled_mma_sfb_arg: cute.TiledMma,
+ b_smem_layout_arg,
+ sfb_smem_layout_arg,
+ cluster_layout_vmnk_shape_arg: cutlass.Constexpr,
+ cluster_layout_sfb_vmnk_shape_arg: cutlass.Constexpr,
+ ):
+ """Pre-main-kernel initialization.
+
+ Launched with grid=(expert_cnt, 1, 1) for discrete mode, or
+ grid=(1, 1, 1) for dense+dynamic mode.
+
+ Discrete weight: each block builds B/SFB TMA descriptors for one expert.
+ Dynamic sched: block 0 resets the atomic tile counter to 0.
+ """
+ expert_idx = cute.arch.block_idx()[0]
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ b_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+ sfb_tma_op_arg = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma_arg.thr_id)
+
+ b_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_b.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,))
+ )
+ sfb_ptr_tensor = cute.make_tensor(
+ cute.make_ptr(cutlass.Int64, ptrs_sfb.toint(), AddressSpace.gmem, assumed_align=8), cute.make_layout((self.expert_cnt,))
+ )
+
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ stride_n = b_stride_size
+ stride_k = c1_64
+ else:
+ stride_n = c1_64
+ stride_k = b_stride_size
+
+ b_ptr_val = b_ptr_tensor[expert_idx]
+ b_ptr = cute.make_ptr(self.b_dtype, b_ptr_val, AddressSpace.gmem)
+ b_tensor_i = cute.make_tensor(
+ b_ptr,
+ cute.make_layout((n, k, c1), stride=(stride_n, stride_k, c0)),
+ )
+ tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ b_tma_op_arg,
+ b_tensor_i,
+ b_smem_layout_arg,
+ self.mma_tiler,
+ tiled_mma_arg,
+ cluster_layout_vmnk_shape_arg,
+ )
+ workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ store_tma_desc(tma_atom_b, workspace.get_ptr("b", expert_idx))
+
+ sfb_ptr_val = sfb_ptr_tensor[expert_idx]
+ sfb_ptr = cute.make_ptr(self.sf_dtype, sfb_ptr_val, AddressSpace.gmem)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb_tensor_i = cute.make_tensor(sfb_ptr, sfb_layout)
+ tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_tma_op_arg,
+ sfb_tensor_i,
+ sfb_smem_layout_arg,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb_arg,
+ cluster_layout_sfb_vmnk_shape_arg,
+ internal_type=cutlass.Uint64,
+ )
+ store_tma_desc(tma_atom_sfb, workspace.get_ptr("sfb", expert_idx))
+
+ if cutlass.const_expr(self.use_dynamic_sched):
+ if expert_idx == cutlass.Int32(0):
+ sched_counter = cute.make_tensor(
+ self._get_sched_counter_ptr(workspace_ptr),
+ cute.make_layout(1),
+ )
+ sched_counter[0] = cutlass.Int32(0)
+
+ # ------------------------------------------------------------------
+ # __call__
+ # ------------------------------------------------------------------
+
+ @cute.jit
+ def __call__(
+ self,
+ a: cute.Tensor,
+ b, # Dense: cute.Tensor (N,K,L) | Discrete: cute.Pointer to int64[]
+ sfb, # Dense: cute.Tensor | Discrete: cute.Pointer to int64[]
+ n: Int32, # Ignored for dense mode
+ k: Int32, # Ignored for dense mode
+ b_stride_size: cutlass.Int64, # Ignored for dense mode
+ b_major_mode: cutlass.Constexpr, # Ignored for dense mode
+ workspace_ptr,
+ c: cute.Tensor,
+ d: cute.Tensor,
+ d_col: Optional[cute.Tensor],
+ sfa: cute.Tensor,
+ sfd_row_tensor: Optional[cute.Tensor],
+ sfd_col_tensor: Optional[cute.Tensor],
+ amax_tensor: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ bias: Optional[cute.Tensor],
+ prob: cute.Tensor,
+ max_active_clusters: cutlass.Constexpr,
+ stream: cuda.CUstream,
+ epilogue_op: cutlass.Constexpr = lambda x: x,
+ ):
+ """Execute the GEMM.
+
+ Dense mode: ``b`` and ``sfb`` are 3-D cute.Tensor (N, K, L).
+ Discrete mode: ``b`` and ``sfb`` are cute.Pointer to device int64[]
+ arrays of per-expert base addresses; ``n``, ``k``, ``b_stride_size``,
+ ``b_major_mode`` describe the uniform per-expert layout.
+ """
+ self.a_dtype: Type[cutlass.Numeric] = a.element_type
+ self.b_dtype: Type[cutlass.Numeric] = a.element_type
+ self.c_dtype: Type[cutlass.Numeric] = c.element_type
+ self.d_dtype: Type[cutlass.Numeric] = d.element_type
+ self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
+ self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
+ self.c_layout = utils.LayoutEnum.from_tensor(c)
+ self.d_layout = utils.LayoutEnum.from_tensor(d)
+ self.bias_dtype = bias.element_type if cutlass.const_expr(self.enable_bias) else cutlass.BFloat16
+
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
+ else:
+ self.b_major_mode = b_major_mode
+
+ if cutlass.const_expr(self.a_dtype != self.b_dtype):
+ raise TypeError(f"A/B dtype must match: {self.a_dtype} != {self.b_dtype}")
+
+ self._setup_attributes()
+
+ # ---- SFA layout ----
+ sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size)
+ sfa = cute.make_tensor(sfa.iterator, sfa_layout)
+
+ # ---- B / SFB setup (mode-dependent) ----
+ # Save the call-arg b/sfb before the discrete branch overwrites them
+ # with template tensors. helper_kernel needs the original Pointers.
+ b_from_call_arg = b
+ sfb_from_call_arg = sfb
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size)
+ sfb = cute.make_tensor(sfb.iterator, sfb_layout)
+ else:
+ c1 = cutlass.Int32(1)
+ c0 = cutlass.Int64(0)
+ c1_64 = 1
+ if cutlass.const_expr(b_major_mode == OperandMajorMode.K):
+ b_template_stride = (b_stride_size, c1_64, c0)
+ else:
+ b_template_stride = (c1_64, b_stride_size, c0)
+ b_template_layout = cute.make_layout((n, k, c1), stride=b_template_stride)
+ b_ptr_typed = cute.make_ptr(self.b_dtype, b.toint(), AddressSpace.gmem, assumed_align=16)
+ b = cute.make_tensor(b_ptr_typed, b_template_layout)
+
+ sfb_ptr_typed = cute.make_ptr(self.sf_dtype, sfb.toint(), AddressSpace.gmem, assumed_align=16)
+ sfb_layout = blockscaled_utils.tile_atom_to_shape_SF((n, k, c1), self.sf_vec_size)
+ sfb = cute.make_tensor(sfb_ptr_typed, sfb_layout)
+
+ # ---- SFD setup ----
+ self.generate_sfd = sfd_row_tensor is not None and norm_const_tensor is not None
+ if cutlass.const_expr(self.generate_sfd == False):
+ self.discrete_col_sfd = False
+ if cutlass.const_expr(self.generate_sfd):
+ sfd_row_layout = blockscaled_utils.tile_atom_to_shape_SF(d.shape, self.sf_vec_size)
+ sfd_row_tensor = cute.make_tensor(sfd_row_tensor.iterator, sfd_row_layout)
+ sfd_col_layout = cute.tile_to_shape(
+ blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout,
+ d.shape,
+ (1, 2, 3),
+ )
+ if cutlass.const_expr(self.discrete_col_sfd):
+ sfd_col_layout = sfd_row_layout
+ sfd_col_tensor = cute.make_tensor(sfd_col_tensor.iterator, sfd_col_layout)
+
+ self.generate_amax = amax_tensor is not None
+
+ # ---- TMA atoms ----
+ tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ self.cta_group,
+ self.mma_inst_shape_mn,
+ )
+ tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
+ self.a_dtype,
+ self.a_major_mode,
+ self.b_major_mode,
+ self.sf_dtype,
+ self.sf_vec_size,
+ cute.nvgpu.tcgen05.CtaGroup.ONE,
+ self.mma_inst_shape_mn_sfb,
+ )
+ atom_thr_size = cute.size(tiled_mma.thr_id.shape)
+
+ a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
+ tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
+ a_op,
+ a,
+ a_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
+ b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
+ tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
+ b_op,
+ b,
+ b_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ )
+
+ sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
+ sfa_op,
+ sfa,
+ sfa_smem_layout,
+ self.mma_tiler,
+ tiled_mma,
+ self.cluster_layout_vmnk.shape,
+ internal_type=cutlass.Int16,
+ )
+
+ sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
+ sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
+ sfb_op,
+ sfb,
+ sfb_smem_layout,
+ self.mma_tiler_sfb,
+ tiled_mma_sfb,
+ self.cluster_layout_sfb_vmnk.shape,
+ internal_type=cutlass.Uint64,
+ )
+
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ x = tma_tensor_sfb.stride[0][1]
+ y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
+ new_shape = (
+ (tma_tensor_sfb.shape[0][0], ((2, 2), y)),
+ tma_tensor_sfb.shape[1],
+ tma_tensor_sfb.shape[2],
+ )
+ x_times_3 = 3 * x
+ new_stride = (
+ (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)),
+ tma_tensor_sfb.stride[1],
+ tma_tensor_sfb.stride[2],
+ )
+ tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride)
+ tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout)
+
+ a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
+ b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
+ sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout)
+ sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout)
+ self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size
+
+ c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
+ tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ c,
+ c_smem_layout,
+ self.epi_tile,
+ )
+ d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0))
+ tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d,
+ d_smem_layout,
+ self.epi_tile,
+ )
+ tma_atom_d_col, tma_tensor_d_col = cpasync.make_tiled_tma_atom(
+ cpasync.CopyBulkTensorTileS2GOp(),
+ d_col,
+ d_smem_layout,
+ self.epi_tile,
+ )
+
+ # ---- Helper kernel: TMA desc init (discrete) + sched counter reset (dynamic) ----
+ _need_helper = cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE or self.use_dynamic_sched)
+ if cutlass.const_expr(_need_helper):
+ _helper_grid_x = self.expert_cnt if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else 1
+ _helper_args = (
+ b_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ sfb_from_call_arg if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cute.make_ptr(cutlass.Int64, 0, AddressSpace.gmem),
+ n if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ k if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int32(0),
+ b_stride_size if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else cutlass.Int64(0),
+ b_major_mode if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE) else self.b_major_mode,
+ workspace_ptr,
+ tiled_mma,
+ tiled_mma_sfb,
+ b_smem_layout,
+ sfb_smem_layout,
+ self.cluster_layout_vmnk.shape,
+ self.cluster_layout_sfb_vmnk.shape,
+ )
+ self.helper_kernel(*_helper_args).launch(
+ grid=(_helper_grid_x, 1, 1),
+ block=(1, 1, 1),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+
+ # ---- Grid computation via MoE scheduler ----
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ b_n, b_k, b_l = cute.shape(b) # B is (N, K, L)
+ sched_expert_shape = (self.expert_cnt, b_n, b_k)
+ else:
+ sched_expert_shape = (self.expert_cnt, n, k)
+
+ sched_params = MoESchedulerParams(
+ scenario="2Dx3D",
+ expert_shape=sched_expert_shape,
+ cta_tile_shape_mnk=self.cta_tile_shape_mnk,
+ cluster_shape_mn=self.cluster_shape_mn,
+ use_dynamic_sched=self.use_dynamic_sched,
+ )
+ self.sched_params, grid = compute_grid(sched_params, max_active_clusters, self.use_2cta_instrs)
+
+ self.buffer_align_bytes = 1024
+
+ # ---- Shared storage ----
+ sD_col_size = cute.cosize(self.d_smem_layout_staged.outer) if self.generate_sfd else 0
+ SchedulerStorage = MoEPersistentTileScheduler.make_storage_struct(self.num_tile_stage, self.use_dynamic_sched)
+
+ @cute.struct
+ class SharedStorage:
+ ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
+ acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
+ scheduler: SchedulerStorage
+ if cutlass.const_expr(self.enable_bias):
+ bias_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_bias_stage * 2]
+ tmem_dealloc_mbar_ptr: cutlass.Int64
+ tmem_holding_buf: cutlass.Int32
+ sC: cute.struct.Align[
+ cute.struct.MemRange[self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sD: cute.struct.Align[
+ cute.struct.MemRange[self.d_dtype, cute.cosize(self.d_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sD_col: cute.struct.Align[
+ cute.struct.MemRange[self.d_dtype, sD_col_size],
+ self.buffer_align_bytes,
+ ]
+ sA: cute.struct.Align[
+ cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sB: cute.struct.Align[
+ cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ self.buffer_align_bytes,
+ ]
+ sSFA: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ sSFB: cute.struct.Align[
+ cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)],
+ self.buffer_align_bytes,
+ ]
+ if cutlass.const_expr(self.enable_bias):
+ sBias: cute.struct.Align[
+ cute.struct.MemRange[self.bias_dtype, cute.cosize(self.bias_smem_layout_staged)],
+ 16,
+ ]
+ sAmax: cute.struct.Align[
+ cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps],
+ 4,
+ ]
+
+ self.shared_storage = SharedStorage
+
+ # ---- Launch ----
+ self.kernel(
+ tiled_mma,
+ tiled_mma_sfb,
+ tma_atom_a,
+ tma_tensor_a,
+ tma_atom_b,
+ tma_tensor_b,
+ tma_atom_sfa,
+ tma_tensor_sfa,
+ tma_atom_sfb,
+ tma_tensor_sfb,
+ tma_atom_c,
+ tma_tensor_c,
+ tma_atom_d,
+ tma_tensor_d,
+ tma_atom_d_col,
+ tma_tensor_d_col,
+ sfd_row_tensor,
+ sfd_col_tensor,
+ norm_const_tensor,
+ amax_tensor,
+ padded_offsets,
+ alpha,
+ bias,
+ prob,
+ workspace_ptr,
+ self.cluster_layout_vmnk,
+ self.cluster_layout_sfb_vmnk,
+ self.a_smem_layout_staged,
+ self.b_smem_layout_staged,
+ self.sfa_smem_layout_staged,
+ self.sfb_smem_layout_staged,
+ self.c_smem_layout_staged,
+ self.d_smem_layout_staged,
+ self.bias_smem_layout_staged,
+ self.epi_tile,
+ self.sched_params,
+ epilogue_op,
+ ).launch(
+ grid=grid,
+ block=[self.threads_per_cta, 1, 1],
+ cluster=(*self.cluster_shape_mn, 1),
+ max_number_threads=[self.threads_per_cta, 1, 1],
+ smem=self.shared_storage.size_in_bytes(),
+ stream=stream,
+ min_blocks_per_mp=1,
+ )
+ return
+
+ # ------------------------------------------------------------------
+ # Helper methods
+ # ------------------------------------------------------------------
+
+ def mainloop_s2t_copy_and_partition(self, sSF, tSF):
+ tCsSF_compact = cute.filter_zeros(sSF)
+ tCtSF_compact = cute.filter_zeros(tSF)
+ copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
+ tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
+ thr_copy_s2t = tiled_copy_s2t.get_slice(0)
+ tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
+ return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
+
+ @cute.jit
+ def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem):
+ warp_amax = warp_redux_sync(
+ value=amax_fp32,
+ kind=ReduxKind.MAX,
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ if cute.arch.lane_idx() == 0:
+ amax_smem[warp_idx] = cutlass.Float32(warp_amax)
+ self.epilog_sync_barrier.arrive_and_wait()
+ if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0:
+ block_amax = cutlass.Float32(0.0)
+ for i in cutlass.range(self.num_epilog_warps):
+ warp_amax_val = amax_smem[i]
+ block_amax = cute.arch.fmax(block_amax, warp_amax_val)
+ _ = atomic_max_float32(ptr=amax_gmem, value=block_amax)
+
+ @cute.jit
+ def store_c(
+ self,
+ tiled_copy_r2s,
+ tma_atom_c,
+ warp_idx,
+ tTR_rAcc,
+ tRS_rC,
+ tRS_sC,
+ bSG_gC,
+ bSG_sC,
+ c_pipeline,
+ prev_subtile_idx,
+ real_subtile_idx,
+ ):
+ c_buffer = prev_subtile_idx % self.num_c_stage
+ tRS_rC.store(tTR_rAcc.load().to(self.c_dtype))
+ cute.copy(tiled_copy_r2s, tRS_rC[(None, None, 0)], tRS_sC[(None, None, 0, c_buffer)])
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, real_subtile_idx)])
+ c_pipeline.producer_commit()
+ c_pipeline.producer_acquire()
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ @cute.jit
+ def quant_sfd_row(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD):
+ tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size))
+ acc_frg = tTR_rAcc_frg.load()
+ abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value())
+ abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype)
+ pvscale_f32x4 = cute.make_rmem_tensor(4, cutlass.Float32)
+ sfd_f8x4 = cute.make_rmem_tensor(4, self.sf_dtype)
+ tmp_f32 = abs_acc_frg[None, 0].reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) * rcp_limit * norm_const
+ if tile_idx == 0:
+ pvscale[0] = tmp_f32
+ elif tile_idx == 1:
+ pvscale[1] = tmp_f32
+ elif tile_idx == 2:
+ pvscale[2] = tmp_f32
+ elif tile_idx == 3:
+ pvscale[3] = tmp_f32
+ pvscale_f32x4[0] = tmp_f32
+ sfd_f8x4.store(pvscale_f32x4.load().to(self.sf_dtype))
+ pvscale_f32x4.store(sfd_f8x4.load().to(cutlass.Float32))
+ qpvscale_up = pvscale_f32x4[0]
+ fp32_max = cutlass.Float32(3.40282346638528859812e38)
+ acc_scale = norm_const * cute.arch.rcp_approx(qpvscale_up)
+ acc_scale = fmin(acc_scale, fp32_max, nan=True)
+ if cutlass.const_expr(self.vectorized_f32):
+ vec = tTR_rAcc_frg[None, 0]
+ for ei in cutlass.range_constexpr(0, self.sf_vec_size, 2):
+ vec[ei], vec[ei + 1] = cute.arch.mul_packed_f32x2(
+ (vec[ei], vec[ei + 1]),
+ (acc_scale, acc_scale),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ vec = tTR_rAcc_frg[None, 0]
+ for ei in cutlass.range_constexpr(self.sf_vec_size):
+ vec[ei] = vec[ei] * acc_scale
+ acc_vec = tiled_copy_r2s.retile(src).load()
+ tRSrD.store(acc_vec.to(self.d_dtype))
+
+ @cute.jit
+ def quant_sfd_col(self, tile_idx, tiled_copy_r2s, src, pvscale, norm_const, rcp_limit, tRSrD):
+ tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size))
+ acc_frg = tTR_rAcc_frg.load()
+ abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value())
+ acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype)
+ tmp_f32 = cutlass.Float32(0.0)
+ for vi in cutlass.range_constexpr(acc_frg.shape[0]):
+ max_value_original = (
+ cutlass.Float32(
+ warp_redux_sync(
+ value=acc_frg[vi, 0],
+ kind=ReduxKind.MAX,
+ mask_and_clamp=0xFFFFFFFF,
+ nan=True,
+ )
+ )
+ * rcp_limit
+ * norm_const
+ )
+ max_value_vec = cute.full(4, max_value_original, dtype=cutlass.Float32)
+ max_value_vec_f8 = max_value_vec.to(cutlass.Float8E8M0FNU)
+ max_value_vec_f32_chunked = max_value_vec_f8.to(cutlass.Float32)
+ max_value = max_value_vec_f32_chunked[0]
+ tidx = cute.arch.thread_idx()[0]
+ if tidx % 32 == vi:
+ tmp_f32 = max_value
+ acc_scale_col = cutlass.Float32(0.0)
+ if max_value_vec_f32_chunked[0] == 0.000000:
+ acc_scale_col = cutlass.Float32(0.0)
+ else:
+ acc_scale_col = norm_const * cute.arch.rcp_approx(max_value_vec_f32_chunked[0])
+ fp32_max = cutlass.Float32(3.40282346638528859812e38)
+ acc_scale_col = fmin(acc_scale_col, fp32_max)
+ tTR_rAcc_frg[vi] = tTR_rAcc_frg[vi] * acc_scale_col
+ pvscale[None, None, tile_idx][0] = tmp_f32
+ acc_vec = tiled_copy_r2s.retile(src).load()
+ tRSrD.store(acc_vec.to(self.d_dtype))
+
+ @cute.jit
+ def tile_info_to_mn_idx(self, tile_info: cute.Tensor):
+ m_idx = tile_info[1] * cute.size(self.cta_tile_shape_mnk[0])
+ n_idx = tile_info[2] * cute.size(self.cta_tile_shape_mnk[1])
+ return m_idx, n_idx
+
+ @cute.jit
+ def create_and_partition_new_SFDCol(self, tile_info, mSFDCol_mnl, padded_offsets):
+ m_idx, n_idx = self.tile_info_to_mn_idx(tile_info)
+ expert_idx = tile_info[0]
+ cumsum_tokens, tokens_this_group = compute_expert_token_range(padded_offsets, expert_idx)
+ n_total = cute.size(mSFDCol_mnl.shape[1])
+
+ sf_tile_idx_begin = cumsum_tokens // cute.size(mSFDCol_mnl.shape[0][0])
+ mSFDCol_mnl_new_ptr = mSFDCol_mnl[(None, sf_tile_idx_begin), None, 0].iterator
+
+ sfd_col_quant_layout = cute.tile_to_shape(
+ blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout,
+ (tokens_this_group, n_total, mSFDCol_mnl.shape[2]),
+ (1, 2, 3),
+ )
+ regPerSubtile = 4
+ sfd_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile))
+ mSFDCol_mnl_new = cute.make_tensor(mSFDCol_mnl_new_ptr, sfd_col_quant_layout)
+ gSFDCol_mnl_new = cute.local_tile(mSFDCol_mnl_new, sfd_tile, (None, None, None))
+
+ thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
+ val_layout = cute.make_ordered_layout((1,), order=(0,))
+ copy_atom_sfd_col_quant = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gSFDCol_mnl_new.element_type,
+ num_bits_per_copy=8,
+ )
+ tiled_copy_sfd_col_quant = cute.make_tiled_copy_tv(
+ copy_atom_sfd_col_quant,
+ thr_layout,
+ val_layout,
+ )
+ tidx = cute.arch.thread_idx()[0]
+ thr_copy_sfd_col_quant = tiled_copy_sfd_col_quant.get_slice(tidx)
+ tCgSFDCol_mnl = thr_copy_sfd_col_quant.partition_D(cute.filter_zeros(gSFDCol_mnl_new))
+ tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl)
+ return tCgSFDCol_mnl
+
+ def epilog_tmem_copy_and_partition(self, tidx, tAcc, gD_mnl, epi_tile, use_2cta_instrs):
+ copy_atom_t2r = sm100_utils.get_tmem_load_op(
+ self.cta_tile_shape_mnk,
+ self.d_layout,
+ self.d_dtype,
+ self.acc_dtype,
+ epi_tile,
+ use_2cta_instrs,
+ )
+ tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
+ tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
+ tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
+ gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile)
+ tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi)
+ tTR_rAcc = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
+ return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
+
+ def epilog_smem_copy_and_partition(self, tiled_copy_t2r, tTR_rD, tidx, sD):
+ copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r)
+ tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
+ thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
+ tRS_sD = thr_copy_r2s.partition_D(sD)
+ tRS_rD = tiled_copy_r2s.retile(tTR_rD)
+ return tiled_copy_r2s, tRS_rD, tRS_sD
+
+ # ------------------------------------------------------------------
+ # GPU device kernel
+ # ------------------------------------------------------------------
+
+ @cute.kernel
+ def kernel(
+ self,
+ tiled_mma: cute.TiledMma,
+ tiled_mma_sfb: cute.TiledMma,
+ tma_atom_a: cute.CopyAtom,
+ mA_mkl: cute.Tensor,
+ tma_atom_b: cute.CopyAtom,
+ mB_nkl: cute.Tensor,
+ tma_atom_sfa: cute.CopyAtom,
+ mSFA_mkl: cute.Tensor,
+ tma_atom_sfb: cute.CopyAtom,
+ mSFB_nkl: cute.Tensor,
+ tma_atom_c: cute.CopyAtom,
+ mC_mnl: cute.Tensor,
+ tma_atom_d: cute.CopyAtom,
+ mD_mnl: cute.Tensor,
+ tma_atom_d_col: cute.CopyAtom,
+ mD_col_mnl: cute.Tensor,
+ mSFDRow_mnl: Optional[cute.Tensor],
+ mSFDCol_mnl: Optional[cute.Tensor],
+ norm_const_tensor: Optional[cute.Tensor],
+ mAmax_tensor: Optional[cute.Tensor],
+ padded_offsets: cute.Tensor,
+ alpha: cute.Tensor,
+ mBias_nl: Optional[cute.Tensor],
+ prob: cute.Tensor,
+ workspace_ptr,
+ cluster_layout_vmnk: cute.Layout,
+ cluster_layout_sfb_vmnk: cute.Layout,
+ a_smem_layout_staged: cute.ComposedLayout,
+ b_smem_layout_staged: cute.ComposedLayout,
+ sfa_smem_layout_staged: cute.Layout,
+ sfb_smem_layout_staged: cute.Layout,
+ c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
+ bias_smem_layout_staged: Optional[cute.Layout],
+ epi_tile: cute.Tile,
+ sched_params: MoESchedulerParams,
+ epilogue_op: cutlass.Constexpr,
+ ):
+ """GPU device kernel for persistent MoE grouped GEMM with quantization."""
+ warp_idx = cute.arch.warp_idx()
+ warp_idx = cute.arch.make_warp_uniform(warp_idx)
+ lane_idx = cute.arch.lane_idx()
+
+ if warp_idx == self.tma_warp_id:
+ cpasync.prefetch_descriptor(tma_atom_a)
+ cpasync.prefetch_descriptor(tma_atom_sfa)
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DENSE):
+ cpasync.prefetch_descriptor(tma_atom_b)
+ cpasync.prefetch_descriptor(tma_atom_sfb)
+ cpasync.prefetch_descriptor(tma_atom_d)
+ if cutlass.const_expr(self.generate_sfd):
+ cpasync.prefetch_descriptor(tma_atom_d_col)
+ if cutlass.const_expr(self.generate_c):
+ cpasync.prefetch_descriptor(tma_atom_c)
+
+ use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
+ total_token = padded_offsets[self.expert_cnt - 1]
+
+ bidx, bidy, bidz = cute.arch.block_idx()
+ mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
+ is_leader_cta = mma_tile_coord_v == 0
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
+ tidx, _, _ = cute.arch.thread_idx()
+
+ smem = utils.SmemAllocator()
+ storage = smem.allocate(self.shared_storage)
+ sched_storage = storage.scheduler
+
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
+ ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer)
+ ab_pipeline = pipeline.PipelineTmaUmma.create(
+ barrier_storage=storage.ab_mbar_ptr.data_ptr(),
+ num_stages=self.num_ab_stage,
+ producer_group=ab_pipeline_producer_group,
+ consumer_group=ab_pipeline_consumer_group,
+ tx_count=self.num_tma_load_bytes,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
+ num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1)
+ acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads)
+ acc_pipeline = pipeline.PipelineUmmaAsync.create(
+ barrier_storage=storage.acc_mbar_ptr.data_ptr(),
+ num_stages=self.num_acc_stage,
+ producer_group=acc_pipeline_producer_group,
+ consumer_group=acc_pipeline_consumer_group,
+ cta_layout_vmnk=cluster_layout_vmnk,
+ )
+
+ tile_info_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp * 1)
+ tile_info_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_wo_sched)
+ tile_info_pipeline = pipeline.PipelineAsync.create(
+ barrier_storage=sched_storage.tile_info_mbar.data_ptr(),
+ num_stages=self.num_tile_stage,
+ producer_group=tile_info_pipeline_producer_group,
+ consumer_group=tile_info_pipeline_consumer_group,
+ )
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.threads_per_warp)
+ bias_pipeline_consumer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread,
+ self.threads_per_warp * len(self.epilog_warp_id),
+ )
+ bias_pipeline = pipeline.PipelineCpAsync.create(
+ barrier_storage=storage.bias_mbar_ptr.data_ptr(),
+ num_stages=self.num_bias_stage,
+ producer_group=bias_pipeline_producer_group,
+ consumer_group=bias_pipeline_consumer_group,
+ )
+ sBias = storage.sBias.get_tensor(bias_smem_layout_staged)
+ # (MMA_N, loopN, loopL) — tiled over mma_tile_n in N, full L; indexed
+ # directly by expert_idx (ref glu_bias line 1771-1773 pattern).
+ gBias_nl = cute.local_tile(mBias_nl, cute.slice_(self.mma_tiler[:2], (0, None)), (None, None))
+
+ scheduler = MoEPersistentTileScheduler.create(
+ sched_params,
+ padded_offsets,
+ cute.arch.block_idx(),
+ cute.arch.grid_dim(),
+ counter_ptr=self._get_sched_counter_ptr(workspace_ptr),
+ sched_storage=sched_storage,
+ )
+ scheduler.internal_init()
+
+ tmem = utils.TmemAllocator(
+ storage.tmem_holding_buf,
+ barrier_for_retrieve=self.tmem_alloc_barrier,
+ allocator_warp_id=self.epilog_warp_id[0],
+ is_two_cta=use_2cta_instrs,
+ two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
+ )
+
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_arrive_relaxed()
+
+ sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner)
+ sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
+ sD_col = sD
+ if cutlass.const_expr(self.generate_sfd):
+ sD_col = storage.sD_col.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner)
+ sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
+ sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
+ sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
+ sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
+ amax_layout = cute.make_layout((self.num_epilog_warps,))
+ sAmax = storage.sAmax.get_tensor(amax_layout)
+ info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4))
+ sInfo = sched_storage.sInfo.get_tensor(info_layout)
+
+ # Multicast masks — must create ALL when any mcast or 2CTA is active
+ a_full_mcast_mask = None
+ b_full_mcast_mask = None
+ sfa_full_mcast_mask = None
+ sfb_full_mcast_mask = None
+ if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
+ a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1)
+ sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2)
+ sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1)
+
+ # MMA partition (for tCtAcc_fake shape computation only)
+ thr_mma_common = tiled_mma.get_slice(0)
+ tCsA_common = thr_mma_common.partition_A(sA)
+ tCsB_common = thr_mma_common.partition_B(sB)
+ tCsA_common = cute.filter_zeros(tCsA_common)
+ tCsB_common = cute.filter_zeros(tCsB_common)
+
+ # SMEM fragments for MMA (used by MMA warp)
+ tCrA = tiled_mma.make_fragment_A(sA)
+ tCrB = tiled_mma.make_fragment_B(sB)
+
+ # TMEM accumulator shape
+ acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
+ if cutlass.const_expr(self.overlapping_accum):
+ num_acc_stage_overlapped = 2
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped))
+ tCtAcc_fake = cute.make_tensor(
+ tCtAcc_fake.iterator,
+ cute.make_layout(
+ tCtAcc_fake.shape,
+ stride=(
+ tCtAcc_fake.stride[0],
+ tCtAcc_fake.stride[1],
+ tCtAcc_fake.stride[2],
+ (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1],
+ ),
+ ),
+ )
+ else:
+ tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
+
+ # Cluster sync before warp specialization
+ if cute.size(self.cluster_shape_mn) > 1:
+ cute.arch.cluster_wait()
+ else:
+ self.cta_sync_barrier.arrive_and_wait()
+
+ if total_token <= 0:
+ cute.arch.nvvm.exit()
+
+ # ==============================================================
+ # Scheduler warp (MoE Persistent Tile Scheduler)
+ # ==============================================================
+ if warp_idx == self.sched_warp_id:
+ work_tile_info = scheduler.initial_work_tile_info()
+ tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage)
+ while work_tile_info.is_valid_tile:
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = work_tile_info.expert_idx
+ sInfo[(1, tile_info_producer_state.index)] = work_tile_info.tile_m_idx
+ sInfo[(2, tile_info_producer_state.index)] = work_tile_info.tile_n_idx
+ sInfo[(3, tile_info_producer_state.index)] = work_tile_info.k_tile_cnt
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ work_tile_info = scheduler.advance_to_next_work()
+
+ tile_info_pipeline.producer_acquire(tile_info_producer_state)
+ with cute.arch.elect_one():
+ sInfo[(0, tile_info_producer_state.index)] = cutlass.Int32(-1)
+ sInfo[(1, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(2, tile_info_producer_state.index)] = cutlass.Int32(0)
+ sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.sched_sync_barrier.arrive_and_wait()
+ tile_info_pipeline.producer_commit(tile_info_producer_state)
+ tile_info_producer_state.advance()
+ tile_info_pipeline.producer_tail(tile_info_producer_state)
+
+ # ==============================================================
+ # Bias load warp
+ # ==============================================================
+ if cutlass.const_expr(self.enable_bias):
+ if warp_idx == self.bias_load_warp_id:
+ # Ported from ref glu_bias bias_load_warp (lines 2422-2483).
+ # No extension needed: mBias_nl has shape (N, L) and we index L
+ # directly by expert_idx from sInfo — much simpler than the old
+ # `bias_ext.get_gmem_tensor("bias", ...)` path (which reshaped the
+ # tensor and was unnecessary for a global, per-expert-sliced bias).
+ bias_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_bias_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ # 128-bit cp.async: threads_per_warp × (128 / bias_dtype bits) = tile_N
+ bias_elems_per_thread = 128 // self.bias_dtype.width
+ bias_g2s_atom = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ self.bias_dtype,
+ num_bits_per_copy=128,
+ )
+ bias_g2s_tiled = cute.make_tiled_copy_tv(
+ bias_g2s_atom,
+ cute.make_layout((self.threads_per_warp,)),
+ cute.make_layout((bias_elems_per_thread,)),
+ )
+ thr_bias_g2s = bias_g2s_tiled.get_slice(cute.arch.lane_idx())
+ tBs_sBias = thr_bias_g2s.partition_D(sBias)
+
+ # Per-thread N predicate tensor
+ bias_n_total = mBias_nl.shape[0]
+ tBpBias = cute.make_rmem_tensor(cute.make_layout((1,)), cutlass.Boolean)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ bias_producer_state.reset_count()
+ mma_n_coord = tile_info[2]
+ expert_idx = tile_info[0]
+
+ # Direct L-indexing — no ext, no domain_offset
+ gBias_tile = gBias_nl[(None, mma_n_coord, expert_idx)]
+ tBs_gBias = thr_bias_g2s.partition_S(gBias_tile)
+
+ # Predicate: this thread's chunk must be within valid N
+ tBpBias[0] = mma_n_coord * self.mma_tiler[1] + cute.arch.lane_idx() * bias_elems_per_thread < bias_n_total
+
+ bias_pipeline.producer_acquire(bias_producer_state)
+ cute.copy(
+ bias_g2s_tiled,
+ tBs_gBias[(None, 0)],
+ tBs_sBias[(None, 0, bias_producer_state.index)],
+ pred=tBpBias,
+ )
+ bias_pipeline.producer_commit(bias_producer_state)
+ bias_producer_state.advance()
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+ bias_pipeline.producer_tail(bias_producer_state)
+
+ # ==============================================================
+ # DMA / TMA load warp
+ # ==============================================================
+ if warp_idx == self.tma_warp_id:
+ ext = self._make_extension(workspace_ptr)
+ ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ k_tile_cnt = work_tile_info.k_tile_cnt
+ ext.update_expert_info(padded_offsets, work_tile_info.expert_idx)
+
+ real_a, _ = ext.get_gmem_tensor("a", mA_mkl, padded_offsets, work_tile_info)
+ real_b, desc_ptr_b = ext.get_gmem_tensor("b", mB_nkl, padded_offsets, work_tile_info)
+ real_sfa, _ = ext.get_gmem_tensor("sfa", mSFA_mkl, padded_offsets, work_tile_info)
+ real_sfb, desc_ptr_sfb = ext.get_gmem_tensor("sfb", mSFB_nkl, padded_offsets, work_tile_info)
+
+ gA_mkl = cute.local_tile(real_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gB_nkl = cute.local_tile(real_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
+ gSFA_mkl = cute.local_tile(real_sfa, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
+ gSFB_nkl = cute.local_tile(real_sfb, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None))
+
+ # MMA partition on gmem tensors
+ thr_mma_dma = tiled_mma.get_slice(mma_tile_coord_v)
+ thr_mma_sfb_dma = tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ tCgA = thr_mma_dma.partition_A(gA_mkl)
+ tCgB = thr_mma_dma.partition_B(gB_nkl)
+ tCgSFA = thr_mma_dma.partition_A(gSFA_mkl)
+ tCgSFB = thr_mma_sfb_dma.partition_B(gSFB_nkl)
+
+ # TMA partition A
+ a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
+ tAsA, tAgA = cpasync.tma_partition(
+ tma_atom_a,
+ block_in_cluster_coord_vmnk[2],
+ a_cta_layout,
+ cute.group_modes(sA, 0, 3),
+ cute.group_modes(tCgA, 0, 3),
+ )
+ # TMA partition B
+ b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
+ tBsB, tBgB = cpasync.tma_partition(
+ tma_atom_b,
+ block_in_cluster_coord_vmnk[1],
+ b_cta_layout,
+ cute.group_modes(sB, 0, 3),
+ cute.group_modes(tCgB, 0, 3),
+ )
+ # TMA partition SFA
+ sfa_cta_layout = a_cta_layout
+ tAsSFA, tAgSFA = cpasync.tma_partition(
+ tma_atom_sfa,
+ block_in_cluster_coord_vmnk[2],
+ sfa_cta_layout,
+ cute.group_modes(sSFA, 0, 3),
+ cute.group_modes(tCgSFA, 0, 3),
+ )
+ tAsSFA = cute.filter_zeros(tAsSFA)
+ tAgSFA = cute.filter_zeros(tAgSFA)
+ # TMA partition SFB
+ sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
+ tBsSFB, tBgSFB = cpasync.tma_partition(
+ tma_atom_sfb,
+ block_in_cluster_coord_sfb_vmnk[1],
+ sfb_cta_layout,
+ cute.group_modes(sSFB, 0, 3),
+ cute.group_modes(tCgSFB, 0, 3),
+ )
+ tBsSFB = cute.filter_zeros(tBsSFB)
+ tBgSFB = cute.filter_zeros(tBgSFB)
+
+ mma_tile_coord_m = work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape)
+ mma_tile_coord_n = work_tile_info.tile_n_idx
+ tAgA_slice = tAgA[(None, mma_tile_coord_m, None, 0)]
+ tBgB_slice = tBgB[(None, mma_tile_coord_n, None, 0)]
+ tAgSFA_slice = tAgSFA[(None, mma_tile_coord_m, None, 0)]
+ slice_n = mma_tile_coord_n
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ slice_n = mma_tile_coord_n // 2
+ tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)]
+
+ ab_producer_state.reset_count()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ tAgA_k = tAgA_slice[(None, ab_producer_state.count)]
+ tBgB_k = tBgB_slice[(None, ab_producer_state.count)]
+ tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)]
+ tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)]
+ tAsA_pipe = tAsA[(None, ab_producer_state.index)]
+ tBsB_pipe = tBsB[(None, ab_producer_state.index)]
+ tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)]
+ tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)]
+
+ tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state)
+ ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status)
+
+ cute.copy(tma_atom_a, tAgA_k, tAsA_pipe, tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask)
+ cute.copy(tma_atom_b, tBgB_k, tBsB_pipe, tma_bar_ptr=tma_bar, mcast_mask=b_full_mcast_mask, tma_desc_ptr=desc_ptr_b)
+ cute.copy(tma_atom_sfa, tAgSFA_k, tAsSFA_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask)
+ cute.copy(tma_atom_sfb, tBgSFB_k, tBsSFB_pipe, tma_bar_ptr=tma_bar, mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=desc_ptr_sfb)
+
+ ab_producer_state.advance()
+ peek_ab_empty_status = cutlass.Boolean(1)
+ if ab_producer_state.count < k_tile_cnt:
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+ ab_pipeline.producer_tail(ab_producer_state)
+
+ # ==============================================================
+ # MMA warp
+ # ==============================================================
+ if warp_idx == self.mma_warp_id:
+ tmem.wait_for_alloc()
+ acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
+
+ # SFA TMEM tensor
+ sfa_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols,
+ dtype=self.sf_dtype,
+ )
+ tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout)
+
+ # SFB TMEM tensor
+ sfb_tmem_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb(
+ tiled_mma,
+ self.mma_tiler,
+ self.sf_vec_size,
+ cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)),
+ )
+ tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
+
+ # S2T copy partition for SFA/SFB
+ (
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t,
+ tCtSFA_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA)
+ (
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t,
+ tCtSFB_compact_s2t,
+ ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
+
+ ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
+ acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ while is_valid_tile:
+ k_tile_cnt = tile_info[3]
+
+ # Peek AB buffer full
+ ab_consumer_state.reset_count()
+ peek_ab_full_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state)
+
+ # Peek Acc buffer empty
+ acc_producer_state.reset_count()
+ peek_acc_empty_status = cutlass.Boolean(1)
+ if ab_consumer_state.count < k_tile_cnt and is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state)
+
+ mma_tile_coord_mnl = (
+ tile_info[1] // cute.size(tiled_mma.thr_id.shape),
+ tile_info[2],
+ tile_info[0],
+ )
+
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acc_producer_state.phase ^ 1
+ else:
+ acc_stage_index = acc_producer_state.index
+
+ tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)]
+
+ tCtSFB_mma = tCtSFB
+ if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192):
+ offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+ elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64):
+ offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2)
+ shifted_ptr = cute.recast_ptr(
+ acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset,
+ dtype=self.sf_dtype,
+ )
+ tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout)
+
+ if is_leader_cta:
+ acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status)
+
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
+
+ for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
+ if is_leader_cta:
+ ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status)
+
+ s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
+ cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t[s2t_stage_coord], tCtSFA_compact_s2t)
+ cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t[s2t_stage_coord], tCtSFB_compact_s2t)
+
+ num_kblocks = cute.size(tCrA, mode=[2])
+ ab_consumer_state_next = ab_consumer_state.clone()
+ ab_consumer_state_next.advance()
+ if ab_consumer_state_next.count < k_tile_cnt:
+ peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state_next)
+
+ for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
+ kblock_coord = (None, None, kblock_idx, ab_consumer_state.index)
+ sf_kblock_coord = (None, None, kblock_idx)
+ tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
+ tiled_mma.set(tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator)
+ cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
+ tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
+
+ ab_pipeline.consumer_release(ab_consumer_state)
+ ab_consumer_state = ab_consumer_state_next
+
+ if is_leader_cta:
+ acc_pipeline.producer_commit(acc_producer_state)
+
+ acc_producer_state.advance()
+ if acc_producer_state.count < k_tile_cnt:
+ if is_leader_cta:
+ peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state)
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ acc_pipeline.producer_tail(acc_producer_state)
+
+ # ==============================================================
+ # Epilogue warps
+ # ==============================================================
+ if warp_idx < self.mma_warp_id:
+ tmem.allocate(self.num_tmem_alloc_cols)
+ tmem.wait_for_alloc()
+ tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
+ tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
+
+ epi_tidx = tidx
+ thr_mma_epi = tiled_mma.get_slice(mma_tile_coord_v)
+
+ # Shape-only partition on global tensor (invariant setup for t2r copy atom)
+ gD_mnl_shape = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_shape = thr_mma_epi.partition_C(gD_mnl_shape)
+
+ tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
+ epi_tidx,
+ tCtAcc_base,
+ tCgD_shape,
+ epi_tile,
+ use_2cta_instrs,
+ )
+
+ tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
+ tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(
+ tiled_copy_t2r,
+ tTR_rC,
+ epi_tidx,
+ sC,
+ )
+ tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(
+ tiled_copy_t2r,
+ tTR_rD,
+ epi_tidx,
+ sD,
+ )
+ tTR_rD_col = cute.make_rmem_tensor(tTR_rAcc.shape, self.d_dtype)
+ tiled_copy_r2s, tRS_rD_col, tRS_sD_col = self.epilog_smem_copy_and_partition(
+ tiled_copy_t2r,
+ tTR_rD_col,
+ epi_tidx,
+ sD_col,
+ )
+
+ if cutlass.const_expr(self.generate_sfd):
+ norm_const = norm_const_tensor[0]
+ regPerSubtile = 4
+ sfd_row_tile = (cute.make_layout(128), cute.make_layout(32 * regPerSubtile))
+ gSFDRow_mnl = cute.local_tile(mSFDRow_mnl, sfd_row_tile, (None, None, None))
+ thr_copy_t2r_local = tiled_copy_t2r.get_slice(tidx)
+ tCgSFDRow_mnl = thr_copy_t2r_local.partition_D(gSFDRow_mnl)
+ tCgSFDRow_mnl = cute.filter_zeros(tCgSFDRow_mnl)
+ tCrSFDRow = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype)
+ tCrSFDRow_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32)
+ d_rcp_limits = get_dtype_rcp_limits(self.d_dtype)
+
+ sfd_col_tile = sfd_row_tile
+ gSFDCol_mnl = cute.local_tile(mSFDCol_mnl, sfd_col_tile, (None, None, None))
+ thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
+ val_layout = cute.make_ordered_layout((1,), order=(0,))
+ copy_atom_sfd_col = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gSFDCol_mnl.element_type, num_bits_per_copy=8)
+ tiled_copy_sfd_col = cute.make_tiled_copy_tv(copy_atom_sfd_col, thr_layout, val_layout)
+ thr_copy_sfd_col = tiled_copy_sfd_col.get_slice(tidx)
+ tCgSFDCol_mnl = thr_copy_sfd_col.partition_D(cute.filter_zeros(gSFDCol_mnl))
+ tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl)
+ tCrSFDCol = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].shape, self.sf_dtype)
+ tCrSFDCol_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32)
+
+ epi_ext = self._make_extension(workspace_ptr)
+
+ acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
+ c_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id))
+ c_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_producer_group)
+ d_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilog_warp_id))
+ d_pipeline = pipeline.PipelineTmaStore.create(num_stages=self.num_d_stage, producer_group=d_producer_group)
+
+ tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage)
+ tile_info = cute.make_rmem_tensor((4,), cutlass.Int32)
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_bias_stage)
+ bias_s2r_tom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.bias_dtype, num_bits_per_copy=128)
+ tTR_rBias = cute.make_rmem_tensor(cute.make_layout(self.epi_tile[1]), self.bias_dtype)
+
+ num_prev_subtiles = cutlass.Int32(0)
+ while is_valid_tile:
+ epi_work_tile_info = MoEWorkTileInfo(
+ expert_idx=tile_info[0],
+ tile_m_idx=tile_info[1],
+ tile_n_idx=tile_info[2],
+ k_tile_cnt=tile_info[3],
+ )
+ expert_idx = epi_work_tile_info.expert_idx
+ epi_ext.update_expert_info(padded_offsets, expert_idx)
+
+ alpha_val = alpha[expert_idx]
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_consumer_state.reset_count()
+ bias_pipeline.consumer_wait(bias_consumer_state)
+ sBias_stage = sBias[(None, bias_consumer_state.index)]
+ sBias_subtiles = cute.flat_divide(sBias_stage, cute.make_layout(self.epi_tile[1]))
+
+ real_d, _ = epi_ext.get_gmem_tensor("d", mD_mnl, padded_offsets, epi_work_tile_info)
+ real_c, _ = epi_ext.get_gmem_tensor("c", mC_mnl, padded_offsets, epi_work_tile_info)
+ if cutlass.const_expr(self.generate_sfd):
+ real_d_col, _ = epi_ext.get_gmem_tensor("d_col", mD_col_mnl, padded_offsets, epi_work_tile_info)
+
+ thr_mma_epi_loop = tiled_mma.get_slice(mma_tile_coord_v)
+
+ gD_mnl_loop = cute.local_tile(real_d, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_loop = thr_mma_epi_loop.partition_C(gD_mnl_loop)
+ _, bSG_sD, bSG_gD_partitioned = epilog_gmem_copy_and_partition(
+ epi_tidx,
+ tma_atom_d,
+ tCgD_loop,
+ epi_tile,
+ sD,
+ )
+
+ gC_mnl_loop = cute.local_tile(real_c, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None))
+ tCgC_loop = thr_mma_epi_loop.partition_C(gC_mnl_loop)
+ _, bSG_sC, bSG_gC_partitioned = epilog_gmem_copy_and_partition(
+ epi_tidx,
+ tma_atom_c,
+ tCgC_loop,
+ epi_tile,
+ sC,
+ )
+
+ if cutlass.const_expr(self.generate_sfd):
+ gD_col_mnl_loop = cute.local_tile(real_d_col, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None))
+ tCgD_col_loop = thr_mma_epi_loop.partition_C(gD_col_mnl_loop)
+ _, bSG_sD_col, bSG_gD_col_partitioned = epilog_gmem_copy_and_partition(
+ epi_tidx,
+ tma_atom_d_col,
+ tCgD_col_loop,
+ epi_tile,
+ sD_col,
+ )
+
+ epi_mma_tile_coord = (
+ epi_work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
+ epi_work_tile_info.tile_n_idx,
+ 0,
+ )
+ bSG_gC = bSG_gC_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bSG_gD = bSG_gD_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
+ bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD))
+ if cutlass.const_expr(self.generate_sfd):
+ bSG_gD_col = bSG_gD_col_partitioned[(None, None, None, *epi_mma_tile_coord)]
+ bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col))
+
+ if cutlass.const_expr(self.generate_sfd):
+ tCgSFDRow_mn = tCgSFDRow_mnl[(None, None, None, None, None, 0)]
+ tCgSFDCol_mnl_new = tCgSFDCol_mnl
+ if cutlass.const_expr(self.discrete_col_sfd):
+ tCgSFDCol_mnl_new = self.create_and_partition_new_SFDCol(tile_info, mSFDCol_mnl, padded_offsets)
+ tCgSFDCol_mn = tCgSFDCol_mnl_new[(None, None, None, None, None, 0)]
+
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = cutlass.Float32(0.0)
+
+ mPosition = epi_work_tile_info.tile_m_idx * self.cta_tile_shape_mnk[0] + tidx
+ real_prob, _ = epi_ext.get_gmem_tensor("prob", prob, padded_offsets, epi_work_tile_info)
+ mProb = real_prob[mPosition, 0, 0]
+
+ # C1 fix: phase-based acc stage indexing for overlapping_accum
+ if cutlass.const_expr(self.overlapping_accum):
+ acc_stage_index = acc_consumer_state.phase
+ reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False)
+ else:
+ acc_stage_index = acc_consumer_state.index
+
+ tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
+ tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
+
+ acc_pipeline.consumer_wait(acc_consumer_state)
+
+ subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
+ for subtile_idx in cutlass.range(0, subtile_cnt, 1, unroll=1):
+ real_subtile_idx = subtile_idx
+ if cutlass.const_expr(self.overlapping_accum):
+ if reverse_subtile:
+ real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx
+
+ if cutlass.const_expr(self.overlapping_accum):
+ if subtile_idx == self.iter_acc_early_release_in_epilogue:
+ cute.arch.fence_view_async_tmem_load()
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)]
+ cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
+
+ if cutlass.const_expr(self.enable_bias):
+ # m7 fix: use real_subtile_idx directly (matches contiguous)
+ sBias_sub = sBias_subtiles[(None, real_subtile_idx)]
+ cute.copy(bias_s2r_tom, sBias_sub, tTR_rBias)
+ bias_vec = tTR_rBias.load()
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ bias_f32_0 = bias_vec[i].to(cutlass.Float32)
+ bias_f32_1 = bias_vec[i + 1].to(cutlass.Float32)
+ bias_f32_0, bias_f32_1 = cute.arch.mul_packed_f32x2(
+ (mProb, mProb),
+ (bias_f32_0, bias_f32_1),
+ rnd="rn",
+ ftz=False,
+ )
+ tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.fma_packed_f32x2(
+ (tTR_rAcc[i], tTR_rAcc[i + 1]),
+ (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)),
+ (bias_f32_0, bias_f32_1),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val) + bias_vec[i].to(cutlass.Float32) * mProb
+ else:
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2(
+ (tTR_rAcc[i], tTR_rAcc[i + 1]),
+ (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tTR_rAcc[i] = tTR_rAcc[i] * cutlass.Float32(alpha_val)
+
+ if cutlass.const_expr(self.generate_c):
+ self.store_c(
+ tiled_copy_r2s,
+ tma_atom_c,
+ warp_idx,
+ tTR_rAcc,
+ tRS_rC,
+ tRS_sC,
+ bSG_gC,
+ bSG_sC,
+ c_pipeline,
+ num_prev_subtiles,
+ real_subtile_idx,
+ )
+
+ acc_vec = tTR_rAcc.load()
+
+ if cutlass.const_expr(self.epilogue_type == EpilogueType.SRELU.value):
+ acc_relu = cute.where(acc_vec > 0, acc_vec, cute.full_like(acc_vec, 0))
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tTR_rAcc[i], tTR_rAcc[i + 1] = cute.arch.mul_packed_f32x2(
+ (acc_relu[i], acc_relu[i + 1]),
+ (acc_relu[i], acc_relu[i + 1]),
+ rnd="rn",
+ ftz=False,
+ )
+ acc_vec = tTR_rAcc.load()
+
+ if cutlass.const_expr(not self.enable_bias):
+ tCompute = cute.make_rmem_tensor(acc_vec.shape, self.acc_dtype)
+ if cutlass.const_expr(self.vectorized_f32):
+ for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2):
+ tCompute[i], tCompute[i + 1] = cute.arch.mul_packed_f32x2(
+ (acc_vec[i], acc_vec[i + 1]),
+ (mProb, mProb),
+ rnd="rn",
+ ftz=False,
+ )
+ else:
+ for i in cutlass.range_constexpr(cute.size(tTR_rAcc)):
+ tCompute[i] = acc_vec[i] * mProb
+ else:
+ tCompute = tTR_rAcc
+
+ if cutlass.const_expr(self.generate_amax):
+ thread_tile_amax = amax_reduction_per_thread(tCompute, thread_tile_amax)
+
+ if cutlass.const_expr(self.generate_sfd):
+ tCompute_col = cute.make_rmem_tensor(tCompute.layout, tCompute.element_type)
+ tCompute_col.store(tCompute.load())
+ self.quant_sfd_row(
+ real_subtile_idx % 4,
+ tiled_copy_r2s,
+ tCompute,
+ tCrSFDRow_pvscale,
+ norm_const,
+ d_rcp_limits,
+ tRS_rD,
+ )
+ self.quant_sfd_col(
+ real_subtile_idx % 4,
+ tiled_copy_r2s,
+ tCompute_col,
+ tCrSFDCol_pvscale,
+ norm_const,
+ d_rcp_limits,
+ tRS_rD_col,
+ )
+ # SFD M tile = cta_tile_m = 128; tile_m_idx is CTA-level per-expert
+ global_sfd_m = epi_work_tile_info.tile_m_idx + epi_ext.token_offset // self.cta_tile_shape_mnk[0]
+ if cutlass.const_expr(self.mma_tiler[1] == 256):
+ sfd_n = epi_work_tile_info.tile_n_idx * 2 + (real_subtile_idx >> 2)
+ else:
+ sfd_n = epi_work_tile_info.tile_n_idx
+ sfd_row_idx_mn = (global_sfd_m, sfd_n)
+ sfd_col_idx_mn = sfd_row_idx_mn
+ if cutlass.const_expr(self.discrete_col_sfd):
+ sfd_col_idx_mn = (
+ epi_work_tile_info.tile_m_idx,
+ sfd_n,
+ )
+ tCgSFDRow = tCgSFDRow_mn[(None, None, None, *sfd_row_idx_mn)]
+ tCgSFDCol = tCgSFDCol_mn[(None, None, None, *sfd_col_idx_mn)]
+ if subtile_idx == 3 or subtile_idx == 7:
+ if sfd_row_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDRow_mnl.layout, mode=[1])):
+ tCrSFDRow.store(tCrSFDRow_pvscale.load().to(self.sf_dtype))
+ cute.autovec_copy(tCrSFDRow, tCgSFDRow)
+ if sfd_col_idx_mn[1] * 32 * regPerSubtile < cute.size(cute.shape(mSFDCol_mnl.layout, mode=[1])):
+ tCrSFDCol.store(tCrSFDCol_pvscale.load().to(self.sf_dtype))
+ cute.autovec_copy(tCrSFDCol, tCgSFDCol)
+ else:
+ acc_vec = tiled_copy_r2s.retile(tCompute).load()
+ tRS_rD.store(acc_vec.to(self.d_dtype))
+
+ d_buffer = num_prev_subtiles % self.num_d_stage
+ num_prev_subtiles = num_prev_subtiles + 1
+ cute.copy(tiled_copy_r2s, tRS_rD, tRS_sD[(None, None, None, d_buffer)])
+ if cutlass.const_expr(self.generate_sfd):
+ cute.copy(tiled_copy_r2s, tRS_rD_col, tRS_sD_col[(None, None, None, d_buffer)])
+ cute.arch.fence_proxy("async.shared", space="cta")
+ self.epilog_sync_barrier.arrive_and_wait()
+ if warp_idx == self.epilog_warp_id[0]:
+ cute.copy(tma_atom_d, bSG_sD[(None, d_buffer)], bSG_gD[(None, real_subtile_idx)])
+ if cutlass.const_expr(self.generate_sfd):
+ cute.copy(tma_atom_d_col, bSG_sD_col[(None, d_buffer)], bSG_gD_col[(None, real_subtile_idx)])
+ d_pipeline.producer_commit()
+ d_pipeline.producer_acquire()
+ self.epilog_sync_barrier.arrive_and_wait()
+
+ if cutlass.const_expr(not self.overlapping_accum):
+ with cute.arch.elect_one():
+ acc_pipeline.consumer_release(acc_consumer_state)
+ acc_consumer_state.advance()
+
+ if cutlass.const_expr(self.enable_bias):
+ bias_pipeline.consumer_release(bias_consumer_state)
+ bias_consumer_state.advance()
+
+ tile_info_pipeline.consumer_wait(tile_info_consumer_state)
+ for idx in cutlass.range(4, unroll_full=True):
+ tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)]
+ is_valid_tile = tile_info[0] >= cutlass.Int32(0)
+ cute.arch.fence_proxy("async.shared", space="cta")
+ tile_info_pipeline.consumer_release(tile_info_consumer_state)
+ tile_info_consumer_state.advance()
+
+ if cutlass.const_expr(self.generate_amax):
+ gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr
+ self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax)
+
+ tmem.relinquish_alloc_permit()
+ self.epilog_sync_barrier.arrive_and_wait()
+ tmem.free(tmem_ptr)
+ if cutlass.const_expr(self.generate_c):
+ c_pipeline.producer_tail()
+ d_pipeline.producer_tail()
+
+ # ------------------------------------------------------------------
+ # Internal: create extension based on weight_mode
+ # ------------------------------------------------------------------
+
+ @cute.jit
+ def _make_extension(self, workspace_ptr):
+ if cutlass.const_expr(self.weight_mode == MoEWeightMode.DISCRETE):
+ desc_workspace = TensormapWorkspace(workspace_ptr, ["b", "sfb"])
+ return DiscreteWeightScaledGemmSchedExtension(
+ tensormap_ctor=desc_workspace,
+ sf_vec_size=self.sf_vec_size,
+ )
+ else:
+ return ContiguousAndConsistentGroupedGemmSchedExtension(
+ sf_vec_size=self.sf_vec_size,
+ )
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py
index 91f5431a..7f876e1e 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py
@@ -123,7 +123,7 @@ def __init__(
"""
super().__init__()
- self._logger.warning("GroupedGemmSwigluSm100 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
# Store sample tensor descriptors
@@ -168,7 +168,7 @@ def __init__(
self._kernel = BlockScaledContiguousGroupedGemmKernel
self.num_cluster_overlap_margin = int(os.getenv("CUDNNFE_CLUSTER_OVERLAP_MARGIN", "0"))
- print(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
+ self._logger.debug(f"setting num_cluster_overlap_margin: {self.num_cluster_overlap_margin}")
self._logger.debug(f"__init__ completed")
def check_support(self) -> bool:
@@ -482,7 +482,7 @@ def compile(self) -> None:
fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
self._logger.debug("Compiling grouped_gemm_swiglu kernel")
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
if not use_full_dynamic: # only mark the m dimension as dynamic
valid_m = cute.sym_int(divisibility=256)
@@ -910,11 +910,10 @@ def grouped_gemm_swiglu_wrapper_sm100(
)
sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order)
- if d_dtype in [torch.bfloat16, torch.float16]:
- _logger.debug("grouped_gemm_swiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor")
- amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
-
if valid_m == 0:
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
+
_logger.debug("grouped_gemm_swiglu_wrapper_sm100: valid_m is zero, skipping kernel execution")
return TupleDict(
c_tensor=c_tensor,
@@ -925,7 +924,7 @@ def grouped_gemm_swiglu_wrapper_sm100(
sfd_col_tensor=sfd_col_tensor,
)
- use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL") is not None
+ use_full_dynamic = os.environ.get("CUDNN_FE_GROUPED_GEMM_DYNAMIC_MNKL", "1") != "0"
def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
return tuple(i for i, s in sorted(enumerate(tensor.stride()), key=lambda x: x[1]))
@@ -962,7 +961,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
if cache_key in _cache_of_GroupedGemmSwigluSm100Objects:
_logger.debug("group_gemm_swiglu_wrapper_sm100: Using previously cached GroupedGemmSwigluSm100 object")
- grouped_gemm_swiglu = _cache_of_GroupedGemmSwigluSm100Objects[cache_key]
+ grouped_gemm_swiglu, amax_tensor = _cache_of_GroupedGemmSwigluSm100Objects[cache_key]
+ # The cuDNN graph API binds data pointers at execute time, not plan-build time.
+ # During CUDA graph capture, padded_offsets is allocated in the graph pool
+ # (stable address across replays), so passing it directly is graph-safe.
grouped_gemm_swiglu.execute(
a_tensor=a_tensor,
b_tensor=b_tensor,
@@ -982,6 +984,10 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
)
else:
_logger.debug("group_gemm_swiglu_wrapper_sm100: No previously cached GroupedGemmSwigluSm100 object found, creating new GroupedGemmSwigluSm100 object")
+ # Allocate amax_tensor once here; cache-hit calls reuse this buffer so
+ # the FillFunctor (torch.full) only fires during warmup, not every step.
+ if d_dtype in [torch.bfloat16, torch.float16]:
+ amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device)
grouped_gemm_swiglu = GroupedGemmSwigluSm100(
sample_a=a_tensor,
sample_b=b_tensor,
@@ -1025,7 +1031,7 @@ def stride_order(tensor: torch.Tensor) -> Tuple[int, ...]:
prob_tensor=prob_tensor,
current_stream=current_stream,
)
- _cache_of_GroupedGemmSwigluSm100Objects[cache_key] = grouped_gemm_swiglu
+ _cache_of_GroupedGemmSwigluSm100Objects[cache_key] = (grouped_gemm_swiglu, amax_tensor)
return TupleDict(
c_tensor=c_tensor,
diff --git a/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py b/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py
index cb5808f8..c880295e 100644
--- a/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py
+++ b/python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py
@@ -44,13 +44,13 @@ def __init__(
sample_global_scale_a: Optional[torch.Tensor] = None,
sample_global_scale_b: Optional[torch.Tensor] = None,
acc_dtype: torch.dtype = torch.float32,
- mma_tiler_mn: Tuple[int, int] = (128, 128),
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Optional[Tuple[int, int]] = None,
sf_vec_size: int = 16,
accumulate_on_output: bool = False,
):
super().__init__()
- self._logger.warning("GroupedGemmWgradSm100 is an experimental API")
+ self._warn_experimental_api()
if sample_wgrad is not None and num_experts is None:
self.weight_mode = MoEWeightMode.DENSE
@@ -79,7 +79,7 @@ def __init__(
f"sample_a and sample_b token dimensions must match, got {tokens_sum_a} and {tokens_sum_b}",
)
self._offset_values = self._validate_offsets(sample_offsets, tokens_sum_a, name="sample_offsets")
- self._scale_cols = self._compute_scale_cols(self._offset_values)
+ self._scale_cols = _round_up(ceil_div(tokens_sum_a, self.sf_vec_size), 4)
if self.weight_mode == MoEWeightMode.DENSE:
self.wgrad_desc = self._make_tensor_desc(sample_wgrad, name="sample_wgrad")
@@ -136,23 +136,14 @@ def _validate_offsets(self, offsets_tensor: torch.Tensor, tokens_sum: int, name:
if offset_values:
self._value_error_if(
- offset_values[-1] != tokens_sum,
- f"{name} last value must equal total tokens {tokens_sum}, got {offset_values[-1]}",
+ offset_values[-1] > tokens_sum,
+ f"{name} last value must not exceed total tokens {tokens_sum}, got {offset_values[-1]}",
)
else:
self._value_error_if(tokens_sum != 0, f"{name} cannot be empty when total tokens is {tokens_sum}")
return offset_values
- def _compute_scale_cols(self, offset_values: Tuple[int, ...]) -> int:
- prev_offset = 0
- scale_cols = 0
- for offset in offset_values:
- group_k = offset - prev_offset
- scale_cols += _round_up(ceil_div(group_k, self.sf_vec_size), 4)
- prev_offset = offset
- return scale_cols
-
def check_support(self) -> bool:
m, tokens_sum = self._tensor_shape(self.a_desc, name="sample_a")
_, n = self._tensor_shape(self.b_desc, name="sample_b")
@@ -572,11 +563,13 @@ def grouped_gemm_wgrad_wrapper_sm100(
sfb_tensor: torch.Tensor,
offsets_tensor: torch.Tensor,
output_mode: str = "dense",
+ wgrad_tensor: Optional[torch.Tensor] = None,
+ wgrad_ptrs: Optional[torch.Tensor] = None,
global_scale_a: Optional[torch.Tensor] = None,
global_scale_b: Optional[torch.Tensor] = None,
acc_dtype: torch.dtype = torch.float32,
wgrad_dtype: torch.dtype = torch.bfloat16,
- mma_tiler_mn: Tuple[int, int] = (128, 128),
+ mma_tiler_mn: Tuple[int, int] = (256, 256),
cluster_shape_mn: Optional[Tuple[int, int]] = None,
sf_vec_size: int = 16,
accumulate_on_output: bool = False,
@@ -585,15 +578,18 @@ def grouped_gemm_wgrad_wrapper_sm100(
"""Compile and execute grouped GEMM wgrad in one call."""
hidden, _ = a_tensor.shape
_, intermediate = b_tensor.shape
+ wgrad_shape = (hidden, intermediate)
expert_cnt = offsets_tensor.shape[0]
if output_mode not in {"dense", "discrete"}:
raise ValueError(f"output_mode must be 'dense' or 'discrete', got {output_mode}")
- if accumulate_on_output:
- wgrad_tensor = torch.zeros((expert_cnt, hidden, intermediate), dtype=wgrad_dtype, device=a_tensor.device)
- else:
- wgrad_tensor = torch.empty((expert_cnt, hidden, intermediate), dtype=wgrad_dtype, device=a_tensor.device)
+ if wgrad_tensor is None and wgrad_ptrs is None:
+ # Backward compatibility: Dense mode.
+ if accumulate_on_output:
+ wgrad_tensor = torch.zeros((expert_cnt, *wgrad_shape), dtype=wgrad_dtype, device=a_tensor.device)
+ else:
+ wgrad_tensor = torch.empty((expert_cnt, *wgrad_shape), dtype=wgrad_dtype, device=a_tensor.device)
cache_key = (
output_mode,
@@ -604,6 +600,7 @@ def grouped_gemm_wgrad_wrapper_sm100(
tuple(offsets_tensor.shape),
tuple(offsets_tensor.stride()),
offsets_tensor.dtype,
+ *_dynamic_dim_tensor_signature(wgrad_tensor, dynamic_dims=()),
tuple(global_scale_a.shape) if global_scale_a is not None else None,
global_scale_a.dtype if global_scale_a is not None else None,
tuple(global_scale_b.shape) if global_scale_b is not None else None,
@@ -636,13 +633,14 @@ def grouped_gemm_wgrad_wrapper_sm100(
accumulate_on_output=accumulate_on_output,
)
else:
+ sample_expert = torch.empty(wgrad_shape, dtype=wgrad_dtype, device=a_tensor.device)
op = GroupedGemmWgradSm100(
sample_a=a_tensor,
sample_b=b_tensor,
sample_sfa=sfa_tensor,
sample_sfb=sfb_tensor,
sample_offsets=offsets_tensor,
- sample_wgrad_expert=wgrad_tensor[0],
+ sample_wgrad_expert=sample_expert,
num_experts=expert_cnt,
wgrad_shape=(hidden, intermediate),
wgrad_dtype=wgrad_dtype,
@@ -665,6 +663,7 @@ def grouped_gemm_wgrad_wrapper_sm100(
sfb_tensor=sfb_tensor,
offsets_tensor=offsets_tensor,
wgrad_tensor=wgrad_tensor,
+ wgrad_ptrs=wgrad_ptrs,
global_scale_a=global_scale_a,
global_scale_b=global_scale_b,
current_stream=current_stream,
diff --git a/python/cudnn/native_sparse_attention/compression/api.py b/python/cudnn/native_sparse_attention/compression/api.py
index 7433eff5..fc89c004 100644
--- a/python/cudnn/native_sparse_attention/compression/api.py
+++ b/python/cudnn/native_sparse_attention/compression/api.py
@@ -40,7 +40,7 @@ def __init__(
super().__init__()
self._kernel = BlackwellFusedMultiHeadAttentionForward
- self._logger.warning("CompressionAttention is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self.q_desc = self._make_tensor_desc(sample_q, name="sample_q")
diff --git a/python/cudnn/native_sparse_attention/selection/api.py b/python/cudnn/native_sparse_attention/selection/api.py
index daf0f7a6..03ca47c2 100644
--- a/python/cudnn/native_sparse_attention/selection/api.py
+++ b/python/cudnn/native_sparse_attention/selection/api.py
@@ -33,7 +33,7 @@ def __init__(
super().__init__()
self._kernel = HopperSelectAttentionFwd
- self._logger.warning("SelectionAttention is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self.q_desc = self._make_tensor_desc(sample_q, name="sample_q")
diff --git a/python/cudnn/native_sparse_attention/top_k/api.py b/python/cudnn/native_sparse_attention/top_k/api.py
index 392afdc5..3d0a5f63 100644
--- a/python/cudnn/native_sparse_attention/top_k/api.py
+++ b/python/cudnn/native_sparse_attention/top_k/api.py
@@ -47,7 +47,7 @@ def __init__(
super().__init__()
self._kernel = FineGrainedReductionQK
- self._logger.warning("TopKReduction is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
self.q_desc = self._make_tensor_desc(sample_q, name="sample_q")
diff --git a/python/cudnn/ops/__init__.py b/python/cudnn/ops/__init__.py
new file mode 100644
index 00000000..54a48c07
--- /dev/null
+++ b/python/cudnn/ops/__init__.py
@@ -0,0 +1 @@
+from .causal_conv1d import causal_conv1d
diff --git a/python/cudnn/ops/causal_conv1d.py b/python/cudnn/ops/causal_conv1d.py
new file mode 100644
index 00000000..bab76197
--- /dev/null
+++ b/python/cudnn/ops/causal_conv1d.py
@@ -0,0 +1,231 @@
+from typing import List, Optional
+
+import torch
+from torch import Tensor
+
+_TORCH_DTYPE_TO_CUDNN = {
+ torch.float32: 0, # CUDNN_DATA_FLOAT
+ torch.float16: 2, # CUDNN_DATA_HALF
+ torch.bfloat16: 9, # CUDNN_DATA_BFLOAT16
+}
+
+_ACTIVATION_TO_INT = {
+ "identity": 0, # CUDNN_CAUSAL_CONV1D_ACTIVATION_IDENTITY
+ "silu": 1, # CUDNN_CAUSAL_CONV1D_ACTIVATION_SILU
+}
+
+
+def _dtype_to_int(dtype: torch.dtype) -> int:
+ if dtype not in _TORCH_DTYPE_TO_CUDNN:
+ raise ValueError(f"Unsupported dtype {dtype}. Supported: float32, float16, bfloat16.")
+ return _TORCH_DTYPE_TO_CUDNN[dtype]
+
+
+def _activation_to_int(activation: str) -> int:
+ if activation not in _ACTIVATION_TO_INT:
+ raise ValueError(f"Unsupported activation '{activation}'. Supported: 'identity', 'silu'.")
+ return _ACTIVATION_TO_INT[activation]
+
+
+# ---------------------------------------------------------------------------
+# Forward primitive
+# ---------------------------------------------------------------------------
+
+
+@torch.library.custom_op(
+ "cudnn::causal_conv1d_fwd_primitive",
+ mutates_args=(),
+ device_types="cuda",
+)
+def _fwd_primitive(x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> Tensor:
+ if x.dim() != 3 or weight.dim() != 2 or bias.dim() != 1:
+ raise ValueError(f"Expected x(3D), weight(2D), bias(1D); got {x.shape}, {weight.shape}, {bias.shape}")
+
+ if not (x.is_cuda and weight.is_cuda and bias.is_cuda):
+ raise ValueError(f"All tensors must be on CUDA: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}")
+ if not (x.device == weight.device == bias.device):
+ raise ValueError(f"All tensors must be on the same device: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}")
+
+ if not (x.dtype == weight.dtype == bias.dtype):
+ raise TypeError(f"Dtype mismatch: x.dtype={x.dtype}, weight.dtype={weight.dtype}, " f"bias.dtype={bias.dtype} (all must match)")
+
+ x = x.contiguous()
+ weight = weight.contiguous()
+ bias = bias.contiguous()
+
+ batch, dim, seq_len = x.shape
+ kernel_size = weight.shape[1]
+
+ if weight.shape[0] != dim:
+ raise ValueError(f"Channel mismatch: x has dim={dim} but weight has shape {weight.shape} " f"(expected weight.shape[0]={dim})")
+
+ if bias.shape[0] != dim:
+ raise ValueError(f"Bias mismatch: x has dim={dim} but bias has shape {bias.shape} " f"(expected bias.shape[0]={dim})")
+
+ y = torch.empty_like(x)
+
+ import cudnn
+
+ cudnn.causal_conv1d_forward(
+ torch.cuda.current_stream().cuda_stream,
+ x.data_ptr(),
+ weight.data_ptr(),
+ bias.data_ptr(),
+ y.data_ptr(),
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ _dtype_to_int(x.dtype),
+ _activation_to_int(activation),
+ )
+ return y
+
+
+@torch.library.register_fake("cudnn::causal_conv1d_fwd_primitive")
+def _fwd_fake(x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> Tensor:
+ return torch.empty_like(x)
+
+
+# ---------------------------------------------------------------------------
+# Backward primitive
+# ---------------------------------------------------------------------------
+
+
+@torch.library.custom_op(
+ "cudnn::causal_conv1d_bwd_primitive",
+ mutates_args=(),
+ device_types="cuda",
+)
+def _bwd_primitive(grad_out: Tensor, x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> List[Tensor]:
+ if x.dim() != 3 or weight.dim() != 2 or bias.dim() != 1:
+ raise ValueError(f"Expected x(3D), weight(2D), bias(1D); got {x.shape}, {weight.shape}, {bias.shape}")
+ if grad_out.shape != x.shape:
+ raise ValueError(f"Shape mismatch: dy has shape {grad_out.shape} but x has shape {x.shape} " f"(expected dy.shape == x.shape)")
+ if not grad_out.is_cuda:
+ raise ValueError(f"grad_out must be on CUDA: grad_out.device={grad_out.device}")
+ if grad_out.device != x.device:
+ raise ValueError(f"Device mismatch: grad_out.device={grad_out.device}, x.device={x.device}")
+ if grad_out.dtype != x.dtype:
+ raise ValueError(f"Dtype mismatch: grad_out.dtype={grad_out.dtype}, x.dtype={x.dtype}")
+
+ if not (x.is_cuda and weight.is_cuda and bias.is_cuda):
+ raise ValueError(f"All tensors must be on CUDA: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}")
+ if not (x.device == weight.device == bias.device):
+ raise ValueError(f"All tensors must be on the same device: x.device={x.device}, " f"weight.device={weight.device}, bias.device={bias.device}")
+
+ if not (x.dtype == weight.dtype == bias.dtype):
+ raise TypeError(f"Dtype mismatch: x.dtype={x.dtype}, weight.dtype={weight.dtype}, " f"bias.dtype={bias.dtype} (all must match)")
+
+ batch, dim, seq_len = x.shape
+
+ if weight.shape[0] != dim:
+ raise ValueError(f"Channel mismatch: x has dim={dim} but weight has shape {weight.shape} " f"(expected weight.shape[0]={dim})")
+
+ if bias.shape[0] != dim:
+ raise ValueError(f"Bias mismatch: x has dim={dim} but bias has shape {bias.shape} " f"(expected bias.shape[0]={dim})")
+
+ x = x.contiguous()
+ weight = weight.contiguous()
+ bias = bias.contiguous()
+ grad_out = grad_out.contiguous()
+
+ kernel_size = weight.shape[1]
+
+ dx = torch.empty_like(x)
+ dweight = torch.zeros(weight.shape, device=x.device, dtype=torch.float32)
+ dbias = torch.zeros(bias.shape, device=x.device, dtype=torch.float32)
+
+ import cudnn
+
+ cudnn.causal_conv1d_backward(
+ torch.cuda.current_stream().cuda_stream,
+ x.data_ptr(),
+ weight.data_ptr(),
+ bias.data_ptr(),
+ grad_out.data_ptr(),
+ dx.data_ptr(),
+ dweight.data_ptr(),
+ dbias.data_ptr(),
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ _dtype_to_int(x.dtype),
+ _dtype_to_int(torch.float32),
+ _activation_to_int(activation),
+ )
+ return [dx, dweight.to(x.dtype), dbias.to(x.dtype)]
+
+
+@torch.library.register_fake("cudnn::causal_conv1d_bwd_primitive")
+def _bwd_fake(grad_out: Tensor, x: Tensor, weight: Tensor, bias: Tensor, activation: str) -> List[Tensor]:
+ return [torch.empty_like(x), torch.empty_like(weight), torch.empty_like(bias)]
+
+
+# ---------------------------------------------------------------------------
+# Autograd glue
+# ---------------------------------------------------------------------------
+
+
+def _setup_context(ctx, inputs, output):
+ x, weight, bias, activation = inputs
+ ctx.save_for_backward(x, weight, bias)
+ ctx.activation = activation
+
+
+@torch.compiler.allow_in_graph
+def _autograd_bwd(ctx, grad_out):
+ x, weight, bias = ctx.saved_tensors
+ dx, dw, db = torch.ops.cudnn.causal_conv1d_bwd_primitive(grad_out, x, weight, bias, ctx.activation)
+ return dx, dw, db, None
+
+
+torch.library.register_autograd(
+ "cudnn::causal_conv1d_fwd_primitive",
+ _autograd_bwd,
+ setup_context=_setup_context,
+)
+
+
+# ---------------------------------------------------------------------------
+# Public API
+# ---------------------------------------------------------------------------
+
+
+def causal_conv1d(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ activation: str = "identity",
+) -> Tensor:
+ r"""Depthwise causal 1D convolution with optional activation.
+
+ Computes a depthwise 1D convolution with causal (left-only) padding
+ and optional fused activation::
+
+ y = activation(conv1d_causal(x, weight) + bias)
+
+ Causal padding: ``(kernel_size - 1)`` on the left, ``0`` on the right.
+ Each channel is convolved independently with its own 1D filter.
+
+ Supports ``torch.compile`` and ``torch.autograd`` — backward is handled
+ automatically when inputs require gradients.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape ``(batch, dim, seq_len)``.
+ Must be BF16, FP16, or FP32. Must be contiguous and on CUDA.
+ weight (torch.Tensor): Filter tensor of shape ``(dim, kernel_size)``.
+ Same dtype as *x*.
+ bias (torch.Tensor | None): Optional bias of shape ``(dim,)``.
+ Same dtype as *x*. Defaults to zeros if ``None``.
+ activation (str): ``"identity"`` (default) or ``"silu"``.
+
+ Returns:
+ torch.Tensor: Output of shape ``(batch, dim, seq_len)``, same dtype as *x*.
+ """
+ if activation not in _ACTIVATION_TO_INT:
+ raise ValueError(f"Unsupported activation '{activation}'. Supported: 'identity', 'silu'.")
+ if bias is None:
+ bias = torch.zeros(weight.shape[0], device=x.device, dtype=x.dtype)
+ return torch.ops.cudnn.causal_conv1d_fwd_primitive(x, weight, bias, activation)
diff --git a/python/cudnn/rmsnorm_rht_amax/__init__.py b/python/cudnn/rmsnorm_rht_amax/__init__.py
new file mode 100644
index 00000000..9da5cae1
--- /dev/null
+++ b/python/cudnn/rmsnorm_rht_amax/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+from .api import (
+ RmsNormRhtAmaxSm100,
+ best_num_threads,
+ pick_rows_per_cta,
+ rmsnorm_rht_amax_wrapper_sm100,
+)
+
+__all__ = [
+ "RmsNormRhtAmaxSm100",
+ "best_num_threads",
+ "pick_rows_per_cta",
+ "rmsnorm_rht_amax_wrapper_sm100",
+]
diff --git a/python/cudnn/rmsnorm_rht_amax/api.py b/python/cudnn/rmsnorm_rht_amax/api.py
new file mode 100644
index 00000000..0294a3a2
--- /dev/null
+++ b/python/cudnn/rmsnorm_rht_amax/api.py
@@ -0,0 +1,302 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""FE API for fused RMSNorm + RHT + per-CTA amax."""
+
+import logging
+from typing import Optional
+
+from cuda.bindings import driver as cuda
+import cutlass
+import cutlass.cute as cute
+import torch
+from cutlass import Float32
+from cutlass.cute.runtime import make_fake_stream
+
+from cudnn.api_base import APIBase, TupleDict
+
+from .kernel import RMSNormRHTAmaxKernel
+
+DEFAULT_NUM_THREADS_BY_N = {
+ 2048: 128,
+ 4096: 256,
+ 7168: 128,
+ 8192: 512,
+ 16384: 1024,
+ 32768: 512,
+}
+RPC_CANDIDATES = (2, 4, 8)
+TARGET_MIN_CTAS = 148
+
+
+def best_num_threads(n: int) -> Optional[int]:
+ for num_threads in (1024, 512, 256, 128, 64):
+ if n % num_threads != 0:
+ continue
+ ept = n // num_threads
+ if ept >= 8 and ept % 8 == 0:
+ return num_threads
+ return None
+
+
+def pick_rows_per_cta(m: int) -> int:
+ for rows_per_cta in reversed(RPC_CANDIDATES):
+ if m % rows_per_cta != 0:
+ continue
+ num_ctas = m // rows_per_cta
+ if num_ctas >= TARGET_MIN_CTAS:
+ return rows_per_cta
+ return RPC_CANDIDATES[0]
+
+
+class RmsNormRhtAmaxSm100(APIBase):
+ """Class API for the RMSNorm + RHT + amax kernel."""
+
+ def __init__(
+ self,
+ sample_x: torch.Tensor,
+ sample_w: torch.Tensor,
+ sample_o: torch.Tensor,
+ sample_amax: torch.Tensor,
+ eps: float = 1e-5,
+ num_threads: Optional[int] = None,
+ rows_per_cta: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self._warn_experimental_api()
+
+ self.x_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_x, name="sample_x"), 2, "sample_x")
+ self.w_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_w, name="sample_w"), 1, "sample_w")
+ self.o_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_o, name="sample_o"), 2, "sample_o")
+ self.amax_desc = self._unpad_tensor_to_ndim(self._make_tensor_desc(sample_amax, name="sample_amax"), 1, "sample_amax")
+
+ self.eps = eps
+ self.requested_num_threads = num_threads
+ self.requested_rows_per_cta = rows_per_cta
+ self.num_threads = None
+ self.rows_per_cta = None
+ self.n = None
+
+ def check_support(self) -> bool:
+ m, n = self._tensor_shape(self.x_desc, name="sample_x")
+ w_n = self._tensor_shape(self.w_desc, name="sample_w")[0]
+ o_m, o_n = self._tensor_shape(self.o_desc, name="sample_o")
+
+ self._check_tensor_shape(self.x_desc, (m, n), "X")
+ self._check_tensor_shape(self.w_desc, (n,), "W")
+ self._check_tensor_shape(self.o_desc, (m, n), "O")
+ self._value_error_if(w_n != n, f"W length must match X hidden dimension, got {w_n} and {n}")
+ self._value_error_if((n % 16) != 0, f"N must be divisible by 16 for the Hadamard block size, got {n}")
+ self._value_error_if(o_m != m or o_n != n, f"O shape must match X shape, got {(o_m, o_n)} and {(m, n)}")
+
+ self._check_tensor_stride(self.x_desc, stride=(n, 1), name="X", extra_error_msg="X must be row-major contiguous")
+ self._check_tensor_stride(self.w_desc, stride=(1,), name="W", extra_error_msg="W must be contiguous")
+ self._check_tensor_stride(self.o_desc, stride=(n, 1), name="O", extra_error_msg="O must be row-major contiguous")
+
+ self._check_dtype(self.x_desc, dtype=torch.bfloat16, name="X")
+ self._check_dtype(self.w_desc, dtype=torch.bfloat16, name="W")
+ self._check_dtype(self.o_desc, dtype=torch.bfloat16, name="O")
+ self._check_dtype(self.amax_desc, dtype=torch.float32, name="Amax")
+
+ resolved_num_threads = self.requested_num_threads
+ if resolved_num_threads is None:
+ resolved_num_threads = DEFAULT_NUM_THREADS_BY_N.get(n, best_num_threads(n))
+ self._value_error_if(resolved_num_threads is None, f"No valid num_threads found for N={n}")
+ self._value_error_if(resolved_num_threads <= 0, f"num_threads must be positive, got {resolved_num_threads}")
+ self._value_error_if(
+ (resolved_num_threads % 32) != 0,
+ f"num_threads must be warp-aligned, got {resolved_num_threads}",
+ )
+ self._value_error_if(
+ resolved_num_threads > 1024,
+ f"num_threads must not exceed the CUDA block size limit, got {resolved_num_threads}",
+ )
+
+ resolved_rows_per_cta = self.requested_rows_per_cta
+ if resolved_rows_per_cta is None:
+ resolved_rows_per_cta = pick_rows_per_cta(m)
+
+ self._value_error_if(m % resolved_rows_per_cta != 0, f"M must be divisible by rows_per_cta, got M={m}, rows_per_cta={resolved_rows_per_cta}")
+ self._value_error_if(n % resolved_num_threads != 0, f"N={n} must be divisible by num_threads={resolved_num_threads}")
+
+ ept = n // resolved_num_threads
+ self._value_error_if(ept < 8 or ept % 8 != 0, f"EPT={ept} must be >= 8 and divisible by 8")
+
+ expected_num_ctas = m // resolved_rows_per_cta
+ self._check_tensor_shape(self.amax_desc, (expected_num_ctas,), "Amax")
+
+ self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available")
+ major, minor = torch.cuda.get_device_capability(self.x_desc.device)
+ compute_capability = major * 10 + minor
+ self._runtime_error_if(
+ compute_capability < 100,
+ f"RmsNormRhtAmaxSm100 requires SM100+, found SM{compute_capability}",
+ )
+
+ self.num_threads = resolved_num_threads
+ self.rows_per_cta = resolved_rows_per_cta
+ self.n = n
+ self._is_supported = True
+ return True
+
+ def compile(self) -> None:
+ self._ensure_support_checked()
+ if self._compiled_kernel is not None:
+ return
+
+ kernel = RMSNormRHTAmaxKernel(
+ n=self.n,
+ num_threads=self.num_threads,
+ eps=self.eps,
+ rows_per_cta=self.rows_per_cta,
+ )
+
+ valid_m = cute.sym_int(divisibility=self.rows_per_cta)
+
+ fake_x_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.x_desc.dtype,
+ shape=(valid_m, self.n),
+ stride_order=self.x_desc.stride_order,
+ dynamic_mode=None,
+ divisibility=self.rows_per_cta,
+ )
+ fake_w_tensor = self._make_fake_cute_tensor_from_desc(self.w_desc, assumed_align=16)
+ fake_o_tensor = self._make_fake_cute_compact_tensor(
+ dtype=self.o_desc.dtype,
+ shape=(valid_m, self.n),
+ stride_order=self.o_desc.stride_order,
+ dynamic_mode=None,
+ divisibility=self.rows_per_cta,
+ )
+ fake_num_ctas = cute.sym_int()
+ fake_amax_tensor = self._make_fake_cute_tensor(
+ dtype=self.amax_desc.dtype,
+ shape=(fake_num_ctas,),
+ stride=self.amax_desc.stride,
+ assumed_align=16,
+ )
+ fake_stream = make_fake_stream(use_tvm_ffi_env_stream=False)
+
+ compiled_kernel = cute.compile(
+ kernel,
+ fake_x_tensor,
+ fake_w_tensor,
+ fake_o_tensor,
+ fake_amax_tensor,
+ Float32(self.eps),
+ fake_stream,
+ options="--enable-tvm-ffi",
+ )
+
+ def tensor_api(
+ x_tensor: torch.Tensor,
+ w_tensor: torch.Tensor,
+ o_tensor: torch.Tensor,
+ amax_tensor: torch.Tensor,
+ stream: cuda.CUstream,
+ ) -> None:
+ compiled_kernel(
+ x_tensor,
+ w_tensor,
+ o_tensor,
+ amax_tensor,
+ Float32(self.eps),
+ stream,
+ )
+
+ self._compiled_kernel = tensor_api
+
+ def execute(
+ self,
+ x_tensor: torch.Tensor,
+ w_tensor: torch.Tensor,
+ o_tensor: torch.Tensor,
+ amax_tensor: torch.Tensor,
+ current_stream: Optional[cuda.CUstream] = None,
+ ) -> None:
+ self._runtime_error_if(self._compiled_kernel is None, "RmsNormRhtAmaxSm100 kernel not compiled; call compile() first")
+
+ x_tensor = self._unpad_tensor_to_ndim(x_tensor, 2, "x_tensor")
+ w_tensor = self._unpad_tensor_to_ndim(w_tensor, 1, "w_tensor")
+ o_tensor = self._unpad_tensor_to_ndim(o_tensor, 2, "o_tensor")
+ amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax_tensor")
+
+ if current_stream is None:
+ current_stream = cuda.CUstream(torch.cuda.current_stream(x_tensor.device).cuda_stream)
+
+ self._compiled_kernel(
+ x_tensor=x_tensor,
+ w_tensor=w_tensor,
+ o_tensor=o_tensor,
+ amax_tensor=amax_tensor,
+ stream=current_stream,
+ )
+
+
+_logger = logging.getLogger(__name__)
+_cache_of_RmsNormRhtAmaxSm100Objects = {}
+
+
+def rmsnorm_rht_amax_wrapper_sm100(
+ x_tensor: torch.Tensor,
+ w_tensor: torch.Tensor,
+ eps: float = 1e-5,
+ num_threads: Optional[int] = None,
+ rows_per_cta: Optional[int] = None,
+ current_stream: Optional[cuda.CUstream] = None,
+) -> TupleDict:
+ """High-level wrapper for the RMSNorm + RHT + per-CTA amax kernel."""
+
+ x_tensor = x_tensor.squeeze(-1) if x_tensor.ndim == 3 and x_tensor.shape[-1] == 1 else x_tensor
+ w_tensor = w_tensor.squeeze(-1) if w_tensor.ndim == 2 and w_tensor.shape[-1] == 1 else w_tensor
+
+ m, n = x_tensor.shape
+ resolved_num_threads = num_threads if num_threads is not None else DEFAULT_NUM_THREADS_BY_N.get(n, best_num_threads(n))
+ if resolved_num_threads is None:
+ raise ValueError(f"No valid num_threads found for N={n}")
+ resolved_rows_per_cta = rows_per_cta if rows_per_cta is not None else pick_rows_per_cta(m)
+ if m % resolved_rows_per_cta != 0:
+ raise ValueError(f"M must be divisible by rows_per_cta, got M={m}, rows_per_cta={resolved_rows_per_cta}")
+
+ o_tensor = torch.empty_like(x_tensor)
+ amax_tensor = torch.full((m // resolved_rows_per_cta,), float("-inf"), dtype=torch.float32, device=x_tensor.device)
+
+ cache_key = (
+ n,
+ x_tensor.dtype,
+ w_tensor.dtype,
+ o_tensor.dtype,
+ tuple(x_tensor.stride()),
+ tuple(w_tensor.stride()),
+ tuple(o_tensor.stride()),
+ eps,
+ resolved_num_threads,
+ resolved_rows_per_cta,
+ )
+
+ if cache_key in _cache_of_RmsNormRhtAmaxSm100Objects:
+ api = _cache_of_RmsNormRhtAmaxSm100Objects[cache_key]
+ else:
+ api = RmsNormRhtAmaxSm100(
+ sample_x=x_tensor,
+ sample_w=w_tensor,
+ sample_o=o_tensor,
+ sample_amax=amax_tensor,
+ eps=eps,
+ num_threads=resolved_num_threads,
+ rows_per_cta=resolved_rows_per_cta,
+ )
+ assert api.check_support(), "Unsupported configuration"
+ api.compile()
+ _cache_of_RmsNormRhtAmaxSm100Objects[cache_key] = api
+
+ api.execute(
+ x_tensor=x_tensor,
+ w_tensor=w_tensor,
+ o_tensor=o_tensor,
+ amax_tensor=amax_tensor,
+ current_stream=current_stream,
+ )
+
+ return TupleDict(o_tensor=o_tensor, amax_tensor=amax_tensor)
diff --git a/python/cudnn/rmsnorm_rht_amax/kernel.py b/python/cudnn/rmsnorm_rht_amax/kernel.py
new file mode 100644
index 00000000..2445cae9
--- /dev/null
+++ b/python/cudnn/rmsnorm_rht_amax/kernel.py
@@ -0,0 +1,253 @@
+# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+"""CUTE DSL kernel for fused RMSNorm + RHT + per-CTA amax."""
+
+import math
+import operator
+
+import cuda.bindings.driver as cuda
+import cutlass
+import cutlass.cute as cute
+import cutlass.utils as utils
+from cutlass import Float32, Int32
+from cutlass._mlir.dialects import llvm
+from cutlass.cute.arch import shuffle_sync_bfly
+from cutlass.cutlass_dsl import T, dsl_user_op
+
+
+@dsl_user_op
+def fabs_f32(val, *, loc=None, ip=None):
+ val_ir = val.ir_value(loc=loc, ip=ip)
+ result = llvm.inline_asm(
+ T.f32(),
+ [val_ir],
+ "abs.f32 $0, $1;",
+ "=f,f",
+ has_side_effects=False,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ loc=loc,
+ ip=ip,
+ )
+ return Float32(result)
+
+
+@dsl_user_op
+def fmax_f32(a, b, *, loc=None, ip=None):
+ a_ir = a.ir_value(loc=loc, ip=ip)
+ b_ir = b.ir_value(loc=loc, ip=ip)
+ result = llvm.inline_asm(
+ T.f32(),
+ [a_ir, b_ir],
+ "max.f32 $0, $1, $2;",
+ "=f,f,f",
+ has_side_effects=False,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ loc=loc,
+ ip=ip,
+ )
+ return Float32(result)
+
+
+@dsl_user_op
+def redux_sync_max_f32(val, *, loc=None, ip=None):
+ val_ir = val.ir_value(loc=loc, ip=ip)
+ result = llvm.inline_asm(
+ T.f32(),
+ [val_ir],
+ "redux.sync.max.f32 $0, $1, 0xffffffff;",
+ "=f,f",
+ has_side_effects=False,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ loc=loc,
+ ip=ip,
+ )
+ return Float32(result)
+
+
+class RMSNormRHTAmaxKernel:
+ """Fused RMSNorm + block-diagonal Hadamard + running per-CTA amax."""
+
+ COPY_BITS = 128
+ HAD_BLOCK = 16
+
+ def __init__(self, n, num_threads=256, eps=1e-5, rows_per_cta=8):
+ self.n = n
+ self.num_threads = num_threads
+ self.eps = eps
+ self.rows_per_cta = rows_per_cta
+ self.vec_size = self.COPY_BITS // 16
+ self.ept = n // num_threads
+
+ assert n % num_threads == 0, f"N={n} must be divisible by num_threads={num_threads}"
+ assert self.ept % self.vec_size == 0, f"EPT={self.ept} must be a multiple of vec_size={self.vec_size}"
+ assert self.ept >= self.vec_size, f"EPT={self.ept} must be >= vec_size={self.vec_size}"
+
+ self.num_vec_blocks = self.ept // self.vec_size
+ self.warps_per_row = num_threads // 32
+ self.inv_sqrt_had = 1.0 / math.sqrt(self.HAD_BLOCK)
+ self.num_intra_stages = int(math.log2(self.vec_size))
+ self.num_cross_stages = 1
+
+ self.tv_shape = ((num_threads, 1), (self.vec_size, self.num_vec_blocks))
+ self.tv_stride = ((self.vec_size, 1), (1, self.vec_size * num_threads))
+ self.tiler_mn = (1, n)
+
+ tile_bytes = n * 2
+ reduce_bytes = self.warps_per_row * 4
+ amax_bytes = self.warps_per_row * 4
+ self.smem_bytes = tile_bytes + reduce_bytes + amax_bytes + 128
+
+ self.intra_butterfly_pairs = []
+ for stage in range(self.num_intra_stages):
+ delta = 1 << stage
+ pairs = []
+ for pair_idx in range(self.vec_size // 2):
+ i_idx = (pair_idx // delta) * 2 * delta + (pair_idx % delta)
+ j_idx = i_idx + delta
+ pairs.append((i_idx, j_idx))
+ self.intra_butterfly_pairs.append(pairs)
+
+ @cute.kernel
+ def kernel(self, m_x: cute.Tensor, m_w: cute.Tensor, m_o: cute.Tensor, m_amax: cute.Tensor, eps: Float32, tv_layout: cute.Layout, tiler_mn: cute.Shape):
+ cfg = self
+ tid = cute.arch.thread_idx()[0]
+ bid = cute.arch.block_idx()[0]
+ inv_sqrt_had = cutlass.Float32(cfg.inv_sqrt_had)
+
+ smem = utils.SmemAllocator()
+ s_x = smem.allocate_tensor(
+ cutlass.BFloat16,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer = smem.allocate_tensor(Float32, cute.make_layout((1, cfg.warps_per_row)), byte_alignment=4)
+ amax_buffer = smem.allocate_tensor(Float32, cute.make_layout((1, cfg.warps_per_row)), byte_alignment=4)
+
+ copy_atom_g2s = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS)
+ copy_atom_load_w = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS)
+ copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=cfg.COPY_BITS)
+
+ tiled_copy_load = cute.make_tiled_copy(copy_atom_g2s, tv_layout, tiler_mn)
+ tiled_copy_w = cute.make_tiled_copy(copy_atom_load_w, tv_layout, tiler_mn)
+ tiled_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn)
+
+ thr_load = tiled_copy_load.get_slice(tid)
+ thr_w = tiled_copy_w.get_slice(tid)
+ thr_store = tiled_copy_store.get_slice(tid)
+
+ t_xs_x = thr_load.partition_D(s_x)
+
+ m_w_layout = cute.prepend(m_w.layout, cute.make_layout((1,), stride=(0,)))
+ m_w_2d = cute.make_tensor(m_w.iterator, m_w_layout)
+ g_w = cute.local_tile(m_w_2d, tiler_mn, (0, 0))
+ t_wg_w = thr_w.partition_S(g_w)
+ t_wr_w = cute.make_fragment_like(t_wg_w)
+ cute.copy(copy_atom_load_w, t_wg_w, t_wr_w)
+ t_xr_w = thr_load.retile(t_wr_w)
+
+ row_base = bid * cfg.rows_per_cta
+ g_x_first = cute.local_tile(m_x, tiler_mn, (row_base, 0))
+ t_xg_x_first = thr_load.partition_S(g_x_first)
+ t_xr_x = cute.make_fragment_like(t_xg_x_first)
+
+ cute.copy(copy_atom_g2s, t_xg_x_first, t_xs_x)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ reg = cute.make_rmem_tensor(cute.make_layout((cfg.ept,)), cutlass.Float32)
+ lane_id = cute.arch.lane_idx()
+ warp_id = cute.arch.warp_idx()
+ running_max = cutlass.Float32(0.0)
+
+ for row_idx in cutlass.range_constexpr(cfg.rows_per_cta):
+ cute.autovec_copy(t_xs_x, t_xr_x)
+
+ if row_idx < cfg.rows_per_cta - 1:
+ g_x_next = cute.local_tile(m_x, tiler_mn, (row_base + (row_idx + 1), 0))
+ t_xg_x_next = thr_load.partition_S(g_x_next)
+ cute.copy(copy_atom_g2s, t_xg_x_next, t_xs_x)
+ cute.arch.cp_async_commit_group()
+
+ x = t_xr_x.load().to(Float32)
+ x_sq = x * x
+ local_sum = x_sq.reduce(cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0)
+ warp_sum = cute.arch.warp_reduction(local_sum, operator.add)
+ if lane_id == 0:
+ reduction_buffer[0, warp_id] = warp_sum
+ cute.arch.barrier()
+
+ block_val = Float32(0.0)
+ if lane_id < cfg.warps_per_row:
+ block_val = reduction_buffer[0, lane_id]
+ sum_sq = cute.arch.warp_reduction(block_val, operator.add)
+
+ mean_sq = sum_sq / cfg.n
+ rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True)
+
+ w = t_xr_w.load().to(Float32)
+ y = x * rstd * w
+
+ for elem_idx in cutlass.range_constexpr(cfg.ept):
+ reg[elem_idx] = y[elem_idx]
+
+ for block_idx in cutlass.range_constexpr(cfg.num_vec_blocks):
+ block_offset = block_idx * cfg.vec_size
+ for stage_idx in cutlass.range_constexpr(cfg.num_intra_stages):
+ for pair_idx in cutlass.range_constexpr(cfg.vec_size // 2):
+ i_idx = block_offset + cfg.intra_butterfly_pairs[stage_idx][pair_idx][0]
+ j_idx = block_offset + cfg.intra_butterfly_pairs[stage_idx][pair_idx][1]
+ a_val = reg[i_idx]
+ b_val = reg[j_idx]
+ reg[i_idx] = a_val + b_val
+ reg[j_idx] = a_val - b_val
+
+ for cross_stage in cutlass.range_constexpr(cfg.num_cross_stages):
+ xor_mask = cutlass.Int32(1 << cross_stage)
+ is_lower = (tid & xor_mask) == cutlass.Int32(0)
+ for elem_idx in cutlass.range_constexpr(cfg.ept):
+ partner = shuffle_sync_bfly(reg[elem_idx], offset=xor_mask)
+ if is_lower:
+ reg[elem_idx] = reg[elem_idx] + partner
+ else:
+ reg[elem_idx] = partner - reg[elem_idx]
+
+ for elem_idx in cutlass.range_constexpr(cfg.ept):
+ scaled = reg[elem_idx] * inv_sqrt_had
+ abs_val = fabs_f32(scaled)
+ running_max = fmax_f32(running_max, abs_val)
+ t_xr_x[elem_idx] = scaled.to(cutlass.BFloat16)
+
+ g_o_r = cute.local_tile(m_o, tiler_mn, (row_base + row_idx, 0))
+ t_xg_o_r = thr_store.partition_D(g_o_r)
+ cute.copy(copy_atom_store, t_xr_x, t_xg_o_r)
+
+ if row_idx < cfg.rows_per_cta - 1:
+ cute.arch.cp_async_wait_group(0)
+
+ warp_max = redux_sync_max_f32(running_max)
+ if lane_id == 0:
+ amax_buffer[0, warp_id] = warp_max
+ cute.arch.barrier()
+
+ amax_val = cutlass.Float32(0.0)
+ if lane_id < cfg.warps_per_row:
+ amax_val = amax_buffer[0, lane_id]
+ cta_max = redux_sync_max_f32(amax_val)
+ if tid == cutlass.Int32(0):
+ m_amax[bid] = cta_max
+
+ @cute.jit
+ def __call__(self, x_tensor: cute.Tensor, w_tensor: cute.Tensor, o_tensor: cute.Tensor, amax_tensor: cute.Tensor, eps: Float32, stream: cuda.CUstream):
+ m = x_tensor.shape[0]
+ num_ctas = m // self.rows_per_cta
+ tv_layout = cute.make_layout(self.tv_shape, stride=self.tv_stride)
+ self.kernel(x_tensor, w_tensor, o_tensor, amax_tensor, eps, tv_layout, self.tiler_mn).launch(
+ grid=(num_ctas, 1, 1),
+ block=(self.num_threads, 1, 1),
+ smem=self.smem_bytes,
+ stream=stream,
+ )
diff --git a/python/cudnn/sdpa/bwd/api.py b/python/cudnn/sdpa/bwd/api.py
index 125f79d7..bd99fe3e 100644
--- a/python/cudnn/sdpa/bwd/api.py
+++ b/python/cudnn/sdpa/bwd/api.py
@@ -76,7 +76,7 @@ def __init__(
super().__init__()
self._kernel = BlackwellFusedMultiHeadAttentionBackward
- self._logger.warning("SdpabwdSm100D256 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
if sample_cum_seqlen_q is not None:
diff --git a/python/cudnn/sdpa/fwd/api.py b/python/cudnn/sdpa/fwd/api.py
index 421d8745..90f4c76d 100644
--- a/python/cudnn/sdpa/fwd/api.py
+++ b/python/cudnn/sdpa/fwd/api.py
@@ -44,7 +44,7 @@ def __init__(
super().__init__()
self._kernel = BlackwellFusedMultiHeadAttentionForward
- self._logger.warning("SdpafwdSm100D256 is an experimental API")
+ self._warn_experimental_api()
self._logger.debug("Entering __init__")
if sample_cum_seqlen_q is not None:
diff --git a/python/properties.cpp b/python/properties.cpp
index e81dd480..c3f6c84f 100644
--- a/python/properties.cpp
+++ b/python/properties.cpp
@@ -148,6 +148,14 @@ init_properties(py::module_& m) {
.value("F16x16", cudnn_frontend::TensorReordering_t::F16x16)
.value("F8_128x4", cudnn_frontend::TensorReordering_t::F8_128x4);
+ py::enum_(m, "scalar_type")
+ .value("RUNTIME_PARAM", cudnn_frontend::graph::ScalarType::RUNTIME_PARAM)
+ .value("COMPILE_TIME_CONST", cudnn_frontend::graph::ScalarType::COMPILE_TIME_CONST);
+
+ py::enum_(m, "reshape_mode")
+ .value("VIEW_ONLY", cudnn_frontend::ReshapeMode_t::VIEW_ONLY)
+ .value("LOGICAL", cudnn_frontend::ReshapeMode_t::LOGICAL);
+
py::class_>(
m, "tensor")
.def(py::init<>())
@@ -171,6 +179,7 @@ init_properties(py::module_& m) {
py::return_value_policy::reference) // NOTICE THATS ITS JUST ANOTHER NAME FOR SET_IS_VIRTUAL
.def("get_is_pass_by_value", &cudnn_frontend::graph::Tensor_attributes::get_is_pass_by_value)
.def("set_is_pass_by_value", &cudnn_frontend::graph::Tensor_attributes::set_is_pass_by_value)
+ .def("get_has_compile_time_constant", &cudnn_frontend::graph::Tensor_attributes::get_has_compile_time_constant)
.def("get_uid", &cudnn_frontend::graph::Tensor_attributes::get_uid)
.def("set_uid", &cudnn_frontend::graph::Tensor_attributes::set_uid)
.def("get_reordering_type", &cudnn_frontend::graph::Tensor_attributes::get_reordering_type)
diff --git a/python/pycudnn.cpp b/python/pycudnn.cpp
index a1895607..9655f691 100644
--- a/python/pycudnn.cpp
+++ b/python/pycudnn.cpp
@@ -92,6 +92,70 @@ PYBIND11_MODULE(_compiled_module, m) {
m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn);
py::register_exception(m, "cudnnGraphNotSupportedError");
+
+#if CUDNN_VERSION >= 92200
+ m.def("causal_conv1d_forward",
+ [](std::intptr_t stream,
+ std::intptr_t x_ptr,
+ std::intptr_t weight_ptr,
+ std::intptr_t bias_ptr,
+ std::intptr_t out_ptr,
+ int batch,
+ int dim,
+ int seq_len,
+ int kernel_size,
+ int data_type,
+ int activation) {
+ auto status = detail::causal_conv1d_forward(reinterpret_cast(stream),
+ reinterpret_cast(x_ptr),
+ reinterpret_cast(weight_ptr),
+ reinterpret_cast(bias_ptr),
+ reinterpret_cast(out_ptr),
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ static_cast(data_type),
+ static_cast(activation));
+ if (status != 0)
+ throw std::runtime_error("cudnnCausalConv1dForward failed with status " + std::to_string(status));
+ });
+
+ m.def("causal_conv1d_backward",
+ [](std::intptr_t stream,
+ std::intptr_t x_ptr,
+ std::intptr_t weight_ptr,
+ std::intptr_t bias_ptr,
+ std::intptr_t dy_ptr,
+ std::intptr_t dx_ptr,
+ std::intptr_t dweight_ptr,
+ std::intptr_t dbias_ptr,
+ int batch,
+ int dim,
+ int seq_len,
+ int kernel_size,
+ int data_type,
+ int dw_data_type,
+ int activation) {
+ auto status = detail::causal_conv1d_backward(reinterpret_cast(stream),
+ reinterpret_cast(x_ptr),
+ reinterpret_cast(weight_ptr),
+ reinterpret_cast(bias_ptr),
+ reinterpret_cast(dy_ptr),
+ reinterpret_cast(dx_ptr),
+ reinterpret_cast(dweight_ptr),
+ reinterpret_cast(dbias_ptr),
+ batch,
+ dim,
+ seq_len,
+ kernel_size,
+ static_cast(data_type),
+ static_cast(dw_data_type),
+ static_cast(activation));
+ if (status != 0)
+ throw std::runtime_error("cudnnCausalConv1dBackward failed with status " + std::to_string(status));
+ });
+#endif
}
} // namespace python_bindings
diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp
index 772f64af..721c9448 100644
--- a/python/pygraph/pygraph.cpp
+++ b/python/pygraph/pygraph.cpp
@@ -1,3 +1,4 @@
+#include
#include
#include
#include
@@ -185,16 +186,21 @@ PyGraph::slice(std::shared_ptr& input,
auto input_dim = input->get_dim();
std::vector> start_end_indices;
+ std::vector steps;
+ steps.reserve(slices.size());
for (size_t i = 0; i < slices.size(); ++i) {
int64_t start, stop, step, length;
if (!slices[i].compute(input_dim[i], &start, &stop, &step, &length)) {
throw std::runtime_error("Invalid slice");
}
+ CUDNN_FRONTEND_UNUSED(length);
start_end_indices.push_back({start, stop});
+ steps.push_back(step);
}
auto attributes = cudnn_frontend::graph::Slice_attributes()
.set_slices(start_end_indices)
+ .set_strides(steps)
.set_compute_data_type(compute_data_type)
.set_name(name);
@@ -350,13 +356,60 @@ PyGraph::reduction(std::shared_ptr& in
}
std::shared_ptr
-PyGraph::reshape(std::shared_ptr& input, std::string const& name) {
- auto attributes = cudnn_frontend::graph::Reshape_attributes().set_name(name);
+PyGraph::reshape(std::shared_ptr& input,
+ std::string const& name,
+ cudnn_frontend::ReshapeMode_t reshape_mode) {
+ auto attributes = cudnn_frontend::graph::Reshape_attributes().set_name(name).set_reshape_mode(reshape_mode);
auto OUT_0 = graph->reshape(input, attributes);
return OUT_0;
}
+std::shared_ptr
+PyGraph::transpose(std::shared_ptr& input,
+ std::vector const& permutation,
+ cudnn_frontend::DataType_t const& compute_data_type,
+ std::string const& name) {
+ auto attributes = cudnn_frontend::graph::Transpose_attributes()
+ .set_name(name)
+ .set_permutation(permutation)
+ .set_compute_data_type(compute_data_type);
+
+ return graph->transpose(input, attributes);
+}
+
+std::shared_ptr
+PyGraph::concatenate(std::vector> inputs,
+ int64_t axis,
+ std::optional in_place_index,
+ std::string const& name) {
+ auto attributes = cudnn_frontend::graph::Concatenate_attributes().set_axis(axis).set_name(name);
+ if (in_place_index.has_value()) {
+ attributes.set_in_place_index(in_place_index.value());
+ }
+ return graph->concatenate(std::move(inputs), attributes);
+}
+
+std::shared_ptr
+PyGraph::tensor_scalar(float const& value, cudnn_frontend::graph::ScalarType scalar_type) {
+ return graph->tensor(value, scalar_type);
+}
+
+std::shared_ptr
+PyGraph::tensor_scalar(double const& value, cudnn_frontend::graph::ScalarType scalar_type) {
+ return graph->tensor(value, scalar_type);
+}
+
+std::shared_ptr
+PyGraph::tensor_scalar(int32_t const& value, cudnn_frontend::graph::ScalarType scalar_type) {
+ return graph->tensor(value, scalar_type);
+}
+
+std::shared_ptr
+PyGraph::tensor_scalar(int64_t const& value, cudnn_frontend::graph::ScalarType scalar_type) {
+ return graph->tensor(value, scalar_type);
+}
+
std::shared_ptr
PyGraph::moe_grouped_matmul(std::shared_ptr& token,
std::shared_ptr& weight,
@@ -772,6 +825,8 @@ init_pygraph_submodule(py::module_& m) {
Args:
input (cudnn_tensor): The input tensor to be sliced.
slices (List[slice]): A list of Python slice objects, one for each dimension.
+ Per-axis step comes from each slice's ``step`` (default 1), after normalization
+ for the tensor shape (same semantics as indexing a sequence of that length).
compute_data_type (Optional[cudnn.data_type]): The data type for computation.
Default is NOT_SET.
name (Optional[str]): A name for the slice operation.
@@ -781,7 +836,7 @@ init_pygraph_submodule(py::module_& m) {
Example:
>>> input_tensor = graph.tensor([4, 8, 16])
- >>> sliced_tensor = graph.slice(input_tensor, [slice(0, 2), slice(1, 5), slice(0, 16)])
+ >>> sliced_tensor = graph.slice(input_tensor, [slice(0, 2), slice(1, 8, 2), slice(0, 16)])
)pbdoc")
.def(
"conv_fprop",
@@ -1014,6 +1069,7 @@ init_pygraph_submodule(py::module_& m) {
&PyGraph::reshape,
py::arg("input"),
py::arg_v("name", ""),
+ py::arg_v("reshape_mode", cudnn_frontend::ReshapeMode_t::VIEW_ONLY),
R"pbdoc(
Reshape an input tensor to other dimensions without changing the actual memory layout.
These dimensions to reshape to are inferred from output tensor shape.
@@ -1021,10 +1077,73 @@ init_pygraph_submodule(py::module_& m) {
Args:
input (cudnn_tensor): The input tensor.
name (Optional[str]): A name for the operation to be performed.
+ reshape_mode (cudnn.reshape_mode): VIEW_ONLY (default) or LOGICAL for lexicographic logical reshapes.
Returns:
cudnn_tensor: The result of reshape operation. Please set the dims for the output tensor.
)pbdoc")
+ .def("transpose",
+ &PyGraph::transpose,
+ py::arg("input"),
+ py::arg("permutation"),
+ py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET),
+ py::arg_v("name", ""),
+ R"pbdoc(
+ Permute tensor dimensions using a permutation vector (output axis i reads input axis permutation[i]).
+
+ Args:
+ input (cudnn_tensor): The input tensor.
+ permutation (List[int]): Permutation of axis indices.
+ compute_data_type (Optional[cudnn.data_type]): Optional compute type; default NOT_SET.
+ name (Optional[str]): Operation name.
+
+ Returns:
+ cudnn_tensor: Transposed tensor (dims/strides inferred when not set).
+ )pbdoc")
+ .def("concatenate",
+ &PyGraph::concatenate,
+ py::arg("inputs"),
+ py::arg("axis"),
+ py::arg_v("in_place_index", std::optional{}),
+ py::arg_v("name", ""),
+ R"pbdoc(
+ Concatenate tensors along an axis.
+
+ Args:
+ inputs (List[cudnn_tensor]): Tensors to concatenate.
+ axis (int): Concatenation axis.
+ in_place_index (Optional[int]): When set, optional in-place concat per backend semantics.
+ name (Optional[str]): Operation name.
+
+ Returns:
+ cudnn_tensor: Concatenated output tensor.
+ )pbdoc")
+ .def("tensor_scalar",
+ py::overload_cast(&PyGraph::tensor_scalar),
+ py::arg("value"),
+ py::arg("scalar_type"),
+ R"pbdoc(
+ Create a rank-1 scalar tensor from a Python float, marked runtime or compile-time.
+
+ Args:
+ value (float): Scalar value.
+ scalar_type (cudnn.scalar_type): RUNTIME_PARAM or COMPILE_TIME_CONST.
+
+ Returns:
+ cudnn_tensor: Scalar tensor (set dim/stride/name as needed for your graph).
+ )pbdoc")
+ .def("tensor_scalar",
+ py::overload_cast(&PyGraph::tensor_scalar),
+ py::arg("value"),
+ py::arg("scalar_type"))
+ .def("tensor_scalar",
+ py::overload_cast(&PyGraph::tensor_scalar),
+ py::arg("value"),
+ py::arg("scalar_type"))
+ .def("tensor_scalar",
+ py::overload_cast(&PyGraph::tensor_scalar),
+ py::arg("value"),
+ py::arg("scalar_type"))
.def("moe_grouped_matmul",
&PyGraph::moe_grouped_matmul,
py::arg("token"),
diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h
index fd617b62..ae25927f 100644
--- a/python/pygraph/pygraph.h
+++ b/python/pygraph/pygraph.h
@@ -1,3 +1,4 @@
+#include
#include
#include
#include
@@ -335,7 +336,33 @@ class PyGraph {
std::string const& name);
std::shared_ptr
- reshape(std::shared_ptr& input, std::string const& name);
+ reshape(std::shared_ptr& input,
+ std::string const& name,
+ cudnn_frontend::ReshapeMode_t reshape_mode = cudnn_frontend::ReshapeMode_t::VIEW_ONLY);
+
+ std::shared_ptr
+ transpose(std::shared_ptr& input,
+ std::vector const& permutation,
+ cudnn_frontend::DataType_t const& compute_data_type,
+ std::string const& name);
+
+ std::shared_ptr
+ concatenate(std::vector> inputs,
+ int64_t axis,
+ std::optional