diff --git a/CMakeLists.txt b/CMakeLists.txt index 3eafa47d..ff0438a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.23) -project(cudnn_frontend VERSION 1.17.0) +project(cudnn_frontend VERSION 1.18.0) option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) diff --git a/benchmark/Llama-3.2-1B-Training/Dockerfile b/benchmark/Llama-3.2-1B-Training/Dockerfile deleted file mode 100644 index 12e9fb5d..00000000 --- a/benchmark/Llama-3.2-1B-Training/Dockerfile +++ /dev/null @@ -1,6 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:25.04-py3 -RUN pip install --upgrade pip && \ - pip install seaborn transformers -COPY training_perf.py / -WORKDIR /workspace -CMD ["python", "/training_perf.py"] diff --git a/benchmark/Llama-3.2-1B-Training/README.md b/benchmark/Llama-3.2-1B-Training/README.md deleted file mode 100644 index 75d36e4b..00000000 --- a/benchmark/Llama-3.2-1B-Training/README.md +++ /dev/null @@ -1,152 +0,0 @@ -# Introduction - -Llama-3.2-1B is a 1.3 billion parameter language model based on the Llama architecture. It is a smaller variant of the Llama model family, designed to provide good performance while being more efficient and requiring less computational resources than larger models. In particular, this model is one of the few in the Llama family that can comfortably run on a single consumer-grade GPU. The model was trained on a large corpus of text data using the standard Llama training methodology. - -The easiest way to get this model is through the Hugging Face transformers library. It is a PyTorch implementation using the standard `torch.nn.Module` interface. Specifically, the attention in the model is implemented using `torch.nn.functional.scaled_dot_product_attention`. PyTorch provided multiple backends for this attention, swappable at runtime by setting the context manager `sdpa_kernel()` with backend set to `SDPBackend.CUDNN_ATTENTION`, `SDPBackend.EFFICIENT_ATTENTION`, or `SDPBackend.FLASH_ATTENTION` respectively. - -This benchmark focuses on training. The model architecture is from the transformers library but the model is randomly initialized without using any pre-trained weights. The benchmark script simulates a training loop by running multiple iterations of forward and backward passes to the model. The goal is to measure the per-iteration time of training under different batch sizes, sequence lengths, and SDPA backends. - -## Software versions - -This benchmark code should run on any decently modern Python environment with CUDA-enabled GPU. In particular, it has been tested using the PyTorch docker image [from the NGC catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), `nvcr.io/nvidia/pytorch:25.04-py3`. Specifically, the following are included in the image: - -| Software | Version | -|----------|---------| -| Python | 3.12.9 | -| CUDA | 12.9 | -| cuDNN | 9.9.0 | -| PyTorch | 2.7.0 | - -## Steps to reproduce - -This benchmark is conducted by running the `training_perf.py` script and uses the first CUDA-enabled GPU in your system. This Python script depends on several other Python packages, including: - -- `torch` -- `transformers` -- `pandas` -- `seaborn` -- `matplotlib` - -The script uses the Hugging Face transformers library to create the Llama model. You also need to configure the environment for the Hugging Face library to run, including the `HF_TOKEN` and `HF_HOME` environment variables if necessary. - -Note that the Llama model is a gated model in the Hugging Face Hub. Your token must have access to this model. Please visit to request access. - -For more accurate results, it is recommended to lock the clock frequency of the GPU using the following command: - -```bash -nvidia-smi -i 0 -lgc , -``` - -which the min and max clock frequency are in MHz and should be the same value. You need to check the supported clock frequency of your specific GPU. - -Once the environment is configured, you can run the script by executing the following command: - -```bash -python training_perf.py -``` - -This will measure the mean time spent per training loop for 50 iterations under different batch sizes, sequence lengths, and SDPA backends. The script will output the timing data to a CSV file called `training_timing.csv`, and plot the results in `iteration_time.png` and `speedup.png`. - -For your convenience, a `Dockerfile` is also provided that bootstraps all of the above. You can build the image by executing the following command: - -```bash -docker build -t benchmark-llama-3.2-1b . -``` - -Then you can run the benchmark by executing the following command, which set the necessary environment variables and mount the current directory to the working directory of the container: - -```bash -nvidia-smi -i 0 -lgc CLOCK_MHZ,CLOCK_MHZ # optional, clock must be locked outside the container -docker run --gpus all --rm -e HF_TOKEN="hf_XXXXXXX" -e HF_HOME="/tmp/huggingface" -v $(pwd):/workspace benchmark-llama-3.2-1b -``` - -where `hf_XXXXXXX` is your Hugging Face token, and the clock frequency should match your GPU. Note the argument `-v :/workspace` in the docker command is to mount a local directory to the `/workspace` directory in the container. This is how you can have the script output stored in your host machine. - -Running this benchmark will produce the following artifacts: - -- `training_timing.csv`: the timing results -- `iteration_time.png`: the iteration time plot -- `speedup.png`: the speed-up plot - -An example is show in the next section. - -## Results - -Below are the result of the benchmark running on a single B200 GPU or H200 GPU. - -For both runs, the following software versions are used: - -- CUDA: 12.9 (from NGC container) -- cuDNN: 9.9.0 (from NGC container) -- PyTorch: 2.7.0 (from NGC container) -- transformers: 4.35.0 - -### B200 - -The hardware configuration is: - -- GPU: NVIDIA B200 -- CPU: INTEL(R) XEON(R) PLATINUM 8570 -- RAM: 2TB - -The benchmark is conducted after the clock is locked at 1665 MHz, namely, the base clock of B200. - -cuDNN attention is the fastest SDPA backend compared to flash attention and efficient attention, in the six setups we tested. - -The left figure compares cuDNN and flash attention backends against efficient attention, which is the slowest SDPA backend among these three. The speed-up is the per-iteration speed-up, not the speed-up of the attention operation itself. The right figure displays per-iteration times (as recorded in `iteration_time.png`), where lower values indicate better performance. - -| ![Speed-up](artifacts/b200_speedup.png) | ![Iteration time](artifacts/b200_iteration_time.png) | -|--------------------------|--------------------------| - -Note that the exact implementation of the various backends are specific to PyTorch. We notice other implementation of Flash Attention from the [flash-attn](https://github.com/Dao-AILab/flash-attention) Python library, for example. But this comparison is limited to the off-the-shelf options from PyTorch. - -For completeness, the exact timing results are shown below, which the script will output to [`training_timing.csv`](artifacts/b200_training_timing.csv): - -| (sec per iter) | CUDNN_ATTENTION | EFFICIENT_ATTENTION | FLASH_ATTENTION | -|--|-----------------|---------------------|-----------------| -| BS=1 SL=16384 | 0.285 | 0.716 | 0.374 | -| BS=2 SL=8192 | 0.231 | 0.444 | 0.276 | -| BS=3 SL=4096 | 0.155 | 0.236 | 0.172 | -| BS=6 SL=2048 | 0.144 | 0.184 | 0.154 | -| BS=12 SL=1024 | 0.139 | 0.160 | 0.145 | -| BS=24 SL=768 | 0.200 | 0.222 | 0.208 | - -### H200 - -The hardware configuration is: - -- GPU: NVIDIA H200 -- CPU: Intel(R) Xeon(R) Silver 4314 CPU @ 2.40GHz -- RAM: 256GB - -Below is another run performed on an H200 GPU, with the clock frequency locked at 1500 MHz to match the H200’s base clock. -The figure on the left shows the per-iteration speed-up, while the figure on the right displays the per-iteration timing. - -| ![Speed-up](artifacts/h200_speedup.png) | ![Iteration time](artifacts/h200_iteration_time.png) | -|--------------------------|--------------------------| - -and the exact timing results are shown below, the raw data is stored [here](artifacts/h200_training_timing.csv): - -| (sec per iter) | CUDNN_ATTENTION | EFFICIENT_ATTENTION | FLASH_ATTENTION | -|--|-----------------|---------------------|-----------------| -| BS=1 SL=16384 | 0.424 | 0.850 | 0.500 | -| BS=2 SL=8192 | 0.357 | 0.566 | 0.397 | -| BS=3 SL=4096 | 0.243 | 0.322 | 0.260 | -| BS=6 SL=2048 | 0.229 | 0.270 | 0.240 | -| BS=12 SL=1024 | 0.222 | 0.243 | 0.228 | -| BS=24 SL=768 | 0.329 | 0.352 | 0.336 | - -### How much faster is Blackwell vs Hopper? - -Comparing the timing data we collected above between B200 and H200, this is what we get: - -| ![cuDNN time](artifacts/cudnn_attention_time.png) | ![Flash attention time](artifacts/flash_attention_time.png) | -|--------------------------|--------------------------| - -Keep it concise, we do not plot the efficient attention time here. While B200 is always faster than the previous generation H200, the actual difference depends on the backend. It is easier to compare if we plot the speed-up of B200 vs H200, as shown below: - - - -It is easy to see that B200 can achieve around 1.2x to 1.6x speed-up over H200 for the entire training loop, depends on the batch size, sequence length, and SDPA backend. The speed-up is most significant when cuDNN attention backend is used because of the optimization implemented for Blackwell architecture. The other backends show some speed-up but not as much as cuDNN attention. - -As a final note, the timing measured here involves the entire 16 layers of the Llama model and each iteration includes a forward pass and a backward pass. Attention operation, although important for the model, only accounts for a portion of the total operations. Should the other part of the model be optimized (such as the RMS norm used in the Llama model), the speed-up will be even more significant. \ No newline at end of file diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/b200_h200_speedup.png b/benchmark/Llama-3.2-1B-Training/artifacts/b200_h200_speedup.png deleted file mode 100644 index 1acbd454..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/b200_h200_speedup.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/b200_iteration_time.png b/benchmark/Llama-3.2-1B-Training/artifacts/b200_iteration_time.png deleted file mode 100644 index 396209f5..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/b200_iteration_time.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/b200_run_plot.b200.txt b/benchmark/Llama-3.2-1B-Training/artifacts/b200_run_plot.b200.txt deleted file mode 100644 index 47a7a1fb..00000000 --- a/benchmark/Llama-3.2-1B-Training/artifacts/b200_run_plot.b200.txt +++ /dev/null @@ -1,1960 +0,0 @@ -+ hostname -umbriel-b200-035 -+ nvidia-smi -q - -==============NVSMI LOG============== - -Timestamp : Fri May 9 20:19:46 2025 -Driver Version : 570.124.06 -CUDA Version : 12.9 - -Attached GPUs : 8 -GPU 00000000:1B:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650325040841 - GPU UUID : GPU-6702348a-a951-10cb-9ce8-50ded871996d - Minor Number : 0 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0x1b00 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 4 - GPU Fabric GUID : 0x2329e979a57f2519 - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:19.523 - Latest Duration : 54975 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0x1B - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:1B:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 14809 KB/s - Rx Throughput : 806 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 36 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 145.00 W - Instantaneous Power Draw : 144.47 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 20.82 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:43:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650225107604 - GPU UUID : GPU-0fa39766-ffdf-6d00-2b80-5c16746b6fc1 - Minor Number : 1 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0x4300 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 1 - GPU Fabric GUID : 0x4460412f30ae288b - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:57:08.601 - Latest Duration : 51255 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0x43 - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:43:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 958 KB/s - Rx Throughput : 816 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 36 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 143.17 W - Instantaneous Power Draw : 143.27 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 19.21 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:52:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650225107867 - GPU UUID : GPU-df23009f-b21e-6a8f-4bcc-17b2bf64c84e - Minor Number : 2 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0x5200 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 3 - GPU Fabric GUID : 0x87716ad1224cdb9d - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:58:15.025 - Latest Duration : 51783 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0x52 - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:52:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 14766 KB/s - Rx Throughput : 853 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 38 C - GPU T.Limit Temp : 50 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 37 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 141.07 W - Instantaneous Power Draw : 140.81 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 17.16 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:61:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650225106441 - GPU UUID : GPU-b6372752-aa9d-b964-f614-a79eb75b0352 - Minor Number : 3 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0x6100 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 2 - GPU Fabric GUID : 0xcba981891c2f40ab - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:38.913 - Latest Duration : 50844 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0x61 - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:61:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 15339 KB/s - Rx Throughput : 851 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 37 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 142.13 W - Instantaneous Power Draw : 142.11 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 21.22 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:9D:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650325041547 - GPU UUID : GPU-27e9f59b-f968-8f87-dac2-95ef5d829a26 - Minor Number : 4 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0x9d00 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 8 - GPU Fabric GUID : 0x898e9f309cf4edc - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:36.763 - Latest Duration : 47584 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0x9D - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:9D:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 926 KB/s - Rx Throughput : 807 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 36 C - GPU T.Limit Temp : 52 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 35 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 148.58 W - Instantaneous Power Draw : 149.07 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 21.38 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:C3:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650225108129 - GPU UUID : GPU-520cf485-ff7c-d45e-99dc-47467c40e22b - Minor Number : 5 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0xc300 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 5 - GPU Fabric GUID : 0xad92835ec8b851c6 - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:33.156 - Latest Duration : 59063 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0xC3 - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:C3:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 896 KB/s - Rx Throughput : 872 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 37 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 147.45 W - Instantaneous Power Draw : 148.37 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 22.17 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:D1:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650325040834 - GPU UUID : GPU-f7648327-f6f3-a75e-8af9-93e5612f8570 - Minor Number : 6 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0xd100 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 7 - GPU Fabric GUID : 0x93bed9161da5c2a0 - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:35.807 - Latest Duration : 57057 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0xD1 - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:D1:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 916 KB/s - Rx Throughput : 789 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 36 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 143.87 W - Instantaneous Power Draw : 144.77 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 22.68 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -GPU 00000000:DF:00.0 - Product Name : NVIDIA B200 - Product Brand : NVIDIA - Product Architecture : Blackwell - Display Mode : Disabled - Display Active : Disabled - Persistence Mode : Enabled - Addressing Mode : HMM - MIG Mode - Current : Disabled - Pending : Disabled - Accounting Mode : Disabled - Accounting Mode Buffer Size : 4000 - Driver Model - Current : N/A - Pending : N/A - Serial Number : 1650225108107 - GPU UUID : GPU-29c2cbe8-0703-a473-1f02-007cbfa6ae7b - Minor Number : 7 - VBIOS Version : 97.00.9A.00.0F - MultiGPU Board : No - Board ID : 0xdf00 - Board Part Number : 692-2G525-0220-000 - GPU Part Number : 2901-886-A1 - FRU Part Number : N/A - Platform Info - Chassis Serial Number : - Slot Number : N/A - Tray Index : N/A - Host ID : 1 - Peer Type : Switch Connected - Module Id : 6 - GPU Fabric GUID : 0xc677503be1032394 - Inforom Version - Image Version : G525.0220.00.03 - OEM Object : 2.1 - ECC Object : 7.16 - Power Management Object : N/A - Inforom BBX Object Flush - Latest Timestamp : 2025/05/09 16:56:38.281 - Latest Duration : 24160 us - GPU Operation Mode - Current : N/A - Pending : N/A - GPU C2C Mode : Disabled - GPU Virtualization Mode - Virtualization Mode : None - Host VGPU Mode : N/A - vGPU Heterogeneous Mode : N/A - GPU Reset Status - Reset Required : Requested functionality has been deprecated - Drain and Reset Recommended : Requested functionality has been deprecated - GPU Recovery Action : None - GSP Firmware Version : 570.124.06 - IBMNPU - Relaxed Ordering Mode : N/A - PCI - Bus : 0xDF - Device : 0x00 - Domain : 0x0000 - Base Classcode : 0x3 - Sub Classcode : 0x2 - Device Id : 0x290110DE - Bus Id : 00000000:DF:00.0 - Sub System Id : 0x199910DE - GPU Link Info - PCIe Generation - Max : 5 - Current : 5 - Device Current : 5 - Device Max : 5 - Host Max : 5 - Link Width - Max : 16x - Current : 16x - Bridge Chip - Type : N/A - Firmware : N/A - Replays Since Reset : 0 - Replay Number Rollovers : 0 - Tx Throughput : 14738 KB/s - Rx Throughput : 851 KB/s - Atomic Caps Outbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Atomic Caps Inbound : FETCHADD_32 FETCHADD_64 SWAP_32 SWAP_64 CAS_32 CAS_64 - Fan Speed : N/A - Performance State : P0 - Clocks Event Reasons - Idle : Active - Applications Clocks Setting : Not Active - SW Power Cap : Not Active - HW Slowdown : Not Active - HW Thermal Slowdown : Not Active - HW Power Brake Slowdown : Not Active - Sync Boost : Not Active - SW Thermal Slowdown : Not Active - Display Clock Setting : Not Active - Sparse Operation Mode : N/A - FB Memory Usage - Total : 183359 MiB - Reserved : 717 MiB - Used : 1 MiB - Free : 182643 MiB - BAR1 Memory Usage - Total : 262144 MiB - Used : 1 MiB - Free : 262143 MiB - Conf Compute Protected Memory Usage - Total : 0 MiB - Used : 0 MiB - Free : 0 MiB - Compute Mode : Default - Utilization - GPU : 0 % - Memory : 0 % - Encoder : 0 % - Decoder : 0 % - JPEG : 0 % - OFA : 0 % - Encoder Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - FBC Stats - Active Sessions : 0 - Average FPS : 0 - Average Latency : 0 - DRAM Encryption Mode - Current : N/A - Pending : N/A - ECC Mode - Current : Enabled - Pending : Enabled - ECC Errors - Volatile - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - Aggregate - SRAM Correctable : 0 - SRAM Uncorrectable Parity : 0 - SRAM Uncorrectable SEC-DED : 0 - DRAM Correctable : 0 - DRAM Uncorrectable : 0 - SRAM Threshold Exceeded : No - Aggregate Uncorrectable SRAM Sources - SRAM L2 : 0 - SRAM SM : 0 - SRAM Microcontroller : 0 - SRAM PCIE : 0 - SRAM Other : 0 - Retired Pages - Single Bit ECC : N/A - Double Bit ECC : N/A - Pending Page Blacklist : N/A - Remapped Rows - Correctable Error : 0 - Uncorrectable Error : 0 - Pending : No - Remapping Failure Occurred : No - Bank Remap Availability Histogram - Max : 3840 bank(s) - High : 0 bank(s) - Partial : 0 bank(s) - Low : 0 bank(s) - None : 0 bank(s) - Temperature - GPU Current Temp : 37 C - GPU T.Limit Temp : 51 C - GPU Shutdown T.Limit Temp : -5 C - GPU Slowdown T.Limit Temp : -3 C - GPU Max Operating T.Limit Temp : 0 C - GPU Target Temperature : N/A - Memory Current Temp : 37 C - Memory Max Operating T.Limit Temp : 0 C - GPU Power Readings - Average Power Draw : 142.09 W - Instantaneous Power Draw : 142.02 W - Current Power Limit : 1000.00 W - Requested Power Limit : 1000.00 W - Default Power Limit : 1000.00 W - Min Power Limit : 200.00 W - Max Power Limit : 1000.00 W - GPU Memory Power Readings - Average Power Draw : 22.49 W - Instantaneous Power Draw : N/A - Module Power Readings - Average Power Draw : N/A - Instantaneous Power Draw : N/A - Current Power Limit : N/A - Requested Power Limit : N/A - Default Power Limit : N/A - Min Power Limit : N/A - Max Power Limit : N/A - Power Smoothing : Insufficient Permissions - Workload Power Profiles - Requested Profiles : N/A - Enforced Profiles : N/A - Clocks - Graphics : 120 MHz - SM : 120 MHz - Memory : 3996 MHz - Video : 600 MHz - Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Default Applications Clocks - Graphics : 1965 MHz - Memory : 3996 MHz - Deferred Clocks - Memory : N/A - Max Clocks - Graphics : 1965 MHz - SM : 1965 MHz - Memory : 3996 MHz - Video : 1965 MHz - Max Customer Boost Clocks - Graphics : 1965 MHz - Clock Policy - Auto Boost : N/A - Auto Boost Default : N/A - Voltage - Graphics : N/A - Fabric - State : Completed - Status : Success - CliqueId : 0 - ClusterUUID : 00000000-0000-0000-0000-000000000000 - Health - Bandwidth : N/A - Route Recovery in progress : N/A - Route Unhealthy : N/A - Access Timeout Recovery : False - Processes : None - Capabilities - EGM : disabled - -+ sudo nvidia-smi -i 0 -pm 1 -Persistence mode is already Enabled for GPU 00000000:1B:00.0. -All done. -+ sudo nvidia-smi -i 0 -lgc 1665,1665 -The current user does not have permission to change clocks for GPU 00000000:1B:00.0. -Terminating early due to previous errors. -+ echo 'You should run '\''nvidia-smi dmon -i 0'\'' on a terminal to ensure device 0 is running in pclk=1365MHz' -You should run 'nvidia-smi dmon -i 0' on a terminal to ensure device 0 is running in pclk=1365MHz -+ sleep 5 -+ export HF_TOKEN_PATH=/home/adriant/.ssh/huggingface_token -+ HF_TOKEN_PATH=/home/adriant/.ssh/huggingface_token -+ export HF_HOME=/tmp/huggingface -+ HF_HOME=/tmp/huggingface -+ export CUDA_VISIBLE_DEVICES=0 -+ CUDA_VISIBLE_DEVICES=0 -+ python -u training_perf.py -torch.__version__ = '2.7.0a0+79aa17489c.nv25.04' -torch.version.cuda = '12.9' -torch.cuda.is_available() = True -torch.cuda.device_count() = 1 -torch.cuda.current_device() = 0 -torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200' -torch.backends.cudnn.version() = 90900 -torch.backends.cudnn.enabled = True -Timing CUDNN_ATTENTION with batch_size=24 and seq_len=768 -Timing EFFICIENT_ATTENTION with batch_size=24 and seq_len=768 -Timing FLASH_ATTENTION with batch_size=24 and seq_len=768 -Timing CUDNN_ATTENTION with batch_size=12 and seq_len=1024 -Timing EFFICIENT_ATTENTION with batch_size=12 and seq_len=1024 -Timing FLASH_ATTENTION with batch_size=12 and seq_len=1024 -Timing CUDNN_ATTENTION with batch_size=6 and seq_len=2048 -Timing EFFICIENT_ATTENTION with batch_size=6 and seq_len=2048 -Timing FLASH_ATTENTION with batch_size=6 and seq_len=2048 -Timing CUDNN_ATTENTION with batch_size=3 and seq_len=4096 -Timing EFFICIENT_ATTENTION with batch_size=3 and seq_len=4096 -Timing FLASH_ATTENTION with batch_size=3 and seq_len=4096 -Timing CUDNN_ATTENTION with batch_size=2 and seq_len=8192 -Timing EFFICIENT_ATTENTION with batch_size=2 and seq_len=8192 -Timing FLASH_ATTENTION with batch_size=2 and seq_len=8192 -Timing CUDNN_ATTENTION with batch_size=1 and seq_len=16384 -Timing EFFICIENT_ATTENTION with batch_size=1 and seq_len=16384 -Timing FLASH_ATTENTION with batch_size=1 and seq_len=16384 -+ sudo nvidia-smi -i 1 -rgc -The current user does not have permission to change clocks for GPU 00000000:43:00.0. -Terminating early due to previous errors. diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/b200_speedup.png b/benchmark/Llama-3.2-1B-Training/artifacts/b200_speedup.png deleted file mode 100644 index c7314af8..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/b200_speedup.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/b200_training_timing.csv b/benchmark/Llama-3.2-1B-Training/artifacts/b200_training_timing.csv deleted file mode 100644 index d740973c..00000000 --- a/benchmark/Llama-3.2-1B-Training/artifacts/b200_training_timing.csv +++ /dev/null @@ -1,19 +0,0 @@ -backend,batch_size,seq_len,time,label,speedup_label,speedup -CUDNN_ATTENTION,24,768,0.1999587368965149,BS=24 SL=768,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.1114361360693807 -EFFICIENT_ATTENTION,24,768,0.2222413659095764,BS=24 SL=768,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,24,768,0.2080540633201599,BS=24 SL=768,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.0681904614743556 -CUDNN_ATTENTION,12,1024,0.13925427675247193,BS=12 SL=1024,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.1457744870575388 -EFFICIENT_ATTENTION,12,1024,0.15955399751663207,BS=12 SL=1024,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,12,1024,0.14496242761611938,BS=12 SL=1024,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.1006575989411076 -CUDNN_ATTENTION,6,2048,0.1443575620651245,BS=6 SL=2048,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.2769321339533377 -EFFICIENT_ATTENTION,6,2048,0.18433480978012085,BS=6 SL=2048,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,6,2048,0.1538383936882019,BS=6 SL=2048,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.198236703860343 -CUDNN_ATTENTION,3,4096,0.15531346797943116,BS=3 SL=4096,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.5222432152923429 -EFFICIENT_ATTENTION,3,4096,0.23642487287521363,BS=3 SL=4096,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,3,4096,0.17227792978286743,BS=3 SL=4096,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.3723456810352583 -CUDNN_ATTENTION,2,8192,0.2312054991722107,BS=2 SL=8192,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.9215099745433377 -EFFICIENT_ATTENTION,2,8192,0.4442636728286743,BS=2 SL=8192,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,2,8192,0.2755130219459534,BS=2 SL=8192,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.6124960979732719 -CUDNN_ATTENTION,1,16384,0.2849757242202759,BS=1 SL=16384,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,2.513904291901673 -EFFICIENT_ATTENTION,1,16384,0.7164016962051392,BS=1 SL=16384,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,1,16384,0.374487898349762,BS=1 SL=16384,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.9130169475758028 diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/cudnn_attention_time.png b/benchmark/Llama-3.2-1B-Training/artifacts/cudnn_attention_time.png deleted file mode 100644 index 6110664c..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/cudnn_attention_time.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/efficient_attention_time.png b/benchmark/Llama-3.2-1B-Training/artifacts/efficient_attention_time.png deleted file mode 100644 index 5f9fc495..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/efficient_attention_time.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/flash_attention_time.png b/benchmark/Llama-3.2-1B-Training/artifacts/flash_attention_time.png deleted file mode 100644 index 4d073c87..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/flash_attention_time.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/h200_iteration_time.png b/benchmark/Llama-3.2-1B-Training/artifacts/h200_iteration_time.png deleted file mode 100644 index 5156f008..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/h200_iteration_time.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/h200_speedup.png b/benchmark/Llama-3.2-1B-Training/artifacts/h200_speedup.png deleted file mode 100644 index f7e2be94..00000000 Binary files a/benchmark/Llama-3.2-1B-Training/artifacts/h200_speedup.png and /dev/null differ diff --git a/benchmark/Llama-3.2-1B-Training/artifacts/h200_training_timing.csv b/benchmark/Llama-3.2-1B-Training/artifacts/h200_training_timing.csv deleted file mode 100644 index a75e809c..00000000 --- a/benchmark/Llama-3.2-1B-Training/artifacts/h200_training_timing.csv +++ /dev/null @@ -1,19 +0,0 @@ -backend,batch_size,seq_len,time,label,speedup_label,speedup -CUDNN_ATTENTION,24,768,0.32912482261657716,BS=24 SL=768,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.0699159337629112 -EFFICIENT_ATTENTION,24,768,0.35213589191436767,BS=24 SL=768,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,24,768,0.33675609350204466,BS=24 SL=768,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.045670438364999 -CUDNN_ATTENTION,12,1024,0.222159903049469,BS=12 SL=1024,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.0930867199559324 -EFFICIENT_ATTENTION,12,1024,0.242840039730072,BS=12 SL=1024,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,12,1024,0.22822560071945192,BS=12 SL=1024,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.0640350555088909 -CUDNN_ATTENTION,6,2048,0.22936189889907838,BS=6 SL=2048,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.1772111978508244 -EFFICIENT_ATTENTION,6,2048,0.27000739574432375,BS=6 SL=2048,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,6,2048,0.23996224880218506,BS=6 SL=2048,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.1252078070284575 -CUDNN_ATTENTION,3,4096,0.24294065952301025,BS=3 SL=4096,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.3253087553435736 -EFFICIENT_ATTENTION,3,4096,0.3219713830947876,BS=3 SL=4096,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,3,4096,0.2598997974395752,BS=3 SL=4096,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.2388289112447024 -CUDNN_ATTENTION,2,8192,0.35719603300094604,BS=2 SL=8192,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,1.5836347656461844 -EFFICIENT_ATTENTION,2,8192,0.5656680560112,BS=2 SL=8192,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,2,8192,0.39713253974914553,BS=2 SL=8192,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.4243810299919324 -CUDNN_ATTENTION,1,16384,0.4235471987724304,BS=1 SL=16384,CUDNN_ATTENTION vs EFFICIENT_ATTENTION,2.007099764399456 -EFFICIENT_ATTENTION,1,16384,0.8501014828681945,BS=1 SL=16384,EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION,1.0 -FLASH_ATTENTION,1,16384,0.5001843190193176,BS=1 SL=16384,FLASH_ATTENTION vs EFFICIENT_ATTENTION,1.6995764372120648 diff --git a/benchmark/Llama-3.2-1B-Training/training_perf.py b/benchmark/Llama-3.2-1B-Training/training_perf.py deleted file mode 100644 index aa5651d0..00000000 --- a/benchmark/Llama-3.2-1B-Training/training_perf.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Training Performance Measurement - -This script is to measure the training performance of different backends in a Llama model. -Different batch sizes and sequence lengths are tested. Multiple training iterations are -run to collect timing data. Geometric mean of the timing is visualized afterwards. - -This code uses models from Hugging Face Hub. You need to run with a valid token. -Consider to set the environment variables HF_TOKEN and HF_HOME appropriately. -Only the first GPU is used. You may set the environment variable CUDA_VISIBLE_DEVICES -before running this code to use a different GPU. - -For more accurate results, it is recommended to lock the clock frequency of the GPU -using the following command: - - nvidia-smi -i 0 -lgc , -""" - -import time - -import matplotlib.pyplot as plt -import seaborn as sns -import pandas as pd -import torch -import transformers -from torch.nn.attention import SDPBackend, sdpa_kernel -from transformers.models.llama.modeling_llama import LlamaForCausalLM - -# print system info -print(f"{torch.__version__ = }") -print(f"{torch.version.cuda = }") -print(f"{torch.cuda.is_available() = }") -print(f"{torch.cuda.device_count() = }") -print(f"{torch.cuda.current_device() = }") -print(f"{torch.cuda.get_device_name(torch.cuda.current_device()) = }") -print(f"{torch.backends.cudnn.version() = }") -print(f"{torch.backends.cudnn.enabled = }") - -dtype = torch.bfloat16 -device = torch.device("cuda:0") -torch.set_default_device(device) -torch.set_default_dtype(dtype) - -model_name = "meta-llama/Llama-3.2-1B" -config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True) -tokenizer = transformers.AutoTokenizer.from_pretrained( - model_name, trust_remote_code=True -) -tokenizer.pad_token = tokenizer.eos_token -model = LlamaForCausalLM(config).to(device).train() # set norm layers to training mode -loss_fct = torch.nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - -# Configuration matrix to test -batch_seqlen = [(24, 768), (12, 1024), (6, 2048), (3, 4096), (2, 8192), (1, 16384)] -backends = [ - SDPBackend.CUDNN_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.FLASH_ATTENTION, -] - -# Run timing experiments -warmup_iterations = 5 # num of training iterations to run for warmup -measure_iterations = 100 # num of training iterations to run to measure for timing -data = [] -for batch_size, seq_len in batch_seqlen: - assert ( - seq_len < tokenizer.model_max_length - ), "seqlen must be less than the model max length" - # create random tensors - # - input embedding tensor to simulate a batch of input token sequences converted into embeddings - # - attention mask of all ones for full attention - # - random target to compute cross entropy loss in training loop - shape = (batch_size, seq_len, config.hidden_size) - inputs_embeds = torch.randn(*shape, dtype=dtype, device=device) - attention_mask = torch.ones(*shape[:2], dtype=torch.int64, device=device) - target = torch.randint( - 2, config.vocab_size - 2, shape[:2], dtype=torch.int64, device=device - ) - for backend in backends: - backend_name = str(backend).split(".")[-1] - print( - f"Timing {backend_name} with batch_size={batch_size} and seq_len={seq_len}" - ) - with sdpa_kernel(backends=[backend]): - # warmup iterations: to minimize the effect of system cache - for _ in range(warmup_iterations): - output = model.forward( - inputs_embeds=inputs_embeds, attention_mask=attention_mask - ) - loss = loss_fct( - output.logits.view(-1, config.vocab_size), target.view(-1) - ) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - torch.cuda.synchronize() - start = time.time() - # measure iterations: per-iteration time obtained by averaging - for _ in range(measure_iterations): - output = model.forward( - inputs_embeds=inputs_embeds, attention_mask=attention_mask - ) - loss = loss_fct( - output.logits.view(-1, config.vocab_size), target.view(-1) - ) - optimizer.zero_grad() - loss.backward() - optimizer.step() - torch.cuda.synchronize() # wait for all kernels to finish for accurate timing - duration = time.time() - start - data.append( - (backend_name, batch_size, seq_len, duration / measure_iterations) - ) - -# Process stats -df = pd.DataFrame(data, columns=["backend", "batch_size", "seq_len", "time"]) -df["label"] = "BS=" + df["batch_size"].astype(str) + " SL=" + df["seq_len"].astype(str) -# compute the speedup w.r.t. CUDNN_ATTENTION -df["speedup_label"] = df["backend"] + " vs EFFICIENT_ATTENTION" -df["speedup"] = df.apply( - lambda row: df.loc[ - (df["backend"] == "EFFICIENT_ATTENTION") - & (df["batch_size"] == row["batch_size"]) - & (df["seq_len"] == row["seq_len"]), - "time", - ].values[0] - / row["time"], - axis=1, -) -df.to_csv("training_timing.csv", index=False) - -# Create plots -label_order = [f"BS={b} SL={s}" for b, s in batch_seqlen] # x-axis order -hue_order = ["CUDNN_ATTENTION", "FLASH_ATTENTION", "EFFICIENT_ATTENTION"] -g = sns.barplot( - data=df, - x="label", - y="time", - hue="backend", - palette=["#76B900", "orchid", "royalblue"], - order=label_order, - hue_order=hue_order, -) -g.set_title("\nTraining Iteration Time") -g.set( - xlabel="Batch size and sequence length", - ylabel="Mean iteration time (s), lower is better", -) -g.get_legend().set_title("") -plt.legend(fontsize=8) -plt.xticks(rotation=10, size=8) -plt.tight_layout() -plt.savefig("iteration_time.png", dpi=300) - -plt.clf() -hue_order = [ - "CUDNN_ATTENTION vs EFFICIENT_ATTENTION", - "FLASH_ATTENTION vs EFFICIENT_ATTENTION", -] -g = sns.barplot( - data=df[df["speedup_label"] != "EFFICIENT_ATTENTION vs EFFICIENT_ATTENTION"], - x="label", - y="speedup", - hue="speedup_label", - palette=["#76B900", "orchid"], - order=label_order, - hue_order=hue_order, -) -for container in g.containers: - g.bar_label(container, fmt="%.2f", fontsize=6) -g.set_title( - "Per-iteration Speed-up of\ncuDNN/Flash Attention Backend vs Efficient Attention" -) -g.set( - xlabel="Batch size and sequence length", ylabel="Speed-up ratio, higher is better" -) -g.get_legend().set_title("") -plt.legend(fontsize=8) -plt.xticks(rotation=10, size=8) -plt.tight_layout() -plt.savefig("speedup.png", dpi=300) diff --git a/benchmark/sdpa_benchmark/Dockerfile b/benchmark/sdpa_benchmark/Dockerfile deleted file mode 100755 index c39cda0b..00000000 --- a/benchmark/sdpa_benchmark/Dockerfile +++ /dev/null @@ -1,20 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:24.12-py3 - -RUN apt-get update && \ - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \ - dpkg -i cuda-keyring_1.1-1_all.deb && \ - apt-get update && \ - apt-get -y install cudnn - - -RUN pip uninstall -y cudnn - -RUN pip install nvidia-cudnn-frontend - -COPY benchmark_flash_attention.py . - -ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH - -CMD ["python", "benchmark_flash_attention.py"] - -WORKDIR /workspace diff --git a/benchmark/sdpa_benchmark/README.md b/benchmark/sdpa_benchmark/README.md deleted file mode 100644 index d8ad48ea..00000000 --- a/benchmark/sdpa_benchmark/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# Attention-benchmark - -## Contents -- Dockerfile to create a docker container for the dependencies. -- benchmark_flash_attention.py which runs cudnn, pytorch upto 64k sequence length. -- Pytorch native sdpa operation vs cudnn sdpa operation. - -## Steps to run -Lock the clocks. -For eg. in H200, use `nvidia-smi -q -d SUPPORTED_CLOCKS` to get the supported clocks - -``` -sudo nvidia-smi -pm 1 -sudo nvidia-smi -ac 3201,1980 -sudo nvidia-smi -pl 700 -``` - -Launch the docker build and run. -``` -docker build -t cudnn_attention_benchmark . && docker run --gpus=all --rm --shm-size=16g -it cudnn_attention_benchmark -``` - -## Sample output - -``` -$ python benchmark_flash_attention.py -Is flash sdp enabled in Pytorch : True -cudnn backend version : 90100 -### causal=False, headdim=128, batch_size=32, seqlen=512 ### -Pytorch fwd: 302.38 TFLOPs/s, bwd: 169.09 TFLOPs/s, fwd + bwd: 193.45 TFLOPs/s -cudnn_bf16 fwd: 501.13 TFLOPs/s, bwd: 351.28 TFLOPs/s, fwd + bwd: 384.09 TFLOPs/s -cudnn_fp8 fwd: 678.07 TFLOPs/s, bwd: 418.37 TFLOPs/s, fwd + bwd: 469.78 TFLOPs/s -``` - -Please refer to the [benchmark_results.csv](benchmark_results.csv) for sample output. - -## Results - -#### Forward -![Comparison of pytorch and cudnn](images/forward.png) - -#### Bprop -![Comparison of pytorch and cudnn](images/bprop.png) - -#### Fwd + Bprop -![Comparison of pytorch and cudnn](images/fwd_bprop.png) - -## Pytorch adoption - -cuDNN v9 can achieve over 2x the performance of the comparable PyTorch eager implementation, as detailed in [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html) PyTorch eager mode SDPA doesn't use cuDNN today, but adding a cuDNN-based implementation is in progress (see the PyTorch PRs for [Fprop](https://github.com/pytorch/pytorch/pull/115663), and [Bprop](https://github.com/pytorch/pytorch/pull/122510)). \ No newline at end of file diff --git a/benchmark/sdpa_benchmark/benchmark_flash_attention.py b/benchmark/sdpa_benchmark/benchmark_flash_attention.py deleted file mode 100755 index fc468a0d..00000000 --- a/benchmark/sdpa_benchmark/benchmark_flash_attention.py +++ /dev/null @@ -1,787 +0,0 @@ -import pickle -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.benchmark as benchmark -import os -import csv -import itertools - -from einops import rearrange, repeat - - -# benchmarking functions from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/benchmark.py -def benchmark_forward( - fn, - *inputs, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" - if verbose: - print(desc, "- Forward pass") - - def amp_wrapper(*inputs, **kwinputs): - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - fn(*inputs, **kwinputs) - - t = benchmark.Timer( - stmt="fn_amp(*inputs, **kwinputs)", - globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_backward( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" - if verbose: - print(desc, "- Backward pass") - with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): - y = fn(*inputs, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError("Grad shape does not match output shape") - - def f(*inputs, y, grad): - # Set .grad to None to avoid extra operation of gradient accumulation - for x in inputs: - if isinstance(x, torch.Tensor): - x.grad = None - y.backward(grad, retain_graph=True) - - t = benchmark.Timer( - stmt="f(*inputs, y=y, grad=grad)", - globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_fwd_bwd( - fn, - *inputs, - grad=None, - repeats=10, - desc="", - verbose=True, - amp=False, - amp_dtype=torch.float16, - **kwinputs, -): - """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" - return ( - benchmark_forward( - fn, - *inputs, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - benchmark_backward( - fn, - *inputs, - grad=grad, - repeats=repeats, - desc=desc, - verbose=verbose, - amp=amp, - amp_dtype=amp_dtype, - **kwinputs, - ), - ) - - -try: - import cudnn -except ImportError: - cudnn = None -assert cudnn is not None - - -def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - else: - raise ValueError("Unsupported tensor data type.") - - -def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): - assert mode in ["fwd", "bwd", "fwd_bwd"] - f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) - return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) - - -def efficiency(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def attention_pytorch(qkv, dropout_p=0.0, causal=True): - # batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, "b t h d -> b h t d") - k = rearrange(k, "b s h d -> b h s d") - v = rearrange(v, "b s h d -> b h s d") - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p - ) - return out - - -def time_fwd_bwd(func, *args, **kwargs): - time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) - return time_f[1].mean, time_b[1].mean - - -def time_fwd(func, *args, **kwargs): - time_f = benchmark_forward(func, *args, **kwargs) - return time_f[1].mean - - -print("Is flash sdp enabled in Pytorch : " + str(torch._C._get_flash_sdp_enabled())) -print("cudnn backend version : " + str(cudnn.backend_version())) - -filename = "benchmark_results.csv" -csvfile = open(filename, "w") -csvwriter = csv.writer(csvfile) - -repeats = 30 -device = "cuda" -dtype = torch.bfloat16 - -bs_seqlen_vals = [ - # (32, 512), - # (16, 1024), - # (8, 2048), - (4, 4096), - (2, 8192), - (1, 16384), - (1, 32768), - (1, 65536), - # (1, 262144), -] -causal_vals = [False, True] -headdim_vals = [128] -# headdim_vals = [128, 256] -# n_heads = 16, 32, 64 -n_heads = [16] -dropout_p = 0.0 - -fields = [ - "Batch", - "Number of heads", - "Sequence length", - "Head dim", - "causal", - "dropout_p", - "pytorch (TFlops/s fwd)", - "pytorch (TFlops/s bwd)", - "pytorch (TFlops/s fwd + bwd)", - "cudnn BF16 (TFlops/s fwd)", - "cudnn BF16 (TFlops/s bwd)", - "cudnn BF16 (TFlops/s fwd + bwd)", -] - -if cudnn.backend_version() >= 90100: - fields += [ - "cudnn FP8 (TFlops/s fwd)", - "cudnn FP8 (TFlops/s bwd)", - "cudnn FP8 (TFlops/s fwd + bwd)", - ] -csvwriter.writerow(fields) - -methods = ["Pytorch"] -if cudnn is not None: - methods += ["cudnn_bf16"] - if cudnn.backend_version() >= 90100: - methods += ["cudnn_fp8"] - -time_f = {} -time_b = {} -time_f_b = {} -speed_f = {} -speed_b = {} -speed_f_b = {} - -for causal, headdim, bs_seqlen, nheads in itertools.product( - causal_vals, headdim_vals, bs_seqlen_vals, n_heads -): - batch_size, seqlen = bs_seqlen - config = (causal, headdim, batch_size, seqlen) - # nheads = dim // headdim - - if (seqlen >= 262144) and (nheads > 16): - continue - - if (seqlen >= 262144) and (headdim > 128): - continue - - print( - "Running bs={}, seqlen={}, d={}, h={}, causal={}".format( - batch_size, seqlen, headdim, nheads, causal - ) - ) - - if "Pytorch" in methods: - qkv = torch.randn( - batch_size, - seqlen, - 3, - nheads, - headdim, - device=device, - dtype=dtype, - requires_grad=True, - ) - try: - qkv = qkv.detach().requires_grad_(True) - f, b = time_fwd_bwd( - attention_pytorch, - qkv, - dropout_p, - causal=causal, - repeats=repeats, - verbose=False, - ) - except: # Skip if OOM - f, b = float("nan"), float("nan") - time_f[config, "Pytorch"] = f - time_b[config, "Pytorch"] = b - - if ( - ("cudnn_fp16" in methods or "cudnn_bf16" in methods) - and device == "cuda" - and cudnn is not None - ): - is_causal = causal - is_dropout = False if (abs(dropout_p - 0.0) < 1e-6) else True - is_infer = False - input_type = dtype - attn_scale = headdim ** (-0.5) - dropout_prob = dropout_p if is_dropout else 0.0 - - shape_qkvo = (batch_size, nheads, seqlen, headdim) - stride_qkv = (seqlen * 3 * nheads * headdim, headdim, 3 * nheads * headdim, 1) - stride_o = (seqlen * nheads * headdim, headdim, nheads * headdim, 1) - offset_q, offset_k, offset_v = [nheads * headdim * i for i in range(3)] - - qkv_gpu = ( - torch.randn( - batch_size * seqlen * 3 * nheads * headdim, - dtype=input_type, - device="cuda", - ) - - 0.5 - ) - q_gpu, k_gpu, v_gpu = [ - torch.as_strided(qkv_gpu, shape_qkvo, stride_qkv, storage_offset=offset) - for offset in [offset_q, offset_k, offset_v] - ] - o_gpu = torch.empty(*shape_qkvo, dtype=input_type, device="cuda").as_strided( - shape_qkvo, stride_o - ) - dQ_gpu, dK_gpu, dV_gpu = [ - torch.empty_like(tensor) for tensor in [q_gpu, k_gpu, v_gpu] - ] - dO_gpu = torch.randn_like(o_gpu) - 0.5 - - stats_gpu = ( - torch.empty( - batch_size, nheads, seqlen, 1, dtype=torch.float32, device="cuda" - ) - if not is_infer - else None - ) - - if is_dropout: - seed_gpu = torch.full( - (1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda" - ) - offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - - # cuDNN graph forward - graph_fwd = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(input_type), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - - q_fwd = graph_fwd.tensor_like(q_gpu) - k_fwd = graph_fwd.tensor_like(k_gpu) - v_fwd = graph_fwd.tensor_like(v_gpu) - - if is_dropout: - seed_fwd = graph_fwd.tensor_like(seed_gpu) - offset_fwd = graph_fwd.tensor_like(offset_gpu) - dropout_tuple = (dropout_prob, seed_fwd, offset_fwd) - - o_fwd, stats_fwd = graph_fwd.sdpa( - q=q_fwd, - k=k_fwd, - v=v_fwd, - generate_stats=not is_infer, - attn_scale=attn_scale, - use_causal_mask=is_causal, - dropout=dropout_tuple if is_dropout else None, - ) - - o_fwd.set_output(True).set_dim(o_gpu.size()).set_stride(o_gpu.stride()) - ( - stats_fwd.set_output(True) - .set_dim(stats_gpu.size()) - .set_stride(stats_gpu.stride()) - .set_data_type(cudnn.data_type.FLOAT) - if not is_infer - else None - ) - - graph_fwd.validate() - graph_fwd.build_operation_graph() - graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_fwd.check_support() - graph_fwd.build_plans() - - # cuDNN graph backward - graph_bwd = cudnn.pygraph( - io_data_type=cudnn.data_type.HALF, - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - - q_bwd = graph_bwd.tensor_like(q_gpu) - k_bwd = graph_bwd.tensor_like(k_gpu) - v_bwd = graph_bwd.tensor_like(v_gpu) - o_bwd = graph_bwd.tensor_like(o_gpu) - dO_bwd = graph_bwd.tensor_like(dO_gpu) - stats_bwd = graph_bwd.tensor_like(stats_gpu) - - if is_dropout: - seed_bwd = graph_fwd.tensor_like(seed_gpu) - offset_bwd = graph_fwd.tensor_like(offset_gpu) - dropout_tuple = (dropout_prob, seed_bwd, offset_bwd) - - dQ_bwd, dK_bwd, dV_bwd = graph_bwd.sdpa_backward( - q=q_bwd, - k=k_bwd, - v=v_bwd, - o=o_bwd, - dO=dO_bwd, - stats=stats_bwd, - attn_scale=attn_scale, - use_causal_mask=is_causal, - dropout=dropout_tuple if is_dropout else None, - ) - - dQ_bwd.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) - dK_bwd.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) - dV_bwd.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) - - # cuDNN Flash Attention doesn't support bprop for d=256 - if headdim != 256: - graph_bwd.validate() - graph_bwd.build_operation_graph() - graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_bwd.check_support() - graph_bwd.build_plans() - - variant_pack_fwd = { - q_fwd: q_gpu, - k_fwd: k_gpu, - v_fwd: v_gpu, - o_fwd: o_gpu, - stats_fwd: stats_gpu, - } - variant_pack_bwd = { - q_bwd: q_gpu, - k_bwd: k_gpu, - v_bwd: v_gpu, - o_bwd: o_gpu, - dO_bwd: dO_gpu, - stats_bwd: stats_gpu, - dQ_bwd: dQ_gpu, - dK_bwd: dK_gpu, - dV_bwd: dV_gpu, - } - if is_dropout: - variant_pack_fwd[seed_fwd] = seed_gpu - variant_pack_fwd[offset_fwd] = offset_gpu - variant_pack_bwd[seed_bwd] = seed_gpu - variant_pack_bwd[offset_bwd] = offset_gpu - - workspace = torch.empty( - max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), - device="cuda", - dtype=torch.uint8, - ) - - f = time_fwd( - graph_fwd.execute, - variant_pack_fwd, - workspace, - repeats=repeats, - verbose=False, - ) - if headdim != 256: - b = time_fwd( - graph_bwd.execute, - variant_pack_bwd, - workspace, - repeats=repeats, - verbose=False, - ) - else: - b = 100000 - - time_f[config, "cudnn_bf16"] = f - time_b[config, "cudnn_bf16"] = b - - print("cudnn_fp16 done") - if "cudnn_fp8" in methods and device == "cuda" and cudnn is not None: - is_causal = causal - is_dropout = False if (abs(dropout_p - 0.0) < 1e-6) else True - is_infer = False - input_type = dtype - attn_scale = headdim ** (-0.5) - dropout_prob = dropout_p if is_dropout else 0.0 - - shape_qkvo = (batch_size, nheads, seqlen, headdim) - stride_qkv = (seqlen * 3 * nheads * headdim, headdim, 3 * nheads * headdim, 1) - stride_o = (seqlen * nheads * headdim, headdim, nheads * headdim, 1) - offset_q, offset_k, offset_v = [nheads * headdim * i for i in range(3)] - - qkv_gpu = torch.randint( - 256, - (batch_size * seqlen * 3 * nheads * headdim,), - dtype=torch.uint8, - device="cuda", - ) - q_gpu, k_gpu, v_gpu = [ - torch.as_strided(qkv_gpu, shape_qkvo, stride_qkv, storage_offset=offset) - for offset in [offset_q, offset_k, offset_v] - ] - o_gpu = torch.empty(*shape_qkvo, dtype=torch.uint8, device="cuda").as_strided( - shape_qkvo, stride_o - ) - dQ_gpu, dK_gpu, dV_gpu = [ - torch.empty_like(tensor) for tensor in [q_gpu, k_gpu, v_gpu] - ] - dO_gpu = torch.randint_like(o_gpu, 256) - - stats_gpu = ( - torch.empty( - batch_size, nheads, seqlen, 1, dtype=torch.float32, device="cuda" - ) - if not is_infer - else None - ) - - descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - - scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device="cuda") - - amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device="cuda") - - # cudnn graph forward - graph_fwd = cudnn.pygraph( - io_data_type=cudnn.data_type.FP8_E4M3, - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - - q_fwd = graph_fwd.tensor_like(q_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - k_fwd = graph_fwd.tensor_like(k_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - v_fwd = graph_fwd.tensor_like(v_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - - descale_q_fwd = graph_fwd.tensor_like(descale_q_gpu) - descale_k_fwd = graph_fwd.tensor_like(descale_k_gpu) - descale_v_fwd = graph_fwd.tensor_like(descale_v_gpu) - descale_s_fwd = graph_fwd.tensor_like(descale_s_gpu) - - scale_s_fwd = graph_fwd.tensor_like(scale_s_gpu) - scale_o_fwd = graph_fwd.tensor_like(scale_o_gpu) - - o_fwd, stats_fwd, amax_s_fwd, amax_o_fwd = graph_fwd.sdpa_fp8( - q=q_fwd, - k=k_fwd, - v=v_fwd, - descale_q=descale_q_fwd, - descale_k=descale_k_fwd, - descale_v=descale_v_fwd, - descale_s=descale_s_fwd, - scale_s=scale_s_fwd, - scale_o=scale_o_fwd, - generate_stats=not is_infer, - attn_scale=attn_scale, - use_causal_mask=is_causal, - use_padding_mask=False, - ) - - o_fwd.set_output(True).set_dim(o_gpu.size()).set_stride( - o_gpu.stride() - ).set_data_type(cudnn.data_type.FP8_E4M3) - ( - stats_fwd.set_output(True) - .set_dim(stats_gpu.size()) - .set_stride(stats_gpu.stride()) - .set_data_type(cudnn.data_type.FLOAT) - if not is_infer - else None - ) - amax_s_fwd.set_output(True).set_dim(amax_s_gpu.size()).set_stride( - amax_s_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - amax_o_fwd.set_output(True).set_dim(amax_o_gpu.size()).set_stride( - amax_o_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - - graph_fwd.validate() - graph_fwd.build_operation_graph() - graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_fwd.check_support() - graph_fwd.build_plans() - - # cudnn graph backward - graph_bwd = cudnn.pygraph( - io_data_type=cudnn.data_type.FP8_E4M3, - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - - q_bwd = graph_bwd.tensor_like(q_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - k_bwd = graph_bwd.tensor_like(k_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - v_bwd = graph_bwd.tensor_like(v_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - o_bwd = graph_bwd.tensor_like(o_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - dO_bwd = graph_bwd.tensor_like(dO_gpu).set_data_type(cudnn.data_type.FP8_E4M3) - stats_bwd = graph_bwd.tensor_like(stats_gpu) - - descale_q_bwd = graph_bwd.tensor_like(descale_q_gpu) - descale_k_bwd = graph_bwd.tensor_like(descale_k_gpu) - descale_v_bwd = graph_bwd.tensor_like(descale_v_gpu) - descale_o_bwd = graph_bwd.tensor_like(descale_o_gpu) - descale_dO_bwd = graph_bwd.tensor_like(descale_dO_gpu) - descale_s_bwd = graph_bwd.tensor_like(descale_s_gpu) - descale_dP_bwd = graph_bwd.tensor_like(descale_dP_gpu) - - scale_s_bwd = graph_bwd.tensor_like(scale_s_gpu) - scale_dQ_bwd = graph_bwd.tensor_like(scale_dQ_gpu) - scale_dK_bwd = graph_bwd.tensor_like(scale_dK_gpu) - scale_dV_bwd = graph_bwd.tensor_like(scale_dV_gpu) - scale_dP_bwd = graph_bwd.tensor_like(scale_dP_gpu) - - dQ_bwd, dK_bwd, dV_bwd, amax_dQ_bwd, amax_dK_bwd, amax_dV_bwd, amax_dP_bwd = ( - graph_bwd.sdpa_fp8_backward( - q=q_bwd, - k=k_bwd, - v=v_bwd, - o=o_bwd, - dO=dO_bwd, - stats=stats_bwd, - descale_q=descale_q_bwd, - descale_k=descale_k_bwd, - descale_v=descale_v_bwd, - descale_o=descale_o_bwd, - descale_dO=descale_dO_bwd, - descale_s=descale_s_bwd, - descale_dP=descale_dP_bwd, - scale_s=scale_s_bwd, - scale_dQ=scale_dQ_bwd, - scale_dK=scale_dK_bwd, - scale_dV=scale_dV_bwd, - scale_dP=scale_dP_bwd, - attn_scale=attn_scale, - use_causal_mask=is_causal, - ) - ) - - dQ_bwd.set_output(True).set_dim(dQ_gpu.size()).set_stride( - dQ_gpu.stride() - ).set_data_type(cudnn.data_type.FP8_E4M3) - dK_bwd.set_output(True).set_dim(dK_gpu.size()).set_stride( - dK_gpu.stride() - ).set_data_type(cudnn.data_type.FP8_E4M3) - dV_bwd.set_output(True).set_dim(dV_gpu.size()).set_stride( - dV_gpu.stride() - ).set_data_type(cudnn.data_type.FP8_E4M3) - amax_dQ_bwd.set_output(True).set_dim(amax_dQ_gpu.size()).set_stride( - amax_dQ_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - amax_dK_bwd.set_output(True).set_dim(amax_dK_gpu.size()).set_stride( - amax_dK_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - amax_dV_bwd.set_output(True).set_dim(amax_dV_gpu.size()).set_stride( - amax_dV_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - amax_dP_bwd.set_output(True).set_dim(amax_dP_gpu.size()).set_stride( - amax_dP_gpu.stride() - ).set_data_type(cudnn.data_type.FLOAT) - - # cuDNN Flash Attention fp8 only support bprop for d=128 - if headdim == 128: - graph_bwd.validate() - graph_bwd.build_operation_graph() - graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_bwd.check_support() - graph_bwd.build_plans() - - variant_pack_fwd = { - q_fwd: q_gpu, - k_fwd: k_gpu, - v_fwd: v_gpu, - o_fwd: o_gpu, - stats_fwd: stats_gpu, - descale_q_fwd: descale_q_gpu, - descale_k_fwd: descale_k_gpu, - descale_v_fwd: descale_v_gpu, - descale_s_fwd: descale_s_gpu, - scale_s_fwd: scale_s_gpu, - scale_o_fwd: scale_o_gpu, - amax_s_fwd: amax_s_gpu, - amax_o_fwd: amax_o_gpu, - } - - variant_pack_bwd = { - q_bwd: q_gpu, - k_bwd: k_gpu, - v_bwd: v_gpu, - o_bwd: o_gpu, - dQ_bwd: dQ_gpu, - dK_bwd: dK_gpu, - dV_bwd: dV_gpu, - dO_bwd: dO_gpu, - stats_bwd: stats_gpu, - descale_q_bwd: descale_q_gpu, - descale_k_bwd: descale_k_gpu, - descale_v_bwd: descale_v_gpu, - descale_o_bwd: descale_o_gpu, - descale_s_bwd: descale_s_gpu, - descale_dP_bwd: descale_dP_gpu, - descale_dO_bwd: descale_dO_gpu, - scale_s_bwd: scale_s_gpu, - scale_dQ_bwd: scale_dQ_gpu, - scale_dK_bwd: scale_dK_gpu, - scale_dV_bwd: scale_dV_gpu, - scale_dP_bwd: scale_dP_gpu, - amax_dQ_bwd: amax_dQ_gpu, - amax_dK_bwd: amax_dK_gpu, - amax_dV_bwd: amax_dV_gpu, - amax_dP_bwd: amax_dP_gpu, - } - - workspace = torch.empty( - max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), - device="cuda", - dtype=torch.uint8, - ) - - f = time_fwd( - graph_fwd.execute, - variant_pack_fwd, - workspace, - repeats=repeats, - verbose=False, - ) - # cuDNN Flash Attention doesn't support bprop for d=256 - if headdim == 128: - b = time_fwd( - graph_bwd.execute, - variant_pack_bwd, - workspace, - repeats=repeats, - verbose=False, - ) - else: - b = 100000 - - time_f[config, "cudnn_fp8"] = f - time_b[config, "cudnn_fp8"] = b - - row = [] - row.append(str(batch_size)) - row.append(str(nheads)) - row.append(str(seqlen)) - row.append(str(headdim)) - row.append(str(causal)) - row.append(str(dropout_p)) - - print( - f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###" - ) - for method in methods: - time_f_b[config, method] = time_f[config, method] + time_b[config, method] - speed_f[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), - time_f[config, method], - ) - speed_b[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), - time_b[config, method], - ) - speed_f_b[config, method] = efficiency( - flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), - time_f_b[config, method], - ) - print( - f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " - f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " - f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" - ) - row.append(str(speed_f[config, method])) - row.append(str(speed_b[config, method])) - row.append(str(speed_f_b[config, method])) - csvwriter.writerow(row) - - print(row) - -csvfile.close() diff --git a/benchmark/sdpa_benchmark/benchmark_results.csv b/benchmark/sdpa_benchmark/benchmark_results.csv deleted file mode 100644 index 4ad15c73..00000000 --- a/benchmark/sdpa_benchmark/benchmark_results.csv +++ /dev/null @@ -1,11 +0,0 @@ -Batch,Heads,Sequence length,Head dim,Causal,Pytorch FWD (TFLOPs/s),Pytorch BWD (TFLOPs/s),Pytorch FWD+BWD (TFLOPs/s),cuDNN BF16 FWD (TFLOPs/s),cuDNN BF16 BWD (TFLOPs/s),cuDNN BF16 FWD+BWD (TFLOPs/s),cuDNN FP8 FWD (TFLOPs/s),cuDNN FP8 BWD (TFLOPs/s),cuDNN FP8 FWD+BWD (TFLOPs/s) -4,16,4096,128,False,362.03,290.47,307.86,704.41,535.06,574.53,1020.77,696.09,765.67 -2,16,8192,128,False,364.03,305.47,320.18,696.52,536.95,574.56,1042.20,715.65,786.02 -1,16,16384,128,False,360.50,317.34,328.58,687.97,546.53,580.64,1037.43,736.88,803.38 -1,16,32768,128,False,356.15,320.44,329.89,648.63,546.31,572.09,992.24,731.45,790.84 -1,16,65536,128,False,360.62,330.07,338.26,635.00,549.29,571.32,972.50,736.03,790.99 -4,16,4096,128,True,296.52,245.18,257.94,613.49,438.45,477.36,889.73,434.33,508.73 -2,16,8192,128,True,321.94,275.40,287.26,645.95,461.89,502.82,985.71,602.80,678.05 -1,16,16384,128,True,327.17,305.19,311.17,649.99,488.56,525.88,970.39,634.72,704.33 -1,16,32768,128,True,345.64,316.92,324.62,636.08,504.14,535.90,984.84,641.10,712.11 -1,16,65536,128,True,338.85,323.98,328.09,631.20,508.07,538.06,962.72,631.44,700.29 \ No newline at end of file diff --git a/benchmark/sdpa_benchmark/images/bprop.png b/benchmark/sdpa_benchmark/images/bprop.png deleted file mode 100644 index 32a7de1f..00000000 Binary files a/benchmark/sdpa_benchmark/images/bprop.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark/images/forward.png b/benchmark/sdpa_benchmark/images/forward.png deleted file mode 100644 index 45eb64f6..00000000 Binary files a/benchmark/sdpa_benchmark/images/forward.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark/images/fwd_bprop.png b/benchmark/sdpa_benchmark/images/fwd_bprop.png deleted file mode 100644 index d3eb66e0..00000000 Binary files a/benchmark/sdpa_benchmark/images/fwd_bprop.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark_training/Dockerfile b/benchmark/sdpa_benchmark_training/Dockerfile index 1ae6e6de..8a0efd6d 100644 --- a/benchmark/sdpa_benchmark_training/Dockerfile +++ b/benchmark/sdpa_benchmark_training/Dockerfile @@ -1,46 +1,28 @@ -FROM nvcr.io/nvidia/pytorch:25.11-py3 +FROM nvcr.io/nvidia/pytorch:25.12-py3 -RUN pip install --upgrade pip && \ - pip install seaborn +# Set working directory +WORKDIR /workspace -RUN apt-get update && \ - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \ +# Update libcudnn9-cuda-13 +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \ dpkg -i cuda-keyring_1.1-1_all.deb && \ + apt-get remove -y *cudnn9* && \ apt-get update && \ - apt-get -y install cudnn - -RUN pip uninstall -y cudnn - -COPY benchmark_bf16_sdpa.py . - -COPY benchmark_fp8_sdpa.py . - -COPY benchmark_single_sdpa.py . - -ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH - -WORKDIR /workspace - -RUN pip install nvidia-cutlass-dsl - -RUN pip uninstall -y flash_attn - -ENV CUDA_HOME=/usr/local/cuda -ENV MAX_JOBS=32 - -RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive - -WORKDIR /workspace/flash-attention - -ENV FLASH_ATTENTION_DISABLE_BACKWARD=TRUE -ENV FLASH_ATTENTION_DISABLE_APPENDKV=TRUE -ENV FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE -ENV FLASH_ATTENTION_DISABLE_FP8=TRUE -ENV FLASH_ATTENTION_DISABLE_CLUSTER=TRUE -ENV FLASH_ATTENTION_DISABLE_HDIM96=TRUE -ENV FLASH_ATTENTION_DISABLE_HDIM192=TRUE -ENV FLASH_ATTENTION_DISABLE_SM80=TRUE - -RUN python3 setup.py install - -WORKDIR /workspace \ No newline at end of file + apt-get -y install cudnn && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Clone cudnn_frontend and install latest cudnn +RUN git clone https://github.com/NVIDIA/cudnn-frontend.git +RUN pip install -v cudnn-frontend + +# Clone flash-attention +RUN pip uninstall -y flash-attn && \ + git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention && \ + sed -i 's/^ import flash_attn_2_cuda as flash_attn_gpu$/ pass/' /workspace/flash-attention/flash_attn/flash_attn_interface.py +RUN pip install nvidia-cutlass-dsl apache-tvm-ffi quack-kernels +ENV PYTHONPATH=/workspace/flash-attention + +# Install additional dependencies for benchmarking +RUN pip install seaborn \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/README.md b/benchmark/sdpa_benchmark_training/README.md index 91aa6c25..87368634 100644 --- a/benchmark/sdpa_benchmark_training/README.md +++ b/benchmark/sdpa_benchmark_training/README.md @@ -1,170 +1,234 @@ # Scaled Dot Product Attention Benchmark -## Introduction -The benchmarking script in this current directory profiles scaled dot product attention (SDPA) from various backends. Here we benchmark attention layer dimensions inspired by [Llama-3.1-405B](https://ai.meta.com/blog/meta-llama-3-1/) with sequence lengths ranging from 512 to 131,072. +## Introduction -The provided benchmark targets training use cases--causal masking is enabled for grouped query attention (GQA). Layer dimensions and causal masking can be altered by modifying the preset parameters in `benchmark_{bf16,fp8}_sdpa.py`. Inference-specific attention optimizations such as paged attention are not benchmarked at this time. +This directory contains benchmarking tools for Scaled Dot Product Attention (SDPA) operations across various backends. The benchmarks target training use cases with support for causal masking and grouped query attention (GQA). ## Contents -- `Dockerfile` to create a Docker container for the dependencies and run the benchmark. -- `benchmark_bf16_sdpa.py` which runs cudnn, pytorch, and other backends up to 128k sequence length. -- `benchmark_fp8_sdpa.py` which runs cudnn on fp8 along with bf16 up to 128k sequence length. -- Sample benchmark output and results on GB200 and GB300 in the `artifacts` directory. -- Useful Python scripts for running single attention layers: - - `benchmark_single_sdpa.py` for benchmarking a single flash attention instance from various backends. - - See below for usage example. - -## Software versions +- `Dockerfile` - Docker container setup for running benchmarks +- `benchmark_single_sdpa.py` - Single SDPA benchmark script +- `configs/` - Benchmark configuration files + - `llama.py` - Llama 3.1 GQA benchmarks (causal + non-causal) + - `dsv3.py` - DeepSeek V3 MHA benchmarks (causal only) +- `runner.py` - Configuration-based benchmark runner +- `config_types.py` - Data types for benchmark configuration +- `charts.py` - Chart generation utilities +- `../results/` - Benchmark outputs (CSV and charts) -This benchmark code should run on any decently modern Python environment with CUDA-enabled GPU. The results in `artifacts` were collected using the PyTorch docker image [from the NVIDIA GPU CLOUD (NGC) catalog](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), `nvcr.io/nvidia/pytorch:25.11-py3`, where cuDNN 9.17.0 was used. We provide a `Dockerfile` to reproduce the environment with the following library versions +## Quick Start +### 1. Build Docker Container -| Software | Version | -|----------------|---------| -| Python | 3.12.3 | -| CUDA | 13.0.0 | -| cuDNN | 9.17.0 | -| PyTorch | 2.10.0 | -| FlashAttention 2 | 2.8.3 | - - -## Steps to run -### 0. *Optional*: Lock Clocks -Although the benchmarking code inserts dynamically-sized delays to avoid GPU throttling, most reproducible results can be obtained when clocks are locked. For example, use `nvidia-smi -q -d SUPPORTED_CLOCKS` to get the supported clocks +```bash +docker build -t cudnn_attention_benchmark . -``` -sudo nvidia-smi -pm 1 -nvidia-smi -lgc , +docker run -it --gpus all --rm cudnn_attention_benchmark ``` -### 1. Build docker container -Launch the docker build and run. We prodivde a simple `Dockerfile` to help run the benchmark -``` -docker build -t cudnn_attention_benchmark . -docker run -it --gpus all --rm -v $(pwd):/workspace cudnn_attention_benchmark -``` +### 2. Run Benchmarks -### 2. Run Benchmark script -The `benchmark_{bf16,fp8}_sdpa.py` scripts execute a predefined set of attention layers of various sequence lengths, where the transformer dimensions are inspired by [Llama-3.1-405B](https://ai.meta.com/blog/meta-llama-3-1/) (`num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`) +```bash +# Run Llama 3.1 benchmark suite +python -m benchmark.sdpa_benchmark_training.runner --config llama -The following scaled dot product attention backends are benchmarked: -- [PyTorch's SDPA backends](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html): - - FlashAttention-2 (`FLASH_ATTENTION`; PyTorch FAv2 ) -- cuDNN Frontend (bfloat16) -- cuDNN Frontend (fp8) +# Run DeepSeek V3 benchmark suite +python -m benchmark.sdpa_benchmark_training.runner --config dsv3 -Please note that FlashAttention-3 is currently not supported on NVIDIA's Blackwell generation GPUs. +# Dry run (show what would be executed) +python -m benchmark.sdpa_benchmark_training.runner --config llama --dry-run -Sample outputs: -``` -$ python3 benchmark_bf16_sdpa.py -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB200' -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.backends.cudnn.enabled = True -[INFO] flash_attn.__version__ = '2.8.3' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -... -[INFO] Saving results to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.csv -[INFO] Saving plot to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png -``` +# Filter by backend +python -m benchmark.sdpa_benchmark_training.runner --config llama --backend cudnn +# Filter by data type +python -m benchmark.sdpa_benchmark_training.runner --config llama --dtype bfloat16 ``` -$ python3 benchmark_sdpa_fp8.py -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB200' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -... -[INFO] Saving results to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.csv -[INFO] Saving plot to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png -``` - -Benchmarked performance numbers are stored in the [artifacts](artifacts) directory as csv and png files. -## Results -Below are the result of the benchmark running on a single GB200 GPU and a single GB300 GPU. - -For both runs, the following software versions are used: - -- CUDA: 13.0 (from NGC container) -- PyTorch: 2.10.0 (from NGC container) -- cuDNN: 9.17.0 (Installed via `apt-get`; see `Dockerfile`) +## Configuration-Based Benchmarking + +### Creating Custom Configurations + +1. Copy the template: + ```bash + cp configs/llama.py configs/my_config.py + ``` + +2. Edit your config: + ```python + from ..config_types import ModelPreset, BenchmarkConfig + + MY_MODEL = ModelPreset( + name="my_model", + num_q_heads=32, + num_kv_heads=8, + head_dim=128, + ) + + CONFIG = BenchmarkConfig( + name="my_benchmark", + models=[MY_MODEL], + seqlens=[(4096, 4096), (8192, 8192)], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], + profile_pass="fwd", # "fwd", "bwd", or "both" + num_iterations=10, + ) + ``` + +3. Run: + ```bash + python -m benchmark.sdpa_benchmark_training.runner --config my_config + ``` + +### Configuration Options + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `models` | List of `ModelPreset` to benchmark | Required | +| `seqlens` | List of `(q_seqlen, kv_seqlen)` tuples | Required | +| `backends` | Backends to compare | `["cudnn"]` | +| `data_types` | Data types to test | `["bfloat16"]` | +| `attn_masks` | Attention masks (`top_left`, `no_mask`, `bottom_right`) | `["top_left"]` | +| `profile_pass` | Which pass to profile (`fwd`, `bwd`, `both`) | `"fwd"` | +| `batch_size` | Batch size | `1` | +| `num_iterations` | Iterations per benchmark | `10` | +| `deterministic_bwd` | Deterministic modes for backward | `[False]` | + +### Model Presets + +Standard model: +```python +LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, +) +``` +Asymmetric head dimensions (DeepSeek V3): +```python +DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, # Q/K head dimension + head_dim_vo=128, # V/O head dimension +) +``` -### GB200 - BF16 Performance Comparison between Backends -![Comparison of pytorch and cudnn](artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png) -- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`. -- Sequence lengths are shown in the x-axis. -- Results were obtained on an NVIDIA GB200 GPU with free clock. +### Output + +The runner produces (in `benchmark/results/`): +- **CSV**: `_.csv` +- **Charts**: Separate chart per mask type: + - `_top_left.png` (causal) + - `_no_mask.png` (non-causal) +- Charts show backends side-by-side with distinct colors for BF16 vs FP8 + +## Single Benchmark Script + +For running individual benchmarks: + +```bash +# cuDNN Frontend (BF16) +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend cudnn --data_type bfloat16 \ + --attn_mask top_left --fwd_bwd + +# cuDNN Frontend (FP8) +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend cudnn --data_type fp8 \ + --attn_mask top_left --fwd_bwd + +# FlashAttention 4 +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend flash_attention_4 --data_type bfloat16 \ + --attn_mask top_left --fwd_bwd +``` -### GB200 - cuDNN's FP8 Performance Relative to BF16 -![Comparison of pytorch and cudnn](artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png) -- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`. -- Sequence lengths are shown in the x-axis. -- Results were obtained on an NVIDIA GB200 GPU with free clock. +Run `python benchmark_single_sdpa.py --help` for all options. -### GB300 - BF16 Performance Comparison between Backends -![Comparison of pytorch and cudnn](artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png) -- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`. -- Sequence lengths are shown in the x-axis. -- Results were obtained on an NVIDIA GB300 GPU with free clock. +## Programmatic Usage -### GB300 - cuDNN's FP8 Performance Relative to BF16 -![Comparison of pytorch and cudnn](artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png) -- SDPA parameters were used `batch=1; num_q_heads=128; num_kv_heads=8; head_dim=128; is_causal=True; dtype=bfloat16`. -- Sequence lengths are shown in the x-axis. -- Results were obtained on an NVIDIA GB300 GPU with free clock. +```python +from benchmark.sdpa_benchmark_training import ( + BenchmarkRunner, + BenchmarkConfig, + ModelPreset, + load_config, +) -## Pytorch adoption -As demonstrated can be seen from the results, cuDNN v9 can achieve over 2x the performance of the comparable PyTorch eager implementation. Refer to [PyTorch's scaled_dot_product_attention()](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and [sdpa_kernel](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) context manager documentations for enabling the cuDNN backend for scaled dot product attention. +# Load existing config +config = load_config("llama") -## `benchmark_single_sdpa.py` -`benchmark_single_sdpa.py` is provided to conveniently run a single SDPA operation. Try running `python benchmark_single_sdpa.py --help` to see available flags. +# Or create programmatically +config = BenchmarkConfig( + name="custom", + models=[ModelPreset("test", 64, 8, 128)], + seqlens=[(4096, 4096)], + backends=["cudnn"], +) -Example commands and outputs: +runner = BenchmarkRunner() +results = runner.run_config(config) +runner.save_csv(results, config) ``` -## For running various PyTorch backends (FlashAttention, cuDNN, ...) or FlashAttention-2: -$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend pyt_cudnn --data_type bfloat16 --fwd_bwd -pyt_cudnn:: Median (fwd, bwd) Execution Times: 24.645 ms (1428 TFLOPS), 78.674 ms (1118 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations) -## For directly running cuDNN via cuDNN Frontend -$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type bfloat16 --fwd_bwd -cudnn_fe:: Median (fwd, bwd) Execution Times: 24.543 ms (1434 TFLOPS), 73.210 ms (1201 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations) +## Supported Backends + +| Backend | Description | +|---------|-------------| +| `cudnn` | cuDNN (native, via cuDNN Frontend) | +| `flash_attention_4` | FlashAttention 4 | +| `flash_attention_3` | FlashAttention 3 | +| `pyt_flash_attention` | PyTorch FlashAttention | +| `pyt_cudnn` | PyTorch cuDNN backend | +| `pyt_efficient_attention` | PyTorch xFormers | + +## Benchmark Results + +### GB200 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB200](results/gb200_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB200](results/gb200_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB200](results/gb200_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB300 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB300](results/gb300_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB300](results/gb300_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB300](results/gb300_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU -## For running cuDNN FP8 -$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 32768 --kv_seqlen 32768 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --data_type fp8 --fwd_bwd -cudnn_fe:: Median (fwd, bwd) Execution Times: 21.334 ms (1649 TFLOPS), 56.373 ms (1560 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations) -``` - -The cuDNN version used in the benchmark can be replaced by setting the `LD_LIBRARY_PATH` environment variable. -``` -$ export LD_LIBRARY_PATH= -$ python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 16384 --kv_seqlen 16384 --num_q_heads 128 --num_kv_heads 8 --head_dim 128 --attn_mask top_left --data_type bfloat16 --num_iterations 10 --sdpa_backend cudnn_fe --fwd_bwd --data_type fp8 --verbose -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91002 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.12.0' -[INFO] torch.__version__ = '2.8.0a0+5228986c39.nv25.06' -[INFO] torch.version.cuda = '12.9' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 1 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA B200' -cudnn_fe:: Median (fwd, bwd) Execution Times: 5.634 ms (1561 TFLOPS), 15.282 ms (1439 TFLOPS) (max difference vs. pyt_reference: 0.000000 from 10 iterations) -``` \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/__init__.py b/benchmark/sdpa_benchmark_training/__init__.py new file mode 100644 index 00000000..dd665e2d --- /dev/null +++ b/benchmark/sdpa_benchmark_training/__init__.py @@ -0,0 +1,40 @@ +""" +SDPA Benchmark Training Package + +This package provides a flexible benchmark configuration system for +Scaled Dot Product Attention (SDPA) operations. + +Usage: + # Run benchmarks from command line + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + + # Dry run to see what would be executed + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Import and use programmatically + from benchmark.sdpa_benchmark_training import ( + BenchmarkRunner, + BenchmarkConfig, + BenchmarkResult, + ModelPreset, + load_config, + ) + + config = load_config("mlperf") + runner = BenchmarkRunner() + results = runner.run_config(config) + runner.save_csv(results, config) +""" + +from .config_types import ModelPreset, BenchmarkConfig, BenchmarkResult +from .configs import load_config, list_configs +from .runner import BenchmarkRunner + +__all__ = [ + "ModelPreset", + "BenchmarkConfig", + "BenchmarkResult", + "BenchmarkRunner", + "load_config", + "list_configs", +] diff --git a/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_bf16_run.txt b/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_bf16_run.txt deleted file mode 100644 index abeb7ecc..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_bf16_run.txt +++ /dev/null @@ -1,52 +0,0 @@ -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB200' -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.backends.cudnn.enabled = True -[INFO] flash_attn.__version__ = '2.8.3' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_bf16_sdpa.py:240: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 1024, 1024, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 2048, 2048, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 4096, 4096, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 8192, 8192, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 16384, 16384, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 32768, 32768, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 65536, 65536, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 131072, 131072, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Saving results to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.csv -[INFO] Saving plot to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_fp8_run.txt b/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_fp8_run.txt deleted file mode 100644 index 78f68afc..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sample_gb200_fp8_run.txt +++ /dev/null @@ -1,41 +0,0 @@ -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB200' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -[INFO] Benchmarking data type fp8 -/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_fp8_sdpa.py:227: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 1024, 1024, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 2048, 2048, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 4096, 4096, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 8192, 8192, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 16384, 16384, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 32768, 32768, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 65536, 65536, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 131072, 131072, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Saving results to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.csv -[INFO] Saving plot to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_bf16_run.txt b/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_bf16_run.txt deleted file mode 100644 index d6e5116d..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_bf16_run.txt +++ /dev/null @@ -1,52 +0,0 @@ -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB300' -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.backends.cudnn.enabled = True -[INFO] flash_attn.__version__ = '2.8.3' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_bf16_sdpa.py:240: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 1024, 1024, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 2048, 2048, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 4096, 4096, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 8192, 8192, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 16384, 16384, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 32768, 32768, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 65536, 65536, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Running layer (1, 131072, 131072, 128, 8, 128) -[INFO] Benchmarking backend cudnn_fe -[INFO] Benchmarking backend pyt_flash_attention -[INFO] Benchmarking backend cudnn_fe_fp8 -[INFO] Saving results to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.csv -[INFO] Saving plot to ./artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_fp8_run.txt b/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_fp8_run.txt deleted file mode 100644 index 482cccc1..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sample_gb300_fp8_run.txt +++ /dev/null @@ -1,41 +0,0 @@ -[INFO] cuDNN Backend Version: cudnn.backend_version() = 91700 [16/1484] -[INFO] cuDNN Frontend Version: cudnn.__version__ = '1.16.0' -[INFO] torch.__version__ = '2.10.0a0+b558c986e8.nv25.11' -[INFO] torch.version.cuda = '13.0' -[INFO] torch.cuda.is_available() = True -[INFO] torch.cuda.device_count() = 4 -[INFO] torch.cuda.current_device() = 0 -[INFO] torch.cuda.get_device_name(torch.cuda.current_device()) = 'NVIDIA GB300' -[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim) -[INFO] sdpa_configs = [(1, 512, 512, 128, 8, 128), (1, 1024, 1024, 128, 8, 128), (1, 2048, 2048, 128, 8, 128), (1, 4096, 4096, 128, 8, 128), (1, 8192, 8192, 128, 8, 128), (1, 16384, 16384, 128, 8, 128), (1, 32768, 32768, 128, 8, 128), (1, 65536, 65536, 128, 8, 128), (1, 131072, 131072, 128, 8, 128)] -[INFO] Running layer (1, 512, 512, 128, 8, 128) -[INFO] Benchmarking data type fp8 -/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_fp8_sdpa.py:227: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation. - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 1024, 1024, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 2048, 2048, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 4096, 4096, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 8192, 8192, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 16384, 16384, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 32768, 32768, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 65536, 65536, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Running layer (1, 131072, 131072, 128, 8, 128) -[INFO] Benchmarking data type fp8 -[INFO] Benchmarking data type bf16 -[INFO] Saving results to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.csv -[INFO] Saving plot to ./artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png \ No newline at end of file diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.csv b/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.csv deleted file mode 100644 index 6d9837c8..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.csv +++ /dev/null @@ -1,28 +0,0 @@ -batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,is_causal,backend,forward_time,backward_time,fwd_tflops_per_sec,bwd_tflops_per_sec -1,512,512,128,8,128,True,cudnn_fe,0.021,0.065,408.000,329.000 -1,512,512,128,8,128,True,pyt_flash_attention,0.053,0.147,163.000,146.000 -1,512,512,128,8,128,True,cudnn_fe_fp8,0.018,0.072,466.000,300.000 -1,1024,1024,128,8,128,True,cudnn_fe,0.047,0.163,733.000,527.000 -1,1024,1024,128,8,128,True,pyt_flash_attention,0.135,0.416,255.000,207.000 -1,1024,1024,128,8,128,True,cudnn_fe_fp8,0.040,0.156,870.000,551.000 -1,2048,2048,128,8,128,True,cudnn_fe,0.129,0.525,1069.000,655.000 -1,2048,2048,128,8,128,True,pyt_flash_attention,0.409,1.284,336.000,268.000 -1,2048,2048,128,8,128,True,cudnn_fe_fp8,0.104,0.511,1321.000,672.000 -1,4096,4096,128,8,128,True,cudnn_fe,0.397,1.304,1386.000,1054.000 -1,4096,4096,128,8,128,True,pyt_flash_attention,1.422,4.414,387.000,311.000 -1,4096,4096,128,8,128,True,cudnn_fe_fp8,0.331,1.264,1661.000,1088.000 -1,8192,8192,128,8,128,True,cudnn_fe,1.389,4.542,1584.000,1211.000 -1,8192,8192,128,8,128,True,pyt_flash_attention,5.308,16.249,414.000,338.000 -1,8192,8192,128,8,128,True,cudnn_fe_fp8,1.162,4.359,1893.000,1261.000 -1,16384,16384,128,8,128,True,cudnn_fe,5.273,16.863,1668.000,1304.000 -1,16384,16384,128,8,128,True,pyt_flash_attention,20.445,62.397,430.000,352.000 -1,16384,16384,128,8,128,True,cudnn_fe_fp8,4.326,16.051,2033.000,1370.000 -1,32768,32768,128,8,128,True,cudnn_fe,20.937,65.277,1681.000,1348.000 -1,32768,32768,128,8,128,True,pyt_flash_attention,80.322,244.331,438.000,360.000 -1,32768,32768,128,8,128,True,cudnn_fe_fp8,16.563,61.567,2124.000,1429.000 -1,65536,65536,128,8,128,True,cudnn_fe,89.454,286.341,1573.000,1229.000 -1,65536,65536,128,8,128,True,pyt_flash_attention,318.301,969.457,442.000,363.000 -1,65536,65536,128,8,128,True,cudnn_fe_fp8,64.533,242.633,2181.000,1450.000 -1,131072,131072,128,8,128,True,cudnn_fe,383.713,1184.584,1467.000,1188.000 -1,131072,131072,128,8,128,True,pyt_flash_attention,1267.708,3872.480,444.000,363.000 -1,131072,131072,128,8,128,True,cudnn_fe_fp8,259.939,977.151,2166.000,1440.000 diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png b/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png deleted file mode 100644 index f02ebf08..00000000 Binary files a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB200.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.csv b/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.csv deleted file mode 100644 index c0796b06..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.csv +++ /dev/null @@ -1,28 +0,0 @@ -batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,is_causal,backend,forward_time,backward_time,fwd_tflops_per_sec,bwd_tflops_per_sec -1,512,512,128,8,128,True,cudnn_fe,0.019,0.065,461.000,330.000 -1,512,512,128,8,128,True,pyt_flash_attention,0.052,0.145,165.000,149.000 -1,512,512,128,8,128,True,cudnn_fe_fp8,0.015,0.069,564.000,310.000 -1,1024,1024,128,8,128,True,cudnn_fe,0.041,0.165,839.000,522.000 -1,1024,1024,128,8,128,True,pyt_flash_attention,0.134,0.409,256.000,210.000 -1,1024,1024,128,8,128,True,cudnn_fe_fp8,0.032,0.148,1090.000,580.000 -1,2048,2048,128,8,128,True,cudnn_fe,0.105,0.543,1312.000,633.000 -1,2048,2048,128,8,128,True,pyt_flash_attention,0.406,1.267,339.000,271.000 -1,2048,2048,128,8,128,True,cudnn_fe_fp8,0.081,0.496,1707.000,693.000 -1,4096,4096,128,8,128,True,cudnn_fe,0.327,1.239,1682.000,1110.000 -1,4096,4096,128,8,128,True,pyt_flash_attention,1.416,4.349,388.000,316.000 -1,4096,4096,128,8,128,True,cudnn_fe_fp8,0.251,1.168,2194.000,1177.000 -1,8192,8192,128,8,128,True,cudnn_fe,1.158,4.256,1899.000,1292.000 -1,8192,8192,128,8,128,True,pyt_flash_attention,5.259,16.023,418.000,343.000 -1,8192,8192,128,8,128,True,cudnn_fe_fp8,0.857,3.957,2565.000,1390.000 -1,16384,16384,128,8,128,True,cudnn_fe,4.448,15.721,1978.000,1399.000 -1,16384,16384,128,8,128,True,pyt_flash_attention,20.296,61.516,433.000,357.000 -1,16384,16384,128,8,128,True,cudnn_fe_fp8,3.131,14.340,2809.000,1534.000 -1,32768,32768,128,8,128,True,cudnn_fe,17.536,61.187,2007.000,1438.000 -1,32768,32768,128,8,128,True,pyt_flash_attention,79.734,240.847,441.000,365.000 -1,32768,32768,128,8,128,True,cudnn_fe_fp8,11.856,54.986,2968.000,1600.000 -1,65536,65536,128,8,128,True,cudnn_fe,72.127,251.530,1951.000,1399.000 -1,65536,65536,128,8,128,True,pyt_flash_attention,315.773,955.743,446.000,368.000 -1,65536,65536,128,8,128,True,cudnn_fe_fp8,45.963,216.262,3062.000,1627.000 -1,131072,131072,128,8,128,True,cudnn_fe,309.147,1034.802,1821.000,1360.000 -1,131072,131072,128,8,128,True,pyt_flash_attention,1257.150,3815.728,448.000,369.000 -1,131072,131072,128,8,128,True,cudnn_fe_fp8,188.234,876.367,2991.000,1606.000 diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png b/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png deleted file mode 100644 index 560844f1..00000000 Binary files a/benchmark/sdpa_benchmark_training/artifacts/sdpa_bf16_benchmark_results_NVIDIA_GB300.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.csv b/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.csv deleted file mode 100644 index 3e94ae06..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.csv +++ /dev/null @@ -1,19 +0,0 @@ -batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,is_causal,precision,forward_time,backward_time,fwd_tflops_per_sec,bwd_tflops_per_sec -1,512,512,128,8,128,True,fp8,0.018,0.072,470.000,301.000 -1,512,512,128,8,128,True,bf16,0.021,0.066,408.000,328.000 -1,1024,1024,128,8,128,True,fp8,0.040,0.157,870.000,547.000 -1,1024,1024,128,8,128,True,bf16,0.047,0.163,733.000,528.000 -1,2048,2048,128,8,128,True,fp8,0.104,0.512,1324.000,671.000 -1,2048,2048,128,8,128,True,bf16,0.129,0.524,1068.000,656.000 -1,4096,4096,128,8,128,True,fp8,0.331,1.264,1663.000,1088.000 -1,4096,4096,128,8,128,True,bf16,0.399,1.304,1378.000,1054.000 -1,8192,8192,128,8,128,True,fp8,1.161,4.352,1894.000,1263.000 -1,8192,8192,128,8,128,True,bf16,1.389,4.539,1583.000,1211.000 -1,16384,16384,128,8,128,True,fp8,4.327,16.020,2033.000,1373.000 -1,16384,16384,128,8,128,True,bf16,5.257,16.875,1673.000,1303.000 -1,32768,32768,128,8,128,True,fp8,16.564,61.559,2124.000,1429.000 -1,32768,32768,128,8,128,True,bf16,20.855,67.092,1687.000,1311.000 -1,65536,65536,128,8,128,True,fp8,64.518,242.637,2181.000,1450.000 -1,65536,65536,128,8,128,True,bf16,88.377,282.172,1592.000,1247.000 -1,131072,131072,128,8,128,True,fp8,259.183,977.197,2172.000,1440.000 -1,131072,131072,128,8,128,True,bf16,382.750,1185.117,1471.000,1188.000 diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png b/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png deleted file mode 100644 index 3fe61067..00000000 Binary files a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB200.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.csv b/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.csv deleted file mode 100644 index c1d49f33..00000000 --- a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.csv +++ /dev/null @@ -1,19 +0,0 @@ -batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,is_causal,precision,forward_time,backward_time,fwd_tflops_per_sec,bwd_tflops_per_sec -1,512,512,128,8,128,True,fp8,0.015,0.069,557.000,311.000 -1,512,512,128,8,128,True,bf16,0.019,0.066,461.000,327.000 -1,1024,1024,128,8,128,True,fp8,0.032,0.149,1083.000,576.000 -1,1024,1024,128,8,128,True,bf16,0.041,0.165,838.000,522.000 -1,2048,2048,128,8,128,True,fp8,0.081,0.494,1700.000,696.000 -1,2048,2048,128,8,128,True,bf16,0.105,0.542,1310.000,634.000 -1,4096,4096,128,8,128,True,fp8,0.250,1.167,2197.000,1178.000 -1,4096,4096,128,8,128,True,bf16,0.327,1.238,1680.000,1110.000 -1,8192,8192,128,8,128,True,fp8,0.857,3.926,2565.000,1400.000 -1,8192,8192,128,8,128,True,bf16,1.158,4.253,1899.000,1293.000 -1,16384,16384,128,8,128,True,fp8,3.130,14.336,2810.000,1534.000 -1,16384,16384,128,8,128,True,bf16,4.453,15.717,1975.000,1399.000 -1,32768,32768,128,8,128,True,fp8,11.857,54.790,2967.000,1605.000 -1,32768,32768,128,8,128,True,bf16,17.538,61.182,2006.000,1438.000 -1,65536,65536,128,8,128,True,fp8,45.918,216.393,3065.000,1626.000 -1,65536,65536,128,8,128,True,bf16,71.471,251.354,1969.000,1400.000 -1,131072,131072,128,8,128,True,fp8,188.393,876.942,2988.000,1605.000 -1,131072,131072,128,8,128,True,bf16,309.487,1034.596,1819.000,1360.000 diff --git a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png b/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png deleted file mode 100644 index ae8c0b88..00000000 Binary files a/benchmark/sdpa_benchmark_training/artifacts/sdpa_fp8_benchmark_results_NVIDIA_GB300.png and /dev/null differ diff --git a/benchmark/sdpa_benchmark_training/benchmark_bf16_sdpa.py b/benchmark/sdpa_benchmark_training/benchmark_bf16_sdpa.py deleted file mode 100644 index 00b3e0e4..00000000 --- a/benchmark/sdpa_benchmark_training/benchmark_bf16_sdpa.py +++ /dev/null @@ -1,356 +0,0 @@ -""" -Scaled Dot Product Attention (SDPA) benchmark - -This script benchmarks several SDPA backends including cuDNN using torch profiler. -Output csv and png files are saved in the artifacts directory. - -""" - -import torch -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import subprocess -import os -import sys -import cudnn - -###### SDPA Benchmark -- Setup ###### -## Define constants for benchmarking -verbose = True -data_type = "bfloat16" # Data type for benchmarking -num_iters = 30 # Number of iterations to run for each config; take median time -dry_run_iters = 0 # Number of iterations to dry run for warmup -attn_mask = "top_left" # Causal mask type (top_left is equivalent to is_causal=True) -backends = [ - "cudnn_fe", - # 'pyt_efficient_attention', # Disabled for GQA - "pyt_flash_attention", - "cudnn_fe_fp8", # cuDNN FE with FP8 precision -] - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -## Define SDPA configs -# Add or remove configs to benchmark; results will be included in output csv. -# Note: Altering the configs may result in incorrectly generated plots. - -# (batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim) -sdpa_configs = [ - (1, 512, 512, 128, 8, 128), - (1, 1024, 1024, 128, 8, 128), - (1, 2048, 2048, 128, 8, 128), - (1, 4096, 4096, 128, 8, 128), - (1, 8192, 8192, 128, 8, 128), - (1, 16384, 16384, 128, 8, 128), - (1, 32768, 32768, 128, 8, 128), - (1, 65536, 65536, 128, 8, 128), - (1, 131072, 131072, 128, 8, 128), -] - -## Helper function to run benchmark_single_sdpa.py and parse its output -def run_single_benchmark(config, backend): - """ - Run benchmark_single_sdpa.py for a single configuration and backend. - - Args: - config: Tuple of (batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim) - backend: Backend name (e.g., "pyt_cudnn", "flash_attention_4") - - Returns: - Dictionary with benchmark results or None if failed - """ - batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim = config - - # Get the directory of the current script - script_dir = os.path.dirname(os.path.abspath(__file__)) - benchmark_script = os.path.join(script_dir, "benchmark_single_sdpa.py") - - # Handle cudnn_fe_fp8 specially: use cudnn_fe backend with fp8 data type - if backend == "cudnn_fe_fp8": - actual_backend = "cudnn_fe" - actual_data_type = "fp8" - else: - actual_backend = backend - actual_data_type = data_type - - # Build command - cmd = [ - sys.executable, # Use the same Python interpreter - benchmark_script, - "--batch_size", str(batch_size), - "--q_seqlen", str(q_seqlen), - "--kv_seqlen", str(kv_seqlen), - "--num_q_heads", str(num_q_heads), - "--num_kv_heads", str(num_kv_heads), - "--head_dim", str(head_dim), - "--data_type", actual_data_type, - "--num_iterations", str(num_iters), - "--num_warmup_iterations", str(dry_run_iters), - "--sdpa_backend", actual_backend, - "--attn_mask", attn_mask, - "--fwd_bwd", # Run both forward and backward - "--format_output", # Get CSV-formatted output - "--skip_ref", # Skip reference check for speed - ] - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False, # Don't raise exception on non-zero exit - ) - - if result.returncode != 0: - if verbose: - print(f" [WARNING] Benchmark failed with return code {result.returncode}") - print(f" stderr: {result.stderr}") - # Return entry with infinite times and 0 TFLOPs for failed benchmarks - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'backend': backend, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - - # Parse output - format is: - # case_tag,backend,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,fwd_time,bwd_time,fwd_tflops,bwd_tflops,max_diff,num_iters - output_line = result.stdout.strip().split('\n')[-1] # Get last line - parts = output_line.split(',') - - if len(parts) < 12: - if verbose: - print(f" [WARNING] Unexpected output format: {output_line}") - # Return entry with infinite times and 0 TFLOPs for parse failures - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'backend': backend, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - - return { - 'batch_size': int(parts[2]), - 'q_seqlen': int(parts[3]), - 'kv_seqlen': int(parts[4]), - 'num_q_heads': int(parts[5]), - 'num_kv_heads': int(parts[6]), - 'head_dim': int(parts[7]), - 'is_causal': attn_mask == "top_left", - 'backend': backend, # Use original backend name for plotting - 'forward_time': float(parts[8]), - 'backward_time': float(parts[9]), - 'fwd_tflops_per_sec': float(parts[10]), - 'bwd_tflops_per_sec': float(parts[11]), - } - except Exception as e: - if verbose: - print(f" [ERROR] Failed to run benchmark: {e}") - # Return entry with infinite times and 0 TFLOPs for exceptions - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'backend': backend, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - - -###### SDPA Benchmark -- Run ###### -## Print System Info -print(f"[INFO] {torch.__version__ = }") -print(f"[INFO] {torch.version.cuda = }") -print(f"[INFO] {torch.cuda.is_available() = }") -print(f"[INFO] {torch.cuda.device_count() = }") -print(f"[INFO] {torch.cuda.current_device() = }") -print(f"[INFO] {torch.cuda.get_device_name(torch.cuda.current_device()) = }") -print(f"[INFO] cuDNN Backend Version: {cudnn.backend_version() = }") -print(f"[INFO] cuDNN Frontend Version: {cudnn.__version__ = }") -print(f"[INFO] {torch.backends.cudnn.enabled = }") -try: - import flash_attn - print(f"[INFO] {flash_attn.__version__ = }") -except ImportError: - pass - -## Begin Benchmark -# Define dataframe to store results -data_df = pd.DataFrame( - columns=[ - "batch_size", - "q_seqlen", - "kv_seqlen", - "num_q_heads", - "num_kv_heads", - "head_dim", - "is_causal", - "backend", - "forward_time", - "backward_time", - "fwd_tflops_per_sec", - "bwd_tflops_per_sec", - ] -) - -if verbose: - print( - f"[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)" - ) - print(f"[INFO] {sdpa_configs = }") - -# Iterate over each SDPA config -for sdpa_config in sdpa_configs: - batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim = sdpa_config - if verbose: - print(f"[INFO] Running layer {sdpa_config}") - - # Iterate over each backend - for cur_backend in backends: - print(f"[INFO] Benchmarking backend {cur_backend}") - - # Run benchmark via subprocess - result = run_single_benchmark(sdpa_config, cur_backend) - - # Append data to table (result is always a dict, never None) - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) - -## Save results to a csv file -gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_") -output_file_name = f"./artifacts/sdpa_bf16_benchmark_results_{gpu_name}.csv" -if verbose: - print(f"[INFO] Saving results to {output_file_name}") -try: - data_df.to_csv(output_file_name, float_format="%.3f", index=False) -except Exception as e: - print(f"[ERROR] Failed to save results to {output_file_name}: {e}") - print(f"[INFO] Printing results to console instead") - print(data_df.to_csv(float_format="%.3f", index=False)) - print(f"[INFO] Printing results to console done") - -###### SDPA Benchmark -- Plot ###### -## Generate plots for (num_q_heads=128, num_kv_heads=8, head_dim=128, is_causal=True) - -# Configurations for bar plots -backend_ordering = { - "pyt_math": 0, - "pyt_efficient_attention": 1, - "pyt_flash_attention": 2, - "flash_attention": 3, - "cudnn_fe": 4, - "cudnn_fe_fp8": 5, # cuDNN FE with FP8 precision -} -backend_name = { - "pyt_math": "Standard Attention", - "pyt_efficient_attention": "xFormers (PyTorch)", - "pyt_flash_attention": "FAv2 (PyTorch)", - "flash_attention": "FAv2 (Native)", - "cudnn_fe": "cuDNN BF16 (Native)", - "cudnn_fe_fp8": "cuDNN FP8 (Native)", # cuDNN FE with FP8 precision -} -backend_barplot_color = { - backend_name["pyt_math"]: "darkorange", - backend_name["pyt_efficient_attention"]: "magenta", - backend_name["pyt_flash_attention"]: "royalblue", - backend_name["flash_attention"]: "lightcoral", - backend_name["cudnn_fe"]: "#76b900", - backend_name["cudnn_fe_fp8"]: "gold", # cuDNN FE with FP8 precision -} -LABEL_FONT_SIZE = 8 -LEGEND_FONT_SIZE = 6 -TITLE_FONT_SIZE = 9 -# Select desired cases -plot_df = data_df[ - (data_df["is_causal"] == True) - & (data_df["num_q_heads"] == 128) - & (data_df["num_kv_heads"] == 8) - & (data_df["q_seqlen"] == data_df["kv_seqlen"]) - & (data_df["head_dim"] == 128) -].copy() - -plot_df["backend_rank"] = plot_df["backend"].map(backend_ordering) -plot_df["backend_name"] = plot_df["backend"].map(backend_name) -plot_df.sort_values(["q_seqlen", "backend_rank"], inplace=True) - -# Generate plots: forward on left subplot and backward on right subplot -YLIM_MAX = ( - np.max([plot_df["fwd_tflops_per_sec"].max(), plot_df["bwd_tflops_per_sec"].max()]) - * 1.1 -) - -plt.figure(figsize=(10, 4), dpi=200) -plt.subplot(1, 2, 1) -cur_plot_df = plot_df[plot_df.fwd_tflops_per_sec > 0] -ax = sns.barplot( - data=cur_plot_df, - x="q_seqlen", - y="fwd_tflops_per_sec", - hue="backend_name", - edgecolor="black", - linewidth=0.5, - palette=backend_barplot_color, -) -for container in ax.containers: - ax.bar_label(container, fmt="%.f", fontsize=6) -plt.xticks(rotation=45) -plt.xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) -plt.ylabel("Speed (TFLOPs/sec)", fontsize=LABEL_FONT_SIZE) -plt.title("SDPA Forward", fontsize=TITLE_FONT_SIZE) -plt.tick_params(axis="y", which="major", labelsize=LABEL_FONT_SIZE) -plt.tick_params(axis="x", which="major", labelsize=LABEL_FONT_SIZE) -plt.ylim(0, YLIM_MAX) -plt.legend(fontsize=LEGEND_FONT_SIZE, loc="upper left") -plt.subplot(1, 2, 2) -cur_plot_df = plot_df[plot_df.bwd_tflops_per_sec > 0] -ax = sns.barplot( - data=cur_plot_df, - x="q_seqlen", - y="bwd_tflops_per_sec", - hue="backend_name", - edgecolor="black", - linewidth=0.5, - palette=backend_barplot_color, -) -for container in ax.containers: - ax.bar_label(container, fmt="%.f", fontsize=6) -plt.xticks(rotation=45) -plt.xlabel("SequenceLength", fontsize=LABEL_FONT_SIZE) -plt.ylabel("Speed (TFLOPs/sec)", fontsize=LABEL_FONT_SIZE) -plt.title("SDPA Backward", fontsize=TITLE_FONT_SIZE) -plt.tick_params(axis="y", which="major", labelsize=LABEL_FONT_SIZE) -plt.tick_params(axis="x", which="major", labelsize=LABEL_FONT_SIZE) -plt.ylim(0, YLIM_MAX) -plt.legend(fontsize=LEGEND_FONT_SIZE, loc="upper left") -# Save plot -plt.tight_layout() -png_file_name = f"./artifacts/sdpa_bf16_benchmark_results_{gpu_name}.png" -if verbose: - print(f"[INFO] Saving plot to {png_file_name}") -try: - plt.savefig(png_file_name) -except Exception as e: - print(f"[ERROR] Failed to save plot to {png_file_name}: {e}") diff --git a/benchmark/sdpa_benchmark_training/benchmark_fp8_sdpa.py b/benchmark/sdpa_benchmark_training/benchmark_fp8_sdpa.py deleted file mode 100644 index f8ff2262..00000000 --- a/benchmark/sdpa_benchmark_training/benchmark_fp8_sdpa.py +++ /dev/null @@ -1,366 +0,0 @@ -""" -Scaled Dot Product Attention (SDPA) benchmark - -This script benchmarks several SDPA backends including cuDNN using torch profiler. -Output csv and png files are saved in the artifacts directory. - -""" - -import torch -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import subprocess -import os -import sys - -###### SDPA Benchmark -- Setup ###### -## Define constants for benchmarking -verbose = True -num_iters = 10 # Number of iterations to run for each config; take median time -dry_run_iters = 5 # Number of iterations to dry run for warmup -attn_mask = "top_left" # Causal mask type (top_left is equivalent to is_causal=True) -backend = "cudnn_fe" # Backend to use for benchmarking -precisions = ["fp8", "bf16"] - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -## Define SDPA configs -# Add or remove configs to benchmark; results will be included in output csv. -# Note: Altering the configs may result in incorrectly generated plots. - -# (batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim) -sdpa_configs = [ - (1, 512, 512, 128, 8, 128), - (1, 1024, 1024, 128, 8, 128), - (1, 2048, 2048, 128, 8, 128), - (1, 4096, 4096, 128, 8, 128), - (1, 8192, 8192, 128, 8, 128), - (1, 16384, 16384, 128, 8, 128), - (1, 32768, 32768, 128, 8, 128), - (1, 65536, 65536, 128, 8, 128), - (1, 131072, 131072, 128, 8, 128), -] - -## Helper function to run benchmark_single_sdpa.py and parse its output -def run_single_benchmark(config, precision): - """ - Run benchmark_single_sdpa.py for a single configuration and precision. - - Args: - config: Tuple of (batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim) - precision: Data type (e.g., "fp8", "bf16") - - Returns: - Dictionary with benchmark results or None if failed - """ - batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim = config - - # Get the directory of the current script - script_dir = os.path.dirname(os.path.abspath(__file__)) - benchmark_script = os.path.join(script_dir, "benchmark_single_sdpa.py") - - # Map precision names - data_type = "bfloat16" if precision == "bf16" else precision - - # Build command - cmd = [ - sys.executable, # Use the same Python interpreter - benchmark_script, - "--batch_size", str(batch_size), - "--q_seqlen", str(q_seqlen), - "--kv_seqlen", str(kv_seqlen), - "--num_q_heads", str(num_q_heads), - "--num_kv_heads", str(num_kv_heads), - "--head_dim", str(head_dim), - "--data_type", data_type, - "--num_iterations", str(num_iters), - "--num_warmup_iterations", str(dry_run_iters), - "--sdpa_backend", backend, - "--attn_mask", attn_mask, - "--fwd_bwd", # Run both forward and backward - "--format_output", # Get CSV-formatted output - "--skip_ref", # Skip reference check for speed - ] - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False, # Don't raise exception on non-zero exit - ) - - if result.returncode != 0: - if verbose: - print(f" [WARNING] Benchmark failed with return code {result.returncode}") - print(f" stderr: {result.stderr}") - # Return entry with infinite times and 0 TFLOPs for failed benchmarks - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'precision': precision, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - - # Parse output - format is: - # case_tag,backend,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,fwd_time,bwd_time,fwd_tflops,bwd_tflops,max_diff,num_iters - output_line = result.stdout.strip().split('\n')[-1] # Get last line - parts = output_line.split(',') - - if len(parts) < 12: - if verbose: - print(f" [WARNING] Unexpected output format: {output_line}") - # Return entry with infinite times and 0 TFLOPs for parse failures - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'precision': precision, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - - return { - 'batch_size': int(parts[2]), - 'q_seqlen': int(parts[3]), - 'kv_seqlen': int(parts[4]), - 'num_q_heads': int(parts[5]), - 'num_kv_heads': int(parts[6]), - 'head_dim': int(parts[7]), - 'is_causal': attn_mask == "top_left", - 'precision': precision, - 'forward_time': float(parts[8]), - 'backward_time': float(parts[9]), - 'fwd_tflops_per_sec': float(parts[10]), - 'bwd_tflops_per_sec': float(parts[11]), - } - except Exception as e: - if verbose: - print(f" [ERROR] Failed to run benchmark: {e}") - # Return entry with infinite times and 0 TFLOPs for exceptions - return { - 'batch_size': batch_size, - 'q_seqlen': q_seqlen, - 'kv_seqlen': kv_seqlen, - 'num_q_heads': num_q_heads, - 'num_kv_heads': num_kv_heads, - 'head_dim': head_dim, - 'is_causal': attn_mask == "top_left", - 'precision': precision, - 'forward_time': np.inf, - 'backward_time': np.inf, - 'fwd_tflops_per_sec': 0.0, - 'bwd_tflops_per_sec': 0.0, - } - -###### SDPA Benchmark -- Run ###### -## Print System Info -try: - import cudnn - print(f"[INFO] cuDNN Backend Version: {cudnn.backend_version() = }") - print(f"[INFO] cuDNN Frontend Version: {cudnn.__version__ = }") -except ImportError: - pass - -print(f"[INFO] {torch.__version__ = }") -print(f"[INFO] {torch.version.cuda = }") -print(f"[INFO] {torch.cuda.is_available() = }") -print(f"[INFO] {torch.cuda.device_count() = }") -print(f"[INFO] {torch.cuda.current_device() = }") -print(f"[INFO] {torch.cuda.get_device_name(torch.cuda.current_device()) = }") - -## Begin Benchmark -# Define dataframe to store results -data_df = pd.DataFrame( - columns=[ - "batch_size", - "q_seqlen", - "kv_seqlen", - "num_q_heads", - "num_kv_heads", - "head_dim", - "is_causal", - "precision", - "forward_time", - "backward_time", - "fwd_tflops_per_sec", - "bwd_tflops_per_sec", - ] -) - -if verbose: - print( - f"[INFO] Begin benchmark for layers (batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim)" - ) - print(f"[INFO] {sdpa_configs = }") - -# Iterate over each SDPA config -for sdpa_config in sdpa_configs: - batch_size, q_seqlen, kv_seqlen, num_q_heads, num_kv_heads, head_dim = sdpa_config - if verbose: - print(f"[INFO] Running layer {sdpa_config}") - - # Iterate over each precision - for cur_precision in precisions: - print(f"[INFO] Benchmarking data type {cur_precision}") - - # Run benchmark via subprocess - result = run_single_benchmark(sdpa_config, cur_precision) - - # Append data to table (result is always a dict, never None) - data_df = pd.concat([data_df, pd.DataFrame([result])], ignore_index=True) - -## Save results to a csv file -gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_") -output_file_name = f"./artifacts/sdpa_fp8_benchmark_results_{gpu_name}.csv" -if verbose: - print(f"[INFO] Saving results to {output_file_name}") -try: - data_df.to_csv(output_file_name, float_format="%.3f", index=False) -except Exception as e: - print(f"[ERROR] Failed to save results to {output_file_name}: {e}") - print(f"[INFO] Printing results to console instead") - print(data_df.to_csv(float_format="%.3f", index=False)) - print(f"[INFO] Printing results to console done") - -###### SDPA Benchmark -- Plot ###### -## Generate plots for (num_q_heads=128, num_kv_heads=8, head_dim=128, is_causal=True) -baseline_df = data_df[ - (data_df["precision"] == "bf16") & (data_df["q_seqlen"] >= 4000) -].copy() -baseline_df.drop( - columns=[ - "precision", - "fwd_tflops_per_sec", - "bwd_tflops_per_sec", - ], - inplace=True, -) -baseline_df.rename( - columns={ - "forward_time": "baseline_forward_time", - "backward_time": "baseline_backward_time", - }, - inplace=True, -) - -merged_df = baseline_df.merge( - data_df, - on=[ - "batch_size", - "q_seqlen", - "kv_seqlen", - "num_q_heads", - "num_kv_heads", - "head_dim", - "is_causal", - ], -) -merged_df["fwd_speedup"] = ( - merged_df["baseline_forward_time"] / merged_df["forward_time"] -) -merged_df["bwd_speedup"] = ( - merged_df["baseline_backward_time"] / merged_df["backward_time"] -) - - -# Configurations for bar plots -precision_ordering = {"bf16": 0, "fp8": 1} -precision_name = {"bf16": "BFloat16", "fp8": "FP8"} -precision_barplot_color = { - precision_name["bf16"]: "#76b900", - precision_name["fp8"]: "darkgreen", -} -LABEL_FONT_SIZE = 8 -LEGEND_FONT_SIZE = 6 -TITLE_FONT_SIZE = 9 -# Select desired cases -plot_df = merged_df[ - (merged_df["is_causal"] == True) - & (merged_df["num_q_heads"] == 128) - & (merged_df["num_kv_heads"] == 8) - & (merged_df["q_seqlen"] == merged_df["kv_seqlen"]) - & (merged_df["head_dim"] == 128) -].copy() - -plot_df["precision_rank"] = plot_df["precision"].map(precision_ordering) -plot_df["precision_name"] = plot_df["precision"].map(precision_name) -plot_df.sort_values(["q_seqlen", "precision_rank"], inplace=True) - -# Generate plots: forward on left subplot and backward on right subplot -YLIM_MAX = np.max([plot_df["fwd_speedup"].max(), plot_df["bwd_speedup"].max()]) * 1.1 - -plt.figure(figsize=(10, 4), dpi=200) -plt.subplot(1, 2, 1) -cur_plot_df = plot_df[plot_df.fwd_tflops_per_sec > 0] -ax = sns.barplot( - data=cur_plot_df, - x="q_seqlen", - y="fwd_speedup", - hue="precision_name", - edgecolor="black", - linewidth=0.5, - palette=precision_barplot_color, - width=0.6, -) -ax.legend_.set_title(None) -for container in ax.containers: - ax.bar_label(container, fmt="%.2fx", fontsize=6) -plt.xticks(rotation=45) -plt.xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) -plt.ylabel("Speedup", fontsize=LABEL_FONT_SIZE) -plt.title("SDPA Forward", fontsize=TITLE_FONT_SIZE) -plt.tick_params(axis="y", which="major", labelsize=LABEL_FONT_SIZE) -plt.tick_params(axis="x", which="major", labelsize=LABEL_FONT_SIZE) -plt.ylim(0.5, YLIM_MAX) -plt.legend(fontsize=LEGEND_FONT_SIZE, loc="upper left") - -plt.subplot(1, 2, 2) -cur_plot_df = plot_df[plot_df.bwd_tflops_per_sec > 0] -ax = sns.barplot( - data=cur_plot_df, - x="q_seqlen", - y="bwd_speedup", - hue="precision_name", - edgecolor="black", - linewidth=0.5, - palette=precision_barplot_color, - width=0.6, -) -ax.legend_.set_title(None) -for container in ax.containers: - ax.bar_label(container, fmt="%.2fx", fontsize=6) -plt.xticks(rotation=45) -plt.xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) -plt.ylabel("Speedup", fontsize=LABEL_FONT_SIZE) -plt.title("SDPA Backward", fontsize=TITLE_FONT_SIZE) -plt.tick_params(axis="y", which="major", labelsize=LABEL_FONT_SIZE) -plt.tick_params(axis="x", which="major", labelsize=LABEL_FONT_SIZE) -plt.ylim(0.5, YLIM_MAX) -plt.legend(fontsize=LEGEND_FONT_SIZE, loc="upper left") - -# Save plot -plt.tight_layout() -png_file_name = f"./artifacts/sdpa_fp8_benchmark_results_{gpu_name}.png" -if verbose: - print(f"[INFO] Saving plot to {png_file_name}") -try: - plt.savefig(png_file_name) -except Exception as e: - print(f"[ERROR] Failed to save plot to {png_file_name}: {e}") diff --git a/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py b/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py index 8c167ecb..b9d83810 100644 --- a/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py +++ b/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py @@ -4,6 +4,14 @@ This script benchmarks a single SDPA compute instance. The SDPA backend can be chosen. Performance is measured using torch profiler. +Can be used as CLI or imported as a module: + + # CLI usage + python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 8192 ... + + # Module usage + from benchmark_single_sdpa import run_benchmark + result = run_benchmark(batch_size=1, q_seqlen=8192, ...) """ import argparse @@ -15,941 +23,653 @@ import functools import time import math +from typing import Optional, Dict, Any from torch.profiler import profile, record_function, ProfilerActivity -###### SDPA Benchmark -- Parse input arguments ###### -parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("--batch_size", default=1, type=int, help="Batch size to input to the layer") -parser.add_argument("--q_seqlen", default=8192, type=int, help="Sequence length to input to the layer") -parser.add_argument("--kv_seqlen", default=8192, type=int, help="Sequence length to input to the layer") -parser.add_argument( - "--num_q_heads", - default=16, - type=int, - help="Number of query heads to input to the layer", -) -parser.add_argument( - "--num_kv_heads", - default=8, - type=int, - help="Number of key/value heads to input to the layer", -) -parser.add_argument("--head_dim", default=128, type=int, help="Head dimension to input to the layer") -parser.add_argument( - "--head_dim_qk", - default=None, - type=int, - help="Optional: head dimension for Q/K. If set, must also set --head_dim_vo", -) -parser.add_argument( - "--head_dim_vo", - default=None, - type=int, - help="Optional: head dimension for V/O. If set, must also set --head_dim_qk", -) -parser.add_argument( - "--data_type", - default="bfloat16", - type=str, - help="Data type to input to the layer. Can be bfloat16, float16, or fp8", -) -parser.add_argument( - "--num_iterations", - default=20, - type=int, - help="Number of iterations to run the layer for performance measurement", -) -parser.add_argument( - "--num_warmup_iterations", - default=0, - type=int, - help="Number of warmup iterations to run before measuring performance", -) -parser.add_argument("--verbose", action="store_true", help="Verbose output") -parser.add_argument( - "--fwd_bwd", - action="store_true", - help="Run both forward and backward pass (fwd only by default)", -) -parser.add_argument( - "--attn_mask", - default="no_mask", - type=str, - help="Attn mask to use. Can be 'top_left', 'bottom_right', or 'no_mask'.", - choices=["top_left", "bottom_right", "no_mask"], -) -parser.add_argument( - "--sdpa_backend", - default="pyt_cudnn", - type=str, - help="SDPA backend to use", - choices=[ - "pyt_math", - "pyt_cudnn", - "pyt_efficient_attention", - "pyt_flash_attention", - "flash_attention", - "flash_attention_3", - "flash_attention_4", - "cudnn_fe", - ], -) -parser.add_argument("--format_output", action="store_true", help="Format output to be used in benchmark") -parser.add_argument( - "--case_tag", - default="", - type=str, - help="Tag to identify the case. Not used in calculations. Only for formatted output", -) -# skip ref -parser.add_argument( - "--skip_ref", - action="store_true", - help="Skip reference SDPA implementation", -) - -args = parser.parse_args() - -if args.data_type == "bfloat16": - target_dtype = torch.bfloat16 -elif args.data_type == "float16": - target_dtype = torch.float16 -elif args.data_type == "float": - target_dtype = torch.float -elif args.data_type == "fp8": - target_dtype = None -else: - raise ValueError(f"Invalid data type: {args.data_type}") - -if args.data_type == "fp8": - if args.sdpa_backend not in ["cudnn_fe", "flash_attention_3"]: - raise ValueError(f"FP8 is only supported for cudnn_fe and flash_attention_3 backends") - -# Parse input arguments -num_iters = args.num_iterations -dry_run_iters = args.num_warmup_iterations -batch_size = args.batch_size -q_seqlen = args.q_seqlen -kv_seqlen = args.kv_seqlen -num_q_heads = args.num_q_heads -num_kv_heads = args.num_kv_heads -if args.head_dim_qk is None and args.head_dim_vo is None: - head_dim_qk = args.head_dim - head_dim_vo = args.head_dim -elif args.head_dim_qk is not None and args.head_dim_vo is not None: - head_dim_qk = args.head_dim_qk - head_dim_vo = args.head_dim_vo -else: - raise ValueError("Both --head_dim_qk and --head_dim_vo must be provided together when using asymmetric head dims.") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert device.type == "cuda", "Requires CUDA device" -enable_gqa = num_q_heads != num_kv_heads -assert ( - args.attn_mask != "bottom_right" or q_seqlen <= kv_seqlen -), "Bottom right causal mask not supported when q_seqlen > kv_seqlen" -# if args.sdpa_backend in ["flash_attention", "flash_attention_3", "pyt_flash_attention"]: -# assert args.attn_mask != "top_left", "Flash Attention does not support top left causal mask" - -l2_flush_size_mb = 256 -l2_flush_size = l2_flush_size_mb * 1024 * 1024 -l2_flush_buffer = torch.empty(l2_flush_size, device=device, dtype=torch.int8) - -############################################################# -########### Set up SDPA function for each backend ########### - -## If using cuDNN FE, set up cuDNN graph. -if args.sdpa_backend == "cudnn_fe": - is_dropout = False # Hard coded - dropout_prob = dropout_p if is_dropout else 0.0 # Hard coded to 0 - is_infer = False # Hard coded - attn_scale = head_dim_qk ** (-0.5) - - try: - import cudnn - except ImportError: - cudnn = None - assert cudnn is not None - - if args.verbose: - print(f"[INFO] cuDNN Backend Version: {cudnn.backend_version() = }") - print(f"[INFO] cuDNN Frontend Version: {cudnn.__version__ = }") - - # Helper function: Convert torch type to cuDNN type - def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - else: - raise ValueError("Unsupported tensor data type.") - ## Will define tensors to set up cuDNN graph once. - if args.data_type == "fp8": - query = torch.randint( - 256, - (batch_size, q_seqlen, num_q_heads, head_dim_qk), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - key = torch.randint( - 256, - (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - value = torch.randint( - 256, - (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - output = torch.empty( - batch_size, - q_seqlen, - num_q_heads, - head_dim_vo, - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - - descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - else: - query = torch.randn( - batch_size, - q_seqlen, - num_q_heads, - head_dim_qk, - dtype=target_dtype, - device=device, - ).transpose(1, 2) - key = torch.randn( - batch_size, - kv_seqlen, - num_kv_heads, - head_dim_qk, - dtype=target_dtype, - device=device, - ).transpose(1, 2) - value = torch.randn( - batch_size, - kv_seqlen, - num_kv_heads, - head_dim_vo, - dtype=target_dtype, - device=device, - ).transpose(1, 2) - output = torch.empty( - batch_size, - q_seqlen, - num_q_heads, - head_dim_vo, - dtype=target_dtype, - device=device, - ).transpose(1, 2) +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--batch_size", default=1, type=int, help="Batch size to input to the layer") + parser.add_argument("--q_seqlen", default=8192, type=int, help="Sequence length to input to the layer") + parser.add_argument("--kv_seqlen", default=8192, type=int, help="Sequence length to input to the layer") + parser.add_argument( + "--num_q_heads", + default=16, + type=int, + help="Number of query heads to input to the layer", + ) + parser.add_argument( + "--num_kv_heads", + default=8, + type=int, + help="Number of key/value heads to input to the layer", + ) + parser.add_argument("--head_dim", default=128, type=int, help="Head dimension to input to the layer") + parser.add_argument( + "--head_dim_qk", + default=None, + type=int, + help="Optional: head dimension for Q/K. If set, must also set --head_dim_vo", + ) + parser.add_argument( + "--head_dim_vo", + default=None, + type=int, + help="Optional: head dimension for V/O. If set, must also set --head_dim_qk", + ) + parser.add_argument( + "--data_type", + default="bfloat16", + type=str, + help="Data type to input to the layer. Can be bfloat16, float16, or fp8", + ) + parser.add_argument( + "--num_iterations", + default=20, + type=int, + help="Number of iterations to run the layer for performance measurement", + ) + parser.add_argument( + "--num_warmup_iterations", + default=0, + type=int, + help="Number of warmup iterations to run before measuring performance", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--fwd_bwd", + action="store_true", + help="Run both forward and backward pass (fwd only by default)", + ) + parser.add_argument( + "--profile_pass", + default=None, + type=str, + choices=["fwd", "bwd", "both"], + help="Which pass to profile (default: fwd unless --fwd_bwd is set).", + ) + parser.add_argument( + "--deterministic_bwd", + action="store_true", + help="Use deterministic algorithm for backward pass where supported (cudnn FP16/BF16/FP8)", + ) + parser.add_argument( + "--attn_mask", + default="no_mask", + type=str, + help="Attn mask to use. Can be 'top_left', 'bottom_right', or 'no_mask'.", + choices=["top_left", "bottom_right", "no_mask"], + ) + parser.add_argument( + "--sdpa_backend", + default="pyt_cudnn", + type=str, + help="SDPA backend to use", + choices=[ + "pyt_math", + "pyt_cudnn", + "pyt_efficient_attention", + "pyt_flash_attention", + "flash_attention", + "flash_attention_3", + "flash_attention_4", + "cudnn", + ], + ) + parser.add_argument("--format_output", action="store_true", help="Format output to be used in benchmark") + parser.add_argument( + "--case_tag", + default="", + type=str, + help="Tag to identify the case. Not used in calculations. Only for formatted output", + ) + parser.add_argument( + "--skip_ref", + action="store_true", + help="Skip reference SDPA implementation", + ) + return parser.parse_args() + + +def run_benchmark( + batch_size: int, + q_seqlen: int, + kv_seqlen: int, + num_q_heads: int, + num_kv_heads: int, + head_dim: int = 128, + head_dim_qk: Optional[int] = None, + head_dim_vo: Optional[int] = None, + data_type: str = "bfloat16", + backend: str = "cudnn", + attn_mask: str = "no_mask", + profile_pass: str = "fwd", + num_iterations: int = 10, + num_warmup_iterations: int = 0, + skip_ref: bool = True, + deterministic_bwd: bool = False, + verbose: bool = False, +) -> Dict[str, Any]: + """ + Run a single SDPA benchmark. + + This function can be called directly when using the module as a library. + Internally uses subprocess to call this script with the appropriate arguments. + + Args: + batch_size: Batch size + q_seqlen: Query sequence length + kv_seqlen: Key/value sequence length + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads + head_dim: Head dimension (used if head_dim_qk/vo not specified) + head_dim_qk: Head dimension for Q/K (optional, for asymmetric) + head_dim_vo: Head dimension for V/O (optional, for asymmetric) + data_type: Data type ("bfloat16", "float16", "fp8") + backend: Backend name ("cudnn", "flash_attention_4", etc.) + attn_mask: Attention mask ("no_mask", "top_left", "bottom_right") + profile_pass: Which pass to profile ("fwd", "bwd", "both") + num_iterations: Number of benchmark iterations + num_warmup_iterations: Warmup iterations before measurement + skip_ref: Skip reference validation + deterministic_bwd: Use deterministic backward algorithm + verbose: Print verbose output + + Returns: + Dict with keys: + - fwd_time_ms: Median forward time in milliseconds + - bwd_time_ms: Median backward time in milliseconds (0 if not run) + - fwd_tflops: Forward TFLOPS + - bwd_tflops: Backward TFLOPS + - max_diff: Maximum difference vs reference + - gpu_name: GPU name string + - cudnn_version: cuDNN version (if available) + + Raises: + RuntimeError: If the benchmark subprocess fails + """ + import subprocess + import sys + + # Build command + script_path = os.path.abspath(__file__) + cmd = [ + sys.executable, + script_path, + "--batch_size", + str(batch_size), + "--q_seqlen", + str(q_seqlen), + "--kv_seqlen", + str(kv_seqlen), + "--num_q_heads", + str(num_q_heads), + "--num_kv_heads", + str(num_kv_heads), + "--data_type", + data_type, + "--sdpa_backend", + backend, + "--attn_mask", + attn_mask, + "--num_iterations", + str(num_iterations), + "--num_warmup_iterations", + str(num_warmup_iterations), + "--format_output", # Get CSV-formatted output for parsing + ] - dQuery = torch.empty_like(query) - dKey = torch.empty_like(key) - dValue = torch.empty_like(value) - if args.data_type == "fp8": - # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues - dOutput_bf16 = torch.randn(output.shape, dtype=torch.bfloat16, device=device) - dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) - dOutput = dOutput_fp8.view(torch.uint8) + # Handle head dimensions + if head_dim_qk is not None and head_dim_vo is not None: + cmd.extend(["--head_dim_qk", str(head_dim_qk)]) + cmd.extend(["--head_dim_vo", str(head_dim_vo)]) else: - dOutput = torch.randn_like(output) - stats = torch.randn(batch_size, q_seqlen, num_q_heads, 1, dtype=torch.float32, device=device).transpose(1, 2) - if is_dropout: - dropout_seed = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") - dropout_offset = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - - # cuDNN graph forward - graph_fwd = cudnn.pygraph( - io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, + cmd.extend(["--head_dim", str(head_dim)]) + + # Handle profile pass + if profile_pass == "both": + cmd.append("--fwd_bwd") + elif profile_pass in ("fwd", "bwd"): + cmd.extend(["--profile_pass", profile_pass]) + + # Handle flags + if skip_ref: + cmd.append("--skip_ref") + if deterministic_bwd: + cmd.append("--deterministic_bwd") + if verbose: + cmd.append("--verbose") + + # Run benchmark + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, ) - if is_dropout: - seed_fwd = graph_fwd.tensor_like(dropout_seed) - offset_fwd = graph_fwd.tensor_like(dropout_offset) - dropout_tuple = (dropout_prob, seed_fwd, offset_fwd) + if result.returncode != 0: + raise RuntimeError(f"Benchmark failed with return code {result.returncode}.\n" f"stderr: {result.stderr}\n" f"stdout: {result.stdout}") - if args.data_type == "fp8": - q_fwd = graph_fwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) - k_fwd = graph_fwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) - v_fwd = graph_fwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) - - descale_q_fwd = graph_fwd.tensor_like(descale_q_gpu) - descale_k_fwd = graph_fwd.tensor_like(descale_k_gpu) - descale_v_fwd = graph_fwd.tensor_like(descale_v_gpu) - descale_s_fwd = graph_fwd.tensor_like(descale_s_gpu) - scale_s_fwd = graph_fwd.tensor_like(scale_s_gpu) - scale_o_fwd = graph_fwd.tensor_like(scale_o_gpu) - - o_fwd, stats_fwd, amax_s_fwd, amax_o_fwd = graph_fwd.sdpa_fp8( - q=q_fwd, - k=k_fwd, - v=v_fwd, - descale_q=descale_q_fwd, - descale_k=descale_k_fwd, - descale_v=descale_v_fwd, - descale_s=descale_s_fwd, - scale_s=scale_s_fwd, - scale_o=scale_o_fwd, - # generate_stats=not is_infer, - is_inference=is_infer, - attn_scale=attn_scale, - diagonal_alignment=( - cudnn.diagonal_alignment.BOTTOM_RIGHT - if args.attn_mask == "bottom_right" - else cudnn.diagonal_alignment.TOP_LEFT - ), - right_bound=None if args.attn_mask == "no_mask" else 0, - # dropout=dropout_tuple if is_dropout else None, - ) - else: - q_fwd = graph_fwd.tensor_like(query) - k_fwd = graph_fwd.tensor_like(key) - v_fwd = graph_fwd.tensor_like(value) - o_fwd, stats_fwd = graph_fwd.sdpa( - q=q_fwd, - k=k_fwd, - v=v_fwd, - # generate_stats=not is_infer, - is_inference=is_infer, - attn_scale=attn_scale, - diagonal_alignment=( - cudnn.diagonal_alignment.BOTTOM_RIGHT - if args.attn_mask == "bottom_right" - else cudnn.diagonal_alignment.TOP_LEFT - ), - diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, - dropout=dropout_tuple if is_dropout else None, - ) + # Parse CSV output + # Format: case_tag,backend,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,fwd_time,bwd_time,fwd_tflops,bwd_tflops,max_diff,num_iters + output_line = result.stdout.strip().split("\n")[-1] + parts = output_line.split(",") - if args.fwd_bwd: - if args.data_type == "fp8": - o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type( - cudnn.data_type.FP8_E4M3 - ) - ( - stats_fwd.set_output(True) - .set_dim(stats.size()) - .set_stride(stats.stride()) - .set_data_type(cudnn.data_type.FLOAT) - if not is_infer - else None - ) - else: - o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) - ( - stats_fwd.set_output(True) - .set_dim(stats.size()) - .set_stride(stats.stride()) - .set_data_type(cudnn.data_type.FLOAT) - if not is_infer - else None - ) - else: - if args.data_type == "fp8": - o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type( - cudnn.data_type.FP8_E4M3 - ) - else: - o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) + if len(parts) < 12: + raise RuntimeError(f"Unexpected output format: {output_line}") - if args.data_type == "fp8": - amax_s_fwd.set_output(True).set_dim(amax_s_gpu.size()).set_stride(amax_s_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - amax_o_fwd.set_output(True).set_dim(amax_o_gpu.size()).set_stride(amax_o_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - graph_fwd.validate() - graph_fwd.build_operation_graph() - graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_fwd.check_support() - graph_fwd.build_plans() - - # If backward is requested, set up backward graph. - if args.fwd_bwd: - graph_bwd = cudnn.pygraph( - io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) + # Get GPU name from torch + gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) if torch.cuda.is_available() else "Unknown" - stats_bwd = graph_bwd.tensor_like(stats) - if is_dropout: - seed_bwd = graph_bwd.tensor_like(dropout_seed) - offset_bwd = graph_bwd.tensor_like(dropout_offset) - dropout_tuple = (dropout_prob, seed_bwd, offset_bwd) + # Try to get cudnn version + cudnn_version = None + cudnn_backend_version = None + try: + import cudnn - if args.data_type == "fp8": - q_bwd = graph_bwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) - k_bwd = graph_bwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) - v_bwd = graph_bwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) - o_bwd = graph_bwd.tensor_like(output).set_data_type(cudnn.data_type.FP8_E4M3) - dO_bwd = graph_bwd.tensor_like(dOutput).set_data_type(cudnn.data_type.FP8_E4M3) - - descale_q_bwd = graph_bwd.tensor_like(descale_q_gpu) - descale_k_bwd = graph_bwd.tensor_like(descale_k_gpu) - descale_v_bwd = graph_bwd.tensor_like(descale_v_gpu) - descale_o_bwd = graph_bwd.tensor_like(descale_o_gpu) - descale_dO_bwd = graph_bwd.tensor_like(descale_dO_gpu) - descale_s_bwd = graph_bwd.tensor_like(descale_s_gpu) - descale_dP_bwd = graph_bwd.tensor_like(descale_dP_gpu) - scale_s_bwd = graph_bwd.tensor_like(scale_s_gpu) - scale_dQ_bwd = graph_bwd.tensor_like(scale_dQ_gpu) - scale_dK_bwd = graph_bwd.tensor_like(scale_dK_gpu) - scale_dV_bwd = graph_bwd.tensor_like(scale_dV_gpu) - scale_dP_bwd = graph_bwd.tensor_like(scale_dP_gpu) - - ( - dQ_bwd, - dK_bwd, - dV_bwd, - amax_dQ_bwd, - amax_dK_bwd, - amax_dV_bwd, - amax_dP_bwd, - ) = graph_bwd.sdpa_fp8_backward( - q=q_bwd, - k=k_bwd, - v=v_bwd, - o=o_bwd, - dO=dO_bwd, - stats=stats_bwd, - descale_q=descale_q_bwd, - descale_k=descale_k_bwd, - descale_v=descale_v_bwd, - descale_o=descale_o_bwd, - descale_dO=descale_dO_bwd, - descale_s=descale_s_bwd, - descale_dP=descale_dP_bwd, - scale_s=scale_s_bwd, - scale_dQ=scale_dQ_bwd, - scale_dK=scale_dK_bwd, - scale_dV=scale_dV_bwd, - scale_dP=scale_dP_bwd, - attn_scale=attn_scale, - use_causal_mask=args.attn_mask != "no_mask" and args.attn_mask != "bottom_right", - use_causal_mask_bottom_right=args.attn_mask == "bottom_right", - dropout=dropout_tuple if is_dropout else None, - ) - else: - q_bwd = graph_bwd.tensor_like(query) - k_bwd = graph_bwd.tensor_like(key) - v_bwd = graph_bwd.tensor_like(value) - o_bwd = graph_bwd.tensor_like(output) - dO_bwd = graph_bwd.tensor_like(dOutput) - - dQ_bwd, dK_bwd, dV_bwd = graph_bwd.sdpa_backward( - q=q_bwd, - k=k_bwd, - v=v_bwd, - o=o_bwd, - dO=dO_bwd, - stats=stats_bwd, - attn_scale=attn_scale, - diagonal_alignment=( - cudnn.diagonal_alignment.BOTTOM_RIGHT - if args.attn_mask == "bottom_right" - else cudnn.diagonal_alignment.TOP_LEFT - ), - diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, - dropout=dropout_tuple if is_dropout else None, - ) + cudnn_version = cudnn.__version__ + cudnn_backend_version = cudnn.backend_version() + except ImportError: + pass + + return { + "fwd_time_ms": float(parts[8]), + "bwd_time_ms": float(parts[9]), + "fwd_tflops": float(parts[10]), + "bwd_tflops": float(parts[11]), + "max_diff": float(parts[12]) if len(parts) > 12 else 0.0, + "gpu_name": gpu_name, + "cudnn_version": cudnn_version, + "cudnn_backend_version": cudnn_backend_version, + } + + +# ============================================================================ +# Main benchmark implementation (runs when script is executed directly) +# ============================================================================ + +# Note: All code below this point is only executed when running as a script. +# When imported as a module, use the run_benchmark() function above. + +if __name__ != "__main__": + # Stop here when imported as module + pass +else: + # Parse command line arguments + args = parse_args() + + if args.data_type == "bfloat16": + target_dtype = torch.bfloat16 + elif args.data_type == "float16": + target_dtype = torch.float16 + elif args.data_type == "float": + target_dtype = torch.float + elif args.data_type == "fp8": + target_dtype = None + else: + raise ValueError(f"Invalid data type: {args.data_type}") - if args.data_type == "fp8": - dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()).set_data_type( - cudnn.data_type.FP8_E4M3 - ) - dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()).set_data_type( - cudnn.data_type.FP8_E4M3 - ) - dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()).set_data_type( - cudnn.data_type.FP8_E4M3 - ) - amax_dQ_bwd.set_output(True).set_dim(amax_dQ_gpu.size()).set_stride(amax_dQ_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - amax_dK_bwd.set_output(True).set_dim(amax_dK_gpu.size()).set_stride(amax_dK_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - amax_dV_bwd.set_output(True).set_dim(amax_dV_gpu.size()).set_stride(amax_dV_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - amax_dP_bwd.set_output(True).set_dim(amax_dP_gpu.size()).set_stride(amax_dP_gpu.stride()).set_data_type( - cudnn.data_type.FLOAT - ) - else: - dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()) - dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()) - dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()) + if args.data_type == "fp8": + if args.sdpa_backend not in ["cudnn", "flash_attention_3"]: + raise ValueError(f"FP8 is only supported for cudnn and flash_attention_3 backends") + + # Parse input arguments + num_iters = args.num_iterations + dry_run_iters = args.num_warmup_iterations + batch_size = args.batch_size + q_seqlen = args.q_seqlen + kv_seqlen = args.kv_seqlen + num_q_heads = args.num_q_heads + num_kv_heads = args.num_kv_heads + if args.head_dim_qk is None and args.head_dim_vo is None: + head_dim_qk = args.head_dim + head_dim_vo = args.head_dim + elif args.head_dim_qk is not None and args.head_dim_vo is not None: + head_dim_qk = args.head_dim_qk + head_dim_vo = args.head_dim_vo + else: + raise ValueError("Both --head_dim_qk and --head_dim_vo must be provided together when using asymmetric head dims.") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + assert device.type == "cuda", "Requires CUDA device" + if args.profile_pass is not None: + run_fwd = args.profile_pass in ("fwd", "both") + run_bwd = args.profile_pass in ("bwd", "both") + elif args.fwd_bwd: + run_fwd = True + run_bwd = True + else: + run_fwd = True + run_bwd = False + enable_gqa = num_q_heads != num_kv_heads + assert args.attn_mask != "bottom_right" or q_seqlen <= kv_seqlen, "Bottom right causal mask not supported when q_seqlen > kv_seqlen" + # if args.sdpa_backend in ["flash_attention", "flash_attention_3", "pyt_flash_attention"]: + # assert args.attn_mask != "top_left", "Flash Attention does not support top left causal mask" + + l2_flush_size_mb = 256 + l2_flush_size = l2_flush_size_mb * 1024 * 1024 + l2_flush_buffer = torch.empty(l2_flush_size, device=device, dtype=torch.int8) + + ############################################################# + ########### Set up SDPA function for each backend ########### + + ## If using cuDNN FE, set up cuDNN graph. + if args.sdpa_backend == "cudnn": + is_dropout = False # Hard coded + dropout_prob = dropout_p if is_dropout else 0.0 # Hard coded to 0 + is_infer = False # Hard coded + attn_scale = head_dim_qk ** (-0.5) - graph_bwd.validate() - graph_bwd.build_operation_graph() - graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_bwd.check_support() - graph_bwd.build_plans() + try: + import cudnn + except ImportError: + cudnn = None + assert cudnn is not None + + if args.verbose: + print(f"[INFO] cuDNN Backend Version: {cudnn.backend_version() = }") + print(f"[INFO] cuDNN Frontend Version: {cudnn.__version__ = }") + + # Helper function: Convert torch type to cuDNN type + def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + ## Will define tensors to set up cuDNN graph once. if args.data_type == "fp8": - variant_pack_fwd = { - q_fwd: query, - k_fwd: key, - v_fwd: value, - o_fwd: output, - stats_fwd: stats, - descale_q_fwd: descale_q_gpu, - descale_k_fwd: descale_k_gpu, - descale_v_fwd: descale_v_gpu, - descale_s_fwd: descale_s_gpu, - scale_s_fwd: scale_s_gpu, - scale_o_fwd: scale_o_gpu, - amax_s_fwd: amax_s_gpu, - amax_o_fwd: amax_o_gpu, - } - - variant_pack_bwd = { - q_fwd: query, - k_fwd: key, - v_fwd: value, - o_fwd: output, - dQ_bwd: dQuery, - dK_bwd: dKey, - dV_bwd: dValue, - dO_bwd: dOutput, - stats_bwd: stats, - descale_q_bwd: descale_q_gpu, - descale_k_bwd: descale_k_gpu, - descale_v_bwd: descale_v_gpu, - descale_o_bwd: descale_o_gpu, - descale_s_bwd: descale_s_gpu, - descale_dP_bwd: descale_dP_gpu, - descale_dO_bwd: descale_dO_gpu, - scale_s_bwd: scale_s_gpu, - scale_dQ_bwd: scale_dQ_gpu, - scale_dK_bwd: scale_dK_gpu, - scale_dV_bwd: scale_dV_gpu, - scale_dP_bwd: scale_dP_gpu, - amax_dQ_bwd: amax_dQ_gpu, - amax_dK_bwd: amax_dK_gpu, - amax_dV_bwd: amax_dV_gpu, - amax_dP_bwd: amax_dP_gpu, - } - - workspace = torch.empty( - max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), - device="cuda", + query = torch.randint( + 256, + (batch_size, q_seqlen, num_q_heads, head_dim_qk), dtype=torch.uint8, - ) - else: - variant_pack_fwd = { - q_fwd: query, - k_fwd: key, - v_fwd: value, - o_fwd: output, - stats_fwd: stats, - } - variant_pack_bwd = { - q_bwd: query, - k_bwd: key, - v_bwd: value, - o_bwd: output, - dO_bwd: dOutput, - stats_bwd: stats, - dQ_bwd: dQuery, - dK_bwd: dKey, - dV_bwd: dValue, - } - workspace = torch.empty( - max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), - device="cuda", + device=device, + ).transpose(1, 2) + key = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), dtype=torch.uint8, - ) - else: - if args.data_type == "fp8": - variant_pack_fwd = { - q_fwd: query, - k_fwd: key, - v_fwd: value, - o_fwd: output, - stats_fwd: stats, - descale_q_fwd: descale_q_gpu, - descale_k_fwd: descale_k_gpu, - descale_v_fwd: descale_v_gpu, - descale_s_fwd: descale_s_gpu, - scale_s_fwd: scale_s_gpu, - scale_o_fwd: scale_o_gpu, - amax_s_fwd: amax_s_gpu, - amax_o_fwd: amax_o_gpu, - } - workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) + device=device, + ).transpose(1, 2) + value = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + + descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) else: - variant_pack_fwd = { - q_fwd: query, - k_fwd: key, - v_fwd: value, - o_fwd: output, - } - workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) - if is_dropout: - variant_pack_fwd[seed_fwd] = dropout_seed - variant_pack_fwd[offset_fwd] = dropout_offset - variant_pack_bwd[seed_bwd] = dropout_seed - variant_pack_bwd[offset_bwd] = dropout_offset -## Done setting up cuDNN graph. - - -# For backends MATH, EFFICIENT_ATTENTION, CUDNN_ATTENTION, PYTORCH_FLASH_ATTENTION -def pyt_backend_sdpa(query, key, value, backend): - with sdpa_kernel(backends=[backend]): - return torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - enable_gqa=enable_gqa, - is_causal=args.attn_mask == "top_left", - attn_mask=causal_lower_right(q_seqlen, kv_seqlen) if args.attn_mask == "bottom_right" else None, - ) - - -if args.sdpa_backend == "flash_attention": - import flash_attn - from flash_attn import flash_attn_func - - # Flash Attention Native - def flash_attention_sdpa(query, key, value): - return flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") - - -if args.sdpa_backend == "flash_attention_3": - import flash_attn_interface - - def flash_attention_3_sdpa(query, key, value): - output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") - return output - - -if args.sdpa_backend == "flash_attention_4" or (not args.skip_ref): - import flash_attn.cute.interface as flash_attn_interface - - def flash_attention_4_sdpa(query, key, value): - output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") - return output - - -def get_sdpa_function(backend): - if backend == "pyt_math": - return functools.partial(pyt_backend_sdpa, backend=SDPBackend.MATH) - elif backend == "pyt_efficient_attention": - return functools.partial(pyt_backend_sdpa, backend=SDPBackend.EFFICIENT_ATTENTION) - elif backend == "pyt_flash_attention": - return functools.partial(pyt_backend_sdpa, backend=SDPBackend.FLASH_ATTENTION) - elif backend == "pyt_cudnn": - return functools.partial(pyt_backend_sdpa, backend=SDPBackend.CUDNN_ATTENTION) - elif backend == "flash_attention": - return flash_attention_sdpa - elif backend == "flash_attention_3": - return flash_attention_3_sdpa - elif backend == "flash_attention_4": - return flash_attention_4_sdpa - elif backend == "cudnn_fe": - return None # Will be set up separately - else: - raise ValueError(f"Invalid backend: {backend}") - - -# Util function for addressing different qkv formats for each backend -def preprocess_qkv(query, key, value, backend): - if backend.startswith("pyt_") or backend == "cudnn_fe": - return query, key, value - elif backend.startswith("flash_attention"): - query = torch.swapaxes(query, 1, 2) - key = torch.swapaxes(key, 1, 2) - value = torch.swapaxes(value, 1, 2) - return query, key, value - else: - raise ValueError(f"Invalid backend: {backend}") - - -# Util function addressing different qkvo formats for each backend -def postprocess_qkvo(query, key, value, output, backend): - if backend.startswith("pyt_") or backend == "cudnn_fe": - return query, key, value, output - elif backend.startswith("flash_attention"): - output = torch.swapaxes(output, 1, 2) - query = torch.swapaxes(query, 1, 2) - key = torch.swapaxes(key, 1, 2) - value = torch.swapaxes(value, 1, 2) - return query, key, value, output - else: - raise ValueError(f"Invalid backend: {backend}") - - -def postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, backend): - if backend.startswith("pyt_") or backend == "cudnn_fe": - return dQuery, dKey, dValue, dOutput - elif backend.startswith("flash_attention"): - dQuery = torch.swapaxes(dQuery, 1, 2) - dKey = torch.swapaxes(dKey, 1, 2) - dValue = torch.swapaxes(dValue, 1, 2) - dOutput = torch.swapaxes(dOutput, 1, 2) - return dQuery, dKey, dValue, dOutput - else: - raise ValueError(f"Invalid backend: {backend}") - - -# Util functions for calculating flops and tflops/s achieved -def flops( - batch_size, - q_seqlen, - kv_seqlen, - head_dim_qk, - head_dim_vo, - num_q_heads, - attn_mask, - mode="fwd", -): - assert mode in ["fwd", "bwd", "fwd_bwd"] - - if attn_mask == "no_mask": - num_nonmasked_elems = q_seqlen * kv_seqlen - elif attn_mask == "top_left": - num_nonmasked_elems = torch.tril(torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool)).sum() - elif attn_mask == "bottom_right": - diagonal_offset = kv_seqlen - q_seqlen - num_nonmasked_elems = torch.tril( - torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool), - diagonal=diagonal_offset, - ).sum() - # BMM FLOPs: 2 * M * N * K. - # Here, M*N = num_nonmasked_elems per head; add batch_size * num_q_heads multiplier. - # Forward: 2 BMMs => (1 x head_dim_qk) + (1 x head_dim_vo) - # Backward: 5 BMMs => (3 x head_dim_qk) + (2 x head_dim_vo) - base = batch_size * num_q_heads * num_nonmasked_elems * 2 - if mode == "fwd": - result = base * (head_dim_qk + head_dim_vo) - elif mode == "bwd": - result = base * (3 * head_dim_qk + 2 * head_dim_vo) - else: # fwd_bwd - result = base * (4 * head_dim_qk + 3 * head_dim_vo) - return result - - -def tflops_per_sec( - batch_size, - q_seqlen, - kv_seqlen, - head_dim_qk, - head_dim_vo, - num_q_heads, - attn_mask, - time, - mode="fwd", -): - assert mode in ["fwd", "bwd", "fwd_bwd"] - f = flops( - batch_size, - q_seqlen, - kv_seqlen, - head_dim_qk, - head_dim_vo, - num_q_heads, - attn_mask, - mode, - ) - return f / time / 1e9 if not math.isnan(time) else 0.0 # Assume time is in msec - - -###### Done setting up SDPA function for each backend ####### -############################################################# - -###### SDPA Benchmark -- Run ###### -## Print System Info -if args.verbose: - print(f"[INFO] {torch.__version__ = }") - print(f"[INFO] {torch.version.cuda = }") - print(f"[INFO] {torch.cuda.is_available() = }") - print(f"[INFO] {torch.cuda.device_count() = }") - print(f"[INFO] {torch.cuda.current_device() = }") - print(f"[INFO] {torch.cuda.get_device_name(torch.cuda.current_device()) = }") - if args.sdpa_backend == "pyt_cudnn": - print(f"[INFO] {torch.backends.cudnn.version() = }") - print(f"[INFO] {torch.backends.cudnn.enabled = }") - elif args.sdpa_backend == "flash_attention": - print(f"[INFO] {flash_attn.__version__ = }") - -forward_times = [] -backward_times = [] -forward_diffs = [] - -total_iters = num_iters + dry_run_iters - -first_error = True # For suppressing error message beyond first error -sdpa_function = get_sdpa_function(args.sdpa_backend) -for i in range(total_iters): - if args.data_type == "fp8" and args.sdpa_backend == "cudnn_fe": - query = torch.randint( - 256, - (batch_size, q_seqlen, num_q_heads, head_dim_qk), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - key = torch.randint( - 256, - (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - value = torch.randint( - 256, - (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), - dtype=torch.uint8, - device=device, - ).transpose(1, 2) - descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) - amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) - elif args.data_type == "fp8" and args.sdpa_backend == "flash_attention_3": - query = ( - torch.randn( + query = torch.randn( batch_size, q_seqlen, num_q_heads, head_dim_qk, - dtype=torch.bfloat16, + dtype=target_dtype, device=device, - requires_grad=True, - ) - .to(torch.float8_e4m3fn) - .transpose(1, 2) - ) - key = ( - torch.randn( + ).transpose(1, 2) + key = torch.randn( batch_size, kv_seqlen, num_kv_heads, head_dim_qk, - dtype=torch.bfloat16, + dtype=target_dtype, device=device, - requires_grad=True, - ) - .to(torch.float8_e4m3fn) - .transpose(1, 2) - ) - value = ( - torch.randn( + ).transpose(1, 2) + value = torch.randn( batch_size, kv_seqlen, num_kv_heads, head_dim_vo, - dtype=torch.bfloat16, + dtype=target_dtype, device=device, - requires_grad=True, - ) - .to(torch.float8_e4m3fn) - .transpose(1, 2) - ) - else: - query = torch.randn( - batch_size, - q_seqlen, - num_q_heads, - head_dim_qk, - dtype=target_dtype, - device=device, - requires_grad=True, - ).transpose(1, 2) - key = torch.randn( - batch_size, - kv_seqlen, - num_kv_heads, - head_dim_qk, - dtype=target_dtype, - device=device, - requires_grad=True, - ).transpose(1, 2) - value = torch.randn( - batch_size, - kv_seqlen, - num_kv_heads, - head_dim_vo, - dtype=target_dtype, - device=device, - requires_grad=True, - ).transpose(1, 2) - - query, key, value = preprocess_qkv(query, key, value, args.sdpa_backend) - if args.data_type == "fp8" and args.sdpa_backend == "cudnn_fe": - # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues - dOutput_bf16 = torch.randn(query.shape, dtype=torch.bfloat16, device=device) - dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) - dOutput = dOutput_fp8.view(torch.uint8) - else: - dOutput = torch.randn_like(query) + ).transpose(1, 2) + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=target_dtype, + device=device, + ).transpose(1, 2) - if args.sdpa_backend == "cudnn_fe": - output = torch.empty( - batch_size, - q_seqlen, - num_q_heads, - head_dim_vo, - dtype=torch.uint8 if args.data_type == "fp8" else target_dtype, - device=device, - ).transpose(1, 2) dQuery = torch.empty_like(query) dKey = torch.empty_like(key) dValue = torch.empty_like(value) + if args.data_type == "fp8": + # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues + dOutput_bf16 = torch.randn(output.shape, dtype=torch.bfloat16, device=device) + dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) + dOutput = dOutput_fp8.view(torch.uint8) + else: + dOutput = torch.randn_like(output) stats = torch.randn(batch_size, q_seqlen, num_q_heads, 1, dtype=torch.float32, device=device).transpose(1, 2) if is_dropout: dropout_seed = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") dropout_offset = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - # Only variant pack and workspace need to be updated for each iteration. - if args.fwd_bwd: + # cuDNN graph forward + graph_fwd = cudnn.pygraph( + io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + if is_dropout: + seed_fwd = graph_fwd.tensor_like(dropout_seed) + offset_fwd = graph_fwd.tensor_like(dropout_offset) + dropout_tuple = (dropout_prob, seed_fwd, offset_fwd) + + if args.data_type == "fp8": + q_fwd = graph_fwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) + k_fwd = graph_fwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) + v_fwd = graph_fwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) + + descale_q_fwd = graph_fwd.tensor_like(descale_q_gpu) + descale_k_fwd = graph_fwd.tensor_like(descale_k_gpu) + descale_v_fwd = graph_fwd.tensor_like(descale_v_gpu) + descale_s_fwd = graph_fwd.tensor_like(descale_s_gpu) + scale_s_fwd = graph_fwd.tensor_like(scale_s_gpu) + scale_o_fwd = graph_fwd.tensor_like(scale_o_gpu) + + o_fwd, stats_fwd, amax_s_fwd, amax_o_fwd = graph_fwd.sdpa_fp8( + q=q_fwd, + k=k_fwd, + v=v_fwd, + descale_q=descale_q_fwd, + descale_k=descale_k_fwd, + descale_v=descale_v_fwd, + descale_s=descale_s_fwd, + scale_s=scale_s_fwd, + scale_o=scale_o_fwd, + # generate_stats=not is_infer, + is_inference=is_infer, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + right_bound=None if args.attn_mask == "no_mask" else 0, + # dropout=dropout_tuple if is_dropout else None, + ) + else: + q_fwd = graph_fwd.tensor_like(query) + k_fwd = graph_fwd.tensor_like(key) + v_fwd = graph_fwd.tensor_like(value) + o_fwd, stats_fwd = graph_fwd.sdpa( + q=q_fwd, + k=k_fwd, + v=v_fwd, + # generate_stats=not is_infer, + is_inference=is_infer, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, + dropout=dropout_tuple if is_dropout else None, + ) + + if run_bwd: + if args.data_type == "fp8": + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + (stats_fwd.set_output(True).set_dim(stats.size()).set_stride(stats.stride()).set_data_type(cudnn.data_type.FLOAT) if not is_infer else None) + else: + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) + (stats_fwd.set_output(True).set_dim(stats.size()).set_stride(stats.stride()).set_data_type(cudnn.data_type.FLOAT) if not is_infer else None) + else: + if args.data_type == "fp8": + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + else: + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) + + if args.data_type == "fp8": + amax_s_fwd.set_output(True).set_dim(amax_s_gpu.size()).set_stride(amax_s_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_o_fwd.set_output(True).set_dim(amax_o_gpu.size()).set_stride(amax_o_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + graph_fwd.validate() + graph_fwd.build_operation_graph() + graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_fwd.check_support() + graph_fwd.build_plans() + + # If backward is requested, set up backward graph. + if run_bwd: + graph_bwd = cudnn.pygraph( + io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + stats_bwd = graph_bwd.tensor_like(stats) + if is_dropout: + seed_bwd = graph_bwd.tensor_like(dropout_seed) + offset_bwd = graph_bwd.tensor_like(dropout_offset) + dropout_tuple = (dropout_prob, seed_bwd, offset_bwd) + + if args.data_type == "fp8": + q_bwd = graph_bwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) + k_bwd = graph_bwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) + v_bwd = graph_bwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) + o_bwd = graph_bwd.tensor_like(output).set_data_type(cudnn.data_type.FP8_E4M3) + dO_bwd = graph_bwd.tensor_like(dOutput).set_data_type(cudnn.data_type.FP8_E4M3) + + descale_q_bwd = graph_bwd.tensor_like(descale_q_gpu) + descale_k_bwd = graph_bwd.tensor_like(descale_k_gpu) + descale_v_bwd = graph_bwd.tensor_like(descale_v_gpu) + descale_o_bwd = graph_bwd.tensor_like(descale_o_gpu) + descale_dO_bwd = graph_bwd.tensor_like(descale_dO_gpu) + descale_s_bwd = graph_bwd.tensor_like(descale_s_gpu) + descale_dP_bwd = graph_bwd.tensor_like(descale_dP_gpu) + scale_s_bwd = graph_bwd.tensor_like(scale_s_gpu) + scale_dQ_bwd = graph_bwd.tensor_like(scale_dQ_gpu) + scale_dK_bwd = graph_bwd.tensor_like(scale_dK_gpu) + scale_dV_bwd = graph_bwd.tensor_like(scale_dV_gpu) + scale_dP_bwd = graph_bwd.tensor_like(scale_dP_gpu) + + ( + dQ_bwd, + dK_bwd, + dV_bwd, + amax_dQ_bwd, + amax_dK_bwd, + amax_dV_bwd, + amax_dP_bwd, + ) = graph_bwd.sdpa_fp8_backward( + q=q_bwd, + k=k_bwd, + v=v_bwd, + o=o_bwd, + dO=dO_bwd, + stats=stats_bwd, + descale_q=descale_q_bwd, + descale_k=descale_k_bwd, + descale_v=descale_v_bwd, + descale_o=descale_o_bwd, + descale_dO=descale_dO_bwd, + descale_s=descale_s_bwd, + descale_dP=descale_dP_bwd, + scale_s=scale_s_bwd, + scale_dQ=scale_dQ_bwd, + scale_dK=scale_dK_bwd, + scale_dV=scale_dV_bwd, + scale_dP=scale_dP_bwd, + attn_scale=attn_scale, + use_causal_mask=args.attn_mask != "no_mask" and args.attn_mask != "bottom_right", + use_causal_mask_bottom_right=args.attn_mask == "bottom_right", + dropout=dropout_tuple if is_dropout else None, + use_deterministic_algorithm=args.deterministic_bwd, + ) + else: + q_bwd = graph_bwd.tensor_like(query) + k_bwd = graph_bwd.tensor_like(key) + v_bwd = graph_bwd.tensor_like(value) + o_bwd = graph_bwd.tensor_like(output) + dO_bwd = graph_bwd.tensor_like(dOutput) + + dQ_bwd, dK_bwd, dV_bwd = graph_bwd.sdpa_backward( + q=q_bwd, + k=k_bwd, + v=v_bwd, + o=o_bwd, + dO=dO_bwd, + stats=stats_bwd, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, + dropout=dropout_tuple if is_dropout else None, + use_deterministic_algorithm=args.deterministic_bwd, + ) + + if args.data_type == "fp8": + dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + amax_dQ_bwd.set_output(True).set_dim(amax_dQ_gpu.size()).set_stride(amax_dQ_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dK_bwd.set_output(True).set_dim(amax_dK_gpu.size()).set_stride(amax_dK_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dV_bwd.set_output(True).set_dim(amax_dV_gpu.size()).set_stride(amax_dV_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dP_bwd.set_output(True).set_dim(amax_dP_gpu.size()).set_stride(amax_dP_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + else: + dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()) + dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()) + dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()) + + graph_bwd.validate() + graph_bwd.build_operation_graph() + graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_bwd.check_support() + graph_bwd.build_plans() + if args.data_type == "fp8": variant_pack_fwd = { q_fwd: query, @@ -966,6 +686,7 @@ def tflops_per_sec( amax_s_fwd: amax_s_gpu, amax_o_fwd: amax_o_gpu, } + variant_pack_bwd = { q_bwd: query, k_bwd: key, @@ -993,6 +714,12 @@ def tflops_per_sec( amax_dV_bwd: amax_dV_gpu, amax_dP_bwd: amax_dP_gpu, } + + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) else: variant_pack_fwd = { q_fwd: query, @@ -1012,11 +739,11 @@ def tflops_per_sec( dK_bwd: dKey, dV_bwd: dValue, } - workspace = torch.empty( - max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), - device="cuda", - dtype=torch.uint8, - ) + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) else: if args.data_type == "fp8": variant_pack_fwd = { @@ -1034,6 +761,7 @@ def tflops_per_sec( amax_s_fwd: amax_s_gpu, amax_o_fwd: amax_o_gpu, } + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) else: variant_pack_fwd = { q_fwd: query, @@ -1041,172 +769,598 @@ def tflops_per_sec( v_fwd: value, o_fwd: output, } - workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) - + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) if is_dropout: variant_pack_fwd[seed_fwd] = dropout_seed variant_pack_fwd[offset_fwd] = dropout_offset - variant_pack_bwd[seed_bwd] = dropout_seed - variant_pack_bwd[offset_bwd] = dropout_offset - - l2_flush_buffer.zero_() + if run_bwd: + variant_pack_bwd[seed_bwd] = dropout_seed + variant_pack_bwd[offset_bwd] = dropout_offset + ## Done setting up cuDNN graph. + + # For backends MATH, EFFICIENT_ATTENTION, CUDNN_ATTENTION, PYTORCH_FLASH_ATTENTION + def pyt_backend_sdpa(query, key, value, backend): + with sdpa_kernel(backends=[backend]): + return torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + enable_gqa=enable_gqa, + is_causal=args.attn_mask == "top_left", + attn_mask=causal_lower_right(q_seqlen, kv_seqlen) if args.attn_mask == "bottom_right" else None, + ) - # Run kernel with profiler - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("sdpa.forward"): # Custom marker - if args.sdpa_backend == "cudnn_fe": - graph_fwd.execute(variant_pack_fwd, workspace) - else: - output = sdpa_function(query, key, value) - torch.cuda.synchronize() # Ensure all kernels finish - - # Filter profiler results by kernel name prefix - matched_kernels = [ - item - for item in prof.key_averages() - if item.key.startswith("cudnn") - or item.key.startswith("kernel_cutlass") - or "pytorch_flash::" in item.key - or "flash::" in item.key - or "at::native::" in item.key - or "cutlass3x" in item.key - or "(anonymous namespace)::" in item.key - or item.key.startswith("fmha_") - ] - if len(matched_kernels) >= 1: - fwd_time = sum(item.device_time for item in matched_kernels) / 1000 - if i >= dry_run_iters: - forward_times.append(fwd_time) + if args.sdpa_backend == "flash_attention": + import flash_attn + from flash_attn import flash_attn_func + + # Flash Attention Native + def flash_attention_sdpa(query, key, value): + return flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + + if args.sdpa_backend == "flash_attention_3": + import flash_attn_interface + + def flash_attention_3_sdpa(query, key, value): + output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + return output + + if args.sdpa_backend == "flash_attention_4" or (not args.skip_ref): + import flash_attn.cute.interface as flash_attn_interface + + def flash_attention_4_sdpa(query, key, value): + output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + return output + + def get_sdpa_function(backend): + if backend == "pyt_math": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.MATH) + elif backend == "pyt_efficient_attention": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.EFFICIENT_ATTENTION) + elif backend == "pyt_flash_attention": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.FLASH_ATTENTION) + elif backend == "pyt_cudnn": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.CUDNN_ATTENTION) + elif backend == "flash_attention": + return flash_attention_sdpa + elif backend == "flash_attention_3": + return flash_attention_3_sdpa + elif backend == "flash_attention_4": + return flash_attention_4_sdpa + elif backend == "cudnn": + return None # Will be set up separately + else: + raise ValueError(f"Invalid backend: {backend}") + + # Util function for addressing different qkv formats for each backend + def preprocess_qkv(query, key, value, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return query, key, value + elif backend.startswith("flash_attention"): + query = torch.swapaxes(query, 1, 2) + key = torch.swapaxes(key, 1, 2) + value = torch.swapaxes(value, 1, 2) + return query, key, value + else: + raise ValueError(f"Invalid backend: {backend}") + + # Util function addressing different qkvo formats for each backend + def postprocess_qkvo(query, key, value, output, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return query, key, value, output + elif backend.startswith("flash_attention"): + output = torch.swapaxes(output, 1, 2) + query = torch.swapaxes(query, 1, 2) + key = torch.swapaxes(key, 1, 2) + value = torch.swapaxes(value, 1, 2) + return query, key, value, output + else: + raise ValueError(f"Invalid backend: {backend}") + + def postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return dQuery, dKey, dValue, dOutput + elif backend.startswith("flash_attention"): + dQuery = torch.swapaxes(dQuery, 1, 2) + dKey = torch.swapaxes(dKey, 1, 2) + dValue = torch.swapaxes(dValue, 1, 2) + dOutput = torch.swapaxes(dOutput, 1, 2) + return dQuery, dKey, dValue, dOutput + else: + raise ValueError(f"Invalid backend: {backend}") - # Sleep for some time proportional to fwd_time for stable measurements - sleep_time = np.min([fwd_time / 100, 1.0]) - time.sleep(sleep_time) + # Util functions for calculating flops and tflops/s achieved + def flops( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + mode="fwd", + ): + assert mode in ["fwd", "bwd", "fwd_bwd"] + + if attn_mask == "no_mask": + num_nonmasked_elems = q_seqlen * kv_seqlen + elif attn_mask == "top_left": + num_nonmasked_elems = torch.tril(torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool)).sum() + elif attn_mask == "bottom_right": + diagonal_offset = kv_seqlen - q_seqlen + num_nonmasked_elems = torch.tril( + torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool), + diagonal=diagonal_offset, + ).sum() + # BMM FLOPs: 2 * M * N * K. + # Here, M*N = num_nonmasked_elems per head; add batch_size * num_q_heads multiplier. + # Forward: 2 BMMs => (1 x head_dim_qk) + (1 x head_dim_vo) + # Backward: 5 BMMs => (3 x head_dim_qk) + (2 x head_dim_vo) + base = batch_size * num_q_heads * num_nonmasked_elems * 2 + if mode == "fwd": + result = base * (head_dim_qk + head_dim_vo) + elif mode == "bwd": + result = base * (3 * head_dim_qk + 2 * head_dim_vo) + else: # fwd_bwd + result = base * (4 * head_dim_qk + 3 * head_dim_vo) + return result + + def tflops_per_sec( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + time, + mode="fwd", + ): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = flops( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + mode, + ) + return f / time / 1e9 if not math.isnan(time) else 0.0 # Assume time is in msec - if args.fwd_bwd: - # Run backward pass + ###### Done setting up SDPA function for each backend ####### + ############################################################# - l2_flush_buffer.zero_() + ###### SDPA Benchmark -- Run ###### + ## Print System Info + if args.verbose: + print(f"[INFO] {torch.__version__ = }") + print(f"[INFO] {torch.version.cuda = }") + print(f"[INFO] {torch.cuda.is_available() = }") + print(f"[INFO] {torch.cuda.device_count() = }") + print(f"[INFO] {torch.cuda.current_device() = }") + print(f"[INFO] {torch.cuda.get_device_name(torch.cuda.current_device()) = }") + if args.sdpa_backend == "pyt_cudnn": + print(f"[INFO] {torch.backends.cudnn.version() = }") + print(f"[INFO] {torch.backends.cudnn.enabled = }") + elif args.sdpa_backend == "flash_attention": + print(f"[INFO] {flash_attn.__version__ = }") + + forward_times = [] + backward_times = [] + forward_diffs = [] + + total_iters = num_iters + dry_run_iters + + first_error = True # For suppressing error message beyond first error + sdpa_function = get_sdpa_function(args.sdpa_backend) + for i in range(total_iters): + if args.data_type == "fp8" and args.sdpa_backend == "cudnn": + query = torch.randint( + 256, + (batch_size, q_seqlen, num_q_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + key = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + value = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + elif args.data_type == "fp8" and args.sdpa_backend == "flash_attention_3": + query = ( + torch.randn( + batch_size, + q_seqlen, + num_q_heads, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + key = ( + torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + value = ( + torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_vo, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + else: + query = torch.randn( + batch_size, + q_seqlen, + num_q_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + key = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + value = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_vo, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + + query, key, value = preprocess_qkv(query, key, value, args.sdpa_backend) + if args.data_type == "fp8" and args.sdpa_backend == "cudnn": + # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues + dOutput_bf16 = torch.randn(query.shape, dtype=torch.bfloat16, device=device) + dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) + dOutput = dOutput_fp8.view(torch.uint8) + else: + dOutput = torch.randn_like(query) - with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("sdpa.backward"): # Custom marker - if args.sdpa_backend == "cudnn_fe": - graph_bwd.execute(variant_pack_bwd, workspace) + if args.sdpa_backend == "cudnn": + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=torch.uint8 if args.data_type == "fp8" else target_dtype, + device=device, + ).transpose(1, 2) + dQuery = torch.empty_like(query) + dKey = torch.empty_like(key) + dValue = torch.empty_like(value) + stats = torch.randn(batch_size, q_seqlen, num_q_heads, 1, dtype=torch.float32, device=device).transpose(1, 2) + if is_dropout: + dropout_seed = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") + dropout_offset = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") + + # Only variant pack and workspace need to be updated for each iteration. + if run_bwd: + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + dO_bwd: dOutput, + stats_bwd: stats, + descale_q_bwd: descale_q_gpu, + descale_k_bwd: descale_k_gpu, + descale_v_bwd: descale_v_gpu, + descale_o_bwd: descale_o_gpu, + descale_s_bwd: descale_s_gpu, + descale_dP_bwd: descale_dP_gpu, + descale_dO_bwd: descale_dO_gpu, + scale_s_bwd: scale_s_gpu, + scale_dQ_bwd: scale_dQ_gpu, + scale_dK_bwd: scale_dK_gpu, + scale_dV_bwd: scale_dV_gpu, + scale_dP_bwd: scale_dP_gpu, + amax_dQ_bwd: amax_dQ_gpu, + amax_dK_bwd: amax_dK_gpu, + amax_dV_bwd: amax_dV_gpu, + amax_dP_bwd: amax_dP_gpu, + } else: - query.retain_grad() - key.retain_grad() - value.retain_grad() - output.backward(dOutput) + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + } + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dO_bwd: dOutput, + stats_bwd: stats, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + } + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) + else: + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + else: + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + } + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) + + if is_dropout: + variant_pack_fwd[seed_fwd] = dropout_seed + variant_pack_fwd[offset_fwd] = dropout_offset + if run_bwd: + variant_pack_bwd[seed_bwd] = dropout_seed + variant_pack_bwd[offset_bwd] = dropout_offset - dQuery = query.grad - dKey = key.grad - dValue = value.grad + l2_flush_buffer.zero_() - query.grad = None - key.grad = None - value.grad = None + # Run kernel with profiler for forward if requested, else run unprofiled to prep for backward + if run_fwd: + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("sdpa.forward"): # Custom marker + if args.sdpa_backend == "cudnn": + graph_fwd.execute(variant_pack_fwd, workspace) + else: + output = sdpa_function(query, key, value) + torch.cuda.synchronize() # Ensure all kernels finish + + # Filter profiler results by kernel name prefix + matched_kernels = [ + item + for item in prof.key_averages() + if item.key.startswith("cudnn") + or item.key.startswith("kernel_cutlass") + or "pytorch_flash::" in item.key + or "flash::" in item.key + or "at::native::" in item.key + or "cutlass3x" in item.key + or "(anonymous namespace)::" in item.key + or item.key.startswith("fmha_") + ] + if len(matched_kernels) >= 1: + fwd_time = sum(item.device_time for item in matched_kernels) / 1000 + if i >= dry_run_iters: + forward_times.append(fwd_time) + else: + if args.sdpa_backend == "cudnn": + graph_fwd.execute(variant_pack_fwd, workspace) + else: + output = sdpa_function(query, key, value) torch.cuda.synchronize() - matched_kernels = [ - item - for item in prof.key_averages() - if "cudnn" in item.key - or item.key.startswith("kernel_cutlass") - or "pytorch_flash::" in item.key - or "flash::" in item.key - or "at::native::" in item.key - or "cutlass3x" in item.key - or "(anonymous namespace)::" in item.key - or item.key.startswith("fmha_") - ] - if len(matched_kernels) >= 1: - bwd_time = sum(item.device_time for item in matched_kernels) / 1000 - if i >= dry_run_iters: - backward_times.append(bwd_time) - - sleep_time = np.min([bwd_time / 100, 1.0]) + # Sleep for some time proportional to fwd_time for stable measurements + sleep_time = np.min([fwd_time / 100, 1.0]) if run_fwd and len(matched_kernels) >= 1 else 0.0 time.sleep(sleep_time) - dQuery, dKey, dValue, dOutput = postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, args.sdpa_backend) - - ( - query, - key, - value, - output, - ) = postprocess_qkvo(query, key, value, output, args.sdpa_backend) - if args.data_type != "fp8" and not args.skip_ref: - try: - output_ref = flash_attention_4_sdpa(query, key, value) - if args.fwd_bwd: - query.retain_grad() - key.retain_grad() - value.retain_grad() - output_ref.backward(dOutput) - - torch.testing.assert_close(dQuery, query.grad, rtol=2e-2, atol=2e-2) - torch.testing.assert_close(dKey, key.grad, rtol=2e-2, atol=2e-2) - torch.testing.assert_close(dValue, value.grad, rtol=2e-2, atol=2e-2) - - torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) - forward_diffs.append(torch.max(torch.abs(output.detach() - output_ref.detach())).item()) - except Exception as e: - if first_error: - print( - f"[WARN] Failed reference check. Target backend has been run, but output has not been validated. Failure may be due to incorrect output or reference function failure." - ) - print(f"[WARN] See error message: {e}") - first_error = False + if run_bwd: + # Run backward pass + + l2_flush_buffer.zero_() + + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("sdpa.backward"): # Custom marker + if args.sdpa_backend == "cudnn": + graph_bwd.execute(variant_pack_bwd, workspace) + else: + query.retain_grad() + key.retain_grad() + value.retain_grad() + output.backward(dOutput) + + dQuery = query.grad + dKey = key.grad + dValue = value.grad + + query.grad = None + key.grad = None + value.grad = None + torch.cuda.synchronize() + + matched_kernels = [ + item + for item in prof.key_averages() + if "cudnn" in item.key + or item.key.startswith("kernel_cutlass") + or "pytorch_flash::" in item.key + or "flash::" in item.key + or "at::native::" in item.key + or "cutlass3x" in item.key + or "(anonymous namespace)::" in item.key + or item.key.startswith("fmha_") + ] + if len(matched_kernels) >= 1: + bwd_time = sum(item.device_time for item in matched_kernels) / 1000 + if i >= dry_run_iters: + backward_times.append(bwd_time) + + sleep_time = np.min([bwd_time / 100, 1.0]) if run_bwd and len(matched_kernels) >= 1 else 0.0 + time.sleep(sleep_time) + + dQuery, dKey, dValue, dOutput = postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, args.sdpa_backend) + + ( + query, + key, + value, + output, + ) = postprocess_qkvo(query, key, value, output, args.sdpa_backend) + if args.data_type != "fp8" and not args.skip_ref and run_fwd: + try: + output_ref = flash_attention_4_sdpa(query, key, value) + if run_bwd: + query.retain_grad() + key.retain_grad() + value.retain_grad() + output_ref.backward(dOutput) + + torch.testing.assert_close(dQuery, query.grad, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(dKey, key.grad, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(dValue, value.grad, rtol=2e-2, atol=2e-2) + + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) + forward_diffs.append(torch.max(torch.abs(output.detach() - output_ref.detach())).item()) + except Exception as e: + if first_error: + print( + f"[WARN] Failed reference check. Target backend has been run, but output has not been validated. Failure may be due to incorrect output or reference function failure." + ) + print(f"[WARN] See error message: {e}") + first_error = False + forward_diffs.append(0.0) + else: forward_diffs.append(0.0) - else: - forward_diffs.append(0.0) - time.sleep(sleep_time) + time.sleep(sleep_time) + + if args.sdpa_backend == "cudnn": + del query, key, value, output, dQuery, dKey, dValue, dOutput, stats + else: + del query, key, value, output - if args.sdpa_backend == "cudnn_fe": - del query, key, value, output, dQuery, dKey, dValue, dOutput, stats - else: - del query, key, value, output - -## print results -fwd_median_time = np.median(np.array(forward_times[5:])) -fwd_tflops = tflops_per_sec( - args.batch_size, - args.q_seqlen, - args.kv_seqlen, - head_dim_qk, - head_dim_vo, - args.num_q_heads, - args.attn_mask, - fwd_median_time, - "fwd", -) -if args.fwd_bwd: - bwd_median_time = np.median(np.array(backward_times[5:])) - bwd_tflops = tflops_per_sec( - args.batch_size, - args.q_seqlen, - args.kv_seqlen, - head_dim_qk, - head_dim_vo, - args.num_q_heads, - args.attn_mask, - bwd_median_time, - "bwd", + ## print results + fwd_median_time = ( + np.median(np.array(forward_times[5:])) if len(forward_times) > 5 else (np.median(np.array(forward_times)) if len(forward_times) > 0 else 0.0) ) - if args.format_output: - print( - f"{args.case_tag},{args.sdpa_backend},{args.batch_size},{args.q_seqlen},{args.kv_seqlen},{args.num_q_heads},{args.num_kv_heads},{head_dim_qk},{fwd_median_time:.3f},{bwd_median_time:.3f},{fwd_tflops:.0f},{bwd_tflops:.0f},{np.max(np.array(forward_diffs[5:])):.6f},{num_iters}" + fwd_tflops = 0.0 + if run_fwd and fwd_median_time > 0: + fwd_tflops = tflops_per_sec( + args.batch_size, + args.q_seqlen, + args.kv_seqlen, + head_dim_qk, + head_dim_vo, + args.num_q_heads, + args.attn_mask, + fwd_median_time, + "fwd", ) - else: - print( - f"{args.sdpa_backend}:: Median (fwd, bwd) Execution Times: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS), {bwd_median_time:.3f} ms ({bwd_tflops:.0f} TFLOPS) (max difference vs. pyt_reference: {np.max(np.array(forward_diffs[5:])):.6f} from {num_iters} iterations)" + + bwd_median_time = ( + np.median(np.array(backward_times[5:])) if len(backward_times) > 5 else (np.median(np.array(backward_times)) if len(backward_times) > 0 else 0.0) + ) + bwd_tflops = 0.0 + if run_bwd and bwd_median_time > 0: + bwd_tflops = tflops_per_sec( + args.batch_size, + args.q_seqlen, + args.kv_seqlen, + head_dim_qk, + head_dim_vo, + args.num_q_heads, + args.attn_mask, + bwd_median_time, + "bwd", ) -else: + if args.format_output: print( - f"{args.case_tag},{args.sdpa_backend},{args.batch_size},{args.q_seqlen},{args.kv_seqlen},{args.num_q_heads},{args.num_kv_heads},{head_dim_qk},{fwd_median_time:.3f},0,{fwd_tflops:.0f},0,{np.max(np.array(forward_diffs[5:])):.6f},{num_iters}" + f"{args.case_tag},{args.sdpa_backend},{args.batch_size},{args.q_seqlen},{args.kv_seqlen},{args.num_q_heads},{args.num_kv_heads},{head_dim_qk},{fwd_median_time:.3f},{bwd_median_time:.3f},{fwd_tflops:.0f},{bwd_tflops:.0f},{(np.max(np.array(forward_diffs[5:])) if len(forward_diffs) > 5 else (np.max(np.array(forward_diffs)) if len(forward_diffs) > 0 else 0.0)):.6f},{num_iters}" ) else: - print( - f"{args.sdpa_backend}:: Median (fwd) Execution Times: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS) (max difference vs. pyt_reference: {np.max(np.array(forward_diffs[5:])):.6f} from {num_iters} iterations)" - ) + if run_fwd and run_bwd: + print( + f"{args.sdpa_backend}:: Median (fwd, bwd) Execution Times: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS), {bwd_median_time:.3f} ms ({bwd_tflops:.0f} TFLOPS)" + ) + elif run_fwd: + print(f"{args.sdpa_backend}:: Median (fwd) Execution Time: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS)") + elif run_bwd: + print(f"{args.sdpa_backend}:: Median (bwd) Execution Time: {bwd_median_time:.3f} ms ({bwd_tflops:.0f} TFLOPS)") diff --git a/benchmark/sdpa_benchmark_training/charts.py b/benchmark/sdpa_benchmark_training/charts.py new file mode 100644 index 00000000..7c9872d5 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/charts.py @@ -0,0 +1,441 @@ +""" +Chart generation for SDPA benchmark results. + +Generates comparison bar charts showing backend performance side-by-side. +""" + +from pathlib import Path +from typing import Optional, TYPE_CHECKING +import logging + +if TYPE_CHECKING: + import pandas as pd + from .config_types import BenchmarkConfig + +logger = logging.getLogger(__name__) + +# Backend display configuration +# Each backend has a base color; FP8 variants get a darker/different shade +BACKEND_CONFIG = { + "cudnn": {"name": "cudnn", "color": "#76b900", "color_fp8": "#4a7500", "order": 0}, + "pyt_cudnn": {"name": "cuDNN (PyTorch)", "color": "#90EE90", "color_fp8": "#228B22", "order": 1}, + "pyt_flash_attention": {"name": "FAv2 (PyTorch)", "color": "#6495ED", "color_fp8": "#0000CD", "order": 2}, + "pyt_efficient_attention": {"name": "xFormers (PyTorch)", "color": "#FF00FF", "color_fp8": "#8B008B", "order": 3}, + "pyt_math": {"name": "Standard Attention", "color": "#FF8C00", "color_fp8": "#D2691E", "order": 4}, + "flash_attention": {"name": "FAv2 (Native)", "color": "#F08080", "color_fp8": "#CD5C5C", "order": 5}, + "flash_attention_3": {"name": "FAv3", "color": "#FFA500", "color_fp8": "#FF6600", "order": 6}, + "flash_attention_4": {"name": "FAv4", "color": "#FFD700", "color_fp8": "#DAA520", "order": 7}, +} + +# Font sizes for plot elements +LABEL_FONT_SIZE = 10 +LEGEND_FONT_SIZE = 8 +TITLE_FONT_SIZE = 12 +BAR_LABEL_FONT_SIZE = 6 + + +def get_backend_display_name(backend: str, data_type: str) -> str: + """ + Get display name for backend+dtype combination. + + Args: + backend: Backend name (e.g., "cudnn") + data_type: Data type (e.g., "bfloat16", "fp8") + + Returns: + Display name for legend (e.g., "cuDNN FE (FP8)") + """ + base_name = BACKEND_CONFIG.get(backend, {}).get("name", backend) + if data_type == "fp8": + return f"{base_name} (FP8)" + elif data_type == "float16": + return f"{base_name} (FP16)" + return base_name + + +def get_backend_color(backend: str, data_type: str) -> str: + """ + Get color for backend+dtype combination. + + Args: + backend: Backend name + data_type: Data type + + Returns: + Color string for matplotlib + """ + config = BACKEND_CONFIG.get(backend, {}) + if data_type == "fp8" and "color_fp8" in config: + return config["color_fp8"] + return config.get("color", "gray") + + +def generate_comparison_chart( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_path: Optional[Path] = None, +) -> Path: + """ + Generate comparison bar chart with multiple backends side-by-side. + + Creates a figure with: + - Left subplot: Forward pass TFLOPS by configuration + - Right subplot: Backward pass TFLOPS by configuration + - Each backend+dtype combo as a separate bar group + + Args: + df: DataFrame with benchmark results (from BenchmarkRunner.results_to_dataframe) + config: BenchmarkConfig used for the run + output_path: Optional path for output file. If None, uses config.output_dir + + Returns: + Path to the saved chart file + """ + import matplotlib.pyplot as plt + import seaborn as sns + import numpy as np + + # Filter to successful results only + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + # Create backend+dtype display name for legend + df["backend_display"] = df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + + # Create config label for x-axis (model/seqlen/mask) + df["config_label"] = df.apply( + lambda r: f"{r['model_name']}\n{r['q_seqlen']}x{r['kv_seqlen']}\n{r['attn_mask']}", + axis=1, + ) + + # Sort by backend order for consistent legend + df["backend_order"] = df["backend"].map(lambda b: BACKEND_CONFIG.get(b, {}).get("order", 99)) + df.sort_values(["model_name", "q_seqlen", "attn_mask", "backend_order"], inplace=True) + + # Build color palette based on unique backend+dtype combinations + # Get unique (backend, data_type, backend_display) tuples to map colors correctly + unique_combos = df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Determine if we have fwd/bwd data + has_fwd = (df["fwd_tflops"] > 0).any() + has_bwd = (df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, axes = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + ax_fwd, ax_bwd = axes + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + elif has_bwd: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + else: + raise ValueError("No forward or backward TFLOPS data to plot") + + # Calculate y-axis limit + max_tflops = max( + df["fwd_tflops"].max() if has_fwd else 0, + df["bwd_tflops"].max() if has_bwd else 0, + ) + ylim_max = max_tflops * 1.15 # Add 15% headroom for labels + + # Plot forward pass + if ax_fwd is not None: + fwd_df = df[df["fwd_tflops"] > 0] + if not fwd_df.empty: + sns.barplot( + data=fwd_df, + x="config_label", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Configuration", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title("SDPA Forward Pass", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45, labelsize=8) + ax_fwd.tick_params(axis="y", labelsize=LABEL_FONT_SIZE) + ax_fwd.set_ylim(0, ylim_max) + + # Add value labels on bars + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + # Plot backward pass + if ax_bwd is not None: + bwd_df = df[df["bwd_tflops"] > 0] + if not bwd_df.empty: + sns.barplot( + data=bwd_df, + x="config_label", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Configuration", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title("SDPA Backward Pass", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45, labelsize=8) + ax_bwd.tick_params(axis="y", labelsize=LABEL_FONT_SIZE) + ax_bwd.set_ylim(0, ylim_max) + + # Add value labels on bars + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + + # Determine output path + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{config.name}_comparison.png" + + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + logger.info(f"Chart saved to {output_path}") + return output_path + + +def generate_charts_by_mask( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_dir: Optional[Path] = None, +) -> list: + """ + Generate separate charts for each mask type. + + This creates cleaner charts when benchmarking both causal and non-causal masks. + Each chart shows seqlen on x-axis and backends as grouped bars. + + Args: + df: DataFrame with benchmark results + config: BenchmarkConfig used for the run + output_dir: Directory for output files + + Returns: + List of paths to saved chart files + """ + import matplotlib.pyplot as plt + import seaborn as sns + + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + if output_dir is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + saved_paths = [] + masks = df["attn_mask"].unique() + + for mask in masks: + mask_df = df[df["attn_mask"] == mask].copy() + + # Create display names + mask_df["backend_display"] = mask_df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + mask_df["seqlen_label"] = mask_df.apply(lambda r: f"{r['q_seqlen']}x{r['kv_seqlen']}", axis=1) + + # Build palette + unique_combos = mask_df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Sort + mask_df["backend_order"] = mask_df["backend"].map(lambda b: BACKEND_CONFIG.get(b, {}).get("order", 99)) + mask_df.sort_values(["q_seqlen", "backend_order"], inplace=True) + + has_fwd = (mask_df["fwd_tflops"] > 0).any() + has_bwd = (mask_df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, (ax_fwd, ax_bwd) = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + else: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + + mask_title = "Causal" if mask == "top_left" else "Non-Causal" if mask == "no_mask" else mask + + if ax_fwd is not None: + fwd_df = mask_df[mask_df["fwd_tflops"] > 0] + if not fwd_df.empty: + sns.barplot( + data=fwd_df, + x="seqlen_label", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title(f"{config.name} Forward ({mask_title})", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45) + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + if ax_bwd is not None: + bwd_df = mask_df[mask_df["bwd_tflops"] > 0] + if not bwd_df.empty: + sns.barplot( + data=bwd_df, + x="seqlen_label", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title(f"{config.name} Backward ({mask_title})", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45) + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + output_path = output_dir / f"{config.name}_{mask}.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + saved_paths.append(output_path) + logger.info(f"Chart saved to {output_path}") + + return saved_paths + + +def generate_seqlen_scaling_chart( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_path: Optional[Path] = None, +) -> Path: + """ + Generate a chart showing performance scaling with sequence length. + + This chart is useful when benchmarking multiple sequence lengths with + the same model configuration. + + Args: + df: DataFrame with benchmark results + config: BenchmarkConfig used for the run + output_path: Optional path for output file + + Returns: + Path to the saved chart file + """ + import matplotlib.pyplot as plt + import seaborn as sns + + # Filter to successful results only + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + # Create backend+dtype display name + df["backend_display"] = df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + + # Use q_seqlen for x-axis (assuming symmetric seqlens for this chart) + df["seqlen"] = df["q_seqlen"] + + # Build color palette based on unique backend+dtype combinations + unique_combos = df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Create figure + has_fwd = (df["fwd_tflops"] > 0).any() + has_bwd = (df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, axes = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + ax_fwd, ax_bwd = axes + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + else: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + + # Plot forward + if ax_fwd is not None and has_fwd: + fwd_df = df[df["fwd_tflops"] > 0] + sns.barplot( + data=fwd_df, + x="seqlen", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title("SDPA Forward Pass", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45) + + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + # Plot backward + if ax_bwd is not None and has_bwd: + bwd_df = df[df["bwd_tflops"] > 0] + sns.barplot( + data=bwd_df, + x="seqlen", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title("SDPA Backward Pass", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45) + + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + + # Determine output path + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{config.name}_seqlen_scaling.png" + + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + logger.info(f"Chart saved to {output_path}") + return output_path diff --git a/benchmark/sdpa_benchmark_training/config_types.py b/benchmark/sdpa_benchmark_training/config_types.py new file mode 100644 index 00000000..ee9df6e5 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/config_types.py @@ -0,0 +1,183 @@ +""" +Core types for the SDPA benchmark configuration system. + +This module defines the dataclasses used to configure and collect results +from SDPA benchmarks. +""" + +from dataclasses import dataclass, field +from typing import Optional, List, Tuple + + +@dataclass +class ModelPreset: + """ + Represents a named model configuration preset. + + Defines the attention head configuration for a specific model architecture. + Can use either symmetric head dimensions (head_dim) or asymmetric + (head_dim_qk, head_dim_vo) for models like DeepSeek V3. + + Attributes: + name: Identifier for this preset (e.g., "llama3.1", "dsv3") + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads (differs from num_q_heads for GQA) + head_dim: Head dimension (used if head_dim_qk/vo not specified) + head_dim_qk: Head dimension for Q/K tensors (optional, for asymmetric) + head_dim_vo: Head dimension for V/O tensors (optional, for asymmetric) + + Example: + # Symmetric head dimensions (Llama 3.1) + LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, + ) + + # Asymmetric head dimensions (DeepSeek V3) + DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, + head_dim_vo=128, + ) + """ + + name: str + num_q_heads: int + num_kv_heads: int + head_dim: int = 128 + head_dim_qk: Optional[int] = None + head_dim_vo: Optional[int] = None + + def __post_init__(self): + """Resolve head dimensions after initialization.""" + if self.head_dim_qk is None: + self.head_dim_qk = self.head_dim + if self.head_dim_vo is None: + self.head_dim_vo = self.head_dim + + +@dataclass +class BenchmarkConfig: + """ + Configuration for a benchmark suite. + + Defines a set of benchmarks to run. The runner will expand this into + individual benchmark cases via cartesian product of: + models x seqlens x backends x data_types x attn_masks x deterministic_bwd + + Attributes: + name: Identifier for this config (used in output filenames) + models: List of ModelPreset to benchmark + seqlens: List of (q_seqlen, kv_seqlen) tuples + backends: List of backend names (e.g., ["cudnn", "flash_attention_4"]) + data_types: List of data types (e.g., ["bfloat16", "fp8"]) + attn_masks: List of attention masks (e.g., ["top_left", "no_mask"]) + profile_pass: Which pass to profile ("fwd", "bwd", or "both") + batch_size: Batch size for all benchmarks + num_iterations: Number of iterations per benchmark + num_warmup_iterations: Warmup iterations before measurement + skip_ref: Skip reference validation + deterministic_bwd: List of deterministic modes to test for backward pass + output_dir: Directory for output files + + Example: + CONFIG = BenchmarkConfig( + name="my_benchmark", + models=[LLAMA3_1, DSV3], + seqlens=[(4096, 4096), (8192, 8192)], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], + profile_pass="fwd", + ) + """ + + name: str + models: List[ModelPreset] + seqlens: List[Tuple[int, int]] + backends: List[str] = field(default_factory=lambda: ["cudnn"]) + data_types: List[str] = field(default_factory=lambda: ["bfloat16"]) + attn_masks: List[str] = field(default_factory=lambda: ["top_left"]) + profile_pass: str = "fwd" + batch_size: int = 1 + num_iterations: int = 10 + num_warmup_iterations: int = 0 + skip_ref: bool = True + deterministic_bwd: List[bool] = field(default_factory=lambda: [False]) + output_dir: str = "../results" + + +@dataclass +class BenchmarkResult: + """ + Result from a single benchmark execution. + + Contains both the configuration that was run and the measured results. + + Attributes: + config_name: Name of the BenchmarkConfig this result belongs to + model_name: Name of the ModelPreset used + backend: Backend that was used + data_type: Data type that was used + attn_mask: Attention mask that was used + batch_size: Batch size + q_seqlen: Query sequence length + kv_seqlen: Key/value sequence length + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads + head_dim_qk: Head dimension for Q/K + head_dim_vo: Head dimension for V/O + profile_pass: Which pass was profiled + deterministic_bwd: Whether deterministic backward was used + fwd_time_ms: Forward pass time in milliseconds + bwd_time_ms: Backward pass time in milliseconds (0 if not run) + fwd_tflops: Forward pass throughput in TFLOPS + bwd_tflops: Backward pass throughput in TFLOPS + max_diff: Maximum difference vs reference (if validated) + num_iterations: Number of iterations run + success: Whether the benchmark completed successfully + error_message: Error message if benchmark failed + gpu_name: Name of the GPU used + cudnn_version: cuDNN version string + """ + + # Config identification + config_name: str + model_name: str + backend: str + data_type: str + attn_mask: str + + # Dimensions + batch_size: int + q_seqlen: int + kv_seqlen: int + num_q_heads: int + num_kv_heads: int + head_dim_qk: int + head_dim_vo: int + + # Execution options + profile_pass: str + deterministic_bwd: bool + + # Results + fwd_time_ms: float + bwd_time_ms: float + fwd_tflops: float + bwd_tflops: float + max_diff: float + num_iterations: int + + # Status + success: bool = True + error_message: Optional[str] = None + + # Metadata + gpu_name: Optional[str] = None + cudnn_version: Optional[str] = None + cudnn_backend_version: Optional[int] = None diff --git a/benchmark/sdpa_benchmark_training/configs/__init__.py b/benchmark/sdpa_benchmark_training/configs/__init__.py new file mode 100644 index 00000000..37bf2d67 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/configs/__init__.py @@ -0,0 +1,62 @@ +""" +Benchmark configuration loading utilities. + +This module provides functions to load benchmark configurations by name. +""" + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..config_types import BenchmarkConfig + + +def load_config(name: str) -> "BenchmarkConfig": + """ + Load a benchmark configuration by name. + + Configurations are Python modules in the configs directory. + Each module should define a CONFIG variable of type BenchmarkConfig. + + Args: + name: Name of the config (without .py extension) + + Returns: + BenchmarkConfig instance + + Raises: + ValueError: If config not found or doesn't define CONFIG + + Example: + config = load_config("mlperf") + print(config.name) # "mlperf" + """ + try: + module = importlib.import_module(f".{name}", package=__package__) + except ModuleNotFoundError: + raise ValueError(f"Config '{name}' not found. " f"Create a file at configs/{name}.py with a CONFIG variable.") + + if not hasattr(module, "CONFIG"): + raise ValueError(f"Config module '{name}' must define a CONFIG variable of type BenchmarkConfig") + + return module.CONFIG + + +def list_configs() -> list: + """ + List available config names. + + Returns: + List of config names (without .py extension) + """ + import os + from pathlib import Path + + configs_dir = Path(__file__).parent + configs = [] + + for f in configs_dir.iterdir(): + if f.suffix == ".py" and f.stem != "__init__": + configs.append(f.stem) + + return sorted(configs) diff --git a/benchmark/sdpa_benchmark_training/configs/dsv3.py b/benchmark/sdpa_benchmark_training/configs/dsv3.py new file mode 100644 index 00000000..404842c7 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/configs/dsv3.py @@ -0,0 +1,41 @@ +""" +DeepSeek V3 SDPA Benchmark Configuration + +Benchmarks DeepSeek V3-style MHA with asymmetric head dimensions. +Only causal (top_left) mask - no non-causal benchmarks needed. +Includes forward and backward pass benchmarking with deterministic mode options. + +Usage: + python -m benchmark.sdpa_benchmark_training.runner --config dsv3 + python -m benchmark.sdpa_benchmark_training.runner --config dsv3 --dry-run +""" + +from ..config_types import ModelPreset, BenchmarkConfig + +DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, + head_dim_vo=128, +) + +CONFIG = BenchmarkConfig( + name="dsv3", + models=[DSV3], + seqlens=[ + (32768, 32768), + (16384, 16384), + (8192, 8192), + (4096, 4096), + (2048, 2048), + ], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left"], # Causal only + profile_pass="both", # Forward and backward + deterministic_bwd=[True], + batch_size=1, + num_iterations=10, + output_dir="results", +) diff --git a/benchmark/sdpa_benchmark_training/configs/llama.py b/benchmark/sdpa_benchmark_training/configs/llama.py new file mode 100644 index 00000000..803db837 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/configs/llama.py @@ -0,0 +1,39 @@ +""" +Llama 3.1 SDPA Benchmark Configuration + +Benchmarks Llama 3.1 405B-style GQA attention with both causal and non-causal masks. +Includes forward and backward pass benchmarking with deterministic mode options. + +Usage: + python -m benchmark.sdpa_benchmark_training.runner --config llama + python -m benchmark.sdpa_benchmark_training.runner --config llama --dry-run +""" + +from ..config_types import ModelPreset, BenchmarkConfig + +LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, +) + +CONFIG = BenchmarkConfig( + name="llama3.1", + models=[LLAMA3_1], + seqlens=[ + (32768, 32768), + (16384, 16384), + (8192, 8192), + (4096, 4096), + (2048, 2048), + ], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], # Both causal and non-causal + profile_pass="both", # Forward and backward + deterministic_bwd=[False], + batch_size=1, + num_iterations=10, + output_dir="results", +) diff --git a/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv new file mode 100644 index 00000000..cabd5912 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv @@ -0,0 +1,41 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,1,32768,32768,128,128,192,128,both,True,24.538,87.230,1792.000,1311.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,32768,32768,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,16384,16384,128,128,192,128,both,True,6.476,22.025,1698.000,1298.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,16384,16384,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,8192,8192,128,128,192,128,both,True,1.831,5.875,1501.000,1217.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,8192,8192,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,4096,4096,128,128,192,128,both,True,0.519,1.650,1324.000,1083.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,4096,4096,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,2048,2048,128,128,192,128,both,True,0.163,0.520,1053.000,859.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,2048,2048,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, diff --git a/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png new file mode 100644 index 00000000..a79e4809 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv new file mode 100644 index 00000000..11f9fa9f --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv @@ -0,0 +1,21 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,32768,32768,64,8,128,128,both,False,10.436,30.513,1686.000,1441.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,32768,32768,64,8,128,128,both,False,20.041,59.879,1756.000,1469.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,32768,32768,64,8,128,128,both,False,8.317,25.675,2115.000,1713.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,32768,32768,64,8,128,128,both,False,16.521,49.482,2130.000,1778.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,16384,16384,64,8,128,128,both,False,2.672,8.018,1646.000,1371.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,16384,16384,64,8,128,128,both,False,5.037,15.384,1746.000,1429.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,16384,16384,64,8,128,128,both,False,2.182,6.730,2016.000,1634.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,16384,16384,64,8,128,128,both,False,4.240,12.707,2075.000,1731.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,8192,8192,64,8,128,128,both,False,0.704,2.150,1563.000,1279.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,8192,8192,64,8,128,128,both,False,1.313,3.980,1675.000,1381.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,8192,8192,64,8,128,128,both,False,0.591,1.851,1862.000,1485.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,8192,8192,64,8,128,128,both,False,1.133,3.385,1941.000,1624.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,4096,4096,64,8,128,128,both,False,0.212,0.622,1297.000,1105.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,4096,4096,64,8,128,128,both,False,0.350,1.090,1569.000,1261.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,4096,4096,64,8,128,128,both,False,0.172,0.555,1602.000,1239.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,4096,4096,64,8,128,128,both,False,0.299,0.941,1841.000,1461.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,2048,2048,64,8,128,128,both,False,0.067,0.209,1022.000,824.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,2048,2048,64,8,128,128,both,False,0.112,0.321,1232.000,1070.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,2048,2048,64,8,128,128,both,False,0.057,0.190,1215.000,905.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,2048,2048,64,8,128,128,both,False,0.090,0.284,1521.000,1210.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 diff --git a/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png new file mode 100644 index 00000000..ac4bced3 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png new file mode 100644 index 00000000..f5f3a306 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv new file mode 100644 index 00000000..9dac38b4 --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv @@ -0,0 +1,41 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,1,32768,32768,128,128,192,128,both,True,21.319,80.520,2063.000,1420.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,32768,32768,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,16384,16384,128,128,192,128,both,True,5.584,20.381,1969.000,1403.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,16384,16384,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,8192,8192,128,128,192,128,both,True,1.518,5.412,1811.000,1321.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,8192,8192,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,4096,4096,128,128,192,128,both,True,0.438,1.541,1570.000,1160.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,4096,4096,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,2048,2048,128,128,192,128,both,True,0.148,0.493,1158.000,906.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,2048,2048,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, diff --git a/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png new file mode 100644 index 00000000..39cdbe44 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv new file mode 100644 index 00000000..433b02ce --- /dev/null +++ b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv @@ -0,0 +1,21 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,32768,32768,64,8,128,128,both,False,8.663,28.331,2031.000,1552.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,32768,32768,64,8,128,128,both,False,17.400,56.680,2022.000,1552.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,32768,32768,64,8,128,128,both,False,5.942,23.707,2961.000,1855.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,32768,32768,64,8,128,128,both,False,11.782,45.618,2986.000,1928.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,16384,16384,64,8,128,128,both,False,2.202,7.361,1998.000,1494.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,16384,16384,64,8,128,128,both,False,4.396,14.124,2001.000,1557.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,16384,16384,64,8,128,128,both,False,1.577,6.233,2789.000,1764.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,16384,16384,64,8,128,128,both,False,3.025,11.772,2907.000,1868.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,8192,8192,64,8,128,128,both,False,0.571,1.976,1927.000,1391.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,8192,8192,64,8,128,128,both,False,1.118,3.670,1967.000,1498.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,8192,8192,64,8,128,128,both,False,0.434,1.728,2534.000,1591.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,8192,8192,64,8,128,128,both,False,0.807,3.154,2724.000,1743.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,4096,4096,64,8,128,128,both,False,0.164,0.574,1679.000,1198.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,4096,4096,64,8,128,128,both,False,0.289,1.016,1901.000,1352.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,4096,4096,64,8,128,128,both,False,0.129,0.527,2136.000,1305.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,4096,4096,64,8,128,128,both,False,0.213,0.884,2580.000,1554.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,2048,2048,64,8,128,128,both,False,0.054,0.191,1265.000,900.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,2048,2048,64,8,128,128,both,False,0.088,0.299,1559.000,1151.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,2048,2048,64,8,128,128,both,False,0.044,0.181,1574.000,947.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,2048,2048,64,8,128,128,both,False,0.066,0.275,2086.000,1251.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 diff --git a/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png new file mode 100644 index 00000000..312bab6f Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png differ diff --git a/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png new file mode 100644 index 00000000..dc6fd4d4 Binary files /dev/null and b/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png differ diff --git a/benchmark/sdpa_benchmark_training/runner.py b/benchmark/sdpa_benchmark_training/runner.py new file mode 100644 index 00000000..e1201eae --- /dev/null +++ b/benchmark/sdpa_benchmark_training/runner.py @@ -0,0 +1,505 @@ +""" +Benchmark runner with configuration expansion, execution, and result collection. + +This module provides the BenchmarkRunner class for running SDPA benchmarks +from configuration files, and a CLI entry point. + +Usage: + # Run from command line + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Import and use programmatically + from benchmark.sdpa_benchmark_training.runner import BenchmarkRunner + from benchmark.sdpa_benchmark_training.configs import load_config + + config = load_config("mlperf") + runner = BenchmarkRunner() + results = runner.run_config(config) + runner.save_csv(results, config) +""" + +import itertools +import logging +import sys +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Iterator, List, Optional + +from .config_types import BenchmarkConfig, BenchmarkResult, ModelPreset + +logger = logging.getLogger(__name__) + + +def log_environment_info(): + """Log environment information (torch, CUDA, cuDNN, flash_attn versions).""" + try: + import torch + + logger.info(f"torch.__version__ = '{torch.__version__}'") + logger.info(f"torch.version.cuda = '{torch.version.cuda}'") + logger.info(f"torch.cuda.is_available() = {torch.cuda.is_available()}") + if torch.cuda.is_available(): + logger.info(f"torch.cuda.device_count() = {torch.cuda.device_count()}") + logger.info(f"torch.cuda.current_device() = {torch.cuda.current_device()}") + logger.info(f"torch.cuda.get_device_name(torch.cuda.current_device()) = '{torch.cuda.get_device_name(torch.cuda.current_device())}'") + logger.info(f"torch.backends.cudnn.enabled = {torch.backends.cudnn.enabled}") + except ImportError: + logger.warning("torch not available") + + try: + import cudnn + + logger.info(f"cuDNN Backend Version: cudnn.backend_version() = {cudnn.backend_version()}") + logger.info(f"cuDNN Frontend Version: cudnn.__version__ = '{cudnn.__version__}'") + except ImportError: + logger.warning("cudnn not available") + + try: + import flash_attn + + logger.info(f"flash_attn.__version__ = '{flash_attn.__version__}'") + except ImportError: + pass # flash_attn is optional + + +class BenchmarkRunner: + """ + Runs benchmarks from configurations with cartesian product expansion. + + The runner takes a BenchmarkConfig and expands it into individual benchmark + cases via cartesian product of all configuration dimensions. Each case is + then executed and results are collected. + + Attributes: + verbose: Whether to print progress information + + Example: + runner = BenchmarkRunner(verbose=True) + config = load_config("mlperf") + + # Dry run to see what would be executed + for case in runner.expand_config(config): + print(case) + + # Actually run the benchmarks + results = runner.run_config(config) + runner.save_csv(results, config) + """ + + def __init__(self, verbose: bool = True): + """ + Initialize the runner. + + Args: + verbose: Whether to print progress information + """ + self.verbose = verbose + self._setup_logging() + + def _setup_logging(self): + """Configure logging based on verbosity setting.""" + level = logging.INFO if self.verbose else logging.WARNING + logging.basicConfig( + level=level, + format="[%(levelname)s] %(message)s", + stream=sys.stderr, + ) + + def expand_config(self, config: BenchmarkConfig) -> Iterator[Dict[str, Any]]: + """ + Expand a BenchmarkConfig into individual benchmark cases. + + Performs cartesian product expansion over: + models x seqlens x backends x data_types x attn_masks x deterministic_bwd + + Args: + config: BenchmarkConfig to expand + + Yields: + Dict containing all parameters for a single benchmark run + """ + for model, (q_seqlen, kv_seqlen), backend, data_type, attn_mask, det_bwd in itertools.product( + config.models, + config.seqlens, + config.backends, + config.data_types, + config.attn_masks, + config.deterministic_bwd, + ): + # Skip deterministic mode for forward-only runs + if det_bwd and config.profile_pass == "fwd": + continue + + yield { + "config_name": config.name, + "model": model, + "q_seqlen": q_seqlen, + "kv_seqlen": kv_seqlen, + "backend": backend, + "data_type": data_type, + "attn_mask": attn_mask, + "profile_pass": config.profile_pass, + "batch_size": config.batch_size, + "num_iterations": config.num_iterations, + "num_warmup_iterations": config.num_warmup_iterations, + "skip_ref": config.skip_ref, + "deterministic_bwd": det_bwd, + } + + def run_single(self, case: Dict[str, Any]) -> BenchmarkResult: + """ + Run a single benchmark case. + + Calls the run_benchmark() function from benchmark_single_sdpa.py + and wraps the result in a BenchmarkResult. + + Args: + case: Dict containing benchmark parameters (from expand_config) + + Returns: + BenchmarkResult with timing data or error information + """ + model: ModelPreset = case["model"] + + try: + # Import here to avoid circular imports and allow the module to be + # used even if torch/cudnn aren't installed (for dry-run mode) + from .benchmark_single_sdpa import run_benchmark + + result = run_benchmark( + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + data_type=case["data_type"], + backend=case["backend"], + attn_mask=case["attn_mask"], + profile_pass=case["profile_pass"], + num_iterations=case["num_iterations"], + num_warmup_iterations=case["num_warmup_iterations"], + skip_ref=case["skip_ref"], + deterministic_bwd=case["deterministic_bwd"], + ) + + return BenchmarkResult( + config_name=case["config_name"], + model_name=model.name, + backend=case["backend"], + data_type=case["data_type"], + attn_mask=case["attn_mask"], + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + profile_pass=case["profile_pass"], + deterministic_bwd=case["deterministic_bwd"], + fwd_time_ms=result["fwd_time_ms"], + bwd_time_ms=result["bwd_time_ms"], + fwd_tflops=result["fwd_tflops"], + bwd_tflops=result["bwd_tflops"], + max_diff=result["max_diff"], + num_iterations=case["num_iterations"], + success=True, + gpu_name=result.get("gpu_name"), + cudnn_version=result.get("cudnn_version"), + cudnn_backend_version=result.get("cudnn_backend_version"), + ) + + except Exception as e: + logger.error(f"Benchmark failed: {e}") + return BenchmarkResult( + config_name=case["config_name"], + model_name=model.name, + backend=case["backend"], + data_type=case["data_type"], + attn_mask=case["attn_mask"], + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + profile_pass=case["profile_pass"], + deterministic_bwd=case["deterministic_bwd"], + fwd_time_ms=float("inf"), + bwd_time_ms=float("inf"), + fwd_tflops=0.0, + bwd_tflops=0.0, + max_diff=0.0, + num_iterations=case["num_iterations"], + success=False, + error_message=str(e), + ) + + def run_config( + self, + config: BenchmarkConfig, + filter_model: Optional[str] = None, + filter_backend: Optional[str] = None, + filter_dtype: Optional[str] = None, + ) -> List[BenchmarkResult]: + """ + Run all benchmarks from a configuration. + + Args: + config: BenchmarkConfig to run + filter_model: Optional model name filter (substring match) + filter_backend: Optional backend filter (exact match) + filter_dtype: Optional data type filter (exact match) + + Returns: + List of BenchmarkResult for all executed cases + """ + # Log environment info at the start + log_environment_info() + logger.info("") # Blank line for readability + + results = [] + cases = list(self.expand_config(config)) + + # Apply filters + if filter_model: + cases = [c for c in cases if filter_model in c["model"].name] + if filter_backend: + cases = [c for c in cases if c["backend"] == filter_backend] + if filter_dtype: + cases = [c for c in cases if c["data_type"] == filter_dtype] + + if not cases: + logger.warning("No benchmark cases to run after applying filters") + return results + + logger.info(f"Running {len(cases)} benchmark cases from config '{config.name}'") + + for i, case in enumerate(cases, 1): + model = case["model"] + det_str = "det" if case["deterministic_bwd"] else "non-det" + logger.info( + f"[{i}/{len(cases)}] {model.name} | " + f"seq={case['q_seqlen']}x{case['kv_seqlen']} | " + f"{case['backend']} | {case['data_type']} | " + f"{case['attn_mask']} | {det_str}" + ) + + result = self.run_single(case) + results.append(result) + + if result.success: + fwd_info = f"fwd: {result.fwd_time_ms:.3f}ms ({result.fwd_tflops:.0f} TFLOPS)" + bwd_info = f"bwd: {result.bwd_time_ms:.3f}ms ({result.bwd_tflops:.0f} TFLOPS)" + logger.info(f" -> {fwd_info}, {bwd_info}") + else: + logger.warning(f" -> FAILED: {result.error_message}") + + return results + + def results_to_dataframe(self, results: List[BenchmarkResult]): + """ + Convert results to a pandas DataFrame. + + Args: + results: List of BenchmarkResult + + Returns: + pandas DataFrame with all result fields as columns + """ + import pandas as pd + + return pd.DataFrame([asdict(r) for r in results]) + + def save_csv( + self, + results: List[BenchmarkResult], + config: BenchmarkConfig, + output_path: Optional[Path] = None, + ) -> Path: + """ + Save results to a CSV file. + + Args: + results: List of BenchmarkResult + config: BenchmarkConfig (used for default filename) + output_path: Optional explicit output path + + Returns: + Path to the saved CSV file + """ + import pandas as pd + + df = self.results_to_dataframe(results) + + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = output_dir / f"{config.name}_{timestamp}.csv" + + df.to_csv(output_path, index=False, float_format="%.3f") + logger.info(f"Results saved to {output_path}") + + return output_path + + +def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser( + description="Run SDPA benchmarks from configuration files", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run all benchmarks from mlperf config + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + + # Dry run (show what would be executed) + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Filter by model name + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --filter llama3.1 + + # Filter by backend + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --backend cudnn + + # Skip chart generation + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --no-chart + """, + ) + + parser.add_argument( + "--config", + required=True, + help="Config name (e.g., 'mlperf'). Must be a Python file in configs/", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print benchmark cases without executing", + ) + parser.add_argument( + "--filter", + dest="filter_model", + help="Filter by model name (substring match)", + ) + parser.add_argument( + "--backend", + dest="filter_backend", + help="Filter by backend (exact match)", + ) + parser.add_argument( + "--dtype", + dest="filter_dtype", + help="Filter by data type (exact match)", + ) + parser.add_argument( + "--output", + type=Path, + help="Output path for CSV (default: artifacts/_.csv)", + ) + parser.add_argument( + "--no-chart", + action="store_true", + help="Skip chart generation", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List available configurations and exit", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Reduce output verbosity", + ) + + args = parser.parse_args() + + # Handle --list-configs + if args.list_configs: + from .configs import list_configs + + configs = list_configs() + print("Available configurations:") + for name in configs: + print(f" {name}") + return + + # Load config + from .configs import load_config + + try: + config = load_config(args.config) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + runner = BenchmarkRunner(verbose=not args.quiet) + + # Dry run mode + if args.dry_run: + cases = list(runner.expand_config(config)) + + # Apply filters for display + if args.filter_model: + cases = [c for c in cases if args.filter_model in c["model"].name] + if args.filter_backend: + cases = [c for c in cases if c["backend"] == args.filter_backend] + if args.filter_dtype: + cases = [c for c in cases if c["data_type"] == args.filter_dtype] + + print(f"Would run {len(cases)} benchmark cases from config '{config.name}':") + print() + for i, case in enumerate(cases, 1): + model = case["model"] + det_str = "det" if case["deterministic_bwd"] else "non-det" + print( + f" [{i}] {model.name} | " + f"seq={case['q_seqlen']}x{case['kv_seqlen']} | " + f"{case['backend']} | {case['data_type']} | " + f"{case['attn_mask']} | {det_str}" + ) + return + + # Run benchmarks + results = runner.run_config( + config, + filter_model=args.filter_model, + filter_backend=args.filter_backend, + filter_dtype=args.filter_dtype, + ) + + if not results: + print("No results to save", file=sys.stderr) + sys.exit(1) + + # Save CSV + csv_path = runner.save_csv(results, config, args.output) + + # Generate charts (separate chart per mask type for clarity) + if not args.no_chart: + try: + from .charts import generate_charts_by_mask + + df = runner.results_to_dataframe(results) + chart_paths = generate_charts_by_mask(df, config) + for path in chart_paths: + print(f"Chart saved to {path}") + except ImportError as e: + logger.warning(f"Could not generate chart (missing dependency): {e}") + except Exception as e: + logger.warning(f"Could not generate chart: {e}") + + print(f"Results saved to {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index 9a7200aa..fe1a3500 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -22,6 +22,12 @@ #pragma once +// Suppress MSVC warning C4756 (overflow in constant arithmetic) that occurs +// in MSVC's header with certain compiler versions +#ifdef _MSC_VER +#pragma warning(disable : 4756) +#endif + /*! \mainpage CUDNN FRONTEND API * * \section Introduction diff --git a/include/cudnn_frontend/backend/execution_helpers.h b/include/cudnn_frontend/backend/execution_helpers.h index ac727c5e..cfc139e6 100644 --- a/include/cudnn_frontend/backend/execution_helpers.h +++ b/include/cudnn_frontend/backend/execution_helpers.h @@ -43,4 +43,57 @@ create_variant_pack(backend_descriptor& variant_pack, return {error_code_t::OK, ""}; } +inline error_t +create_variant_pack(backend_descriptor& variant_pack, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) { + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Dynamic shapes requires cuDNN v9.18.0"}; + + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91800, cudnn_ver_error); + + CUDNN_FRONTEND_UNUSED(override_uids); + CUDNN_FRONTEND_UNUSED(override_shapes); + CUDNN_FRONTEND_UNUSED(override_strides); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_WORKSPACE, CUDNN_TYPE_VOID_PTR, 1, &workspace_ptr)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, + CUDNN_TYPE_VOID_PTR, + device_ptrs.size(), + device_ptrs.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, CUDNN_TYPE_INT64, uids.size(), uids.data())); + +#if (CUDNN_VERSION >= 91800) + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_UNIQUE_IDS, + CUDNN_TYPE_INT64, + override_uids.size(), + override_uids.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_SHAPES, + CUDNN_TYPE_VOID_PTR, + 1, + (void*)&override_shapes)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_STRIDES, + CUDNN_TYPE_VOID_PTR, + 1, + (void*)&override_strides)); +#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(variant_pack.get_ptr())); + + return {error_code_t::OK, ""}; +} + } // namespace cudnn_frontend::detail diff --git a/include/cudnn_frontend/context.h b/include/cudnn_frontend/context.h index a743c412..8c894b22 100644 --- a/include/cudnn_frontend/context.h +++ b/include/cudnn_frontend/context.h @@ -10,6 +10,7 @@ class Context { DataType_t io_data_type = DataType_t::NOT_SET; int32_t target_sm_count = -1; int32_t target_sm_version = -1; + bool is_dynamic_shape_enabled = false; std::string name = ""; @@ -70,6 +71,17 @@ class Context { return *this; } + Context& + set_dynamic_shape_enabled(bool is_enabled) { + is_dynamic_shape_enabled = is_enabled; + return *this; + } + + bool + get_dynamic_shape_enabled() const { + return is_dynamic_shape_enabled; + } + int32_t get_target_sm_count() const { return target_sm_count; diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index 1874d082..c142c61f 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -181,7 +181,10 @@ class ICudnn { execute_cudnn_plan_with_uid(cudnnHandle_t handle, std::unordered_map const& tensor_uid_to_pointer_map, void* workspace_ptr, - int64_t plan_index) const { + int64_t plan_index, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) const { // Make sure device pointer is provided for all uids expected for this plan std::vector device_ptrs; std::vector uids; @@ -196,10 +199,22 @@ class ICudnn { CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(plan_index)); - CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing plan at index " << plan_index << "."); - - CHECK_CUDNN_FRONTEND_ERROR( - detail::execute(handle, plans.execution_plans[plan_index].get(), device_ptrs, uids, workspace_ptr)); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing plan at index " << plan_index + << " with override uids: " << override_uids.size()); + + if (override_uids.size() == 0) { + CHECK_CUDNN_FRONTEND_ERROR( + detail::execute(handle, plans.execution_plans[plan_index].get(), device_ptrs, uids, workspace_ptr)); + } else { + CHECK_CUDNN_FRONTEND_ERROR(detail::execute(handle, + plans.execution_plans[plan_index].get(), + device_ptrs, + uids, + workspace_ptr, + override_uids, + override_shapes, + override_strides)); + } return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 5e0acea0..5a5da0ad 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -47,6 +47,11 @@ class Graph : public ICudnn, public INode { std::unordered_map deserialized_pass_by_value; std::unordered_map>> deserialized_workspace_modifications; + // Cached values computed during build/deserialize, used during execute to avoid repeated collection. + // These are mutable because execute() is const but needs non-const access for pointer extraction. + mutable std::unordered_map cached_pass_by_value; + mutable std::unordered_map>> cached_workspace_modifications; + // char: 'x'=hex, 'd'=decimal, 'b'=base64 std::vector, char>> tensors_to_dump; @@ -82,10 +87,10 @@ class Graph : public ICudnn, public INode { error_t pre_validate_node() const override final { RETURN_CUDNN_FRONTEND_ERROR_IF( - (is_dynamic_shape_enabled || kernel_cache != nullptr) && detail::get_backend_version() < 90400, + (context.get_dynamic_shape_enabled() || kernel_cache != nullptr) && detail::get_backend_version() < 90400, error_code_t::GRAPH_NOT_SUPPORTED, "Dynamic shapes or kernel caching enabled, but cuDNN version < 9.4!"); - RETURN_CUDNN_FRONTEND_ERROR_IF(((is_dynamic_shape_enabled == false) && (kernel_cache != nullptr)), + RETURN_CUDNN_FRONTEND_ERROR_IF(((context.get_dynamic_shape_enabled() == false) && (kernel_cache != nullptr)), error_code_t::GRAPH_NOT_SUPPORTED, "Kernel caching enabled but dynamic shapes is disabled"); if (detail::get_backend_version() != detail::get_compiled_version()) { @@ -355,26 +360,19 @@ class Graph : public ICudnn, public INode { /////////////////////////////////////// //// PASS BY VALUE TENSOR HANDLING //// /////////////////////////////////////// - // Add pass_by_value data pointers to uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid while - // making the cuda graph. cuda graph will then keep a copy of the kernel parameters, meaning that at the time of - // launching the cuda_graph executable, tensor_to_pass_by_value being deallocated does not affect these cpu - // value's. + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // cuda graph will keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, cached values being deallocated does not affect these cpu values. // No cuda graph nodes are required for handling fe owned pass by value tensors. - std::unordered_map tensor_to_pass_by_value; - CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, tensor_to_pass_by_value)); + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, cached_pass_by_value)); //////////////////////////// //// WORKSPACE HANDLING //// //////////////////////////// - // Get all types of extra calls that FE has to do on user workspace. - std::unordered_map>> workspace_modifications; - int64_t workspace_offset = 0; - CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - - for (auto const &[uid, data] : workspace_modifications) { + // Using cached workspace modifications to avoid repeated tree traversal. + for (auto const &[uid, data] : cached_workspace_modifications) { const auto &[operation_type, offset, vec_data] = data; uid_to_device_ptrs[uid] = static_cast(workspace) + offset; @@ -505,26 +503,19 @@ class Graph : public ICudnn, public INode { /////////////////////////////////////// //// PASS BY VALUE TENSOR HANDLING //// /////////////////////////////////////// - // Add pass_by_value data pointers to uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid while - // making the cuda graph. cuda graph will then keep a copy of the kernel parameters, meaning that at the time of - // launching the cuda_graph executable, tensor_to_pass_by_value being deallocated does not affect these cpu - // value's. + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // cuda graph will keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, cached values being deallocated does not affect these cpu values. // No cuda graph nodes are required for handling fe owned pass by value tensors. - std::unordered_map tensor_to_pass_by_value; - CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, tensor_to_pass_by_value)); + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, cached_pass_by_value)); ///////////////////////////////// //// WORKSPACE HANDLING //// ///////////////////////////////// - // Get all types of extra calls that FE has to do on user workspace. - std::unordered_map>> workspace_modifications; - int64_t workspace_offset = 0; - CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - - for (auto const &[uid, data] : workspace_modifications) { + // Using cached workspace modifications to avoid repeated tree traversal. + for (auto const &[uid, data] : cached_workspace_modifications) { const auto &[operation_type, offset, vec_data] = data; uid_to_device_ptrs[uid] = static_cast(workspace) + offset; @@ -696,12 +687,22 @@ class Graph : public ICudnn, public INode { fe_workspace_size = get_fe_workspace_size_subtree(); + // Cache pass_by_value tensors and workspace modifications for fast execution. + // These are collected once here and reused in every execute() call to avoid + // repeated tree traversal and map allocation overhead. + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(cached_pass_by_value)); + { + int64_t temp_offset = 0; + CHECK_CUDNN_FRONTEND_ERROR( + collect_tensors_in_workspace_subtree(cached_workspace_modifications, temp_offset)); + } + CUDNN_FE_LOG_BANNER(" 4/4 LOWERING TO BACKEND OPERATION GRAPH "); // The method here fuses all operations. There will be 1 operation graph in total. CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_operation_graph(handle)); - if (is_dynamic_shape_enabled && kernel_cache && !kernel_cache->is_finalized()) { + if (context.get_dynamic_shape_enabled() && kernel_cache && !kernel_cache->is_finalized()) { CUDNN_FE_LOG_BANNER(" BUILD KERNEL CACHE "); CHECK_CUDNN_FRONTEND_ERROR(kernel_cache->build(operation_graph->get_raw_desc())); } @@ -768,26 +769,18 @@ class Graph : public ICudnn, public INode { std::unordered_map &tensor_uid_to_pointer_map, void *workspace, void *user_impl = nullptr) { - // Add pass_by_value data pointers to tensor_uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during - // execute. - std::unordered_map tensor_to_pass_by_value; - CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); - + // Add pass_by_value data pointers to tensor_uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, tensor_to_pass_by_value)); + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); - std::unordered_map>> workspace_modifications; - int64_t workspace_offset = 0; - CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - - CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); - CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_workspace_tensors_(tensor_uid_to_pointer_map, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); // offset workspace by the already used fe graph workspace // this is where cudnn backend can start using workspace for its execution plans @@ -840,33 +833,29 @@ class Graph : public ICudnn, public INode { return execute(handle, tensor_uid_to_pointer_map, workspace); } - error_t execute_plan_at_index(cudnnHandle_t handle, std::unordered_map &tensor_uid_to_pointer_map, void *workspace, - int64_t plan_index) const { - // Add pass_by_value data pointers to uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during - // execute. + int64_t plan_index, + std::vector const &override_uids, + std::vector> const &override_shapes, + std::vector> const &override_strides) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // Object lifetime is controlled by cached_pass_by_value which persists for the Graph's lifetime. CUDNN_FE_LOG_BANNER(" EXECUTE PLAN AT INDEX for plan index " << plan_index << " "); - std::unordered_map tensor_to_pass_by_value; - CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, tensor_to_pass_by_value)); + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); - std::unordered_map>> workspace_modifications; - int64_t workspace_offset = 0; - CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - - CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); - CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_workspace_tensors_(tensor_uid_to_pointer_map, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); // offset workspace by the already used fe graph workspace // this is where cudnn backend can start using workspace for its execution plans void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; @@ -890,8 +879,13 @@ class Graph : public ICudnn, public INode { } } - CHECK_CUDNN_FRONTEND_ERROR( - execute_cudnn_plan_with_uid(handle, tensor_uid_to_pointer_map, cudnn_workspace, plan_index)); + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid(handle, + tensor_uid_to_pointer_map, + cudnn_workspace, + plan_index, + override_uids, + override_shapes, + override_strides)); CUDNN_FE_LOG_BANNER(" EXECUTE PLAN AT INDEX ALL OK for plan index " << plan_index << " "); return {error_code_t::OK, ""}; @@ -900,27 +894,23 @@ class Graph : public ICudnn, public INode { error_t execute(cudnnHandle_t handle, std::unordered_map &tensor_uid_to_pointer_map, - void *workspace) const { - // Add pass_by_value data pointers to uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during - // execute. + void *workspace, + std::vector const &override_uids, + std::vector> const &override_shapes, + std::vector> const &override_strides) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. CUDNN_FE_LOG_BANNER(" EXECUTE PLAN "); - std::unordered_map tensor_to_pass_by_value; - CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, tensor_to_pass_by_value)); + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); CHECK_CUDNN_FRONTEND_ERROR( make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); - std::unordered_map>> workspace_modifications; - int64_t workspace_offset = 0; - CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); - - CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); - CHECK_CUDNN_FRONTEND_ERROR( - extend_tensor_map_with_workspace_tensors_(tensor_uid_to_pointer_map, workspace, workspace_modifications)); + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); // offset workspace by the already used fe graph workspace // this is where cudnn backend can start using workspace for its execution plans void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; @@ -944,8 +934,54 @@ class Graph : public ICudnn, public INode { } } + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid(handle, + tensor_uid_to_pointer_map, + cudnn_workspace, + plans.candidate, + override_uids, + override_shapes, + override_strides)); + + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN ALL OK "); + return {error_code_t::OK, ""}; + } + + error_t + execute_plan_at_index(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace, + int64_t plan_index) const { + // Add pass_by_value data pointers to uid_to_pointer map + // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during + // execute. + CHECK_CUDNN_FRONTEND_ERROR( + execute_plan_at_index(handle, tensor_uid_to_pointer_map, workspace, plan_index, {}, {}, {})); + return {error_code_t::OK, ""}; + } + + error_t + execute(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN "); + CHECK_CUDNN_FRONTEND_ERROR( - execute_cudnn_plan_with_uid(handle, tensor_uid_to_pointer_map, cudnn_workspace, plans.candidate)); + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); + CHECK_CUDNN_FRONTEND_ERROR( + make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); + + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); + + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid( + handle, tensor_uid_to_pointer_map, cudnn_workspace, plans.candidate, {}, {}, {})); CUDNN_FE_LOG_BANNER(" EXECUTE PLAN ALL OK "); return {error_code_t::OK, ""}; @@ -1120,6 +1156,10 @@ class Graph : public ICudnn, public INode { fe_workspace_size = j["fe_workspace_size"]; + // Initialize the execution caches from deserialized data + cached_pass_by_value = deserialized_pass_by_value; + cached_workspace_modifications = deserialized_workspace_modifications; + if (j.contains("tensors_to_dump")) { auto dump_uids = j["tensors_to_dump"].get>>(); for (auto const &[uid, fmt] : dump_uids) { @@ -1484,11 +1524,12 @@ class Graph : public ICudnn, public INode { // Go over each subnode and serialize them. json full_json; - full_json["context"]["name"] = context.get_name(); - full_json["context"]["compute_data_type"] = context.get_compute_data_type(); - full_json["context"]["intermediate_data_type"] = context.get_intermediate_data_type(); - full_json["context"]["io_data_type"] = context.get_io_data_type(); - full_json["context"]["sm_count"] = context.get_target_sm_count(); + full_json["context"]["name"] = context.get_name(); + full_json["context"]["compute_data_type"] = context.get_compute_data_type(); + full_json["context"]["intermediate_data_type"] = context.get_intermediate_data_type(); + full_json["context"]["io_data_type"] = context.get_io_data_type(); + full_json["context"]["sm_count"] = context.get_target_sm_count(); + full_json["context"]["is_dynamic_shape_enabled"] = context.get_dynamic_shape_enabled(); full_json.update(R"( {"tag": "GRAPH"})"_json); full_json["nodes"]; @@ -1575,7 +1616,7 @@ class Graph : public ICudnn, public INode { size_t key() override final { - return key(is_dynamic_shape_enabled); + return key(context.get_dynamic_shape_enabled()); } // TODO: temparorily placed in graphs class. This function needs to be a free standing function. @@ -1599,6 +1640,9 @@ class Graph : public ICudnn, public INode { if (j_context.contains("sm_count") && !j_context["sm_count"].is_null()) { context.set_target_sm_count(j_context["sm_count"].get()); } + if (j_context.contains("is_dynamic_shape_enabled") && !j_context["is_dynamic_shape_enabled"].is_null()) { + context.set_dynamic_shape_enabled(j_context["is_dynamic_shape_enabled"].get()); + } } std::map> created_tensors; @@ -1943,7 +1987,8 @@ Graph::set_compute_data_type(DataType_t const type) { inline Graph & Graph::set_dynamic_shape_enabled(bool is_enabled) { - is_dynamic_shape_enabled = is_enabled; + context.set_dynamic_shape_enabled(is_enabled); + this->is_dynamic_shape_enabled = is_enabled; return *this; } diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index 3aac8382..03b31b56 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -2183,9 +2183,10 @@ class SDPA_fp8_backward_attributes : public Attributes dropout_probability; std::optional attn_scale_value; @@ -2235,7 +2236,8 @@ class SDPA_fp8_backward_attributes : public Attributes value) { @@ -2304,6 +2306,12 @@ class SDPA_fp8_backward_attributes : public Attributes { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO: Building BatchNormNode operations " << attributes.name); + CUDNN_FE_LOG_LABEL("INFO: Building BatchNormNode operations " << attributes.name << " "); - std::vector peer_stats; - for (auto const& peer_stat : attributes.peer_stats) { - peer_stats.emplace_back(std::move(*(tensors[peer_stat->get_uid()]))); - } + // Create operation by directly calling cuDNN backend API + Operation_v8 batchnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + batchnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); - auto&& batchnorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); - batchnorm_operation_builder.setNormalizationMode(NormMode_t::BATCH_NORM) - .setNormFwdPhase(NormFwdPhase_t::TRAINING); + // Set forward phase to TRAINING + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormFwdPhase_t::TRAINING, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_attributes::input_names::X); - batchnorm_operation_builder.setxDesc(*(tensors[X->second->get_uid()])); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set saved mean and inv_variance CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Batchnorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Batchnorm_attributes::output_names::INV_VARIANCE); - batchnorm_operation_builder.setSavedMeanAndInvVar(*(tensors[MEAN->second->get_uid()]), - *(tensors[INV_VARIANCE->second->get_uid()])); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + // Set scale and bias tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Batchnorm_attributes::input_names::BIAS); - batchnorm_operation_builder.setScaleAndBias(*(tensors[SCALE->second->get_uid()]), - *(tensors[BIAS->second->get_uid()])); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Batchnorm_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + // Check for running stats bool has_running_stats = true; auto it = attributes.inputs.find(Batchnorm_attributes::input_names::PREV_RUNNING_MEAN); if (it == attributes.inputs.end() || it->second == nullptr) { @@ -108,51 +168,86 @@ class BatchNormNode : public NodeCRTP { } if (has_running_stats) { + // Set momentum (exp decay factor) + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MOMENTUM, Batchnorm_attributes::input_names::MOMENTUM); + auto momentum_desc = tensors.at(MOMENTUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &momentum_desc)); + + // Set prev running mean and var CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_MEAN, Batchnorm_attributes::input_names::PREV_RUNNING_MEAN); + auto prev_mean_desc = tensors.at(PREV_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_VAR, Batchnorm_attributes::input_names::PREV_RUNNING_VAR); - batchnorm_operation_builder.setPrevRunningMeanAndVar(*(tensors[PREV_RUNNING_MEAN->second->get_uid()]), - *(tensors[PREV_RUNNING_VAR->second->get_uid()])); + auto prev_var_desc = tensors.at(PREV_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_var_desc)); + // Set next running mean and var CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_MEAN, Batchnorm_attributes::output_names::NEXT_RUNNING_MEAN); + auto next_mean_desc = tensors.at(NEXT_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_VAR, Batchnorm_attributes::output_names::NEXT_RUNNING_VAR); - batchnorm_operation_builder.setNextRunningMeanAndVar(*(tensors[NEXT_RUNNING_MEAN->second->get_uid()]), - *(tensors[NEXT_RUNNING_VAR->second->get_uid()])); + auto next_var_desc = tensors.at(NEXT_RUNNING_VAR->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MOMENTUM, Batchnorm_attributes::input_names::MOMENTUM); - batchnorm_operation_builder.setExpDecayFactorTensor(*(tensors[MOMENTUM->second->get_uid()])); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_var_desc)); } - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Batchnorm_attributes::input_names::EPSILON); - batchnorm_operation_builder.setEpsilonTensor(*(tensors[EPSILON->second->get_uid()])); - + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Batchnorm_attributes::output_names::Y); - batchnorm_operation_builder.setyDesc(*(tensors[Y->second->get_uid()])); - - batchnorm_operation_builder.setPeerStatTensor(peer_stats); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = batchnorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = batchnorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set peer stat tensors if any + if (!attributes.peer_stats.empty()) { + std::vector peer_stat_descs; + for (auto const& peer_stat : attributes.peer_stats) { + peer_stat_descs.push_back(tensors.at(peer_stat->get_uid())->get_raw_desc()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + peer_stat_descs.size(), + peer_stat_descs.data())); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(batchnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(batchnorm_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/batchnorm_inference.h b/include/cudnn_frontend/node/batchnorm_inference.h index aef83c23..9fb433f9 100644 --- a/include/cudnn_frontend/node/batchnorm_inference.h +++ b/include/cudnn_frontend/node/batchnorm_inference.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -47,49 +44,98 @@ class BatchnormInferenceNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO: Building BatchnormInferenceNode operations " << attributes.name << "..."); + CUDNN_FE_LOG_LABEL("INFO: Building BatchnormInferenceNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 batchnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + batchnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); - auto&& batchnorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); - batchnorm_operation_builder.setNormalizationMode(NormMode_t::BATCH_NORM) - .setNormFwdPhase(NormFwdPhase_t::INFERENCE); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + // Set forward phase to INFERENCE + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormFwdPhase_t::INFERENCE, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_inference_attributes::input_names::X); - batchnorm_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set mean and inv_variance (as inputs for inference) CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Batchnorm_inference_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Batchnorm_inference_attributes::input_names::INV_VARIANCE); - batchnorm_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + // Set scale and bias tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_inference_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Batchnorm_inference_attributes::input_names::BIAS); - batchnorm_operation_builder.setScaleAndBias(*(tensors.at(SCALE->second->get_uid())), - *(tensors.at(BIAS->second->get_uid()))); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Batchnorm_inference_attributes::output_names::Y); - batchnorm_operation_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = batchnorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = batchnorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(batchnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(batchnorm_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/bn_finalize.h b/include/cudnn_frontend/node/bn_finalize.h index c0b97c4e..99d3ec0c 100644 --- a/include/cudnn_frontend/node/bn_finalize.h +++ b/include/cudnn_frontend/node/bn_finalize.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -63,77 +60,189 @@ class BatchNormFinalizeNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO:Building BatchNormFinalizeNode operations " << attributes.name << " "); + CUDNN_FE_LOG_LABEL("INFO: " << "Building BatchNormFinalizeNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 bn_finalize_operation; + + _CUDNN_CHECK_CUDNN_ERROR(bn_finalize_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR)); + + // Set BN finalize mode + cudnnBnFinalizeStatsMode_t bn_finalize_mode = CUDNN_BN_FINALIZE_STATISTICS_TRAINING; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE, + CUDNN_TYPE_BN_FINALIZE_STATS_MODE, + 1, + &bn_finalize_mode)); - // Create the batchnorm operation. - auto&& batchnorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR); - batchnorm_operation_builder.setComputeType(CUDNN_DATA_FLOAT) - .setBNFinalizeMode(CUDNN_BN_FINALIZE_STATISTICS_TRAINING); + // Set compute type (math precision) + cudnnDataType_t compute_type = CUDNN_DATA_FLOAT; + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &compute_type)); + + // Set SUM input tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SUM, BN_finalize_attributes::input_names::SUM); - batchnorm_operation_builder.setSumDesc(*(tensors.at(SUM->second->get_uid()))); + auto sum_desc = tensors.at(SUM->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SQ_SUM, BN_finalize_attributes::input_names::SQ_SUM); - batchnorm_operation_builder.setSqSumDesc(*(tensors.at(SQ_SUM->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sum_desc)); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE, BN_finalize_attributes::output_names::EQ_SCALE); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, BN_finalize_attributes::output_names::EQ_BIAS); - batchnorm_operation_builder.setEqScaleAndBias(*(tensors.at(EQ_SCALE->second->get_uid())), - *(tensors.at(EQ_BIAS->second->get_uid()))); + // Set SQ_SUM input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SQ_SUM, BN_finalize_attributes::input_names::SQ_SUM); + auto sq_sum_desc = tensors.at(SQ_SUM->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, BN_finalize_attributes::output_names::MEAN); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, BN_finalize_attributes::output_names::INV_VARIANCE); - batchnorm_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sq_sum_desc)); + // Set SCALE input tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, BN_finalize_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set BIAS input tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, BN_finalize_attributes::input_names::BIAS); - batchnorm_operation_builder.setScaleAndBias(*(tensors.at(SCALE->second->get_uid())), - *(tensors.at(BIAS->second->get_uid()))); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set EQ_SCALE output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE, BN_finalize_attributes::output_names::EQ_SCALE); + auto eq_scale_desc = tensors.at(EQ_SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_desc)); + + // Set EQ_BIAS output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, BN_finalize_attributes::output_names::EQ_BIAS); + auto eq_bias_desc = tensors.at(EQ_BIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_bias_desc)); + + // Set PREV_RUNNING_MEAN input tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_MEAN, BN_finalize_attributes::input_names::PREV_RUNNING_MEAN); + auto prev_running_mean_desc = tensors.at(PREV_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_running_mean_desc)); + + // Set PREV_RUNNING_VAR input tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_VAR, BN_finalize_attributes::input_names::PREV_RUNNING_VAR); - batchnorm_operation_builder.setPrevRunningMeanAndVar(*(tensors.at(PREV_RUNNING_MEAN->second->get_uid())), - *(tensors.at(PREV_RUNNING_VAR->second->get_uid()))); + auto prev_running_var_desc = tensors.at(PREV_RUNNING_VAR->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_running_var_desc)); + + // Set NEXT_RUNNING_MEAN output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_MEAN, BN_finalize_attributes::output_names::NEXT_RUNNING_MEAN); + auto next_running_mean_desc = tensors.at(NEXT_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_running_mean_desc)); + + // Set NEXT_RUNNING_VAR output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_VAR, BN_finalize_attributes::output_names::NEXT_RUNNING_VAR); - batchnorm_operation_builder.setNextRunningMeanAndVar(*(tensors.at(NEXT_RUNNING_MEAN->second->get_uid())), - *(tensors.at(NEXT_RUNNING_VAR->second->get_uid()))); + auto next_running_var_desc = tensors.at(NEXT_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_running_var_desc)); + + // Set MEAN output tensor (saved mean) + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, BN_finalize_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + // Set INV_VARIANCE output tensor (saved inv std) + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, BN_finalize_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set EPSILON tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, BN_finalize_attributes::input_names::EPSILON); - batchnorm_operation_builder.setEpsilonTensor(*(tensors.at(EPSILON->second->get_uid()))); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Set MOMENTUM tensor (exp average factor) CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MOMENTUM, BN_finalize_attributes::input_names::MOMENTUM); - batchnorm_operation_builder.setExpDecayFactorTensor(*(tensors.at(MOMENTUM->second->get_uid()))); + auto momentum_desc = tensors.at(MOMENTUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &momentum_desc)); + // Set ACCUM_COUNT tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(ACCUM_COUNT, BN_finalize_attributes::input_names::ACCUM_COUNT); - batchnorm_operation_builder.setAccumCountTensor(*(tensors.at(ACCUM_COUNT->second->get_uid()))); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = batchnorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = batchnorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto accum_count_desc = tensors.at(ACCUM_COUNT->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &accum_count_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(bn_finalize_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(bn_finalize_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index 7d382bb7..a2e2e4e8 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_ConvDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -77,52 +73,120 @@ class DgradNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building DgradNode operations " << attributes.name << " "); - // dgrad descriptor + // Create dgrad descriptor by directly calling cuDNN backend API + ConvDesc_v8 dgrad_descriptor; int64_t const spatial_dim_count = attributes.get_pre_padding().size(); - auto dgrad_descriptor = cudnn_frontend::ConvDescBuilder() - .setComputeType(attributes.compute_data_type) - .setMathMode(attributes.math_mode) - .setSpatialDimCount(spatial_dim_count) - .setSpatialStride(spatial_dim_count, attributes.get_stride().data()) - .setPrePadding(spatial_dim_count, attributes.get_pre_padding().data()) - .setPostPadding(spatial_dim_count, attributes.get_post_padding().data()) - .setDilation(spatial_dim_count, attributes.get_dilation().data()) - .build(); - - // Create the dgrad operation. - auto&& dgrad_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + _CUDNN_CHECK_CUDNN_ERROR( + dgrad_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + dgrad_descriptor.get_raw_desc(), CUDNN_ATTR_CONVOLUTION_CONV_MODE, CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dgrad_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(dgrad_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 dgrad_operation; + + _CUDNN_CHECK_CUDNN_ERROR(dgrad_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)); CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Conv_dgrad_attributes::output_names::DX); - dgrad_operation_builder.setdxDesc(*(tensors.at(DX->second->get_uid()))); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(W, Conv_dgrad_attributes::input_names::W); - dgrad_operation_builder.setwDesc(*(tensors.at(W->second->get_uid()))); + auto w_desc = tensors.at(W->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &w_desc)); CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Conv_dgrad_attributes::input_names::DY); - dgrad_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); - - dgrad_operation_builder.setcDesc(dgrad_descriptor).setAlpha(1.f).setBeta(0.f); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = dgrad_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = dgrad_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + auto conv_desc_ptr = dgrad_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dgrad_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dgrad_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/conv_fprop.h b/include/cudnn_frontend/node/conv_fprop.h index bcdf9ca1..f8f22ec7 100644 --- a/include/cudnn_frontend/node/conv_fprop.h +++ b/include/cudnn_frontend/node/conv_fprop.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_ConvDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -89,52 +85,140 @@ class ConvolutionNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building ConvolutionNode operations " << attributes.name << " "); - // convolution descriptor + // Create convolution descriptor by directly calling cuDNN backend API + ConvDesc_v8 convolution_descriptor; int64_t const spatial_dim_count = attributes.get_pre_padding().size(); - auto convolution_descriptor = cudnn_frontend::ConvDescBuilder() - .setComputeType(attributes.compute_data_type) - .setMathMode(attributes.math_mode) - .setSpatialDimCount(spatial_dim_count) - .setSpatialStride(spatial_dim_count, attributes.get_stride().data()) - .setPrePadding(spatial_dim_count, attributes.get_pre_padding().data()) - .setPostPadding(spatial_dim_count, attributes.get_post_padding().data()) - .setDilation(spatial_dim_count, attributes.get_dilation().data()) - .build(); - - // Create the convolution operation. - auto&& convolution_operation_builder = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR); + _CUDNN_CHECK_CUDNN_ERROR( + convolution_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set convolution mode + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_CONV_MODE, + CUDNN_TYPE_CONVOLUTION_MODE, + 1, + &mode)); + + // Set spatial dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + // Set pre-padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + // Set post-padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + // Set dilation + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + // Set strides + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(convolution_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(convolution_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 convolution_operation; + + _CUDNN_CHECK_CUDNN_ERROR(convolution_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Conv_fprop_attributes::input_names::X); - convolution_operation_builder.setxDesc(*(tensors[X->second->get_uid()])); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set weight tensor W CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(W, Conv_fprop_attributes::input_names::W); - convolution_operation_builder.setwDesc(*(tensors[W->second->get_uid()])); + auto w_desc = tensors.at(W->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &w_desc)); + + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Conv_fprop_attributes::output_names::Y); - convolution_operation_builder.setyDesc(*(tensors[Y->second->get_uid()])); - - convolution_operation_builder.setcDesc(convolution_descriptor).setAlpha(1.f).setBeta(0.f); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = convolution_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = convolution_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set convolution descriptor + auto conv_desc_ptr = convolution_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + // Set alpha and beta + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(convolution_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(convolution_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/conv_wgrad.h b/include/cudnn_frontend/node/conv_wgrad.h index 2082f1b0..2f9b478c 100644 --- a/include/cudnn_frontend/node/conv_wgrad.h +++ b/include/cudnn_frontend/node/conv_wgrad.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_ConvDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -73,52 +69,120 @@ class WgradNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building WgradNode operations " << attributes.name << " "); - // wgrad descriptor + // Create wgrad descriptor by directly calling cuDNN backend API + ConvDesc_v8 wgrad_descriptor; int64_t const spatial_dim_count = attributes.get_pre_padding().size(); - auto wgrad_descriptor = cudnn_frontend::ConvDescBuilder() - .setComputeType(attributes.compute_data_type) - .setMathMode(attributes.math_mode) - .setSpatialDimCount(spatial_dim_count) - .setSpatialStride(spatial_dim_count, attributes.get_stride().data()) - .setPrePadding(spatial_dim_count, attributes.get_pre_padding().data()) - .setPostPadding(spatial_dim_count, attributes.get_post_padding().data()) - .setDilation(spatial_dim_count, attributes.get_dilation().data()) - .build(); - - // Create the wgrad operation. - auto&& wgrad_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + _CUDNN_CHECK_CUDNN_ERROR( + wgrad_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + wgrad_descriptor.get_raw_desc(), CUDNN_ATTR_CONVOLUTION_CONV_MODE, CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(wgrad_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(wgrad_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 wgrad_operation; + + _CUDNN_CHECK_CUDNN_ERROR(wgrad_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR)); CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Conv_wgrad_attributes::input_names::X); - wgrad_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Conv_wgrad_attributes::input_names::DY); - wgrad_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DW, Conv_wgrad_attributes::output_names::DW); - wgrad_operation_builder.setdwDesc(*(tensors.at(DW->second->get_uid()))); - - wgrad_operation_builder.setcDesc(wgrad_descriptor).setAlpha(1.f).setBeta(0.f); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = wgrad_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = wgrad_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto dw_desc = tensors.at(DW->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dw_desc)); + + auto conv_desc_ptr = wgrad_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(wgrad_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(wgrad_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/dbn.h b/include/cudnn_frontend/node/dbn.h index 35cd4b66..9f2a23e7 100644 --- a/include/cudnn_frontend/node/dbn.h +++ b/include/cudnn_frontend/node/dbn.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -76,61 +73,119 @@ class DBNNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building DBNNode operations " << attributes.name << " "); - std::vector peer_stats; - for (auto const& peer_stat : attributes.peer_stats) { - peer_stats.emplace_back(std::move(*(tensors.at(peer_stat->get_uid())))); - } + // Create operation by directly calling cuDNN backend API + Operation_v8 dbn_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + dbn_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); - // Create the DBN operation. - auto&& DBN_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; - DBN_operation_builder.setNormalizationMode(NormMode_t::BATCH_NORM); + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_backward_attributes::input_names::X); - DBN_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set DY tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Batchnorm_backward_attributes::input_names::DY); - DBN_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_backward_attributes::input_names::SCALE); - DBN_operation_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + // Set mean and inv_variance tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Batchnorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Batchnorm_backward_attributes::input_names::INV_VARIANCE); - DBN_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE and DBIAS output tensors CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Batchnorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Batchnorm_backward_attributes::output_names::DBIAS); - DBN_operation_builder.setDScaleAndDBias(*(tensors.at(DSCALE->second->get_uid())), - *(tensors.at(DBIAS->second->get_uid()))); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Batchnorm_backward_attributes::output_names::DX); - DBN_operation_builder.setdxDesc(*(tensors.at(DX->second->get_uid()))); - - DBN_operation_builder.setPeerStatTensor(peer_stats); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = DBN_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = DBN_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + // Set peer stat tensors if any + if (!attributes.peer_stats.empty()) { + std::vector peer_stat_descs; + for (auto const& peer_stat : attributes.peer_stats) { + peer_stat_descs.push_back(tensors.at(peer_stat->get_uid())->get_raw_desc()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + peer_stat_descs.size(), + peer_stat_descs.data())); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dbn_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dbn_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/dbn_weight.h b/include/cudnn_frontend/node/dbn_weight.h index 8d56f0d6..bfad0dfb 100644 --- a/include/cudnn_frontend/node/dbn_weight.h +++ b/include/cudnn_frontend/node/dbn_weight.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -78,59 +75,126 @@ class DBNWeightNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO:Building DBNWeightNode operations " << attributes.name << " "); + CUDNN_FE_LOG_LABEL("INFO: " << "Building DBNWeightNode operations " << attributes.name << " "); - // Create the batchnorm operation. - auto&& batchnorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 bn_bwd_weight_operation; - batchnorm_operation_builder.setComputeType(CUDNN_DATA_FLOAT); + _CUDNN_CHECK_CUDNN_ERROR(bn_bwd_weight_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR)); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_DY, DBN_weight_attributes::output_names::EQ_SCALE_DY); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_X, DBN_weight_attributes::output_names::EQ_SCALE_X); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, DBN_weight_attributes::output_names::EQ_BIAS); - batchnorm_operation_builder.setEqScalesAndBias(*(tensors.at(EQ_SCALE_DY->second->get_uid())), - *(tensors.at(EQ_SCALE_X->second->get_uid())), - *(tensors.at(EQ_BIAS->second->get_uid()))); + // Set compute type (math precision) + cudnnDataType_t compute_type = CUDNN_DATA_FLOAT; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &compute_type)); + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, DBN_weight_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, DBN_weight_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set mean tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, DBN_weight_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + // Set inv_variance tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, DBN_weight_attributes::input_names::INV_VARIANCE); - batchnorm_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, DBN_weight_attributes::input_names::SCALE); - batchnorm_operation_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, DBN_weight_attributes::input_names::X); - batchnorm_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, DBN_weight_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, DBN_weight_attributes::input_names::DY); - batchnorm_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + // Set DSCALE output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, DBN_weight_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + // Set DBIAS output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, DBN_weight_attributes::output_names::DBIAS); - batchnorm_operation_builder.setDScaleAndDBias(*(tensors.at(DSCALE->second->get_uid())), - *(tensors.at(DBIAS->second->get_uid()))); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = batchnorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = batchnorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set EQ_SCALE_DY output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_DY, DBN_weight_attributes::output_names::EQ_SCALE_DY); + auto eq_scale_dy_desc = tensors.at(EQ_SCALE_DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_dy_desc)); + + // Set EQ_SCALE_X output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_X, DBN_weight_attributes::output_names::EQ_SCALE_X); + auto eq_scale_x_desc = tensors.at(EQ_SCALE_X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_x_desc)); + + // Set EQ_BIAS output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, DBN_weight_attributes::output_names::EQ_BIAS); + auto eq_bias_desc = tensors.at(EQ_BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_bias_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(bn_bwd_weight_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(bn_bwd_weight_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/dln.h b/include/cudnn_frontend/node/dln.h index 339b53ca..4dd508a5 100644 --- a/include/cudnn_frontend/node/dln.h +++ b/include/cudnn_frontend/node/dln.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -99,58 +96,117 @@ class DLNNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building DLNNode operations " << attributes.name << " "); - // Create the DLN operation. - auto&& DLN_op_builder = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 dln_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + dln_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to LAYER_NORM + cudnnBackendNormMode_t cudnn_norm_mode; - DLN_op_builder.setNormalizationMode(NormMode_t::LAYER_NORM); + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::LAYER_NORM, cudnn_norm_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Layernorm_backward_attributes::input_names::X); - DLN_op_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set DY tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Layernorm_backward_attributes::input_names::DY); - DLN_op_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Layernorm_backward_attributes::input_names::SCALE); - DLN_op_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + // Set mean and inv_variance tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Layernorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Layernorm_backward_attributes::input_names::INV_VARIANCE); - DLN_op_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE and DBIAS output tensors CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Layernorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Layernorm_backward_attributes::output_names::DBIAS); - DLN_op_builder.setDScaleAndDBias(*(tensors.at(DSCALE->second->get_uid())), - *(tensors.at(DBIAS->second->get_uid()))); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Layernorm_backward_attributes::output_names::DX); - DLN_op_builder.setdxDesc(*(tensors.at(DX->second->get_uid()))); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + // Set epsilon tensor for older backend versions if (detail::get_backend_version() < 8906) { CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Layernorm_backward_attributes::input_names::EPSILON); - DLN_op_builder.setEpsilonTensor(*(tensors.at(EPSILON->second->get_uid()))); - } + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = DLN_op_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = DLN_op_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dln_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dln_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/genstats.h b/include/cudnn_frontend/node/genstats.h index 499b76df..8f918975 100644 --- a/include/cudnn_frontend/node/genstats.h +++ b/include/cudnn_frontend/node/genstats.h @@ -1,7 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -72,38 +70,63 @@ class GenstatsNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building GenstatsNode operations " << attributes.name << " "); - auto&& genstats_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_GEN_STATS_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 genstats_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + genstats_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR)); + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Genstats_attributes::input_names::X); - genstats_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set gen stats mode + cudnnGenStatsMode_t genstats_mode = CUDNN_GENSTATS_SUM_SQSUM; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_MODE, + CUDNN_TYPE_GENSTATS_MODE, + 1, + &genstats_mode)); + + // Set math precision based on X tensor data type + cudnnDataType_t math_prec = static_cast(tensors.at(X->second->get_uid())->getDataType()); - genstats_operation_builder.setGenStatsMode(CUDNN_GENSTATS_SUM_SQSUM); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &math_prec)); + // Set SUM output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(SUM, Genstats_attributes::output_names::SUM); - genstats_operation_builder.setSumDesc(*(tensors.at(SUM->second->get_uid()))); + auto sum_desc = tensors.at(SUM->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sum_desc)); + + // Set SQ_SUM output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(SQ_SUM, Genstats_attributes::output_names::SQ_SUM); - genstats_operation_builder.setSqSumDesc(*(tensors.at(SQ_SUM->second->get_uid()))); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = genstats_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = genstats_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto sq_sum_desc = tensors.at(SQ_SUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sq_sum_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(genstats_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(genstats_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/instancenorm.h b/include/cudnn_frontend/node/instancenorm.h index b6713616..1b71f4ab 100644 --- a/include/cudnn_frontend/node/instancenorm.h +++ b/include/cudnn_frontend/node/instancenorm.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -91,52 +88,108 @@ class InstanceNormNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building InstanceNormNode operations " << attributes.name << " "); - auto&& op_builder = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 instancenorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + instancenorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to INSTANCE_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::INSTANCE_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); - op_builder.setNormalizationMode(NormMode_t::INSTANCE_NORM); + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; - op_builder.setNormFwdPhase(attributes.forward_phase); + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Instancenorm_attributes::input_names::X); - op_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set scale and bias tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Instancenorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Instancenorm_attributes::input_names::BIAS); - op_builder.setScaleAndBias(*(tensors.at(SCALE->second->get_uid())), *(tensors.at(BIAS->second->get_uid()))); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Instancenorm_attributes::input_names::EPSILON); - op_builder.setEpsilonTensor(*(tensors.at(EPSILON->second->get_uid()))); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Instancenorm_attributes::output_names::Y); - op_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + // Set mean and inv_variance for training phase if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Instancenorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Instancenorm_attributes::output_names::INV_VARIANCE); - op_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); - } + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = op_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = op_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(instancenorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(instancenorm_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); @@ -240,56 +293,107 @@ class DINNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO: Building DINode operations " << attributes.name << " "); + CUDNN_FE_LOG_LABEL("INFO: Building DINNode operations " << attributes.name << " "); - // Create the DIN operation. - auto&& DIN_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 din_operation; - DIN_operation_builder.setNormalizationMode(NormMode_t::INSTANCE_NORM); + _CUDNN_CHECK_CUDNN_ERROR( + din_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + // Set norm mode to INSTANCE_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::INSTANCE_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Instancenorm_backward_attributes::input_names::X); - DIN_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Instancenorm_backward_attributes::input_names::DY); - DIN_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + // Set scale tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Instancenorm_backward_attributes::input_names::SCALE); - DIN_operation_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set mean and inv_variance tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Instancenorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Instancenorm_backward_attributes::input_names::INV_VARIANCE); - DIN_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + // Set DSCALE and DBIAS output tensors CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Instancenorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Instancenorm_backward_attributes::output_names::DBIAS); - DIN_operation_builder.setDScaleAndDBias(*(tensors.at(DSCALE->second->get_uid())), - *(tensors.at(DBIAS->second->get_uid()))); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Instancenorm_backward_attributes::output_names::DX); - DIN_operation_builder.setdxDesc(*(tensors.at(DX->second->get_uid()))); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = DIN_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = DIN_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(din_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(din_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/layernorm.h b/include/cudnn_frontend/node/layernorm.h index 72bd2668..46420aa2 100644 --- a/include/cudnn_frontend/node/layernorm.h +++ b/include/cudnn_frontend/node/layernorm.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -141,50 +138,107 @@ class LayerNormNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building LayerNormNode operations " << attributes.name << " "); - auto&& layernorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); - layernorm_operation_builder.setNormalizationMode(NormMode_t::LAYER_NORM) - .setNormFwdPhase(attributes.forward_phase); + // Create operation by directly calling cuDNN backend API + Operation_v8 layernorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + layernorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::LAYER_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Layernorm_attributes::input_names::X); - layernorm_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set scale and bias tensors CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Layernorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Layernorm_attributes::input_names::BIAS); - layernorm_operation_builder.setScaleAndBias(*(tensors.at(SCALE->second->get_uid())), - *(tensors.at(BIAS->second->get_uid()))); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Layernorm_attributes::input_names::EPSILON); - layernorm_operation_builder.setEpsilonTensor(*(tensors.at(EPSILON->second->get_uid()))); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Layernorm_attributes::output_names::Y); - layernorm_operation_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + // Set mean and inv_variance for training phase if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Layernorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Layernorm_attributes::output_names::INV_VARIANCE); - layernorm_operation_builder.setSavedMeanAndInvVar(*(tensors.at(MEAN->second->get_uid())), - *(tensors.at(INV_VARIANCE->second->get_uid()))); - } -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = layernorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = layernorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(layernorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(layernorm_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index 592ada34..f09d3415 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_MatMulDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -71,59 +67,117 @@ class MatmulNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building MatmulNode operations " << attributes.name << " "); - // matmul descriptor - auto matmul_descriptor = cudnn_frontend::MatMulDescBuilder() - .setComputeType(attributes.compute_data_type) - .setPaddingValue(attributes.padding_value) - .build(); + // Create matmul descriptor by directly calling cuDNN backend API + MatMulDesc_v8 matmul_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR(matmul_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_MATMUL_DESCRIPTOR)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + matmul_descriptor.get_raw_desc(), CUDNN_ATTR_MATMUL_COMP_TYPE, CUDNN_TYPE_DATA_TYPE, 1, &cudnn_data_type)); + + // Set padding value if specified +#if (CUDNN_VERSION >= 8900) + if (attributes.padding_value != 0.0) { + double padding_value = attributes.padding_value; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_descriptor.get_raw_desc(), + CUDNN_ATTR_MATMUL_PADDING_VALUE, + CUDNN_TYPE_DOUBLE, + 1, + &padding_value)); + } +#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(matmul_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(matmul_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 matmul_operation; - auto&& matmul_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_MATMUL_DESCRIPTOR); + _CUDNN_CHECK_CUDNN_ERROR( + matmul_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)); + // Set input tensor A CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(A, Matmul_attributes::input_names::A); - matmul_operation_builder.setaMatDesc(*tensors.at(A->second->get_uid())); + auto a_desc = tensors.at(A->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_ADESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &a_desc)); + + // Set input tensor B CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(B, Matmul_attributes::input_names::B); - matmul_operation_builder.setbMatDesc(*tensors.at(B->second->get_uid())); + auto b_desc = tensors.at(B->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_BDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &b_desc)); + // Set output tensor C CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(C, Matmul_attributes::output_names::C); - matmul_operation_builder.setcMatDesc(*tensors.at(C->second->get_uid())); - matmul_operation_builder.setmatmulDesc(matmul_descriptor); + auto c_desc = tensors.at(C->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_CDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &c_desc)); + // Set matmul descriptor + auto matmul_desc_ptr = matmul_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &matmul_desc_ptr)); + + // Set optional override tensors auto M_override = attributes.inputs.find(Matmul_attributes::input_names::M_override); if ((M_override != attributes.inputs.end()) && (M_override->second != nullptr)) { - matmul_operation_builder.setmOverrideDesc(*tensors.at(M_override->second->get_uid())); + auto m_override_desc = tensors.at(M_override->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &m_override_desc)); } auto N_override = attributes.inputs.find(Matmul_attributes::input_names::N_override); if ((N_override != attributes.inputs.end()) && (N_override->second != nullptr)) { - matmul_operation_builder.setnOverrideDesc(*tensors.at(N_override->second->get_uid())); + auto n_override_desc = tensors.at(N_override->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &n_override_desc)); } auto K_override = attributes.inputs.find(Matmul_attributes::input_names::K_override); if ((K_override != attributes.inputs.end()) && (K_override->second != nullptr)) { - matmul_operation_builder.setkOverrideDesc(*tensors.at(K_override->second->get_uid())); - } + auto k_override_desc = tensors.at(K_override->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = matmul_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = matmul_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &k_override_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(matmul_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(matmul_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/matmul_fp8.h b/include/cudnn_frontend/node/matmul_fp8.h index 8b173427..d3fe5e58 100644 --- a/include/cudnn_frontend/node/matmul_fp8.h +++ b/include/cudnn_frontend/node/matmul_fp8.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_MatMulDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" diff --git a/include/cudnn_frontend/node/moe_grouped_matmul.h b/include/cudnn_frontend/node/moe_grouped_matmul.h index 16e0fdb1..e12acde0 100644 --- a/include/cudnn_frontend/node/moe_grouped_matmul.h +++ b/include/cudnn_frontend/node/moe_grouped_matmul.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_MatMulDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" diff --git a/include/cudnn_frontend/node/paged_cache_load.h b/include/cudnn_frontend/node/paged_cache_load.h index 5b00649f..9fc86109 100644 --- a/include/cudnn_frontend/node/paged_cache_load.h +++ b/include/cudnn_frontend/node/paged_cache_load.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -30,46 +27,69 @@ class PagedCacheLoadNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building PagedCacheLoadNode operations " << attributes.name << " "); + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Paged cache load requires cuDNN v9.5.0"}; - auto&& paged_cache_load_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR); +#if (CUDNN_VERSION >= 90500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90500, cudnn_ver_error); - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(container, PagedCacheLoad_attributes::input_names::container); - paged_cache_load_operation_builder.setcontainerDesc(*(tensors.at(container->second->get_uid()))); + // Create operation by directly calling cuDNN backend API + Operation_v8 paged_cache_load_operation; - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(pageTable, PagedCacheLoad_attributes::input_names::pageTable); - paged_cache_load_operation_builder.setpageTableDesc(*(tensors.at(pageTable->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(paged_cache_load_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR)); + // Set container tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(container, PagedCacheLoad_attributes::input_names::container); + auto container_desc = tensors.at(container->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_CONTAINER_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &container_desc)); + + // Set page table tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(pageTable, PagedCacheLoad_attributes::input_names::pageTable); + auto page_table_desc = tensors.at(pageTable->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_PAGE_TABLE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &page_table_desc)); + + // Set sequence length tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(seqLen, PagedCacheLoad_attributes::input_names::seqLen); - paged_cache_load_operation_builder.setsequenceDesc(*(tensors.at(seqLen->second->get_uid()))); - + auto seq_len_desc = tensors.at(seqLen->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_SEQUENCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &seq_len_desc)); + + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(yOut, PagedCacheLoad_attributes::output_names::yOut); - paged_cache_load_operation_builder.setyDesc(*(tensors.at(yOut->second->get_uid()))); + auto y_desc = tensors.at(yOut->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = paged_cache_load_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = paged_cache_load_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(paged_cache_load_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(paged_cache_load_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif } error_t diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index 80603402..d67ab6a3 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -1,9 +1,5 @@ #pragma once -#include "../../cudnn_frontend_PointWiseDesc.h" -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -81,92 +77,204 @@ class PointwiseNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building PointwiseNode operations " << attributes.name << " "); - auto&& pointwise_descriptor_builder = cudnn_frontend::PointwiseDescBuilder(); + // Create pointwise descriptor by directly calling cuDNN backend API + PointWiseDesc_v8 pointwise_descriptor; - if (attributes.get_axis().has_value()) { - pointwise_descriptor_builder.setAxis(attributes.get_axis().value()); - } + _CUDNN_CHECK_CUDNN_ERROR( + pointwise_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_POINTWISE_DESCRIPTOR)); - if (attributes.relu_lower_clip_slope.has_value()) { - pointwise_descriptor_builder.setReluLowerClipSlope(attributes.relu_lower_clip_slope.value()); - } + // Set pointwise mode + cudnnPointwiseMode_t cudnn_pointwise_mode; - if (attributes.relu_lower_clip.has_value()) { - pointwise_descriptor_builder.setReluLowerClip(attributes.relu_lower_clip.value()); - } + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.mode, cudnn_pointwise_mode)); - if (attributes.relu_upper_clip.has_value()) { - pointwise_descriptor_builder.setReluUpperClip(attributes.relu_upper_clip.value()); - } + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_MODE, + CUDNN_TYPE_POINTWISE_MODE, + 1, + &cudnn_pointwise_mode)); - if (attributes.swish_beta.has_value()) { - pointwise_descriptor_builder.setSwishBeta(attributes.swish_beta.value()); - } + // Set compute type + cudnnDataType_t cudnn_data_type; - if (attributes.elu_alpha.has_value()) { - pointwise_descriptor_builder.setEluAlpha(attributes.elu_alpha.value()); - } + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set mode-specific attributes + if (attributes.mode == PointwiseMode_t::RELU_FWD || attributes.mode == PointwiseMode_t::RELU_BWD) { + cudnnNanPropagation_t nan_propagation = CUDNN_PROPAGATE_NAN; - if (attributes.softplus_beta.has_value()) { - pointwise_descriptor_builder.setSoftplusBeta(attributes.softplus_beta.value()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_NAN_PROPAGATION, + CUDNN_TYPE_NAN_PROPOGATION, + 1, + &nan_propagation)); + + double lower_clip = attributes.relu_lower_clip.value_or(0.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP, + CUDNN_TYPE_DOUBLE, + 1, + &lower_clip)); + + double upper_clip = attributes.relu_upper_clip.value_or(std::numeric_limits::max()); + if (attributes.compute_data_type == DataType_t::FLOAT) { + upper_clip = std::min(upper_clip, std::numeric_limits::max()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP, + CUDNN_TYPE_DOUBLE, + 1, + &upper_clip)); + + double lower_clip_slope = attributes.relu_lower_clip_slope.value_or(0.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE, + CUDNN_TYPE_DOUBLE, + 1, + &lower_clip_slope)); + } else if (attributes.mode == PointwiseMode_t::ELU_FWD || attributes.mode == PointwiseMode_t::ELU_BWD) { + double elu_alpha = attributes.elu_alpha.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_descriptor.get_raw_desc(), CUDNN_ATTR_POINTWISE_ELU_ALPHA, CUDNN_TYPE_DOUBLE, 1, &elu_alpha)); + } else if (attributes.mode == PointwiseMode_t::SOFTPLUS_FWD || + attributes.mode == PointwiseMode_t::SOFTPLUS_BWD) { + double softplus_beta = attributes.softplus_beta.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA, + CUDNN_TYPE_DOUBLE, + 1, + &softplus_beta)); + } else if (attributes.mode == PointwiseMode_t::SWISH_FWD || attributes.mode == PointwiseMode_t::SWISH_BWD) { + double swish_beta = attributes.swish_beta.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_SWISH_BETA, + CUDNN_TYPE_DOUBLE, + 1, + &swish_beta)); + } else if (attributes.mode == PointwiseMode_t::GEN_INDEX) { + int64_t axis = attributes.get_axis().value_or(-1); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_descriptor.get_raw_desc(), CUDNN_ATTR_POINTWISE_AXIS, CUDNN_TYPE_INT64, 1, &axis)); } - pointwise_descriptor_builder.setComputeType(attributes.compute_data_type); - pointwise_descriptor_builder.setMode(attributes.mode); - auto pointwise_descriptor = pointwise_descriptor_builder.build(); + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(pointwise_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(pointwise_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 pointwise_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + pointwise_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)); + + // Set the pointwise descriptor + auto pw_desc_ptr = pointwise_descriptor.get_raw_desc(); - auto const port_count = get_pointwise_mode_port_count(attributes.mode); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &pw_desc_ptr)); - auto&& pointwise_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_POINTWISE_DESCRIPTOR); - pointwise_operation_builder.setpwDesc(pointwise_descriptor); + auto const port_count = get_pointwise_mode_port_count(attributes.mode); + bool const is_activation_bwd = detail::is_activation_backward_mode(attributes.mode); - if (detail::is_activation_backward_mode(attributes.mode)) { + if (is_activation_bwd) { + // Backward mode: IN_0 is dy, IN_1 is x, OUT_0 is dx CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_0, Pointwise_attributes::input_names::IN_0); - pointwise_operation_builder.setdyDesc(*(tensors.at(IN_0->second->get_uid()))); + auto dy_desc = tensors.at(IN_0->second->get_uid())->get_raw_desc(); CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_1, Pointwise_attributes::input_names::IN_1); - pointwise_operation_builder.setxDesc(*(tensors.at(IN_1->second->get_uid()))); + auto x_desc = tensors.at(IN_1->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(OUT_0, Pointwise_attributes::output_names::OUT_0); - pointwise_operation_builder.setdxDesc(*(tensors.at(OUT_0->second->get_uid()))); + auto dx_desc = tensors.at(OUT_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); } else { + // Forward mode CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_0, Pointwise_attributes::input_names::IN_0); - pointwise_operation_builder.setxDesc(*(tensors.at(IN_0->second->get_uid()))); + auto x_desc = tensors.at(IN_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(OUT_0, Pointwise_attributes::output_names::OUT_0); + auto y_desc = tensors.at(OUT_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); if (port_count >= 3) { CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_1, Pointwise_attributes::input_names::IN_1); - pointwise_operation_builder.setbDesc(*(tensors.at(IN_1->second->get_uid()))); + auto b_desc = tensors.at(IN_1->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_BDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &b_desc)); } if (port_count >= 4) { CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_2, Pointwise_attributes::input_names::IN_2); - pointwise_operation_builder.settDesc(*(tensors.at(IN_2->second->get_uid()))); - } + auto t_desc = tensors.at(IN_2->second->get_uid())->get_raw_desc(); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(OUT_0, Pointwise_attributes::output_names::OUT_0); - pointwise_operation_builder.setyDesc(*(tensors.at(OUT_0->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_TDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &t_desc)); + } } -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = pointwise_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = pointwise_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + // Set alpha scaling factors (always set to 1.0) + float alpha1 = 1.0f; + float alpha2 = 1.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1, CUDNN_TYPE_FLOAT, 1, &alpha1)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2, CUDNN_TYPE_FLOAT, 1, &alpha2)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(pointwise_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(pointwise_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index cf212061..193b9d37 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_ReductionDesc.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -65,43 +62,99 @@ class ReductionNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building ReductionNode operations " << attributes.name << " "); - auto reduction_descriptor = cudnn_frontend::ReductionDescBuilder() - .setComputeType(attributes.compute_data_type) - .setReductionOp(attributes.get_mode().value()) - .setIsDeterministic(attributes.get_is_deterministic()) - .build(); + // Create reduction descriptor by directly calling cuDNN backend API + ReductionDesc_v8 reduction_descriptor; - auto&& reduction_operation_builder = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR); + // 1. Create the backend descriptor - CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reduction_attributes::input_names::X); - reduction_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + _CUDNN_CHECK_CUDNN_ERROR( + reduction_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_REDUCTION_DESCRIPTOR)); - CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reduction_attributes::output_names::Y); - reduction_operation_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); - - reduction_operation_builder.setreductionDesc(reduction_descriptor); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = reduction_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = reduction_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + // 2. Set compute type attribute + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // 3. Set reduction operator attribute + cudnnReduceTensorOp_t cudnn_reduction_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_mode().value(), cudnn_reduction_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_OPERATOR, + CUDNN_TYPE_REDUCTION_OPERATOR_TYPE, + 1, + &cudnn_reduction_mode)); + + // 4. Set deterministic mode if supported +#if (CUDNN_VERSION >= 91100) + if (detail::get_backend_version() >= 91100) { + bool is_deterministic = attributes.get_is_deterministic(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_IS_DETERMINISTIC, + CUDNN_TYPE_BOOLEAN, + 1, + &is_deterministic)); } #endif + // 5. Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reduction_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(reduction_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 reduction_operation; + + // Validate input tensors are set + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reduction_attributes::input_names::X); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reduction_attributes::output_names::Y); + + // 1. Create the backend operation descriptor + + _CUDNN_CHECK_CUDNN_ERROR( + reduction_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)); + + // 2. Set the reduction descriptor attribute + auto reduction_desc_ptr = reduction_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &reduction_desc_ptr)); + + // 3. Set the input tensor (X) descriptor attribute + auto x_backend_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_backend_desc)); + + // 4. Set the output tensor (Y) descriptor attribute + auto y_backend_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_backend_desc)); + + // 5. Finalize the operation descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reduction_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(reduction_operation))); + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/resample.h b/include/cudnn_frontend/node/resample.h index fbcd8319..34e6031f 100644 --- a/include/cudnn_frontend/node/resample.h +++ b/include/cudnn_frontend/node/resample.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Resample.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -105,53 +102,159 @@ class ResampleNode : public NodeCRTP { auto number_of_spatial_dim = static_cast(attributes.window.size()); - // Define the resample descriptor - auto resample_descriptor = cudnn_frontend::ResampleDescBuilder_v8() - .setComputeType(attributes.compute_data_type) - .setNanPropagation(CUDNN_PROPAGATE_NAN) - .setResampleMode(attributes.resample_mode) - .setPaddingMode(attributes.padding_mode) - .setSpatialDim(number_of_spatial_dim, attributes.window.data()) - .setSpatialStride(number_of_spatial_dim, attributes.stride.data()) - .setPrePadding(number_of_spatial_dim, attributes.pre_padding.data()) - .setPostPadding(number_of_spatial_dim, attributes.post_padding.data()) - .build(); - - auto&& resample_op_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_RESAMPLE_FWD_DESCRIPTOR); + // Create resample descriptor by directly calling cuDNN backend API + ResampleDesc_v8 resample_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR( + resample_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_RESAMPLE_DESCRIPTOR)); + + // Set resample mode + cudnnResampleMode_t cudnn_resample_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.resample_mode, cudnn_resample_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_MODE, + CUDNN_TYPE_RESAMPLE_MODE, + 1, + &cudnn_resample_mode)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set nan propagation + cudnnNanPropagation_t nan_opt = CUDNN_PROPAGATE_NAN; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION, + CUDNN_TYPE_NAN_PROPOGATION, + 1, + &nan_opt)); + + // Set padding mode + cudnnPaddingMode_t cudnn_padding_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.padding_mode, cudnn_padding_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_PADDING_MODE, + CUDNN_TYPE_PADDING_MODE, + 1, + &cudnn_padding_mode)); + + // Set spatial dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &number_of_spatial_dim)); + + // Set window dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_WINDOW_DIMS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.window.data())); + + // Set pre padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_PRE_PADDINGS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.pre_padding.data())); + // Set post padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_POST_PADDINGS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.post_padding.data())); + + // Set strides + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_STRIDES, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.stride.data())); + + // Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(resample_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(resample_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 resample_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + resample_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Resample_attributes::input_names::X); - resample_op_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Resample_attributes::output_names::Y); - resample_op_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); - resample_op_builder.setResampleDesc(resample_descriptor); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + // Set alpha and beta + double alpha = 1.0; + double beta = 0.0; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + resample_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA, CUDNN_TYPE_DOUBLE, 1, &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + resample_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA, CUDNN_TYPE_DOUBLE, 1, &beta)); + + // Set resample descriptor + auto resample_raw_desc = resample_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &resample_raw_desc)); + + // Set index tensor if available auto index = attributes.outputs.find(Resample_attributes::output_names::Index); if ((index != attributes.outputs.end()) && (index->second != nullptr)) { - resample_op_builder.setidxDesc(*tensors.at(index->second->get_uid())); - } + auto idx_desc = tensors.at(index->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = resample_op_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = resample_op_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &idx_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(resample_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(resample_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 30d96ac9..39b08c79 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -1,7 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -62,35 +60,36 @@ class ReshapeNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: " << "Building ReshapeNode operations " << attributes.name << " "); - auto&& reshape_op_builder = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_RESHAPE_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 reshape_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + reshape_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)); + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reshape_attributes::input_names::X); - reshape_op_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESHAPE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reshape_attributes::output_names::Y); - reshape_op_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); - - reshape_op_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = reshape_op_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = reshape_op_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESHAPE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reshape_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(reshape_operation))); + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/rmsnorm.h b/include/cudnn_frontend/node/rmsnorm.h index 01af69a8..bc1f37d2 100644 --- a/include/cudnn_frontend/node/rmsnorm.h +++ b/include/cudnn_frontend/node/rmsnorm.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Heuristics.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -90,51 +87,101 @@ class RMSNormNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building RMSNormNode operations " << attributes.name << " "); - auto&& rmsnorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 rmsnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + rmsnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to RMS_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::RMS_NORM, cudnn_norm_mode)); - rmsnorm_operation_builder.setNormalizationMode(NormMode_t::RMS_NORM).setNormFwdPhase(attributes.forward_phase); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Rmsnorm_attributes::input_names::X); - rmsnorm_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + // Set scale tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Rmsnorm_attributes::input_names::SCALE); - rmsnorm_operation_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set epsilon tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Rmsnorm_attributes::input_names::EPSILON); - rmsnorm_operation_builder.setEpsilonTensor(*(tensors.at(EPSILON->second->get_uid()))); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Rmsnorm_attributes::output_names::Y); - rmsnorm_operation_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set inv_variance for training phase if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Rmsnorm_attributes::output_names::INV_VARIANCE); - rmsnorm_operation_builder.setSavedInvVar(*(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); } + // Set optional bias tensor auto BIAS = attributes.inputs.find(Rmsnorm_attributes::input_names::BIAS); if ((BIAS != attributes.inputs.end()) && (BIAS->second != nullptr)) { - rmsnorm_operation_builder.setBias(*(tensors.at(BIAS->second->get_uid()))); - } -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = rmsnorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = rmsnorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rmsnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(rmsnorm_operation))); auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); @@ -247,52 +294,99 @@ class DRMSNormNode : public NodeCRTP { CUDNN_FRONTEND_UNUSED(raw_operations); CUDNN_FE_LOG_LABEL("INFO: Building DRMSNormNode operations " << attributes.name << " "); - auto&& DRMSNorm_operation_builder = - cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); + // Create operation by directly calling cuDNN backend API + Operation_v8 drmsnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + drmsnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to RMS_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::RMS_NORM, cudnn_norm_mode)); - DRMSNorm_operation_builder.setNormalizationMode(NormMode_t::RMS_NORM); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + // Set input tensor X CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Rmsnorm_backward_attributes::input_names::X); - DRMSNorm_operation_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Rmsnorm_backward_attributes::input_names::DY); - DRMSNorm_operation_builder.setdyDesc(*(tensors.at(DY->second->get_uid()))); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + // Set scale tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Rmsnorm_backward_attributes::input_names::SCALE); - DRMSNorm_operation_builder.setScale(*(tensors.at(SCALE->second->get_uid()))); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set inv_variance tensor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Rmsnorm_backward_attributes::input_names::INV_VARIANCE); - DRMSNorm_operation_builder.setSavedInvVar(*(tensors.at(INV_VARIANCE->second->get_uid()))); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + // Set DSCALE output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Rmsnorm_backward_attributes::output_names::DSCALE); - DRMSNorm_operation_builder.setDScale(*(tensors.at(DSCALE->second->get_uid()))); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + // Set optional DBIAS output tensor if (attributes.use_dbias.value()) { CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Rmsnorm_backward_attributes::output_names::DBIAS); - DRMSNorm_operation_builder.setDBias(*(tensors.at(DBIAS->second->get_uid()))); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); } + // Set DX output tensor CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Rmsnorm_backward_attributes::output_names::DX); - DRMSNorm_operation_builder.setdxDesc(*(tensors.at(DX->second->get_uid()))); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = DRMSNorm_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = DRMSNorm_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(drmsnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(drmsnorm_operation))); + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index e0404cd2..09b762d3 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -1,8 +1,5 @@ #pragma once -#include "../../cudnn_frontend_Rng.h" -#include "../../cudnn_frontend_Logging.h" - #include "../graph_helpers.h" #include "../node_interface.h" @@ -61,53 +58,96 @@ class RngNode : public NodeCRTP { managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { CUDNN_FRONTEND_UNUSED(raw_operations); - CUDNN_FE_LOG_LABEL("INFO: Building RngNode operations " << attributes.name << " "); + CUDNN_FE_LOG_LABEL("INFO: " << "Building RngNode operations " << attributes.name << " "); RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.get_distribution() != RngDistribution_t::BERNOULLI, error_code_t::ATTRIBUTE_NOT_SET, "no other distribution except bernoulli supported."); - auto rng_descriptor = cudnn_frontend::RngDescBuilder() - .setRngDistribution(attributes.get_distribution()) - .setBernoulliDistProbability(attributes.get_bernoulli_probability().value()) - .build(); + // Create RNG descriptor by directly calling cuDNN backend API + RngDesc_v8 rng_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR(rng_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_RNG_DESCRIPTOR)); + + // Set distribution type + cudnnRngDistribution_t cudnn_rng_distribution; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_distribution(), cudnn_rng_distribution)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_descriptor.get_raw_desc(), + CUDNN_ATTR_RNG_DISTRIBUTION, + CUDNN_TYPE_RNG_DISTRIBUTION, + 1, + &cudnn_rng_distribution)); + + // Set Bernoulli distribution probability + double bernoulli_prob = attributes.get_bernoulli_probability().value(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_descriptor.get_raw_desc(), + CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY, + CUDNN_TYPE_DOUBLE, + 1, + &bernoulli_prob)); + + // Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rng_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(rng_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 rng_operation; - auto&& Rng_operation_builder = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_RNG_DESCRIPTOR); + _CUDNN_CHECK_CUDNN_ERROR( + rng_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)); + // Set output tensor Y CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Rng_attributes::output_names::Y); - Rng_operation_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); - Rng_operation_builder.setRngDesc(rng_descriptor); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + rng_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RNG_YDESC, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &y_desc)); + + // Set RNG descriptor + auto rng_raw_desc = rng_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &rng_raw_desc)); if (attributes.seed.has_value()) { - Rng_operation_builder.setSeed(attributes.get_seed().value()); + // Set seed as int64_t value + int64_t seed_value = attributes.get_seed().value(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + rng_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RNG_SEED, CUDNN_TYPE_INT64, 1, &seed_value)); } else { + // Set seed tensor descriptor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Seed, Rng_attributes::input_names::Seed); - Rng_operation_builder.setSeedDesc(*(tensors.at(Seed->second->get_uid()))); + auto seed_desc = tensors.at(Seed->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_SEED, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &seed_desc)); + + // Set offset tensor descriptor CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Offset, Rng_attributes::input_names::Offset); - Rng_operation_builder.setOffsetDesc(*(tensors.at(Offset->second->get_uid()))); - } + auto offset_desc = tensors.at(Offset->second->get_uid())->get_raw_desc(); -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto operation = Rng_operation_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - operation.get_error()); - operations.push_back(std::make_shared(std::move(operation))); -#else - // build() can throw - // wrap in try catch - try { - auto operation = Rng_operation_builder.build(); - operations.push_back(std::make_shared(std::move(operation))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &offset_desc)); } -#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rng_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(rng_operation))); + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index a6480c02..55d635a4 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -969,6 +969,12 @@ class CompositeSDPABackwardNode : public NodeCRTP { is_deterministic_algorithm_supported_on_blackwell = true; } + if(detail::get_backend_version() >= 91801) { + RETURN_CUDNN_FRONTEND_ERROR_IF(is_ragged && (8 == prop.major || 12 == prop.major) && attributes.is_deterministic_algorithm, + error_code_t::GRAPH_NOT_SUPPORTED, + "Deterministic algorithm is not supported for bprop thd on SM8X and SM12X GPUs"); + } + // version specific validation RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, @@ -1042,6 +1048,18 @@ class CompositeSDPABackwardNode : public NodeCRTP { attributes.max_total_seq_len_q.reset(); attributes.max_total_seq_len_kv.reset(); } + + + if(detail::get_backend_version() >= 91801) { + cudaDeviceProp prop; + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + if((8 == prop.major || 12 == prop.major) && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + } // clang-format on return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h index e08b5928..225ecf1f 100644 --- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -17,6 +17,9 @@ class SDPAFP8BackwardNode : public NodeCRTP { using input_names = SDPA_fp8_backward_attributes::input_names; using output_names = SDPA_fp8_backward_attributes::output_names; + private: + mutable bool is_deterministic_algorithm_supported_on_blackwell = false; // Will be edited in pre_validate_node() + public: SDPA_fp8_backward_attributes attributes; @@ -106,9 +109,13 @@ class SDPAFP8BackwardNode : public NodeCRTP { // validate basic dimension requirements if(prop.major >= 10) { - RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 16 != 0) || (d_v > 128) || (d_v % 16 != 0), + RETURN_CUDNN_FRONTEND_ERROR_IF(((d_qk > 128) || (d_qk % 16 != 0)) && !(d_qk == 192 && d_v == 128), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_qk shoud be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(((d_v > 128) || (d_v % 16 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, - "hidden_dim shoud be less than 128 and hidden_dim should be multiple of 16"); + "hidden_dim d_v shoud be less than or equal to 128 and hidden_dim d_v should be multiple of 16"); } else { RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk != 128) || (d_qk % 16 != 0) || (d_v != 128) || (d_v % 16 != 0), @@ -186,6 +193,20 @@ class SDPAFP8BackwardNode : public NodeCRTP { error_code_t::ATTRIBUTE_NOT_SET, "Intermediate tensor data type needs to be set as internal tensors require it."); + // validate options for deterministic algorithm + if (attributes.is_deterministic_algorithm && (prop.major == 10)) { + RETURN_CUDNN_FRONTEND_ERROR_IF((detail::get_backend_version() < 91900), + error_code_t::GRAPH_NOT_SUPPORTED, + "FP8 deterministic algorithm is not supported on blackwell architecture with cudnn version below 9.19.0"); + + // dbias bias rng/dropout alibi + RETURN_CUDNN_FRONTEND_ERROR_IF(is_dropout, + error_code_t::GRAPH_NOT_SUPPORTED, + "FP8 deterministic algorithm is not supported on blackwell architecture when dropout is enabled"); + + is_deterministic_algorithm_supported_on_blackwell = true; + } + // if output data type is half or bfloat16 for any of dq, dk, dv, and version is below 9.13 or is not blackwell, return NOT_SUPPORTED RETURN_CUDNN_FRONTEND_ERROR_IF( (dq_data_type == DataType_t::HALF || dq_data_type == DataType_t::BFLOAT16 || @@ -607,6 +628,15 @@ class SDPAFP8BackwardNode : public NodeCRTP { return {error_code_t::OK, ""}; } + std::pair> + override_heuristics_query() const { + if (is_deterministic_algorithm_supported_on_blackwell) { + return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } else { + return {-1, {}}; + } + } + #ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { diff --git a/include/cudnn_frontend/node/sdpa_support_surface.h b/include/cudnn_frontend/node/sdpa_support_surface.h index 2a58ec95..6486943a 100644 --- a/include/cudnn_frontend/node/sdpa_support_surface.h +++ b/include/cudnn_frontend/node/sdpa_support_surface.h @@ -158,11 +158,18 @@ SDPA_attributes::validate_sdpa_support_surface(const detail::Context& context, "consider using a newer architecture."); // validate basic dimension requirements + // d_qk=192 with d_v=128 is only supported starting from cuDNN 9.19 + bool const d192_v128_supported = (detail::get_backend_version() >= 91900); if (prop.major >= 10) { RETURN_CUDNN_FRONTEND_ERROR_IF( - (d_qk > 128) || (d_qk % 16 != 0) || (d_v > 128) || (d_v % 16 != 0), + ((d_qk > 128) || (d_qk % 16 != 0)) && !(d192_v128_supported && d_qk == 192 && d_v == 128), error_code_t::GRAPH_NOT_SUPPORTED, - "hidden_dim shoud be less than or equal to 128 and hidden_dim should be multiple of 16"); + "hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk " + "should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+)"); + RETURN_CUDNN_FRONTEND_ERROR_IF( + ((d_v > 128) || (d_v % 16 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_v should be less than or equal to 128 and hidden_dim d_v should be multiple of 16"); } else { RETURN_CUDNN_FRONTEND_ERROR_IF( (d_qk > 256) || (d_qk % 16 != 0) || (d_v > 256) || (d_v % 16 != 0), @@ -409,6 +416,10 @@ SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Co error_code_t::GRAPH_NOT_SUPPORTED, "Unified SDPA node requires cuDNN 9.13.1"); + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_dynamic_shape_enabled(), + error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support dynamic shape"); + // TODO: Provide smarter error messages that provide the required cuDNN version for each input. std::unordered_set allowed_input_names{ input_names::Q, input_names::K, input_names::V, input_names::Attn_scale}; diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h index ec940860..c30812f2 100644 --- a/include/cudnn_frontend/plans.h +++ b/include/cudnn_frontend/plans.h @@ -15,6 +15,33 @@ namespace cudnn_frontend { namespace detail { +inline error_t +execute(cudnnHandle_t handle, + ExecutionPlan* plan, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) { + // TODO: below line fails with MSVC. warning C4127: conditional expression is constant + // RETURN_CUDNN_FRONTEND_ERROR_IF(!plan, error_code_t::GRAPH_EXECUTION_FAILED, "No plan found to execute!!"); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing " << plan->getTag() << "..."); + + backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack( + variant_pack_descriptor, device_ptrs, uids, workspace_ptr, override_uids, override_shapes, override_strides)); + _CUDNN_CHECK_CUDNN_ERROR(execute(handle, plan->get_raw_desc(), variant_pack_descriptor.get_ptr())); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executed " << plan->getTag() << "."); + + return {error_code_t::OK, ""}; +} + inline error_t execute(cudnnHandle_t handle, ExecutionPlan* plan, diff --git a/include/cudnn_frontend_ConvDesc.h b/include/cudnn_frontend_ConvDesc.h index 7f005091..0cbd2480 100644 --- a/include/cudnn_frontend_ConvDesc.h +++ b/include/cudnn_frontend_ConvDesc.h @@ -23,6 +23,22 @@ #pragma once #include + +namespace cudnn_frontend { +namespace graph { +class ConvolutionNode; +} +} // namespace cudnn_frontend +namespace cudnn_frontend { +namespace graph { +class DgradNode; +} +} // namespace cudnn_frontend +namespace cudnn_frontend { +namespace graph { +class WgradNode; +} +} // namespace cudnn_frontend #include #include #include @@ -51,6 +67,9 @@ namespace cudnn_frontend { class ConvDesc_v8 : public BackendDescriptor { public: friend class ConvDescBuilder_v8; + friend class cudnn_frontend::graph::ConvolutionNode; + friend class cudnn_frontend::graph::DgradNode; + friend class cudnn_frontend::graph::WgradNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_MatMulDesc.h b/include/cudnn_frontend_MatMulDesc.h index d79c1f3a..ee57403f 100644 --- a/include/cudnn_frontend_MatMulDesc.h +++ b/include/cudnn_frontend_MatMulDesc.h @@ -23,6 +23,12 @@ #pragma once #include + +namespace cudnn_frontend { +namespace graph { +class MatmulNode; +} +} // namespace cudnn_frontend #include #include #include @@ -44,6 +50,7 @@ namespace cudnn_frontend { class MatMulDesc_v8 : public BackendDescriptor { public: friend class MatMulDescBuilder_v8; + friend class cudnn_frontend::graph::MatmulNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index c394438a..6bf01019 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -23,6 +23,34 @@ #pragma once #include + +namespace cudnn_frontend { +namespace graph { +class ReductionNode; +class PointwiseNode; +class MatmulNode; +class ConvolutionNode; +class DgradNode; +class WgradNode; +class LayerNormNode; +class BatchNormNode; +class BatchnormInferenceNode; +class RMSNormNode; +class DRMSNormNode; +class InstanceNormNode; +class DINNode; +class DLNNode; +class DBNNode; +class DBNWeightNode; +class BatchNormFinalizeNode; +class GenstatsNode; +class ReshapeNode; +class ResampleNode; +class RngNode; +class PagedCacheLoadNode; +} // namespace graph +} // namespace cudnn_frontend + #include #include #include @@ -77,6 +105,28 @@ namespace cudnn_frontend { class Operation_v8 : public BackendDescriptor { public: friend class OperationBuilder_v8; + friend class cudnn_frontend::graph::ReductionNode; + friend class cudnn_frontend::graph::PointwiseNode; + friend class cudnn_frontend::graph::MatmulNode; + friend class cudnn_frontend::graph::ConvolutionNode; + friend class cudnn_frontend::graph::DgradNode; + friend class cudnn_frontend::graph::WgradNode; + friend class cudnn_frontend::graph::LayerNormNode; + friend class cudnn_frontend::graph::BatchNormNode; + friend class cudnn_frontend::graph::BatchnormInferenceNode; + friend class cudnn_frontend::graph::RMSNormNode; + friend class cudnn_frontend::graph::DRMSNormNode; + friend class cudnn_frontend::graph::InstanceNormNode; + friend class cudnn_frontend::graph::DINNode; + friend class cudnn_frontend::graph::DLNNode; + friend class cudnn_frontend::graph::DBNNode; + friend class cudnn_frontend::graph::DBNWeightNode; + friend class cudnn_frontend::graph::BatchNormFinalizeNode; + friend class cudnn_frontend::graph::GenstatsNode; + friend class cudnn_frontend::graph::ReshapeNode; + friend class cudnn_frontend::graph::ResampleNode; + friend class cudnn_frontend::graph::RngNode; + friend class cudnn_frontend::graph::PagedCacheLoadNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_PointWiseDesc.h b/include/cudnn_frontend_PointWiseDesc.h index 89c70900..87bfbd2f 100644 --- a/include/cudnn_frontend_PointWiseDesc.h +++ b/include/cudnn_frontend_PointWiseDesc.h @@ -23,6 +23,12 @@ #pragma once #include + +namespace cudnn_frontend { +namespace graph { +class PointwiseNode; +} +} // namespace cudnn_frontend #include #include #include @@ -53,6 +59,7 @@ namespace cudnn_frontend { class PointWiseDesc_v8 : public BackendDescriptor { public: friend class PointWiseDescBuilder_v8; + friend class cudnn_frontend::graph::PointwiseNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_ReductionDesc.h b/include/cudnn_frontend_ReductionDesc.h index dacda0c7..627b5d0f 100644 --- a/include/cudnn_frontend_ReductionDesc.h +++ b/include/cudnn_frontend_ReductionDesc.h @@ -23,6 +23,12 @@ #pragma once #include + +namespace cudnn_frontend { +namespace graph { +class ReductionNode; +} +} // namespace cudnn_frontend #include #include #include @@ -46,6 +52,7 @@ namespace cudnn_frontend { class ReductionDesc_v8 : public BackendDescriptor { public: friend class ReductionDescBuilder_v8; + friend class cudnn_frontend::graph::ReductionNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h index 57dad803..a375a458 100644 --- a/include/cudnn_frontend_Resample.h +++ b/include/cudnn_frontend_Resample.h @@ -31,6 +31,12 @@ #include "cudnn_frontend_utils.h" +namespace cudnn_frontend { +namespace graph { +class ResampleNode; +} +} // namespace cudnn_frontend + namespace cudnn_frontend { /// @@ -44,6 +50,7 @@ namespace cudnn_frontend { class ResampleDesc_v8 : public BackendDescriptor { public: friend class ResampleDescBuilder_v8; + friend class graph::ResampleNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_Rng.h b/include/cudnn_frontend_Rng.h index 207861e9..829f8bde 100644 --- a/include/cudnn_frontend_Rng.h +++ b/include/cudnn_frontend_Rng.h @@ -31,6 +31,12 @@ #include "cudnn_frontend_utils.h" +namespace cudnn_frontend { +namespace graph { +class RngNode; +} +} // namespace cudnn_frontend + namespace cudnn_frontend { /// @@ -44,6 +50,7 @@ namespace cudnn_frontend { class RngDesc_v8 : public BackendDescriptor { public: friend class RngDescBuilder_v8; + friend class graph::RngNode; std::string describe() const override { std::stringstream ss; diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h index e7c8a1e4..53e9afd7 100644 --- a/include/cudnn_frontend_version.h +++ b/include/cudnn_frontend_version.h @@ -23,7 +23,7 @@ #pragma once #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 17 +#define CUDNN_FRONTEND_MINOR_VERSION 18 #define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/pyproject.toml b/pyproject.toml index abefe7db..d1a34408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ ] [tool.setuptools] -packages = ["cudnn", "include", "cudnn.native_sparse_attention", "cudnn.native_sparse_attention.selection", "cudnn.native_sparse_attention.sliding_window_attention", "cudnn.native_sparse_attention.top_k", "cudnn.gemm_swiglu", "cudnn.gemm_amax"] +packages = ["cudnn", "include", "cudnn.native_sparse_attention", "cudnn.native_sparse_attention.selection", "cudnn.native_sparse_attention.compression", "cudnn.native_sparse_attention.sliding_window_attention", "cudnn.native_sparse_attention.top_k", "cudnn.gemm_swiglu", "cudnn.gemm_amax", "cudnn.grouped_gemm", "cudnn.grouped_gemm.grouped_gemm_swiglu"] package-dir = {"" = "python", "include" = "include"} include-package-data = true @@ -30,7 +30,7 @@ include = ["**/*"] [project.optional-dependencies] cutedsl = [ - "nvidia-cutlass-dsl==4.3.1", + "nvidia-cutlass-dsl==4.3.5", "cuda-python", "torch", -] \ No newline at end of file +] diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index a5dea105..2e42a9a5 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -46,7 +46,7 @@ def is_windows(): from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.17.0" +__version__ = "1.18.0" def _tensor( @@ -113,7 +113,15 @@ def _library_device_pointer(input_tensor): return _pybind_module._get_data_ptr(input_tensor) -def _execute(self, tensor_to_device_buffer, workspace, handle=None): +def _execute( + self, + tensor_to_device_buffer, + workspace, + handle=None, + override_uids=None, + override_shapes=None, + override_strides=None, +): """ Execute a cudnn graph. @@ -125,9 +133,7 @@ def _execute(self, tensor_to_device_buffer, workspace, handle=None): None """ uid_to_tensor_pointer = { - x if type(x) is int else x.get_uid(): _library_device_pointer(pointer) - for x, pointer in tensor_to_device_buffer.items() - if x is not None + x if type(x) is int else x.get_uid(): _library_device_pointer(pointer) for x, pointer in tensor_to_device_buffer.items() if x is not None } workspace_pointer = _library_device_pointer(workspace) @@ -135,7 +141,14 @@ def _execute(self, tensor_to_device_buffer, workspace, handle=None): def _execute_plan_at_index( - self, tensor_to_device_buffer, workspace, index, handle=None + self, + tensor_to_device_buffer, + workspace, + index, + handle=None, + override_uids=None, + override_shapes=None, + override_strides=None, ): """ Execute a cudnn graph. @@ -149,13 +162,19 @@ def _execute_plan_at_index( None """ uid_to_tensor_pointer = { - x if type(x) is int else x.get_uid(): _library_device_pointer(pointer) - for x, pointer in tensor_to_device_buffer.items() - if x is not None + x if type(x) is int else x.get_uid(): _library_device_pointer(pointer) for x, pointer in tensor_to_device_buffer.items() if x is not None } workspace_pointer = _library_device_pointer(workspace) - self._execute_plan_at_index(uid_to_tensor_pointer, workspace_pointer, index, handle) + self._execute_plan_at_index( + uid_to_tensor_pointer, + workspace_pointer, + index, + handle, + override_uids, + override_shapes, + override_strides, + ) pygraph.execute = _execute @@ -164,14 +183,10 @@ def _execute_plan_at_index( def load_cudnn(): # First look at python site packages - lib_path = glob.glob( - os.path.join(sysconfig.get_path("purelib"), "nvidia/cudnn/bin/cudnn64_9.dll") - ) + lib_path = glob.glob(os.path.join(sysconfig.get_path("purelib"), "nvidia/cudnn/bin/cudnn64_9.dll")) if lib_path: - assert ( - len(lib_path) == 1 - ), f"Found {len(lib_path)} libcudnn.dll.x in nvidia-cudnn-cuXX." + assert len(lib_path) == 1, f"Found {len(lib_path)} libcudnn.dll.x in nvidia-cudnn-cuXX." lib = ctypes.windll.LoadLibrary(lib_path[0]) else: # Fallback lib = ctypes.windll.LoadLibrary("cudnn64_9.dll") @@ -182,23 +197,13 @@ def load_cudnn(): def _dlopen_cudnn(): # First look at python site packages - lib_path = glob.glob( - os.path.join( - sysconfig.get_path("purelib"), "nvidia/cudnn/lib/libcudnn.so.*[0-9]" - ) - ) + lib_path = glob.glob(os.path.join(sysconfig.get_path("purelib"), "nvidia/cudnn/lib/libcudnn.so.*[0-9]")) if not lib_path: - lib_path = glob.glob( - os.path.join( - sysconfig.get_path("purelib"), "nvidia/cudnn_jit/lib/libcudnn.so.*[0-9]" - ) - ) + lib_path = glob.glob(os.path.join(sysconfig.get_path("purelib"), "nvidia/cudnn_jit/lib/libcudnn.so.*[0-9]")) if lib_path: - assert ( - len(lib_path) == 1 - ), f"Found {len(lib_path)} libcudnn.so.x in nvidia-cudnn-cuXX." + assert len(lib_path) == 1, f"Found {len(lib_path)} libcudnn.so.x in nvidia-cudnn-cuXX." lib = ctypes.CDLL(lib_path[0]) else: # Fallback try: @@ -232,9 +237,7 @@ def __getattr__(name: str) -> Any: return _NSA except Exception as e: - raise ImportError( - f"NSA requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e + raise ImportError(f"NSA requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e elif name == "GemmSwigluSm100": try: @@ -242,9 +245,7 @@ def __getattr__(name: str) -> Any: return _GemmSwigluSm100 except Exception as e: - raise ImportError( - f"GemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e + raise ImportError(f"GemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e elif name == "gemm_swiglu_wrapper_sm100": try: @@ -264,9 +265,7 @@ def __getattr__(name: str) -> Any: return _GemmAmaxSm100 except Exception as e: - raise ImportError( - f"GemmAmaxSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" - ) from e + raise ImportError(f"GemmAmaxSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e elif name == "gemm_amax_wrapper_sm100": try: @@ -275,9 +274,37 @@ def __getattr__(name: str) -> Any: ) return _gemm_amax_wrapper_sm100 + except Exception as e: + raise ImportError(f"gemm_amax_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e + + # Grouped GEMM module + elif name == "grouped_gemm": + try: + from . import grouped_gemm as _grouped_gemm + + return _grouped_gemm + except Exception as e: + raise ImportError(f"grouped_gemm requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e + + elif name == "GroupedGemmSwigluSm100": + try: + from .grouped_gemm import GroupedGemmSwigluSm100 as _GroupedGemmSwigluSm100 + + return _GroupedGemmSwigluSm100 + except Exception as e: + raise ImportError(f"GroupedGemmSwigluSm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}") from e + + elif name == "grouped_gemm_swiglu_wrapper_sm100": + try: + from .grouped_gemm import ( + grouped_gemm_swiglu_wrapper_sm100 as _grouped_gemm_swiglu_wrapper_sm100, + ) + + return _grouped_gemm_swiglu_wrapper_sm100 except Exception as e: raise ImportError( - f"gemm_amax_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" + f"grouped_gemm_swiglu_wrapper_sm100 requires optional dependencies. Install with 'pip install nvidia-cudnn-frontend[cutedsl]': {e}" ) from e + else: raise AttributeError(name) diff --git a/python/cudnn/api_base.py b/python/cudnn/api_base.py index db6ca945..7f3e544c 100644 --- a/python/cudnn/api_base.py +++ b/python/cudnn/api_base.py @@ -7,6 +7,7 @@ This module provides abstract base classes that define common interfaces for cuDNN API wrapper classes, including validation, compilation, and execution patterns. """ + from __future__ import annotations from abc import ABC, abstractmethod @@ -24,6 +25,11 @@ def ceil_div(a: int, b: int) -> int: return (a + b - 1) // b +def is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 + + class APIBase(ABC): """Abstract base class for cuDNN API wrappers. @@ -218,9 +224,7 @@ def _ensure_support_checked(self) -> None: ... # ... rest of compilation """ if not self._is_supported: - self._logger.info( - f"{self.__class__.__name__}: check_support not previously called, calling now" - ) + self._logger.info(f"{self.__class__.__name__}: check_support not previously called, calling now") assert self.check_support(), "Unsupported configuration" def _get_default_stream(self, stream: Optional[cuda.CUstream]) -> cuda.CUstream: @@ -241,9 +245,7 @@ def _get_default_stream(self, stream: Optional[cuda.CUstream]) -> cuda.CUstream: ... # Now current_stream is guaranteed to be a valid stream """ if stream is None: - self._logger.debug( - f"{self.__class__.__name__}: No CUDA stream provided, using default stream" - ) + self._logger.debug(f"{self.__class__.__name__}: No CUDA stream provided, using default stream") return cutlass.cuda.default_stream() return stream @@ -292,9 +294,7 @@ def _unpad_tensor_to_ndim( for _ in range(tensor.ndim - ndim): tensor = tensor.squeeze(-1) if tensor.ndim != ndim: - self._logger.critical( - f"Unpadding {name} resulted in shape {tensor.shape}, expected {ndim}D" - ) + self._logger.critical(f"Unpadding {name} resulted in shape {tensor.shape}, expected {ndim}D") return tensor def _is_fp4x2(self, tensor_or_dtype: torch.Tensor | torch.dtype) -> bool: @@ -307,14 +307,8 @@ def _is_fp4x2(self, tensor_or_dtype: torch.Tensor | torch.dtype) -> bool: """ if tensor_or_dtype is None: return False - dtype = ( - tensor_or_dtype.dtype - if isinstance(tensor_or_dtype, torch.Tensor) - else tensor_or_dtype - ) - return (dtype == torch.float4_e2m1fn_x2) or ( - self._interpret_uint8_as_fp4x2 and dtype == torch.uint8 - ) + dtype = tensor_or_dtype.dtype if isinstance(tensor_or_dtype, torch.Tensor) else tensor_or_dtype + return (dtype == torch.float4_e2m1fn_x2) or (self._interpret_uint8_as_fp4x2 and dtype == torch.uint8) def _is_fp8(self, tensor_or_dtype: torch.Tensor | torch.dtype) -> bool: """Check if tensor or dtype is an FP8 datatype. @@ -326,11 +320,7 @@ def _is_fp8(self, tensor_or_dtype: torch.Tensor | torch.dtype) -> bool: """ if tensor_or_dtype is None: return False - dtype = ( - tensor_or_dtype.dtype - if isinstance(tensor_or_dtype, torch.Tensor) - else tensor_or_dtype - ) + dtype = tensor_or_dtype.dtype if isinstance(tensor_or_dtype, torch.Tensor) else tensor_or_dtype return dtype in {torch.float8_e5m2, torch.float8_e4m3fn} def _get_innermost_stride_dim(self, tensor: torch.Tensor, name: str = "") -> int: @@ -343,9 +333,7 @@ def _get_innermost_stride_dim(self, tensor: torch.Tensor, name: str = "") -> int self._logger.critical( f"tensor {name} has shape: {tensor.shape} stride {tensor.stride()} – innermost contiguous (stride == 1) dimension not found. " ) - raise RuntimeError( - f"tensor {name} has shape: {tensor.shape} stride {tensor.stride()} – innermost contiguous (stride == 1) dimension not found. " - ) + raise RuntimeError(f"tensor {name} has shape: {tensor.shape} stride {tensor.stride()} – innermost contiguous (stride == 1) dimension not found. ") return idx def _tensor_shape( @@ -371,13 +359,8 @@ def _tensor_shape( if self._is_fp4x2(tensor): innermost_dim_index = self._get_innermost_stride_dim(tensor, name=name) - shape = tuple( - dim * 2 if i == innermost_dim_index else dim - for i, dim in enumerate(tensor.shape) - ) - self._logger.debug( - f"FP4x2 tensor {name}: physical shape {tensor.shape} -> logical shape {shape}" - ) + shape = tuple(dim * 2 if i == innermost_dim_index else dim for i, dim in enumerate(tensor.shape)) + self._logger.debug(f"FP4x2 tensor {name}: physical shape {tensor.shape} -> logical shape {shape}") return shape else: return tensor.shape @@ -405,13 +388,8 @@ def _tensor_stride( if self._is_fp4x2(tensor): innermost_dim_index = self._get_innermost_stride_dim(tensor, name=name) - strides = tuple( - s * 2 if i != innermost_dim_index else s - for i, s in enumerate(tensor.stride()) - ) - self._logger.debug( - f"FP4x2 tensor {name}: physical stride {tensor.stride()} -> logical stride {strides}" - ) + strides = tuple(s * 2 if i != innermost_dim_index else s for i, s in enumerate(tensor.stride())) + self._logger.debug(f"FP4x2 tensor {name}: physical stride {tensor.stride()} -> logical stride {strides}") return strides else: return tensor.stride() @@ -436,21 +414,13 @@ def _check_tensor_shape( """ if tensor_or_shape is None: return None - tensor_shape = ( - self._tensor_shape(tensor_or_shape, name=name) - if isinstance(tensor_or_shape, torch.Tensor) - else tensor_or_shape - ) + tensor_shape = self._tensor_shape(tensor_or_shape, name=name) if isinstance(tensor_or_shape, torch.Tensor) else tensor_or_shape if isinstance(shape, tuple): if tensor_shape != shape: - raise ValueError( - f"{name} tensor shape mismatch: expected {shape}, got {tensor_shape}" - ) + raise ValueError(f"{name} tensor shape mismatch: expected {shape}, got {tensor_shape}") elif isinstance(shape, list): if tensor_shape not in shape: - raise ValueError( - f"{name} tensor shape mismatch: expected one of {shape}, got {tensor_shape}" - ) + raise ValueError(f"{name} tensor shape mismatch: expected one of {shape}, got {tensor_shape}") else: raise ValueError(f"Expected shape to be a tuple or list, got {type(shape)}") return tensor_shape @@ -461,6 +431,7 @@ def _check_tensor_stride( stride: Optional[Tuple[int, ...] | List[Tuple[int, ...]]] = None, stride_order: Optional[Tuple[int, ...] | List[Tuple[int, ...]]] = None, name: str = "", + extra_error_msg: str = "", ) -> Optional[Tuple[Tuple[int, ...], Tuple[int, ...]]]: """Check if the stride of a tensor matches the expected stride(s) or stride order(s). @@ -472,51 +443,53 @@ def _check_tensor_stride( :type stride_order: Tuple[int, ...] | List[Tuple[int, ...]] :param name: Logical tensor name for logging :type name: str + :param extra_error_msg: Extra error message to add to the error + :type extra_error_msg: str :raises ValueError: If the stride of the tensor does not match the expected stride order :return: The stride and stride order of the tensor :rtype: Optional[Tuple[Tuple[int, ...], Tuple[int, ...]]] """ if tensor_or_stride is None: - return None - tensor_stride = ( - self._tensor_stride(tensor_or_stride, name=name) - if isinstance(tensor_or_stride, torch.Tensor) - else tensor_or_stride - ) - tensor_stride_order = tuple( - i for i, s in sorted(enumerate(tensor_stride), key=lambda x: x[1]) - ) + return None, None + tensor_stride = self._tensor_stride(tensor_or_stride, name=name) if isinstance(tensor_or_stride, torch.Tensor) else tensor_or_stride + tensor_stride_order = tuple(i for i, s in sorted(enumerate(tensor_stride), key=lambda x: x[1])) if stride is not None: if isinstance(stride, tuple): if tensor_stride != stride: - raise ValueError( - f"{name} tensor stride mismatch: expected {stride}, got {tensor_stride}" - ) + error_msg = f"{name} tensor stride mismatch: expected {stride}, got {tensor_stride}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) elif isinstance(stride, list): if tensor_stride not in stride: - raise ValueError( - f"{name} tensor stride mismatch: expected one of {stride}, got {tensor_stride}" - ) + error_msg = f"{name} tensor stride mismatch: expected one of {stride}, got {tensor_stride}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) else: - raise ValueError( - f"Expected stride to be a tuple or list, got {type(stride)}" - ) + error_msg = f"Expected stride to be a tuple or list, got {type(stride)}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) if stride_order is not None: if isinstance(stride_order, tuple): if tensor_stride_order != stride_order: - raise ValueError( - f"{name} tensor stride order mismatch: expected {stride_order}, got {tensor_stride_order}" - ) + error_msg = f"{name} tensor stride order mismatch: expected {stride_order}, got {tensor_stride_order}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) elif isinstance(stride_order, list): if tensor_stride_order not in stride_order: - raise ValueError( - f"{name} tensor stride order mismatch: expected one of {stride_order}, got {tensor_stride_order}" - ) + error_msg = f"{name} tensor stride order mismatch: expected one of {stride_order}, got {tensor_stride_order}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) else: - raise ValueError( - f"Expected stride order to be a tuple or list, got {type(stride_order)}" - ) + error_msg = f"Expected stride order to be a tuple or list, got {type(stride_order)}" + if extra_error_msg: + error_msg += f": {extra_error_msg}" + raise ValueError(error_msg) return tensor_stride, tensor_stride_order def _check_dtype( @@ -540,16 +513,10 @@ def _check_dtype( """ if tensor_or_dtype is None: return None - tensor_dtype = ( - tensor_or_dtype.dtype - if isinstance(tensor_or_dtype, torch.Tensor) - else tensor_or_dtype - ) + tensor_dtype = tensor_or_dtype.dtype if isinstance(tensor_or_dtype, torch.Tensor) else tensor_or_dtype if isinstance(dtype, torch.dtype): if tensor_dtype != dtype: - error_msg = ( - f"{name} dtype mismatch: expected {dtype}, got {tensor_dtype}" - ) + error_msg = f"{name} dtype mismatch: expected {dtype}, got {tensor_dtype}" if extra_error_msg: error_msg += f": {extra_error_msg}" raise ValueError(error_msg) @@ -560,9 +527,7 @@ def _check_dtype( error_msg += f": {extra_error_msg}" raise ValueError(error_msg) else: - raise ValueError( - f"Expected dtype to be a torch.dtype or list, got {type(dtype)}" - ) + raise ValueError(f"Expected dtype to be a torch.dtype or list, got {type(dtype)}") return tensor_dtype def _value_error_if(self, condition: bool, error_msg: str) -> None: @@ -601,9 +566,7 @@ def _runtime_error_if(self, condition: bool, error_msg: str) -> None: if condition: raise RuntimeError(error_msg) - def _make_cute_pointer( - self, tensor: torch.Tensor, assumed_align: int = 16 - ) -> cute.Pointer: + def _make_cute_pointer(self, tensor: torch.Tensor, assumed_align: int = 16) -> cute.Pointer: """Make a cute.Pointer for a tensor. :param tensor: The tensor to make a cute.Pointer for @@ -616,9 +579,7 @@ def _make_cute_pointer( if tensor is None: return None return cute.runtime.make_ptr( - _convert_to_cutlass_data_type( - tensor.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2 - ), + _convert_to_cutlass_data_type(tensor.dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2), tensor.data_ptr(), cute.AddressSpace.gmem, assumed_align=assumed_align, @@ -643,7 +604,37 @@ def _make_cute_tensor_descriptor( tensor_ptr = self._make_cute_pointer(tensor, assumed_align=assumed_align) tensor_shape = self._tensor_shape(tensor, name=name) tensor_stride = self._tensor_stride(tensor, name=name) - tensor_stride_order = tuple( - i for i, s in sorted(enumerate(tensor_stride), key=lambda x: x[1]) - ) + tensor_stride_order = tuple(i for i, s in sorted(enumerate(tensor_stride), key=lambda x: x[1])) return tensor_ptr, tensor_shape, tensor_stride_order + + +class TupleDict(dict): + """A dictionary that supports tuple unpacking. + + This class extends dict to allow unpacking like a tuple while still + providing dictionary-style key access. The unpacking order is determined + by the _keys attribute which preserves insertion order. + + Example: + >>> result = TupleDict(a=1, b=2, c=3) + >>> x, y, z = result # Unpacks as (1, 2, 3) + >>> result['a'] # Returns 1 + >>> result[0] # Returns 1 (integer indexing) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Store keys in order for tuple unpacking + self._keys = list(self.keys()) + + def __iter__(self): + """Iterate over values in insertion order for tuple unpacking.""" + return (self[k] for k in self._keys) + + def __getitem__(self, key): + """Support both string keys and integer indices.""" + if isinstance(key, int): + if key < 0 or key >= len(self._keys): + raise IndexError(f"index {key} out of range for TupleDict with {len(self._keys)} items") + return super().__getitem__(self._keys[key]) + return super().__getitem__(key) diff --git a/python/cudnn/datatypes.py b/python/cudnn/datatypes.py index aa7a7fcf..ebb024e5 100644 --- a/python/cudnn/datatypes.py +++ b/python/cudnn/datatypes.py @@ -71,39 +71,23 @@ def is_cutlass_available(): cutlass_available = True mapping = { torch.half: getattr(cutlass, "Float16", None), - getattr(torch, "float16", torch.half): getattr( - cutlass, "Float16", None - ), + getattr(torch, "float16", torch.half): getattr(cutlass, "Float16", None), getattr(torch, "bfloat16", None): getattr(cutlass, "BFloat16", None), torch.float: getattr(cutlass, "Float32", None), - getattr(torch, "float32", torch.float): getattr( - cutlass, "Float32", None - ), + getattr(torch, "float32", torch.float): getattr(cutlass, "Float32", None), torch.double: getattr(cutlass, "Float64", None), - getattr(torch, "float64", torch.double): getattr( - cutlass, "Float64", None - ), + getattr(torch, "float64", torch.double): getattr(cutlass, "Float64", None), getattr(torch, "int8", None): getattr(cutlass, "Int8", None), getattr(torch, "int32", None): getattr(cutlass, "Int32", None), getattr(torch, "int64", None): getattr(cutlass, "Int64", None), getattr(torch, "uint8", None): getattr(cutlass, "Uint8", None), getattr(torch, "bool", None): getattr(cutlass, "Boolean", None), - getattr(torch, "float8_e4m3fn", None): getattr( - cutlass, "Float8E4M3FN", None - ), - getattr(torch, "float8_e5m2", None): getattr( - cutlass, "Float8E5M2", None - ), - getattr(torch, "float8_e8m0fnu", None): getattr( - cutlass, "Float8E8M0FNU", None - ), - getattr(torch, "float4_e2m1fn_x2", None): getattr( - cutlass, "Float4E2M1FN", None - ), - } - _torch_to_cutlass_data_type_dict = { - t: c for t, c in mapping.items() if t is not None and c is not None + getattr(torch, "float8_e4m3fn", None): getattr(cutlass, "Float8E4M3FN", None), + getattr(torch, "float8_e5m2", None): getattr(cutlass, "Float8E5M2", None), + getattr(torch, "float8_e8m0fnu", None): getattr(cutlass, "Float8E8M0FNU", None), + getattr(torch, "float4_e2m1fn_x2", None): getattr(cutlass, "Float4E2M1FN", None), } + _torch_to_cutlass_data_type_dict = {t: c for t, c in mapping.items() if t is not None and c is not None} except ImportError: cutlass_available = False _torch_to_cutlass_data_type_dict = {} @@ -138,9 +122,7 @@ def _convert_to_cutlass_data_type(data_type, interpret_uint8_as_fp4x2: bool = Fa if isinstance(data_type, type) and issubclass(data_type, cutlass.Numeric): return data_type elif data_type is not None: - cutlass_data_type = _torch_to_cutlass_data_type( - data_type, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2 - ) + cutlass_data_type = _torch_to_cutlass_data_type(data_type, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2) if cutlass_data_type is None: raise ValueError("Unsupported tensor data type.") return cutlass_data_type @@ -177,9 +159,7 @@ def _library_type(input_type): if out is not None: return out - raise Exception( - f"No available conversion from type {input_type} to a library type." - ) + raise Exception(f"No available conversion from type {input_type} to a library type.") def _is_torch_tensor(input_tensor) -> bool: diff --git a/python/cudnn/gemm_amax/api.py b/python/cudnn/gemm_amax/api.py index f036c671..79fbade1 100644 --- a/python/cudnn/gemm_amax/api.py +++ b/python/cudnn/gemm_amax/api.py @@ -10,10 +10,10 @@ import cutlass import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack, make_ptr +from cutlass.cute.runtime import from_dlpack from cudnn.datatypes import _convert_to_cutlass_data_type -from cudnn.api_base import APIBase +from cudnn.api_base import APIBase, is_power_of_2, ceil_div class GemmAmaxSm100(APIBase): @@ -40,13 +40,7 @@ def __init__( self.sample_sfa = sample_sfa self.sample_sfb = sample_sfb self.sample_c = sample_c - self.sample_amax = sample_amax - if self.sample_amax.dim() < 3: - self._logger.info( - f"Reshaping sample_amax to (1, 1, 1) from {self.sample_amax.shape}" - ) - for _ in range(3 - self.sample_amax.dim()): - self.sample_amax = self.sample_amax.unsqueeze(-1) + self.sample_amax = self._pad_tensor_to_ndim(sample_amax, 3, "sample_amax") self.acc_dtype = acc_dtype self.mma_tiler_mn = mma_tiler_mn self.cluster_shape_mn = cluster_shape_mn @@ -56,277 +50,227 @@ def __init__( self.atom_m = (32, 4) self.atom_k = 4 + self._interpret_uint8_as_fp4x2 = True self._logger.debug( - f"__init__ completed with args: sample_a {sample_a.shape}, sample_b {sample_b.shape}, sample_sfa {sample_sfa.shape}, sample_sfb {sample_sfb.shape}, sample_c {sample_c.shape}, sample_amax {sample_amax.shape}, acc_dtype {acc_dtype}, mma_tiler_mn {mma_tiler_mn}, cluster_shape_mn {cluster_shape_mn}, sf_vec_size {sf_vec_size}" + f"__init__ completed with args: sample_a {sample_a.shape}, sample_b {sample_b.shape}, sample_sfa {sample_sfa.shape}, sample_sfb {sample_sfb.shape}, sample_c {sample_c.shape}, sample_amax {self.sample_amax.shape}, acc_dtype {acc_dtype}, mma_tiler_mn {mma_tiler_mn}, cluster_shape_mn {cluster_shape_mn}, sf_vec_size {sf_vec_size}" ) def check_support(self) -> bool: self._logger.debug("Entering check_support") - ab_dtype = self.sample_a.dtype - sf_dtype = self.sample_sfa.dtype - c_dtype = self.sample_c.dtype - self._logger.debug("Checking dtypes and sf_vec_size") - if self.sample_a.dtype != self.sample_b.dtype: - raise ValueError( - f"A and B tensor dtypes must match, got {self.sample_a.dtype} and {self.sample_b.dtype}" - ) - if ab_dtype not in { - torch.float4_e2m1fn_x2, - torch.uint8, - torch.float8_e5m2, - torch.float8_e4m3fn, - }: - raise ValueError( - f"Unsupported ab_dtype: received {ab_dtype}, expected {{float4_e2m1fn_x2, uint8, float8_e5m2, float8_e4m3fn}}" - ) + ab_dtype = self._check_dtype( + self.sample_a, + dtype=[torch.float4_e2m1fn_x2, torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn], + name="A", + ) + self._check_dtype( + self.sample_b, + dtype=ab_dtype, + name="B", + extra_error_msg="A and B tensor dtypes must match", + ) if ab_dtype == torch.uint8: - self._logger.warning( - "Uint8 ab_dtype will be interpreted as packed fp4, not as native uint8" - ) - if self.sf_vec_size not in {16, 32}: - raise ValueError( - f"Unsupported sf_vec_size: received {self.sf_vec_size}, expected {{16, 32}}" - ) - if sf_dtype not in { - torch.float8_e8m0fnu, - torch.float8_e4m3fn, - torch.int8, - }: - raise ValueError( - f"Unsupported sf_dtype: received {sf_dtype}, expected {{float8_e8m0fnu, float8_e4m3fn, int8}}" - ) + self._logger.warning("Uint8 ab_dtype will be interpreted as packed fp4, not as native uint8") + + self._value_error_if( + self.sf_vec_size not in {16, 32}, + f"Unsupported sf_vec_size: received {self.sf_vec_size}, expected {{16, 32}}", + ) + + sf_dtype = self._check_dtype( + self.sample_sfa, + dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn, torch.int8], + name="sfa", + ) + self._check_dtype( + self.sample_sfb, + dtype=sf_dtype, + name="sfb", + extra_error_msg="sfa and sfb tensor dtypes must match", + ) if sf_dtype == torch.int8: - self._logger.warning( - "Int8 sf_dtype will be interpreted as float8_e8m0fnu, not as native int8" - ) - if sf_dtype == torch.float8_e4m3fn and self.sf_vec_size == 32: - raise ValueError( - "Unsupported sf_dtype and sf_vec_size combination: float8_e4m3fn and 32 is not supported" - ) - if ( - ab_dtype in {torch.float8_e5m2, torch.float8_e4m3fn} - and self.sf_vec_size == 16 - ): - raise ValueError( - f"Unsupported ab_dtype and sf_vec_size combination: {{float8_e5m2, float8_e4m3fn}} and 16 is not supported" - ) - if c_dtype not in { - torch.float32, - torch.float16, - torch.bfloat16, - torch.float8_e5m2, - torch.float8_e4m3fn, - torch.float4_e2m1fn_x2, - torch.uint8, - }: - raise ValueError( - f"Unsupported c_dtype: received {c_dtype}, expected {{float32, float16, bfloat16, float8_e5m2, float8_e4m3fn, float4_e2m1fn_x2, uint8}}" - ) - if c_dtype in {torch.float4_e2m1fn_x2, torch.uint8}: - if ab_dtype not in {torch.float4_e2m1fn_x2, torch.uint8}: - raise ValueError( - f"Unsupported c_dtype and ab_dtype combination: fp4 c_dtype requires fp4 ab_dtype, got {ab_dtype}" - ) # Kernel fails to launch with other ab_dtype - if c_dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ab_dtype in { - torch.float8_e5m2, - torch.float8_e4m3fn, - }: - raise NotImplementedError( - f"fp8 ab_dtype and fp8 c_dtype currently fails to launch" - ) - if not (self.acc_dtype == torch.float32): - raise ValueError( - f"Unsupported acc_dtype: received {self.acc_dtype}, expected {{float32}}" - ) + self._logger.warning("Int8 sf_dtype will be interpreted as float8_e8m0fnu, not as native int8") + + self._value_error_if( + sf_dtype == torch.float8_e4m3fn and self.sf_vec_size == 32, + "Unsupported sf_dtype and sf_vec_size combination: float8_e4m3fn and 32 is not supported", + ) + self._value_error_if( + ab_dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and self.sf_vec_size == 16, + f"Unsupported ab_dtype and sf_vec_size combination: {{float8_e5m2, float8_e4m3fn}} and 16 is not supported", + ) + + c_dtype = self._check_dtype( + self.sample_c, + dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e5m2, torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.uint8], + name="C", + ) + self._value_error_if( + self._is_fp4x2(c_dtype) and not self._is_fp4x2(ab_dtype), + f"Unsupported c_dtype and ab_dtype combination: fp4 c_dtype requires fp4 ab_dtype, got {ab_dtype}", + ) + self._not_implemented_error_if( + self._is_fp8(c_dtype) and self._is_fp8(ab_dtype), + "Unsupported c_dtype and ab_dtype combination: fp8 ab_dtype and fp8 c_dtype (fails to launch)", + ) + self._check_dtype( + self.acc_dtype, + dtype=torch.float32, + name="Accumulator", + extra_error_msg="Accumulator must be float32", + ) + self.ab_dtype = ab_dtype self.c_dtype = c_dtype self._logger.debug("Checking tensor layout") - m, k, l = self.sample_a.shape - n, k, l = self.sample_b.shape - m_, n_, l = self.sample_c.shape - _, _, m_div_atom_m0_m1, _, sf_k_div_atom_k, l = self.sample_sfa.shape - _, _, n_div_atom_m0_m1, _, sf_k_div_atom_k, l = self.sample_sfb.shape - _, _, _ = self.sample_amax.shape - - if self.sample_a.shape != (m, k, l): - raise ValueError( - f"Input/Output shape mismatch: expected A tensor shape {m, k, l}, got {self.sample_a.shape}" - ) - if self.sample_b.shape != (n, k, l): - raise ValueError( - f"Input/Output shape mismatch: expected B tensor shape {n, k, l}, got {self.sample_b.shape}" - ) - if c_dtype == torch.float4_e2m1fn_x2 or c_dtype == torch.uint8: - if self.sample_c.shape != ( - m, - (n + 1) // 2, - l, - ): - raise ValueError( - f"Input/Output shape mismatch: expected C tensor shape {m, (n + 1) // 2, l}, got {self.sample_c.shape}" - ) - else: - if self.sample_c.shape != (m, n, l): - raise ValueError( - f"Input/Output shape mismatch: expected C tensor shape {m, n, l}, got {self.sample_c.shape}" - ) - if self.sample_sfa.shape != ( - self.atom_m[0], - self.atom_m[1], - m_div_atom_m0_m1, - self.atom_k, - sf_k_div_atom_k, - l, - ): - raise ValueError( - f"Input/Output shape mismatch: expected sfa tensor shape {self.atom_m[0], self.atom_m[1], m_div_atom_m0_m1, self.atom_k, sf_k_div_atom_k, l}, got {self.sample_sfa.shape}" - ) - if self.sample_sfb.shape != ( - self.atom_m[0], - self.atom_m[1], - n_div_atom_m0_m1, - self.atom_k, - sf_k_div_atom_k, - l, - ): - raise ValueError( - f"Input/Output shape mismatch: expected sfb tensor shape {self.atom_m[0], self.atom_m[1], n_div_atom_m0_m1, self.atom_k, sf_k_div_atom_k, l}, got {self.sample_sfb.shape}" - ) - if self.sample_amax.shape != (1, 1, 1): - raise ValueError( - f"Input/Output shape mismatch: expected amax tensor shape {1, 1, 1}, got {self.sample_amax.shape}" - ) - if m_div_atom_m0_m1 != (m + self.atom_m[0] * self.atom_m[1] - 1) // ( - self.atom_m[0] * self.atom_m[1] - ): - raise ValueError( - f"Input/Output shape mismatch: expected m_div_atom_m0_m1 (sfa.shape[2]) = {(m + self.atom_m[0] * self.atom_m[1] - 1) // (self.atom_m[0] * self.atom_m[1])}, got {m_div_atom_m0_m1}" - ) - if n_div_atom_m0_m1 != (n + self.atom_m[0] * self.atom_m[1] - 1) // ( - self.atom_m[0] * self.atom_m[1] - ): - raise ValueError( - f"Input/Output shape mismatch: expected n_div_atom_m0_m1 (sfb.shape[2]) = {(n + self.atom_m[0] * self.atom_m[1] - 1) // (self.atom_m[0] * self.atom_m[1])}, got {n_div_atom_m0_m1}" - ) - if self.sample_a.stride() == (1, m, m * k): - self.a_major = "m" - elif self.sample_a.stride() == (k, 1, m * k): - self.a_major = "k" - else: - raise ValueError( - f"Unsupported A tensor stride: expected {{(1, m, m * k), (k, 1, m * k)}}, got {self.sample_a.stride()}" - ) - if self.sample_b.stride() == (1, n, n * k): - self.b_major = "n" - elif self.sample_b.stride() == (k, 1, n * k): - self.b_major = "k" - else: - raise ValueError( - f"Unsupported B tensor stride: expected {{(1, n, n * k), (k, 1, n * k)}}, got {self.sample_b.stride()}" - ) - if self.sample_c.stride() == (1, m_, m_ * n_): - self.c_major = "m" - elif self.sample_c.stride() == (n_, 1, m_ * n_): - self.c_major = "n" - else: - raise ValueError( - f"Unsupported C tensor stride: expected {{(1, m, m * n), (n, 1, m * n)}}, got {self.sample_c.stride()}" - ) + m, k, l = self._tensor_shape(self.sample_a, name="sample_a") + n, _, _ = self._tensor_shape(self.sample_b, name="sample_b") + _, _, _ = self._tensor_shape(self.sample_c, name="sample_c") + _, _, m_div_atom_m0_m1, _, sf_k_div_atom_k, _ = self.sample_sfa.shape + _, _, n_div_atom_m0_m1, _, sf_k_div_atom_k, _ = self.sample_sfb.shape + + self._check_tensor_shape(self.sample_a, (m, k, l), "A") + self._check_tensor_shape(self.sample_b, (n, k, l), "B") + self._check_tensor_shape(self.sample_c, (m, n, l), "C") + self._check_tensor_shape( + self.sample_sfa, + (self.atom_m[0], self.atom_m[1], m_div_atom_m0_m1, self.atom_k, sf_k_div_atom_k, l), + "sfa", + ) + self._check_tensor_shape( + self.sample_sfb, + (self.atom_m[0], self.atom_m[1], n_div_atom_m0_m1, self.atom_k, sf_k_div_atom_k, l), + "sfb", + ) + self._check_tensor_shape(self.sample_amax, (1, 1, 1), "amax") - if ab_dtype in {torch.float4_e2m1fn_x2, torch.uint8} and not ( - self.a_major == "k" and self.b_major == "k" - ): - raise ValueError( - f"Unsupported A or B tensor stride: Float4 tensors require k-major layout for hardware efficiency, got {self.a_major} and {self.b_major}" - ) - if c_dtype in {torch.float4_e2m1fn_x2, torch.uint8} and self.c_major == "m": - raise ValueError( - f"Unsupported C tensor stride: Float4 tensors require n-major layout for hardware efficiency, got {self.c_major}" - ) + expected_m_div_atom = ceil_div(m, self.atom_m[0] * self.atom_m[1]) + expected_n_div_atom = ceil_div(n, self.atom_m[0] * self.atom_m[1]) + self._value_error_if( + m_div_atom_m0_m1 != expected_m_div_atom, + f"Input/Output shape mismatch: expected m_div_atom_m0_m1 (sfa.shape[2]) = {expected_m_div_atom}, got {m_div_atom_m0_m1}", + ) + self._value_error_if( + n_div_atom_m0_m1 != expected_n_div_atom, + f"Input/Output shape mismatch: expected n_div_atom_m0_m1 (sfb.shape[2]) = {expected_n_div_atom}, got {n_div_atom_m0_m1}", + ) + + # Check tensor strides + a_stride, self.a_stride_order = self._check_tensor_stride( + self.sample_a, + stride=[(1, m, m * k), (k, 1, m * k)], + name="A", + ) + b_stride, self.b_stride_order = self._check_tensor_stride( + self.sample_b, + stride=[(1, n, n * k), (k, 1, n * k)], + name="B", + ) + c_stride, self.c_stride_order = self._check_tensor_stride( + self.sample_c, + stride=[(1, m, m * n), (n, 1, m * n)], + name="C", + ) + + # Derive major mode from stride order + self.a_major = "m" if self.a_stride_order == (0, 1, 2) else "k" + self.b_major = "n" if self.b_stride_order == (0, 1, 2) else "k" + self.c_major = "m" if self.c_stride_order == (0, 1, 2) else "n" + + self._value_error_if( + self._is_fp4x2(ab_dtype) and not (self.a_major == "k" and self.b_major == "k"), + f"Unsupported A or B tensor stride: Float4 tensors require k-major layout for hardware efficiency, got {self.a_major} and {self.b_major}", + ) + self._value_error_if( + self._is_fp4x2(c_dtype) and self.c_major == "m", + f"Unsupported C tensor stride: Float4 tensors require n-major layout for hardware efficiency, got {self.c_major}", + ) self._logger.debug("Checking mma tiler and cluster shape") - if self.mma_tiler_mn[0] not in [128, 256]: - raise ValueError( - f"Unsupported mma_tiler_mn[0]: expected {{128, 256}}, got {self.mma_tiler_mn[0]}" - ) - if self.mma_tiler_mn[1] not in [128, 256]: - raise ValueError( - f"Unsupported mma_tiler_mn[1]: expected {{128, 256}}, got {self.mma_tiler_mn[1]}" - ) - if self.mma_tiler_mn[0] == 256: - raise NotImplementedError("mma_tiler_mn[0] == 256 currently hangs") - if ( - self.ab_dtype in {torch.float4_e2m1fn_x2, torch.uint8} - and self.mma_tiler_mn[1] == 256 - and k <= 128 - ): - raise ValueError( - f"mma_tiler_mn (X, 256) requires k > 128 (packed x2), got {k}" - ) - if not ( - self.cluster_shape_mn[0] % (2 if self.mma_tiler_mn[0] == 256 else 1) == 0 - ): - raise ValueError("Illegal cluster shape") - if ( - self.mma_tiler_mn == (128, 256) - and self.sf_vec_size == 16 - and c_dtype in {torch.float32, torch.float16, torch.bfloat16} - ): - raise NotImplementedError( - "mma_tiler_mn (128, 256), sf_vec_size 16, c_dtype {torch.float32, torch.float16, torch.bfloat16} fails to launch" - ) + self._value_error_if( + self.mma_tiler_mn[0] not in [128, 256], + f"Unsupported mma_tiler_mn[0]: expected {{128, 256}}, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.mma_tiler_mn[1] not in [128, 256], + f"Unsupported mma_tiler_mn[1]: expected {{128, 256}}, got {self.mma_tiler_mn[1]}", + ) + self._not_implemented_error_if( + self.mma_tiler_mn[0] == 256, + "mma_tiler_mn[0] == 256 currently hangs", + ) + self._value_error_if( + self._is_fp4x2(self.ab_dtype) and self.mma_tiler_mn[1] == 256 and k <= 128, + f"mma_tiler_mn (X, 256) requires k > 128 (packed x2), got {k}", + ) + self._value_error_if( + not (self.cluster_shape_mn[0] % (2 if self.mma_tiler_mn[0] == 256 else 1) == 0), + "Illegal cluster shape", + ) + self._not_implemented_error_if( + self.mma_tiler_mn == (128, 256) and self.sf_vec_size == 16 and c_dtype in {torch.float32, torch.float16, torch.bfloat16}, + "mma_tiler_mn (128, 256), sf_vec_size 16, c_dtype {torch.float32, torch.float16, torch.bfloat16} fails to launch", + ) # Special cluster shape check for scale factor multicasts. # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. - def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 - - if not ( - self.cluster_shape_mn[0] <= 4 - and self.cluster_shape_mn[1] <= 4 - and self.cluster_shape_mn[0] > 0 - and self.cluster_shape_mn[1] > 0 - and is_power_of_2(self.cluster_shape_mn[0]) - and is_power_of_2(self.cluster_shape_mn[1]) - ): - raise ValueError( - f"Invalid cluster shape: expected cluster_shape_mn values in {{1, 2, 4}}, got {self.cluster_shape_mn}" - ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] <= 4 + and self.cluster_shape_mn[1] <= 4 + and self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape: expected cluster_shape_mn values in {{1, 2, 4}}, got {self.cluster_shape_mn}", + ) self._logger.debug("Checking tensor alignment") def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] - num_contiguous_elements = ( - 16 * 8 // (_convert_to_cutlass_data_type(dtype).width) - ) + num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype).width) return num_major_elements % num_contiguous_elements == 0 - if not ( - check_contigous_16B_alignment(ab_dtype, self.a_major == "m", (m, k, l)) - and check_contigous_16B_alignment(ab_dtype, self.b_major == "n", (n, k, l)) - and check_contigous_16B_alignment(c_dtype, self.c_major == "m", (m, n, l)) - ): - raise ValueError( - "Unsupported tensor alignment: tensors must be 16B aligned" - ) + self._value_error_if( + not ( + check_contigous_16B_alignment(ab_dtype, self.a_major == "m", (m, k, l)) + and check_contigous_16B_alignment(ab_dtype, self.b_major == "n", (n, k, l)) + and check_contigous_16B_alignment(c_dtype, self.c_major == "m", (m, n, l)) + ), + "Unsupported tensor alignment: tensors must be 16B aligned", + ) self._logger.debug("Checking environment") - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available") + self._runtime_error_if(not torch.cuda.is_available(), "CUDA is not available") device = torch.cuda.current_device() major, minor = torch.cuda.get_device_capability(device) compute_capability = major * 10 + minor - if compute_capability < 100: - raise RuntimeError( - f"GemmAmax requires SM100+ compute capability, but found SM{compute_capability} on device {device}" - ) - if compute_capability == 103: - raise RuntimeError("cuteDSL GemmAmax is not supported on SM103") + self._runtime_error_if( + compute_capability < 100, + f"GemmAmax requires SM100+ compute capability, but found SM{compute_capability} on device {device}", + ) + self._runtime_error_if( + compute_capability == 103, + "cuteDSL GemmAmax is not supported on SM103", + ) + + is_ab_fp4 = self._is_fp4x2(self.ab_dtype) + is_c_fp4 = self._is_fp4x2(self.c_dtype) + is_ab_fp8 = self._is_fp8(self.ab_dtype) + torch_version = version.parse(torch.__version__) + _fp8_dlpack_supported = version.parse(torch_version.base_version) >= version.parse("2.10.0") + use_no_dlpack_kernel = is_ab_fp4 or is_c_fp4 or (is_ab_fp8 and not _fp8_dlpack_supported) + + if use_no_dlpack_kernel: + self._logger.debug("Running no_dlpack kernel wrapper due to fp4 dtype or fp8 dtype on incompatible torch version") + self._kernel = Sm100BlockScaledPersistentDenseGemmKernelNoDlpack + else: + self._kernel = Sm100BlockScaledPersistentDenseGemmKernel self._is_supported = True self._logger.debug("check_support completed successfully") @@ -337,164 +281,63 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: current_stream = self._get_default_stream(current_stream) self._ensure_support_checked() - is_ab_fp4 = self.ab_dtype in {torch.float4_e2m1fn_x2, torch.uint8} - is_c_fp4 = self.c_dtype in {torch.float4_e2m1fn_x2, torch.uint8} - torch_version = version.parse(torch.__version__) - _fp8_dlpack_supported = version.parse( - torch_version.base_version - ) >= version.parse("2.10.0") - use_no_dlpack_kernel = is_ab_fp4 or is_c_fp4 or not _fp8_dlpack_supported - - if use_no_dlpack_kernel: - self._logger.debug( - "Running no_dlpack kernel wrapper due to fp4 dtype or fp8 dtype on incompatible torch version" - ) - self._kernel = Sm100BlockScaledPersistentDenseGemmKernelNoDlpack - else: - self._kernel = Sm100BlockScaledPersistentDenseGemmKernel - gemm_amax = self._kernel( sf_vec_size=self.sf_vec_size, mma_tiler_mn=self.mma_tiler_mn, cluster_shape_mn=self.cluster_shape_mn, ) hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ) - - if not use_no_dlpack_kernel: - sample_a_cute = from_dlpack(self.sample_a, assumed_align=16) - sample_b_cute = from_dlpack(self.sample_b, assumed_align=16) - - sample_c_cute = from_dlpack(self.sample_c, assumed_align=16) + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + if self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: self._logger.debug("Compiling gemm_amax") self._compiled_kernel = cute.compile( gemm_amax, - a_tensor=sample_a_cute, - b_tensor=sample_b_cute, + a_tensor=from_dlpack(self.sample_a, assumed_align=16), + b_tensor=from_dlpack(self.sample_b, assumed_align=16), sfa_tensor=from_dlpack(self.sample_sfa, assumed_align=16), sfb_tensor=from_dlpack(self.sample_sfb, assumed_align=16), - c_tensor=sample_c_cute, + c_tensor=from_dlpack(self.sample_c, assumed_align=16), amax_tensor=from_dlpack(self.sample_amax, assumed_align=16), max_active_clusters=max_active_clusters, stream=current_stream, ) - else: # use_no_dlpack + elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: # Create cute pointers/tensors manually to avoid DLPack requirements # amax is never fp4 or fp8 and is safe to use directly with dlpack self._logger.debug("Compiling gemm_amax (no dlpack)") - a_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_ab_fp4 - else _convert_to_cutlass_data_type(self.sample_a.dtype) - ), - self.sample_a.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_ab_fp4 else 16, - ) - b_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_ab_fp4 - else _convert_to_cutlass_data_type(self.sample_b.dtype) - ), - self.sample_b.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_ab_fp4 else 16, - ) - c_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_c_fp4 - else _convert_to_cutlass_data_type(self.sample_c.dtype) - ), - self.sample_c.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_c_fp4 else 16, - ) - sfa_ptr = make_ptr( - ( - cutlass.Float8E8M0FNU - if self.sample_sfa.dtype == torch.int8 - else _convert_to_cutlass_data_type(self.sample_sfa.dtype) - ), - self.sample_sfa.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - sfb_ptr = make_ptr( - ( - cutlass.Float8E8M0FNU - if self.sample_sfb.dtype == torch.int8 - else _convert_to_cutlass_data_type(self.sample_sfb.dtype) - ), - self.sample_sfb.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - a_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_a.shape) - ) - if is_ab_fp4 - else tuple(self.sample_a.shape) - ) - b_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_b.shape) - ) - if is_ab_fp4 - else tuple(self.sample_b.shape) - ) - c_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_c.shape) - ) - if is_c_fp4 - else tuple(self.sample_c.shape) - ) - sfa_shape = tuple(self.sample_sfa.shape) - sfb_shape = tuple(self.sample_sfb.shape) - - a_order = (1, 0, 2) if self.a_major == "k" else (0, 1, 2) - b_order = (1, 0, 2) if self.b_major == "k" else (0, 1, 2) - c_order = (1, 0, 2) if self.c_major == "n" else (0, 1, 2) - _sfa_strides = self.sample_sfa.stride() - _sfb_strides = self.sample_sfb.stride() - sfa_order = tuple( - sorted(range(len(sfa_shape)), key=lambda i: _sfa_strides[i]) - ) - sfb_order = tuple( - sorted(range(len(sfb_shape)), key=lambda i: _sfb_strides[i]) - ) + + is_ab_fp4 = self._is_fp4x2(self.ab_dtype) + is_c_fp4 = self._is_fp4x2(self.c_dtype) + a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor(self.sample_a, assumed_align=32 if is_ab_fp4 else 16, name="A") + b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor(self.sample_b, assumed_align=32 if is_ab_fp4 else 16, name="B") + c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor(self.sample_c, assumed_align=32 if is_c_fp4 else 16, name="C") + sfa_ptr, sfa_shape, sfa_stride_order = self._make_cute_tensor_descriptor(self.sample_sfa, assumed_align=16, name="sfa") + sfb_ptr, sfb_shape, sfb_stride_order = self._make_cute_tensor_descriptor(self.sample_sfb, assumed_align=16, name="sfb") self._compiled_kernel = cute.compile( gemm_amax, a_ptr=a_ptr, a_shape=a_shape, - a_order=a_order, + a_order=a_stride_order, b_ptr=b_ptr, b_shape=b_shape, - b_order=b_order, + b_order=b_stride_order, sfa_ptr=sfa_ptr, sfa_shape=sfa_shape, - sfa_order=sfa_order, + sfa_order=sfa_stride_order, sfb_ptr=sfb_ptr, sfb_shape=sfb_shape, - sfb_order=sfb_order, + sfb_order=sfb_stride_order, c_ptr=c_ptr, c_shape=c_shape, - c_order=c_order, + c_order=c_stride_order, amax_cute=from_dlpack(self.sample_amax, assumed_align=16), max_active_clusters=max_active_clusters, stream=current_stream, ) + else: + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") self._logger.debug("Kernel compiled successfully") def execute( @@ -511,116 +354,35 @@ def execute( self._logger.debug("Entering execute") current_stream = self._get_default_stream(current_stream) - if amax_tensor.dim() < 3: - self._logger.info( - f"Reshaping amax_tensor to (1, 1, 1) from {amax_tensor.shape}" + amax_tensor = self._pad_tensor_to_ndim(amax_tensor, 3, "amax_tensor") + + is_ab_fp4 = self._is_fp4x2(self.ab_dtype) + is_c_fp4 = self._is_fp4x2(self.c_dtype) + + if not skip_compile: + self._runtime_error_if( + self._compiled_kernel is None, + "GemmAmaxSm100 kernel not compiled; call compile() first or use execute(skip_compile=True)", ) - for _ in range(3 - amax_tensor.dim()): - amax_tensor = amax_tensor.unsqueeze(-1) + self._logger.debug("Executing with compiled kernel") - is_ab_fp4 = self.ab_dtype in {torch.float4_e2m1fn_x2, torch.uint8} - is_c_fp4 = self.c_dtype in {torch.float4_e2m1fn_x2, torch.uint8} - torch_version = version.parse(torch.__version__) - _fp8_dlpack_supported = version.parse( - torch_version.base_version - ) >= version.parse("2.10.0") - use_no_dlpack_kernel = is_ab_fp4 or is_c_fp4 or not _fp8_dlpack_supported - - if not use_no_dlpack_kernel: - a_tensor_cute = from_dlpack(a_tensor, assumed_align=16) - b_tensor_cute = from_dlpack(b_tensor, assumed_align=16) - c_tensor_cute = from_dlpack(c_tensor, assumed_align=16) - - if not skip_compile: - if self._compiled_kernel is None: - raise RuntimeError( - "GemmAmaxSm100 kernel not compiled; call compile() first or use execute(skip_compile=True)" - ) - self._logger.debug("Executing with compiled kernel") + if self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: self._compiled_kernel( - a_tensor=a_tensor_cute, - b_tensor=b_tensor_cute, - sfa_tensor=from_dlpack(sfa_tensor, assumed_align=16), - sfb_tensor=from_dlpack(sfb_tensor, assumed_align=16), - c_tensor=c_tensor_cute, - amax_tensor=from_dlpack(amax_tensor, assumed_align=16), - stream=current_stream, - ) - self._logger.debug("Executed with compiled kernel successfully") - else: - self._logger.debug("Executing without compiled kernel (JIT)") - gemm_amax = self._kernel( - sf_vec_size=self.sf_vec_size, - mma_tiler_mn=self.mma_tiler_mn, - cluster_shape_mn=self.cluster_shape_mn, - ) - gemm_amax( - a_tensor=a_tensor_cute, - b_tensor=b_tensor_cute, + a_tensor=from_dlpack(a_tensor, assumed_align=16), + b_tensor=from_dlpack(b_tensor, assumed_align=16), sfa_tensor=from_dlpack(sfa_tensor, assumed_align=16), sfb_tensor=from_dlpack(sfb_tensor, assumed_align=16), - c_tensor=c_tensor_cute, + c_tensor=from_dlpack(c_tensor, assumed_align=16), amax_tensor=from_dlpack(amax_tensor, assumed_align=16), stream=current_stream, ) - else: # use_no_dlpack - a_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_ab_fp4 - else _convert_to_cutlass_data_type(a_tensor.dtype) - ), - a_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_ab_fp4 else 16, - ) - b_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_ab_fp4 - else _convert_to_cutlass_data_type(b_tensor.dtype) - ), - b_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_ab_fp4 else 16, - ) - c_ptr = make_ptr( - ( - cutlass.Float4E2M1FN - if is_c_fp4 - else _convert_to_cutlass_data_type(c_tensor.dtype) - ), - c_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=32 if is_c_fp4 else 16, - ) - sfa_ptr = make_ptr( - ( - cutlass.Float8E8M0FNU - if sfa_tensor.dtype == torch.int8 - else _convert_to_cutlass_data_type(sfa_tensor.dtype) - ), - sfa_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) - sfb_ptr = make_ptr( - ( - cutlass.Float8E8M0FNU - if sfb_tensor.dtype == torch.int8 - else _convert_to_cutlass_data_type(sfb_tensor.dtype) - ), - sfb_tensor.data_ptr(), - cute.AddressSpace.gmem, - assumed_align=16, - ) + elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: + a_ptr = self._make_cute_pointer(a_tensor, assumed_align=32 if is_ab_fp4 else 16) + b_ptr = self._make_cute_pointer(b_tensor, assumed_align=32 if is_ab_fp4 else 16) + c_ptr = self._make_cute_pointer(c_tensor, assumed_align=32 if is_c_fp4 else 16) + sfa_ptr = self._make_cute_pointer(sfa_tensor, assumed_align=16) + sfb_ptr = self._make_cute_pointer(sfb_tensor, assumed_align=16) - if not skip_compile: - if self._compiled_kernel is None: - raise RuntimeError( - "GemmAmaxSm100 kernel not compiled; call compile() first or use execute(skip_compile=True)" - ) - self._logger.debug("Executing with compiled kernel") self._compiled_kernel( a_ptr=a_ptr, b_ptr=b_ptr, @@ -630,76 +392,59 @@ def execute( amax_cute=from_dlpack(amax_tensor, assumed_align=16), stream=current_stream, ) - self._logger.debug("Executed with compiled kernel successfully") else: - self._logger.debug("Executing without compiled kernel (JIT)") - gemm_amax = self._kernel( - sf_vec_size=self.sf_vec_size, - mma_tiler_mn=self.mma_tiler_mn, - cluster_shape_mn=self.cluster_shape_mn, - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") + self._logger.debug("Executed with compiled kernel successfully") + else: + self._logger.debug("Executing without compiled kernel (JIT)") + gemm_amax = self._kernel( + sf_vec_size=self.sf_vec_size, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + ) + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) - a_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_a.shape) - ) - if is_ab_fp4 - else tuple(self.sample_a.shape) - ) - b_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_b.shape) - ) - if is_ab_fp4 - else tuple(self.sample_b.shape) - ) - c_shape = ( - tuple( - dim * 2 if i == 1 else dim - for i, dim in enumerate(self.sample_c.shape) - ) - if is_c_fp4 - else tuple(self.sample_c.shape) - ) - sfa_shape = tuple(sfa_tensor.shape) - sfb_shape = tuple(sfb_tensor.shape) - a_order = (1, 0, 2) if self.a_major == "k" else (0, 1, 2) - b_order = (1, 0, 2) if self.b_major == "k" else (0, 1, 2) - c_order = (1, 0, 2) if self.c_major == "n" else (0, 1, 2) - _sfa_strides = sfa_tensor.stride() - _sfb_strides = sfb_tensor.stride() - sfa_order = tuple( - sorted(range(len(sfa_shape)), key=lambda i: _sfa_strides[i]) - ) - sfb_order = tuple( - sorted(range(len(sfb_shape)), key=lambda i: _sfb_strides[i]) + if self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: + gemm_amax( + a_tensor=from_dlpack(a_tensor, assumed_align=16), + b_tensor=from_dlpack(b_tensor, assumed_align=16), + sfa_tensor=from_dlpack(sfa_tensor, assumed_align=16), + sfb_tensor=from_dlpack(sfb_tensor, assumed_align=16), + c_tensor=from_dlpack(c_tensor, assumed_align=16), + amax_tensor=from_dlpack(amax_tensor, assumed_align=16), + max_active_clusters=max_active_clusters, + stream=current_stream, ) - hardware_info = cutlass.utils.HardwareInfo() + elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: + a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor(a_tensor, assumed_align=32 if is_ab_fp4 else 16, name="A") + b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor(b_tensor, assumed_align=32 if is_ab_fp4 else 16, name="B") + c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor(c_tensor, assumed_align=32 if is_c_fp4 else 16, name="C") + sfa_ptr, sfa_shape, sfa_stride_order = self._make_cute_tensor_descriptor(sfa_tensor, assumed_align=16, name="sfa") + sfb_ptr, sfb_shape, sfb_stride_order = self._make_cute_tensor_descriptor(sfb_tensor, assumed_align=16, name="sfb") gemm_amax( a_ptr=a_ptr, a_shape=a_shape, - a_order=a_order, + a_order=a_stride_order, b_ptr=b_ptr, b_shape=b_shape, - b_order=b_order, + b_order=b_stride_order, sfa_ptr=sfa_ptr, sfa_shape=sfa_shape, - sfa_order=sfa_order, + sfa_order=sfa_stride_order, sfb_ptr=sfb_ptr, sfb_shape=sfb_shape, - sfb_order=sfb_order, + sfb_order=sfb_stride_order, c_ptr=c_ptr, c_shape=c_shape, - c_order=c_order, + c_order=c_stride_order, amax_cute=from_dlpack(amax_tensor, assumed_align=16), - max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ), + max_active_clusters=max_active_clusters, stream=current_stream, ) + else: + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") self._logger.debug("Executed successfully") @@ -729,18 +474,12 @@ def gemm_amax_wrapper_sm100( n, _, l = b_tensor.shape c_tensor = None if c_major == "m": - c_tensor = torch.empty_strided( - (m, n, l), (1, m, m * n), dtype=c_dtype, device=a_tensor.device - ) + c_tensor = torch.empty_strided((m, n, l), (1, m, m * n), dtype=c_dtype, device=a_tensor.device) elif c_major == "n": - c_tensor = torch.empty_strided( - (m, n, l), (n, 1, m * n), dtype=c_dtype, device=a_tensor.device - ) + c_tensor = torch.empty_strided((m, n, l), (n, 1, m * n), dtype=c_dtype, device=a_tensor.device) else: raise ValueError(f"c_major must be either 'm' or 'n', got {c_major}") - amax_tensor = torch.full( - (1, 1, 1), -float("inf"), device=a_tensor.device, dtype=torch.float32 - ) + amax_tensor = torch.full((1, 1, 1), -float("inf"), device=a_tensor.device, dtype=torch.float32) cache_key = ( a_tensor.shape, @@ -763,9 +502,7 @@ def gemm_amax_wrapper_sm100( sf_vec_size, ) if cache_key in _cache_of_GemmAmaxSm100Objects: - _logger.debug( - "gemm_amax_wrapper_sm100: Using previously cached GemmAmaxSm100 object" - ) + _logger.debug("gemm_amax_wrapper_sm100: Using previously cached GemmAmaxSm100 object") gemm_amax = _cache_of_GemmAmaxSm100Objects[cache_key] gemm_amax.execute( a_tensor=a_tensor, @@ -777,9 +514,7 @@ def gemm_amax_wrapper_sm100( current_stream=stream, ) else: - _logger.debug( - "gemm_amax_wrapper_sm100: No previously cached GemmAmaxSm100 object found, creating new GemmAmaxSm100 object" - ) + _logger.debug("gemm_amax_wrapper_sm100: No previously cached GemmAmaxSm100 object found, creating new GemmAmaxSm100 object") gemm_amax = GemmAmaxSm100( sample_a=a_tensor, sample_b=b_tensor, diff --git a/python/cudnn/gemm_amax/dense_blockscaled_gemm_persistent_amax.py b/python/cudnn/gemm_amax/dense_blockscaled_gemm_persistent_amax.py index 85158edc..e7d112a2 100644 --- a/python/cudnn/gemm_amax/dense_blockscaled_gemm_persistent_amax.py +++ b/python/cudnn/gemm_amax/dense_blockscaled_gemm_persistent_amax.py @@ -199,9 +199,7 @@ def __init__( # K dimension is deferred in _setup_attributes self.mma_tiler = (*mma_tiler_mn, 1) - self.cta_group = ( - tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - ) + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.occupancy = 1 # Set specialized warp ids @@ -213,9 +211,7 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 - self.threads_per_cta = 32 * len( - (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) - ) + self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)) # Set barrier id for cta sync, epilogue sync and tmem ptr sync self.cta_sync_barrier = pipeline.NamedBarrier( barrier_id=1, @@ -433,15 +429,11 @@ def __call__( # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, self.sf_vec_size - ) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size) sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, self.sf_vec_size - ) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size) sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( @@ -467,9 +459,7 @@ def __call__( atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id - ) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, @@ -481,9 +471,7 @@ def __call__( ) # Setup TMA load for B - b_op = sm100_utils.cluster_shape_to_tma_atom_B( - self.cluster_shape_mn, tiled_mma.thr_id - ) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, @@ -495,12 +483,8 @@ def __call__( ) # Setup TMA load for SFA - sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id - ) - sfa_smem_layout = cute.slice_( - self.sfa_smem_layout_staged, (None, None, None, 0) - ) + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( sfa_op, sfa_tensor, @@ -512,12 +496,8 @@ def __call__( ) # Setup TMA load for SFB - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( - self.cluster_shape_mn, tiled_mma.thr_id - ) - sfb_smem_layout = cute.slice_( - self.sfb_smem_layout_staged, (None, None, None, 0) - ) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( sfb_op, sfb_tensor, @@ -532,9 +512,7 @@ def __call__( b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) - self.num_tma_load_bytes = ( - a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size - ) * atom_thr_size + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size # Setup TMA store for C epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) @@ -574,30 +552,22 @@ class SharedStorage: ] # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], self.buffer_align_bytes, ] # (MMA, MMA_M, MMA_K, STAGE) sSFA: cute.struct.Align[ - cute.struct.MemRange[ - self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) - ], + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sSFB: cute.struct.Align[ - cute.struct.MemRange[ - self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) - ], + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], self.buffer_align_bytes, ] # Amax reduction shared memory (one FP32 per epilogue warp) @@ -695,15 +665,9 @@ def kernel( bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( - cta_rank_in_cluster - ) - block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( - cta_rank_in_cluster - ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) # Coord inside cta tidx, _, _ = cute.arch.thread_idx() @@ -716,9 +680,7 @@ def kernel( # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - ab_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_tma_producer - ) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, @@ -730,12 +692,8 @@ def kernel( # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - num_acc_consumer_threads = len(self.epilog_warp_id) * ( - 2 if use_2cta_instrs else 1 - ) - acc_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_acc_consumer_threads - ) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, @@ -753,9 +711,7 @@ def kernel( if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) + cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) cute.arch.mbarrier_init_fence() # Cluster arrive after barrier init @@ -766,17 +722,11 @@ def kernel( # Setup smem tensor A/B/SFA/SFB/C # # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC = storage.sC.get_tensor( - c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner - ) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner - ) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner - ) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) # (MMA, MMA_M, MMA_K, STAGE) sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) # (MMA, MMA_N, MMA_K, STAGE) @@ -795,42 +745,24 @@ def kernel( sfa_full_mcast_mask = None sfb_full_mcast_mask = None if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): - a_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 - ) - b_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 - ) - sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 - ) - sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 - ) + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) # # Local_tile partition global tensors # # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) - ) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) # (bN, bK, RestN, RestK, RestL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) - ) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) # (bM, bK, RestM, RestK, RestL) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) - ) + gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) - ) + gSFB_nkl = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) - ) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) k_tile_cnt = cute.size(gA_mkl, mode=[3]) # @@ -853,9 +785,7 @@ def kernel( # Partition global/shared tensor for TMA load A/B # # TMA load A partition_S/D - a_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape - ) + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( @@ -866,9 +796,7 @@ def kernel( cute.group_modes(tCgA, 0, 3), ) # TMA load B partition_S/D - b_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape - ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( @@ -894,9 +822,7 @@ def kernel( tAgSFA = cute.filter_zeros(tAgSFA) # TMALDG_SFB partition_S/D - sfb_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape - ) + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( @@ -919,9 +845,7 @@ def kernel( # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) - tCtAcc_fake = tiled_mma.make_fragment_C( - cute.append(acc_shape, self.num_acc_stage) - ) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) # # Cluster wait before tensor memory alloc @@ -938,14 +862,10 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -960,38 +880,26 @@ def kernel( # Slice to per mma tile index # # ((atom_v, rest_v), RestK) - tAgA_slice = tAgA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) - ] + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tAgSFA_slice = tAgSFA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) - ] + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] + tBgSFB_slice = tBgSFB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire( - ab_producer_state, peek_ab_empty_status - ) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) # TMA load A/B/SFA/SFB cute.copy( @@ -1027,9 +935,7 @@ def kernel( ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Advance to next tile @@ -1079,9 +985,7 @@ def kernel( # Make SFB tmem tensor sfb_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), dtype=self.sf_dtype, ) # (MMA, MMA_N, MMA_K) @@ -1109,17 +1013,11 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) - acc_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_acc_stage - ) + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -1138,9 +1036,7 @@ def kernel( ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_tile_cnt and is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Wait for accumulator buffer empty @@ -1159,9 +1055,7 @@ def kernel( for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full - ab_pipeline.consumer_wait( - ab_consumer_state, peek_ab_full_status - ) + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) # Copy SFA/SFB from smem to tmem s2t_stage_coord = ( @@ -1224,9 +1118,7 @@ def kernel( peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_tile_cnt: if is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Async arrive accumulator buffer full @@ -1284,33 +1176,23 @@ def kernel( tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc, - ) = self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rC, epi_tidx, sC - ) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) ( tma_atom_c, bSG_sC, bSG_gC_partitioned, - ) = self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( @@ -1346,9 +1228,7 @@ def kernel( # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] # # Wait for accumulator buffer full @@ -1379,12 +1259,8 @@ def kernel( # Note: We need absolute value maximum, so take abs first acc_values = tTR_rAcc.load() # Apply element-wise absolute value using math.absf (supports vectors) - abs_acc_values_ir = math.absf( - acc_values.ir_value() # operand (positional) - ) - abs_acc_values = type(acc_values)( - abs_acc_values_ir, acc_values.shape, acc_values.dtype - ) + abs_acc_values_ir = math.absf(acc_values.ir_value()) # operand (positional) + abs_acc_values = type(acc_values)(abs_acc_values_ir, acc_values.shape, acc_values.dtype) subtile_amax = abs_acc_values.reduce( cute.ReductionOp.MAX, cutlass.Float32(0.0), @@ -1430,9 +1306,7 @@ def kernel( self.epilog_sync_barrier.arrive_and_wait() # Perform amax reduction after all subtiles are processed - _val_i32 = llvm.bitcast( - T.i32(), thread_tile_amax.ir_value(), loc=None, ip=None - ) + _val_i32 = llvm.bitcast(T.i32(), thread_tile_amax.ir_value(), loc=None, ip=None) _res_i32 = nvvm.redux_sync( res=T.i32(), val=_val_i32, @@ -1441,9 +1315,7 @@ def kernel( loc=None, ip=None, ) - warp_amax = cutlass.Float32( - llvm.bitcast(T.f32(), _res_i32, loc=None, ip=None) - ) + warp_amax = cutlass.Float32(llvm.bitcast(T.f32(), _res_i32, loc=None, ip=None)) # Each epilogue warp's lane 0 writes warp amax to shared memory if cute.arch.lane_idx() == 0: @@ -1454,18 +1326,14 @@ def kernel( # Block-level reduction: only first epilogue warp's lane 0 handles this if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: - block_amax = cutlass.Float32( - 0.0 - ) # Initial value for absolute maximum + block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum for i in cutlass.range(self.num_epilog_warps): warp_amax_val = sAmax[i] block_amax = cute.arch.fmax(block_amax, warp_amax_val) # Global atomic max (accumulates across all tiles for final tensor amax) # Since we compute absolute values, all values are non-negative - _value_int = llvm.bitcast( - T.i32(), block_amax.ir_value(), loc=None, ip=None - ) + _value_int = llvm.bitcast(T.i32(), block_amax.ir_value(), loc=None, ip=None) _old_value_int = nvvm.atomicrmw( res=T.i32(), op=nvvm.AtomicOpKind.MAX, @@ -1474,9 +1342,7 @@ def kernel( loc=None, ip=None, ) - _ = cutlass.Float32( - llvm.bitcast(T.f32(), _old_value_int, loc=None, ip=None) - ) + _ = cutlass.Float32(llvm.bitcast(T.f32(), _old_value_int, loc=None, ip=None)) # # Async arrive accumulator buffer empty # @@ -1498,9 +1364,7 @@ def kernel( # Coordinate dealloc across 2-CTA instruction mode if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) # Actually deallocate TMEM @@ -1549,9 +1413,7 @@ def mainloop_s2t_copy_and_partition( # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t, tCsSF_compact_s2t_ - ) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) @@ -1600,24 +1462,18 @@ def epilog_tmem_copy_and_partition( epi_tile, ) # (EPI_TILE_M, EPI_TILE_N) - tiled_copy_t2r = tcgen05.make_tmem_copy( - copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] - ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_mnl_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_rmem_tensor( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype - ) + tTR_rAcc = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc def epilog_smem_copy_and_partition( @@ -1646,9 +1502,7 @@ def epilog_smem_copy_and_partition( - tRS_sC: The partitioned tensor C (smem destination) :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ - copy_atom_r2s = sm100_utils.get_smem_store_op( - self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r - ) + copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r) tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) @@ -1686,9 +1540,7 @@ def epilog_gmem_copy_and_partition( :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) tma_atom_c = atom sC_for_tma_partition = cute.group_modes(sC, 0, 2) @@ -1800,18 +1652,14 @@ def _compute_stages( # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B/SFA/SFB stage - num_ab_stage = ( - smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) - ) // ab_bytes_per_stage + num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes)) // ab_bytes_per_stage # Refine epilogue stages: # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes # Add remaining unused smem to epilogue - num_c_stage += ( - smem_capacity - - occupancy * ab_bytes_per_stage * num_ab_stage - - occupancy * (mbar_helpers_bytes + c_bytes) - ) // (occupancy * c_bytes_per_stage) + num_c_stage += (smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (mbar_helpers_bytes + c_bytes)) // ( + occupancy * c_bytes_per_stage + ) return num_acc_stage, num_ab_stage, num_c_stage @@ -1843,12 +1691,8 @@ def _compute_grid( num_ctas_mnl = gc[(0, (None, None, None))].shape cluster_shape_mnl = (*cluster_shape_mn, 1) - tile_sched_params = utils.PersistentTileSchedulerParams( - num_ctas_mnl, cluster_shape_mnl - ) - grid = utils.StaticPersistentTileScheduler.get_grid_shape( - tile_sched_params, max_active_clusters - ) + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) return tile_sched_params, grid @@ -1895,22 +1739,12 @@ def __call__( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): - a_cute = cute.make_tensor( - a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order) - ) - b_cute = cute.make_tensor( - b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order) - ) - c_cute = cute.make_tensor( - c_ptr, layout=cute.make_ordered_layout(c_shape, order=c_order) - ) + a_cute = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order)) + b_cute = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order)) + c_cute = cute.make_tensor(c_ptr, layout=cute.make_ordered_layout(c_shape, order=c_order)) - sfa_cute = cute.make_tensor( - sfa_ptr, layout=cute.make_ordered_layout(sfa_shape, order=sfa_order) - ) - sfb_cute = cute.make_tensor( - sfb_ptr, layout=cute.make_ordered_layout(sfb_shape, order=sfb_order) - ) + sfa_cute = cute.make_tensor(sfa_ptr, layout=cute.make_ordered_layout(sfa_shape, order=sfa_order)) + sfb_cute = cute.make_tensor(sfb_ptr, layout=cute.make_ordered_layout(sfb_shape, order=sfb_order)) self.kernel( a_cute, b_cute, diff --git a/python/cudnn/gemm_swiglu/api.py b/python/cudnn/gemm_swiglu/api.py index 598167ff..d2e4c696 100644 --- a/python/cudnn/gemm_swiglu/api.py +++ b/python/cudnn/gemm_swiglu/api.py @@ -46,7 +46,7 @@ import cutlass.cute.math as math from cudnn.datatypes import _convert_to_cutlass_data_type -from cudnn.api_base import APIBase, ceil_div +from cudnn.api_base import APIBase, ceil_div, is_power_of_2 class GemmSwigluSm100(APIBase): @@ -83,9 +83,7 @@ def __init__( self.acc_dtype = acc_dtype self.mma_tiler_mn = mma_tiler_mn if cluster_shape_mn is None: - self.cluster_shape_mn = ( - (1, 1) if not self.mma_tiler_mn[0] == 256 else (2, 2) - ) + self.cluster_shape_mn = (1, 1) if not self.mma_tiler_mn[0] == 256 else (2, 2) else: self.cluster_shape_mn = cluster_shape_mn @@ -94,29 +92,17 @@ def __init__( self.sample_sfb = sample_sfb self.sample_sfc = sample_sfc self.sample_amax = self._unpad_tensor_to_ndim(sample_amax, 1, "amax") - self.sample_norm_const = self._unpad_tensor_to_ndim( - sample_norm_const, 1, "norm_const" - ) + self.sample_norm_const = self._unpad_tensor_to_ndim(sample_norm_const, 1, "norm_const") self.sf_vec_size = sf_vec_size self.vector_f32 = vector_f32 self.ab12_stages = ab12_stages # Kernel selection - if ( - self.sample_sfa is None - and self.sample_sfb is None - and self.sample_amax is None - and self.sample_sfc is None - and self.sample_norm_const is None - ): - self._logger.debug( - "No quantization arguments provided, using regular GEMM swiglu kernel" - ) + if self.sample_sfa is None and self.sample_sfb is None and self.sample_amax is None and self.sample_sfc is None and self.sample_norm_const is None: + self._logger.debug("No quantization arguments provided, using regular GEMM swiglu kernel") self._kernel = PersistentDenseGemmKernel else: - self._logger.debug( - "Quantization arguments provided, using quantized GEMM swiglu kernel" - ) + self._logger.debug("Quantization arguments provided, using quantized GEMM swiglu kernel") self._kernel = Sm100BlockScaledPersistentDenseGemmKernel self._logger.debug( @@ -144,31 +130,17 @@ def check_support(self) -> bool: Sm100BlockScaledPersistentDenseGemmKernelNoDlpack, }: rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) - self._check_tensor_shape( - self.sample_sfa, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA" - ) - self._check_tensor_shape( - self.sample_sfb, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB" - ) + self._check_tensor_shape(self.sample_sfa, (32, 4, ceil_div(m, 128), 4, rest_k, l), "SFA") + self._check_tensor_shape(self.sample_sfb, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") self._check_tensor_shape(self.sample_amax, (1,), "amax") rest_n2 = ceil_div(ceil_div(n // 2, self.sf_vec_size), 4) - self._check_tensor_shape( - self.sample_sfc, (32, 4, ceil_div(m, 128), 4, rest_n2, l), "SFC" - ) + self._check_tensor_shape(self.sample_sfc, (32, 4, ceil_div(m, 128), 4, rest_n2, l), "SFC") self._check_tensor_shape(self.sample_norm_const, (1,), "norm_const") - _, self.a_stride_order = self._check_tensor_stride( - self.sample_a, stride=[(1, m, m * k), (k, 1, m * k)] - ) - _, self.b_stride_order = self._check_tensor_stride( - self.sample_b, stride=[(1, n, n * k), (k, 1, n * k)] - ) - _, self.ab12_stride_order = self._check_tensor_stride( - self.sample_ab12, stride=[(1, m, m * n), (n, 1, m * n)] - ) - _, self.c_stride_order = self._check_tensor_stride( - self.sample_c, stride=[(1, m, m * n_2), (n_2, 1, m * n_2)] - ) + _, self.a_stride_order = self._check_tensor_stride(self.sample_a, stride=[(1, m, m * k), (k, 1, m * k)]) + _, self.b_stride_order = self._check_tensor_stride(self.sample_b, stride=[(1, n, n * k), (k, 1, n * k)]) + _, self.ab12_stride_order = self._check_tensor_stride(self.sample_ab12, stride=[(1, m, m * n), (n, 1, m * n)]) + _, self.c_stride_order = self._check_tensor_stride(self.sample_c, stride=[(1, m, m * n_2), (n_2, 1, m * n_2)]) self._value_error_if( self.ab12_stride_order != self.c_stride_order, f"AB12 and C tensor stride orders must match, got {self.ab12_stride_order} and {self.c_stride_order}", @@ -219,12 +191,8 @@ def check_support(self) -> bool: name="A/B (for float16 acc_dtype)", ) case _: - raise ValueError( - f"Unsupported acc_dtype: expected one of {{torch.float32, torch.float16}}, got {self.acc_dtype}" - ) - self.c_dtype = self._check_dtype( - self.sample_c, dtype=[torch.float16, torch.bfloat16], name="C" - ) + raise ValueError(f"Unsupported acc_dtype: expected one of {{torch.float32, torch.float16}}, got {self.acc_dtype}") + self.c_dtype = self._check_dtype(self.sample_c, dtype=[torch.float16, torch.bfloat16], name="C") elif self._kernel in { Sm100BlockScaledPersistentDenseGemmKernel, Sm100BlockScaledPersistentDenseGemmKernelNoDlpack, @@ -278,13 +246,11 @@ def check_support(self) -> bool: ) self._value_error_if( - self._is_fp8(self.c_dtype) - and (self.sample_sfc is None or self.sample_norm_const is None), + self._is_fp8(self.c_dtype) and (self.sample_sfc is None or self.sample_norm_const is None), "sfc and norm_const must be provided when c_dtype is fp8", ) self._value_error_if( - (self._is_fp4x2(self.ab_dtype) and self.c_dtype == torch.bfloat16) - and (self.sample_amax is None), + (self._is_fp4x2(self.ab_dtype) and self.c_dtype == torch.bfloat16) and (self.sample_amax is None), "amax must be provided when ab_dtype is fp4 and c_dtype is bf16", ) @@ -316,9 +282,7 @@ def check_support(self) -> bool: ) if self._is_fp8(self.ab_dtype): self._value_error_if( - not ( - self.sf_dtype == torch.float8_e8m0fnu and self.sf_vec_size == 32 - ), + not (self.sf_dtype == torch.float8_e8m0fnu and self.sf_vec_size == 32), "Invalid ab_dtype and sf_dtype/sf_vec_size combination: fp8 ab_dtype requires float8_e8m0fnu sf_dtype and 32 sf_vec_size", ) elif self._is_fp4x2(self.ab_dtype): @@ -329,8 +293,7 @@ def check_support(self) -> bool: if self._is_fp4x2(self.ab_dtype): self._value_error_if( - self.a_stride_order != (1, 0, 2) - or self.b_stride_order != (1, 0, 2), + self.a_stride_order != (1, 0, 2) or self.b_stride_order != (1, 0, 2), "Invalid A or B tensor stride: fp4 dtype requires k-major layout", ) self._value_error_if( @@ -346,9 +309,6 @@ def check_support(self) -> bool: self._logger.debug("Checking MMA tile shape and cluster shape") - def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 - self._value_error_if( self.mma_tiler_mn[0] not in [128, 256], f"Invalid MMA tile shape: expected mma_tiler_mn[0] in {{128, 256}}, got {self.mma_tiler_mn[0]}", @@ -372,24 +332,11 @@ def is_power_of_2(x): f"Invalid MMA tile shape: expected mma_tiler_mn[1] in {{64, 128, 192, 256}}, got {self.mma_tiler_mn[1]}", ) else: - self._value_error_if( - self.mma_tiler_mn[1] not in [256], - f"Invalid MMA tile shape: MXFP8 Quantized kernel only supports tile_n=256, got {self.mma_tiler_mn[1]}", - ) - - if ( - self.mma_tiler_mn == (256, 256) - and self.cluster_shape_mn != (1, 1) - and self.sf_vec_size == 32 - and self.sf_dtype == torch.float8_e8m0fnu - ): - self._value_error_if( - not ( - self.ab12_dtype == torch.bfloat16 - and self.c_dtype == torch.bfloat16 - ), - "Invalid MMA tile shape/cluster shape/dtype combination: for 256x256mma tile shape, non-1x1 cluster shape, 32 sf_vec_size, float8_e8m0fnu sf_dtype: ab12_dtype must be bfloat16 and c_dtype must be bfloat16", - ) + if self._is_fp8(self.ab_dtype): + self._value_error_if( + self._is_fp8(self.c_dtype) or self._is_fp8(self.ab12_dtype) or self.ab12_dtype == torch.float32, + "For MXFP8 inputs for blockscaled quantized GEMM swiglu kernel, ab12_dtype and c_dtype cannot be FP8. ab12_dtype also cannot be float32", + ) self._value_error_if( self.cluster_shape_mn[0] % (2 if self.mma_tiler_mn[0] == 256 else 1) != 0, @@ -415,30 +362,11 @@ def is_power_of_2(x): not use_2cta_instrs and self.cluster_shape_mn != (1, 1), "Invalid cluster shape: cluster_shape must be (1, 1) when use_2cta_instrs=False", ) - self._value_error_if( - not use_2cta_instrs and self.ab12_dtype == torch.float32, - "Invalid ab12_dtype: use_2cta_instrs=False is incompatbile with float32 accumulator", - ) - - self._value_error_if( - self.mma_tiler_mn == (128, 128) - and self.cluster_shape_mn == (1, 1) - and self.ab12_stride_order != (0, 1, 2), - "Invalid MMA tile shape and AB12 stride order combination: (128, 128) mma tile shape with 1x1 cluster shape is only supported with ab12 stride_order (0, 1, 2)", - ) - self._value_error_if( - self.mma_tiler_mn != (128, 128) and self.ab12_stride_order != (0, 1, 2), - f"Invalid AB12 tensor stride order: for non-128x128mma tile shape, ab12 stride_order must be (0, 1, 2), got {self.ab12_stride_order}", - ) if self.cluster_shape_mn != (1, 1) and self.mma_tiler_mn[0] == 128: self._value_error_if( self.mma_tiler_mn != (128, 128), "Invalid MMA tile shape: for non-1x1 cluster shape and 128xmma tile shape, mma_tiler_mn must be (128, 128)", ) - self._not_implemented_error_if( - self.mma_tiler_mn[0] == 256 and self.ab12_dtype == torch.float32, - "mma_tiler_mn[0] == 256 and ab12_dtype == torch.float32 currently disabled", - ) self._logger.debug("Checking tensor alignment") @@ -446,28 +374,14 @@ def check_contigous_16B_alignment(dtype, stride_order, tensor_shape): is_mode0_major = stride_order == (0, 1, 2) major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] - num_contiguous_elements = ( - 16 - * 8 - // ( - _convert_to_cutlass_data_type( - dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2 - ).width - ) - ) + num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width) return num_major_elements % num_contiguous_elements == 0 self._value_error_if( not ( - check_contigous_16B_alignment( - self.ab_dtype, self.a_stride_order, (m, k, l) - ) - and check_contigous_16B_alignment( - self.ab_dtype, self.b_stride_order, (n, k, l) - ) - and check_contigous_16B_alignment( - self.ab12_dtype, self.ab12_stride_order, (m, n, l) - ) + check_contigous_16B_alignment(self.ab_dtype, self.a_stride_order, (m, k, l)) + and check_contigous_16B_alignment(self.ab_dtype, self.b_stride_order, (n, k, l)) + and check_contigous_16B_alignment(self.ab12_dtype, self.ab12_stride_order, (m, n, l)) ), "Invalid tensor alignment: tensors must be 16B aligned", ) @@ -488,9 +402,7 @@ def check_contigous_16B_alignment(dtype, stride_order, tensor_shape): major, minor = torch.cuda.get_device_capability(device) compute_capability = major * 10 + minor if compute_capability < 100: - raise RuntimeError( - f"GemmSwiglu requires SM100+ compute capability, but found SM{compute_capability} on device {device}" - ) + raise RuntimeError(f"GemmSwiglu requires SM100+ compute capability, but found SM{compute_capability} on device {device}") if compute_capability == 103: raise RuntimeError("cuteDSL GemmSwiglu is not supported on SM103") @@ -508,27 +420,17 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: is_ab12_fp4 = self._is_fp4x2(self.ab12_dtype) is_ab_fp8 = self._is_fp8(self.ab_dtype) is_ab12_fp8 = self._is_fp8(self.ab12_dtype) - _fp8_dlpack_supported = version.parse( - torch_version.base_version - ) >= version.parse("2.10.0") - use_no_dlpack_kernel = ( - is_ab_fp4 - or is_ab12_fp4 - or ((is_ab_fp8 or is_ab12_fp8) and not _fp8_dlpack_supported) - ) + _fp8_dlpack_supported = version.parse(torch_version.base_version) >= version.parse("2.10.0") + use_no_dlpack_kernel = is_ab_fp4 or is_ab12_fp4 or ((is_ab_fp8 or is_ab12_fp8) and not _fp8_dlpack_supported) if use_no_dlpack_kernel: - self._logger.debug( - "Running no_dlpack kernel wrapper due to fp4 dtype or fp8 dtype on incompatible torch version" - ) + self._logger.debug("Running no_dlpack kernel wrapper due to fp4 dtype or fp8 dtype on incompatible torch version") if self._kernel is PersistentDenseGemmKernel: self._kernel = PersistentDenseGemmKernelNoDlpack elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: self._kernel = Sm100BlockScaledPersistentDenseGemmKernelNoDlpack else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {self._kernel}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") gemm_swiglu = None if self._kernel in ( @@ -553,14 +455,10 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: ab12_stages=self.ab12_stages, ) else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {self._kernel}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ) + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) if self._kernel is PersistentDenseGemmKernel: self._logger.debug("Compiling gemm_swiglu (dlpack)") @@ -584,21 +482,9 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: sfb_tensor=from_dlpack(self.sample_sfb, assumed_align=16), c_tensor=from_dlpack(self.sample_c, assumed_align=16), ab12_tensor=from_dlpack(self.sample_ab12, assumed_align=8), - amax_tensor=( - from_dlpack(self.sample_amax, assumed_align=16) - if self.sample_amax is not None - else None - ), - sfc_tensor=( - from_dlpack(self.sample_sfc, assumed_align=16) - if self.sample_sfc is not None - else None - ), - norm_const_tensor=( - from_dlpack(self.sample_norm_const) - if self.sample_norm_const is not None - else None - ), + amax_tensor=(from_dlpack(self.sample_amax, assumed_align=16) if self.sample_amax is not None else None), + sfc_tensor=(from_dlpack(self.sample_sfc, assumed_align=16) if self.sample_sfc is not None else None), + norm_const_tensor=(from_dlpack(self.sample_norm_const) if self.sample_norm_const is not None else None), alpha=self.alpha, max_active_clusters=max_active_clusters, stream=current_stream, @@ -609,15 +495,9 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: ): # Create cute pointers/tensors manually to avoid DLPack requirements # c (output) is always fp16/bf16 and is safe to use directly with dlpack - a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor( - self.sample_a, name="A" - ) - b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor( - self.sample_b, name="B" - ) - ab12_ptr, ab12_shape, ab12_stride_order = self._make_cute_tensor_descriptor( - self.sample_ab12, name="AB12" - ) + a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor(self.sample_a, name="A") + b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor(self.sample_b, name="B") + ab12_ptr, ab12_shape, ab12_stride_order = self._make_cute_tensor_descriptor(self.sample_ab12, name="AB12") if self._kernel is PersistentDenseGemmKernelNoDlpack: self._compiled_kernel = cute.compile( @@ -637,26 +517,12 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: stream=current_stream, ) elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: - c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor( - self.sample_c, name="C" - ) - sfa_ptr, sfa_shape, sfa_stride_order = ( - self._make_cute_tensor_descriptor(self.sample_sfa, name="SFA") - ) - sfb_ptr, sfb_shape, sfb_stride_order = ( - self._make_cute_tensor_descriptor(self.sample_sfb, name="SFB") - ) - amax_ptr, amax_shape, amax_stride_order = ( - self._make_cute_tensor_descriptor(self.sample_amax, name="AMAX") - ) - sfc_ptr, sfc_shape, sfc_stride_order = ( - self._make_cute_tensor_descriptor(self.sample_sfc, name="SFC") - ) - norm_const_ptr, norm_const_shape, norm_const_stride_order = ( - self._make_cute_tensor_descriptor( - self.sample_norm_const, name="NORM_CONST" - ) - ) + c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor(self.sample_c, name="C") + sfa_ptr, sfa_shape, sfa_stride_order = self._make_cute_tensor_descriptor(self.sample_sfa, name="SFA") + sfb_ptr, sfb_shape, sfb_stride_order = self._make_cute_tensor_descriptor(self.sample_sfb, name="SFB") + amax_ptr, amax_shape, amax_stride_order = self._make_cute_tensor_descriptor(self.sample_amax, name="AMAX") + sfc_ptr, sfc_shape, sfc_stride_order = self._make_cute_tensor_descriptor(self.sample_sfc, name="SFC") + norm_const_ptr, norm_const_shape, norm_const_stride_order = self._make_cute_tensor_descriptor(self.sample_norm_const, name="NORM_CONST") self._compiled_kernel = cute.compile( gemm_swiglu, @@ -692,13 +558,9 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: stream=current_stream, ) else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {self._kernel}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {self._kernel}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") self._logger.debug("Kernel compiled successfully") def execute( @@ -737,9 +599,7 @@ def execute( ) elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax") - norm_const_tensor = self._unpad_tensor_to_ndim( - norm_const_tensor, 1, "norm_const" - ) + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") self._compiled_kernel( a_tensor=from_dlpack(a_tensor, assumed_align=16), b_tensor=from_dlpack(b_tensor, assumed_align=16), @@ -747,21 +607,9 @@ def execute( sfb_tensor=from_dlpack(sfb_tensor, assumed_align=16), c_tensor=from_dlpack(c_tensor, assumed_align=16), ab12_tensor=from_dlpack(ab12_tensor, assumed_align=8), - amax_tensor=( - from_dlpack(amax_tensor, assumed_align=16) - if amax_tensor is not None - else None - ), - sfc_tensor=( - from_dlpack(sfc_tensor, assumed_align=16) - if sfc_tensor is not None - else None - ), - norm_const_tensor=( - from_dlpack(norm_const_tensor) - if norm_const_tensor is not None - else None - ), + amax_tensor=(from_dlpack(amax_tensor, assumed_align=16) if amax_tensor is not None else None), + sfc_tensor=(from_dlpack(sfc_tensor, assumed_align=16) if sfc_tensor is not None else None), + norm_const_tensor=(from_dlpack(norm_const_tensor) if norm_const_tensor is not None else None), alpha=alpha, stream=current_stream, ) @@ -784,9 +632,7 @@ def execute( ) elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax") - norm_const_tensor = self._unpad_tensor_to_ndim( - norm_const_tensor, 1, "norm_const" - ) + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") c_ptr = self._make_cute_pointer(c_tensor, assumed_align=16) sfa_ptr = self._make_cute_pointer(sfa_tensor, assumed_align=16) sfb_ptr = self._make_cute_pointer(sfb_tensor, assumed_align=16) @@ -807,13 +653,9 @@ def execute( stream=current_stream, ) else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {type(self._compiled_kernel)}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {type(self._compiled_kernel)}") else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {type(self._compiled_kernel)}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {type(self._compiled_kernel)}") self._logger.debug("Executed with compiled kernel successfully") else: # skip_compile self._logger.debug("Executing without compiled kernel (JIT)") @@ -831,9 +673,7 @@ def execute( ab12=from_dlpack(ab12_tensor), c=from_dlpack(c_tensor), alpha=alpha, - max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ), + max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]), stream=current_stream, ) elif self._kernel is PersistentDenseGemmKernelNoDlpack: @@ -844,15 +684,9 @@ def execute( cluster_shape_mn=self.cluster_shape_mn, ) - a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor( - a_tensor, name="A" - ) - b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor( - b_tensor, name="B" - ) - ab12_ptr, ab12_shape, ab12_stride_order = ( - self._make_cute_tensor_descriptor(ab12_tensor, name="AB12") - ) + a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor(a_tensor, name="A") + b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor(b_tensor, name="B") + ab12_ptr, ab12_shape, ab12_stride_order = self._make_cute_tensor_descriptor(ab12_tensor, name="AB12") gemm_swiglu( a_ptr=a_ptr, @@ -866,9 +700,7 @@ def execute( ab12_order=ab12_stride_order, c_cute=from_dlpack(c_tensor), alpha=alpha, - max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ), + max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]), stream=current_stream, ) elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernel: @@ -880,9 +712,7 @@ def execute( ab12_stages=self.ab12_stages, ) amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax") - norm_const_tensor = self._unpad_tensor_to_ndim( - norm_const_tensor, 1, "norm_const" - ) + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") gemm_swiglu( a_tensor=from_dlpack(a_tensor, assumed_align=16), b_tensor=from_dlpack(b_tensor, assumed_align=16), @@ -890,25 +720,11 @@ def execute( sfb_tensor=from_dlpack(sfb_tensor, assumed_align=16), c_tensor=from_dlpack(c_tensor, assumed_align=16), ab12_tensor=from_dlpack(ab12_tensor, assumed_align=8), - amax_tensor=( - from_dlpack(amax_tensor, assumed_align=16) - if amax_tensor is not None - else None - ), - sfc_tensor=( - from_dlpack(sfc_tensor, assumed_align=16) - if sfc_tensor is not None - else None - ), - norm_const_tensor=( - from_dlpack(norm_const_tensor) - if norm_const_tensor is not None - else None - ), + amax_tensor=(from_dlpack(amax_tensor, assumed_align=16) if amax_tensor is not None else None), + sfc_tensor=(from_dlpack(sfc_tensor, assumed_align=16) if sfc_tensor is not None else None), + norm_const_tensor=(from_dlpack(norm_const_tensor) if norm_const_tensor is not None else None), alpha=alpha, - max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ), + max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]), stream=current_stream, ) elif self._kernel is Sm100BlockScaledPersistentDenseGemmKernelNoDlpack: @@ -920,39 +736,17 @@ def execute( ab12_stages=self.ab12_stages, ) amax_tensor = self._unpad_tensor_to_ndim(amax_tensor, 1, "amax") - norm_const_tensor = self._unpad_tensor_to_ndim( - norm_const_tensor, 1, "norm_const" - ) - - a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor( - a_tensor, name="A" - ) - b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor( - b_tensor, name="B" - ) - ab12_ptr, ab12_shape, ab12_stride_order = ( - self._make_cute_tensor_descriptor(ab12_tensor, name="AB12") - ) - c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor( - c_tensor, name="C" - ) - sfa_ptr, sfa_shape, sfa_stride_order = ( - self._make_cute_tensor_descriptor(sfa_tensor, name="SFA") - ) - sfb_ptr, sfb_shape, sfb_stride_order = ( - self._make_cute_tensor_descriptor(sfb_tensor, name="SFB") - ) - amax_ptr, amax_shape, amax_stride_order = ( - self._make_cute_tensor_descriptor(amax_tensor, name="AMAX") - ) - sfc_ptr, sfc_shape, sfc_stride_order = ( - self._make_cute_tensor_descriptor(sfc_tensor, name="SFC") - ) - norm_const_ptr, norm_const_shape, norm_const_stride_order = ( - self._make_cute_tensor_descriptor( - norm_const_tensor, name="NORM_CONST" - ) - ) + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + + a_ptr, a_shape, a_stride_order = self._make_cute_tensor_descriptor(a_tensor, name="A") + b_ptr, b_shape, b_stride_order = self._make_cute_tensor_descriptor(b_tensor, name="B") + ab12_ptr, ab12_shape, ab12_stride_order = self._make_cute_tensor_descriptor(ab12_tensor, name="AB12") + c_ptr, c_shape, c_stride_order = self._make_cute_tensor_descriptor(c_tensor, name="C") + sfa_ptr, sfa_shape, sfa_stride_order = self._make_cute_tensor_descriptor(sfa_tensor, name="SFA") + sfb_ptr, sfb_shape, sfb_stride_order = self._make_cute_tensor_descriptor(sfb_tensor, name="SFB") + amax_ptr, amax_shape, amax_stride_order = self._make_cute_tensor_descriptor(amax_tensor, name="AMAX") + sfc_ptr, sfc_shape, sfc_stride_order = self._make_cute_tensor_descriptor(sfc_tensor, name="SFC") + norm_const_ptr, norm_const_shape, norm_const_stride_order = self._make_cute_tensor_descriptor(norm_const_tensor, name="NORM_CONST") gemm_swiglu( a_ptr=a_ptr, @@ -983,15 +777,11 @@ def execute( norm_const_shape=norm_const_shape, norm_const_order=norm_const_stride_order, alpha=alpha, - max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters( - self.cluster_shape_mn[0] * self.cluster_shape_mn[1] - ), + max_active_clusters=cutlass.utils.HardwareInfo().get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]), stream=current_stream, ) else: - raise NotImplementedError( - f"Unreachable: invalid kernel type {type(self._kernel)}" - ) + raise NotImplementedError(f"Unreachable: invalid kernel type {type(self._kernel)}") self._logger.debug("Executed without compiled kernel (JIT) successfully") @@ -1026,16 +816,10 @@ def gemm_swiglu_wrapper_sm100( n, k, l = b_tensor.shape ab12_tensor, c_tensor = None, None if c_major == "m": - ab12_tensor = torch.empty_strided( - (m, n, l), (1, m, m * n), dtype=ab12_dtype, device=a_tensor.device - ) - c_tensor = torch.empty_strided( - (m, n // 2, l), (1, m, m * n // 2), dtype=c_dtype, device=a_tensor.device - ) + ab12_tensor = torch.empty_strided((m, n, l), (1, m, m * n), dtype=ab12_dtype, device=a_tensor.device) + c_tensor = torch.empty_strided((m, n // 2, l), (1, m, m * n // 2), dtype=c_dtype, device=a_tensor.device) elif c_major == "n": - ab12_tensor = torch.empty_strided( - (m, n, l), (n, 1, m * n), dtype=ab12_dtype, device=a_tensor.device - ) + ab12_tensor = torch.empty_strided((m, n, l), (n, 1, m * n), dtype=ab12_dtype, device=a_tensor.device) c_tensor = torch.empty_strided( (m, n // 2, l), (n // 2, 1, m * n // 2), @@ -1047,13 +831,9 @@ def gemm_swiglu_wrapper_sm100( sfc_tensor, amax_tensor = None, None if sfa_tensor is not None and sfb_tensor is not None: - _logger.debug( - "gemm_swiglu_wrapper_sm100: Detected sfa_tensor and sfb_tensor, constructing quantized output tensors" - ) + _logger.debug("gemm_swiglu_wrapper_sm100: Detected sfa_tensor and sfb_tensor, constructing quantized output tensors") if c_dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: - _logger.debug( - "gemm_swiglu_wrapper_sm100: Detected fp8 c_dtype, constructing sfc_tensor" - ) + _logger.debug("gemm_swiglu_wrapper_sm100: Detected fp8 c_dtype, constructing sfc_tensor") sf_k = ceil_div(n // 2, sf_vec_size) mma_shape = ( @@ -1070,16 +850,9 @@ def gemm_swiglu_wrapper_sm100( dtype=torch.float8_e8m0fnu, device=a_tensor.device, ).permute(mma_permute_order) - if ( - a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} - and c_dtype == torch.bfloat16 - ): - _logger.debug( - "gemm_swiglu_wrapper_sm100: Detected fp4 ab_dtype and bf16 c_dtype, constructing amax_tensor" - ) - amax_tensor = torch.full( - (1, 1, 1), -float("inf"), device=a_tensor.device, dtype=torch.float32 - ) + if a_tensor.dtype in {torch.float4_e2m1fn_x2, torch.uint8} and c_dtype == torch.bfloat16: + _logger.debug("gemm_swiglu_wrapper_sm100: Detected fp4 ab_dtype and bf16 c_dtype, constructing amax_tensor") + amax_tensor = torch.full((1, 1, 1), -float("inf"), device=a_tensor.device, dtype=torch.float32) cache_key = ( a_tensor.shape, @@ -1109,9 +882,7 @@ def gemm_swiglu_wrapper_sm100( ab12_stages, ) if cache_key in _cache_of_GemmSwigluSm100Objects: - _logger.debug( - "gemm_swiglu_wrapper_sm100: Using previously cached GemmSwigluSm100 object" - ) + _logger.debug("gemm_swiglu_wrapper_sm100: Using previously cached GemmSwigluSm100 object") gemm_swiglu = _cache_of_GemmSwigluSm100Objects[cache_key] gemm_swiglu.execute( a_tensor=a_tensor, @@ -1127,9 +898,7 @@ def gemm_swiglu_wrapper_sm100( current_stream=stream, ) else: - _logger.debug( - "gemm_swiglu_wrapper_sm100: No previously cached GemmSwigluSm100 object found, creating new GemmSwigluSm100 object" - ) + _logger.debug("gemm_swiglu_wrapper_sm100: No previously cached GemmSwigluSm100 object found, creating new GemmSwigluSm100 object") gemm_swiglu = GemmSwigluSm100( sample_a=a_tensor, sample_b=b_tensor, diff --git a/python/cudnn/gemm_swiglu/dense_blockscaled_gemm_persistent_swiglu_interleaved_quant.py b/python/cudnn/gemm_swiglu/dense_blockscaled_gemm_persistent_swiglu_interleaved_quant.py index d1e7f826..df0093e7 100644 --- a/python/cudnn/gemm_swiglu/dense_blockscaled_gemm_persistent_swiglu_interleaved_quant.py +++ b/python/cudnn/gemm_swiglu/dense_blockscaled_gemm_persistent_swiglu_interleaved_quant.py @@ -44,9 +44,7 @@ from cutlass.cute.typing import Float32 -def sigmoid_f32( - a: Union[float, Float32], fastmath: bool = False -) -> Union[float, Float32]: +def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: """ Compute the sigmoid of the input tensor. """ @@ -194,9 +192,7 @@ def __init__( # K dimension is deferred in _setup_attributes self.mma_tiler = (*mma_tiler_mn, 1) - self.cta_group = ( - tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - ) + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.occupancy = 1 # Set specialized warp ids @@ -208,9 +204,7 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 - self.threads_per_cta = 32 * len( - (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) - ) + self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)) # Set barrier id for epilogue sync and tmem ptr sync self.epilog_sync_barrier = pipeline.NamedBarrier( barrier_id=1, @@ -333,24 +327,22 @@ def _setup_attributes(self): self.epi_tile_ab12 = (cute.make_layout(128), cute.make_layout(64)) # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory - self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_ab12_stage = ( - self._compute_stages( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.b_dtype, - self.epi_tile, - self.c_dtype, - self.c_layout, - self.sf_dtype, - self.sf_vec_size, - self.ab12_dtype, - self.ab12_layout, - self.epi_tile_ab12, - self.smem_capacity, - self.occupancy, - self.ab12_stages, - ) + self.num_acc_stage, self.num_ab_stage, self.num_c_stage, self.num_ab12_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.ab12_dtype, + self.ab12_layout, + self.epi_tile_ab12, + self.smem_capacity, + self.occupancy, + self.ab12_stages, ) # Compute A/B/SFA/SFB/C shared memory layout @@ -461,15 +453,11 @@ def __call__( # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) - sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( - a_tensor.shape, self.sf_vec_size - ) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a_tensor.shape, self.sf_vec_size) sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( - b_tensor.shape, self.sf_vec_size - ) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b_tensor.shape, self.sf_vec_size) sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( @@ -495,9 +483,7 @@ def __call__( atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id - ) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, @@ -509,9 +495,7 @@ def __call__( ) # Setup TMA load for B - b_op = sm100_utils.cluster_shape_to_tma_atom_B( - self.cluster_shape_mn, tiled_mma.thr_id - ) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, @@ -523,12 +507,8 @@ def __call__( ) # Setup TMA load for SFA - sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id - ) - sfa_smem_layout = cute.slice_( - self.sfa_smem_layout_staged, (None, None, None, 0) - ) + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( sfa_op, sfa_tensor, @@ -540,12 +520,8 @@ def __call__( ) # Setup TMA load for SFB - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( - self.cluster_shape_mn, tiled_mma.thr_id - ) - sfb_smem_layout = cute.slice_( - self.sfb_smem_layout_staged, (None, None, None, 0) - ) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( sfb_op, sfb_tensor, @@ -573,17 +549,13 @@ def __call__( tma_tensor_sfb.stride[2], ) tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor( - tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout - ) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) - self.num_tma_load_bytes = ( - a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size - ) * atom_thr_size + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size # Setup TMA store for C epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) @@ -593,9 +565,7 @@ def __call__( epi_smem_layout, self.epi_tile, ) - epi_ab12_smem_layout = cute.slice_( - self.ab12_smem_layout_staged, (None, None, 0) - ) + epi_ab12_smem_layout = cute.slice_(self.ab12_smem_layout_staged, (None, None, 0)) tma_atom_ab12, tma_tensor_ab12 = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), ab12_tensor, @@ -615,9 +585,7 @@ def __call__( self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None if cutlass.const_expr(self.generate_sfc): - sfc_layout = blockscaled_utils.tile_atom_to_shape_SF( - c_tensor.shape, self.sf_vec_size - ) + sfc_layout = blockscaled_utils.tile_atom_to_shape_SF(c_tensor.shape, self.sf_vec_size) sfc_tensor = cute.make_tensor(sfc_tensor.iterator, sfc_layout) self.generate_amax = amax_tensor is not None @@ -648,30 +616,22 @@ class SharedStorage: ] # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], self.buffer_align_bytes, ] # (MMA, MMA_M, MMA_K, STAGE) sSFA: cute.struct.Align[ - cute.struct.MemRange[ - self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) - ], + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sSFB: cute.struct.Align[ - cute.struct.MemRange[ - self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) - ], + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], self.buffer_align_bytes, ] # Amax reduction shared memory (one FP32 per epilogue warp) @@ -784,15 +744,9 @@ def kernel( bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( - cta_rank_in_cluster - ) - block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( - cta_rank_in_cluster - ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) # Coord inside cta tidx, _, _ = cute.arch.thread_idx() @@ -805,9 +759,7 @@ def kernel( # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - ab_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_tma_producer - ) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, @@ -820,12 +772,8 @@ def kernel( # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - num_acc_consumer_threads = len(self.epilog_warp_id) * ( - 2 if use_2cta_instrs else 1 - ) - acc_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_acc_consumer_threads - ) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, @@ -851,13 +799,9 @@ def kernel( # Setup smem tensor A/B/SFA/SFB/C # # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner - ) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner - ) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) # (MMA, MMA_M, MMA_K, STAGE) sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) # (MMA, MMA_N, MMA_K, STAGE) @@ -889,34 +833,20 @@ def kernel( sfa_full_mcast_mask = None sfb_full_mcast_mask = None if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): - a_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 - ) - b_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 - ) - sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 - ) - sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 - ) + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) # # Local_tile partition global tensors # # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) - ) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) # (bN, bK, RestN, RestK, RestL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) - ) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) # (bM, bK, RestM, RestK, RestL) - gSFA_mkl = cute.local_tile( - mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) - ) + gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) # (bN, bK, RestN, RestK, RestL) gSFB_nkl = cute.local_tile( mSFB_nkl, @@ -924,9 +854,7 @@ def kernel( (None, None, None), ) # (bM, bN, RestM, RestN, RestL) - gC_mnl = cute.local_tile( - mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None) - ) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None)) # (bM, bN, RestM, RestN, RestL) gAB12_mnl = cute.local_tile( mAB12_mnl, @@ -957,9 +885,7 @@ def kernel( # Partition global/shared tensor for TMA load A/B # # TMA load A partition_S/D - a_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape - ) + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( @@ -970,9 +896,7 @@ def kernel( cute.group_modes(tCgA, 0, 3), ) # TMA load B partition_S/D - b_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape - ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( @@ -998,9 +922,7 @@ def kernel( tAgSFA = cute.filter_zeros(tAgSFA) # TMA load SFB partition_S/D - sfb_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape - ) + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( @@ -1023,9 +945,7 @@ def kernel( # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) - tCtAcc_fake = tiled_mma.make_fragment_C( - cute.append(acc_shape, self.num_acc_stage) - ) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) # # Cluster wait before tensor memory alloc @@ -1039,14 +959,10 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -1061,18 +977,12 @@ def kernel( # Slice to per mma tile index # # ((atom_v, rest_v), RestK) - tAgA_slice = tAgA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) - ] + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tAgSFA_slice = tAgSFA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) - ] + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release} slice_n = mma_tile_coord_mnl[1] @@ -1085,17 +995,13 @@ def kernel( ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Tma load loop # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire( - ab_producer_state, peek_ab_empty_status - ) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) # TMA load A/B/SFA/SFB cute.copy( @@ -1131,9 +1037,7 @@ def kernel( ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Advance to next tile @@ -1179,9 +1083,7 @@ def kernel( # Make SFB tmem tensor sfb_tmem_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), dtype=self.sf_dtype, ) # (MMA, MMA_N, MMA_K) @@ -1209,17 +1111,11 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) - acc_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_acc_stage - ) + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -1238,9 +1134,7 @@ def kernel( ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_tile_cnt and is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Wait for accumulator buffer empty @@ -1252,16 +1146,9 @@ def kernel( tCtSFB_mma = tCtSFB if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) - offset = ( - cutlass.Int32(2) - if mma_tile_coord_mnl[1] % 2 == 1 - else cutlass.Int32(0) - ) + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) shifted_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA) - + offset, + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, dtype=self.sf_dtype, ) tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) @@ -1269,10 +1156,7 @@ def kernel( # Move in increments of 64 columns of SFB offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) shifted_ptr = cute.recast_ptr( - acc_tmem_ptr - + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) - + tcgen05.find_tmem_tensor_col_offset(tCtSFA) - + offset, + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + offset, dtype=self.sf_dtype, ) tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) @@ -1288,9 +1172,7 @@ def kernel( for k_tile in range(k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full - ab_pipeline.consumer_wait( - ab_consumer_state, peek_ab_full_status - ) + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) # Copy SFA/SFB from smem to tmem s2t_stage_coord = ( @@ -1353,9 +1235,7 @@ def kernel( peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_tile_cnt: if is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Async arrive accumulator buffer full @@ -1404,42 +1284,28 @@ def kernel( tTR_tAcc_base, tTR_rAcc_up, tTR_rAcc_gate, - ) = self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs - ) + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs) tTR_rC = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rC, epi_tidx, sC - ) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) ( bSG_sC, bSG_gC_mnl, - ) = self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_c, tCgC, epi_tile, sC - ) + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) tTR_rAB12 = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.ab12_dtype) - _, tRS_rAB12, tRS_sAB12 = self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rAB12, epi_tidx, sAB12 - ) + _, tRS_rAB12, tRS_sAB12 = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rAB12, epi_tidx, sAB12) ( bSG_sAB12, bSG_gAB12_mnl, - ) = self.epilog_gmem_copy_and_partition( - epi_tidx, tma_atom_ab12, tCgAB12, epi_tile_ab12, sAB12 - ) + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_ab12, tCgAB12, epi_tile_ab12, sAB12) # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( @@ -1470,13 +1336,9 @@ def kernel( tCgSFC_mnl = thr_copy_t2r.partition_D(gSFC_mnl) tCgSFC_mnl = cute.filter_zeros(tCgSFC_mnl) # ((T2R, T2R_M, T2R_N), SUBTILLE_IDX) # ((1,1),1,4):((0,0),0,1) - tCrSFC = cute.make_rmem_tensor( - tCgSFC_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype - ) + tCrSFC = cute.make_rmem_tensor(tCgSFC_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype) tCrSFC_pvscale = cute.make_rmem_tensor_like(tCrSFC, cutlass.Float32) - tCrSFC_qpvscale_up_fp32 = cute.make_rmem_tensor_like( - tCrSFC, cutlass.Float32 - ) + tCrSFC_qpvscale_up_fp32 = cute.make_rmem_tensor_like(tCrSFC, cutlass.Float32) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -1509,9 +1371,7 @@ def kernel( ] # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] # Initialize thread-local amax accumulator for this tile # Use 0.0 as initial value since we're computing absolute maximum @@ -1530,9 +1390,7 @@ def kernel( # # Store accumulator to global memory in subtiles # - subtile_cnt = cute.size( - tTR_tAcc.shape, mode=[3] - ) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8)) + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) ## tTR_tAcc.shape: (((32, 32), 1), 1, 1, (1, 8)) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt for subtile_idx in cutlass.range(0, subtile_cnt, 2): # Calculate subtile index for C output (one output per two input subtiles) @@ -1566,9 +1424,7 @@ def kernel( # # Store AB12 to shared memory for bprop # - AB12_buffer = ( - num_prev_subtiles + sfc_subtile_idx - ) % self.num_ab12_stage + AB12_buffer = (num_prev_subtiles + sfc_subtile_idx) % self.num_ab12_stage # Convert to ab12_dtype before storing # Load, convert type, and store back to temporary register tensor # tTR_rAB12_up = cute.make_rmem_tensor(tTR_tAcc_mn_up.shape, self.ab12_dtype) @@ -1583,9 +1439,7 @@ def kernel( cute.copy( tiled_copy_r2s, tRS_rAB12[(None, None, 0)], # ((1, 32), 1, 1), ((0, 1), 0, 0) - tRS_sAB12[ - (None, None, 1, AB12_buffer) - ], # ((1, 32), 1, 2, (1, 1)), ((0, 1), 0, 32, (0, 0)) + tRS_sAB12[(None, None, 1, AB12_buffer)], # ((1, 32), 1, 2, (1, 1)), ((0, 1), 0, 32, (0, 0)) ) # Fence and barrier to make sure shared memory store is visible to TMA store @@ -1600,12 +1454,8 @@ def kernel( if warp_idx == self.epilog_warp_id[0]: cute.copy( tma_atom_ab12, - bSG_sAB12[ - (None, sfc_subtile_idx % self.num_ab12_stage) - ], # ((8192, 1), (1, 4)), ((1, 0), (0, 8192)) - bSG_gAB12[ - (None, sfc_subtile_idx) - ], # (((64, 128), 1), (1, 4)) : (((1@0,1@1),0),(0,64@0)) + bSG_sAB12[(None, sfc_subtile_idx % self.num_ab12_stage)], # ((8192, 1), (1, 4)), ((1, 0), (0, 8192)) + bSG_gAB12[(None, sfc_subtile_idx)], # (((64, 128), 1), (1, 4)) : (((1@0,1@1),0),(0,64@0)) ) # Fence and barrier to make sure shared memory store is visible to TMA store ab12_pipeline.producer_commit() @@ -1617,20 +1467,14 @@ def kernel( if cutlass.const_expr(self.vector_f32): # SwiGelu Packed Version LOG2_E = cutlass.Float32(1.4426950408889634) - for i in cutlass.range( - 0, cute.size(acc_vec_gate), 2, unroll_full=True - ): + for i in cutlass.range(0, cute.size(acc_vec_gate), 2, unroll_full=True): tCompute_log2e = cute.arch.mul_packed_f32x2( (acc_vec_gate[i], acc_vec_gate[i + 1]), (-LOG2_E, -LOG2_E), ) ## replace to add_packed_f32x2 when no precision issue. - tCompute[i + 0] = ( - cute.math.exp2(tCompute_log2e[0], fastmath=True) + 1.0 - ) - tCompute[i + 1] = ( - cute.math.exp2(tCompute_log2e[1], fastmath=True) + 1.0 - ) + tCompute[i + 0] = cute.math.exp2(tCompute_log2e[0], fastmath=True) + 1.0 + tCompute[i + 1] = cute.math.exp2(tCompute_log2e[1], fastmath=True) + 1.0 tCompute[i + 0] = cute.arch.rcp_approx(tCompute[i + 0]) tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) ( @@ -1649,12 +1493,8 @@ def kernel( ) else: # SwiGelu Unpacked Version - for i in cutlass.range( - 0, cute.size(acc_vec_gate), 1, unroll_full=True - ): - tCompute[i] = acc_vec_up[i] * silu_f32( - acc_vec_gate[i], fastmath=True - ) + for i in cutlass.range(0, cute.size(acc_vec_gate), 1, unroll_full=True): + tCompute[i] = acc_vec_up[i] * silu_f32(acc_vec_gate[i], fastmath=True) # # Generate amax @@ -1662,20 +1502,14 @@ def kernel( if cutlass.const_expr(self.generate_amax): acc_values = tCompute.load() # Apply element-wise absolute value using math.absf (supports vectors) - abs_acc_values_ir = cutlass._mlir.dialects.math.absf( - acc_values.ir_value() # operand (positional) - ) - abs_acc_values = type(acc_values)( - abs_acc_values_ir, acc_values.shape, acc_values.dtype - ) + abs_acc_values_ir = cutlass._mlir.dialects.math.absf(acc_values.ir_value()) # operand (positional) + abs_acc_values = type(acc_values)(abs_acc_values_ir, acc_values.shape, acc_values.dtype) subtile_amax = abs_acc_values.reduce( cute.ReductionOp.MAX, cutlass.Float32(0.0), 0, # Use 0.0 as init for abs values ) - thread_tile_amax = cute.arch.fmax( - thread_tile_amax, subtile_amax - ) + thread_tile_amax = cute.arch.fmax(thread_tile_amax, subtile_amax) # # Generate sfc @@ -1697,34 +1531,22 @@ def kernel( # # Get absolute max across a vector and Compute SFC # - tTR_rAcc_frg = cute.logical_divide( - tCompute, cute.make_layout(self.sf_vec_size) - ) + tTR_rAcc_frg = cute.logical_divide(tCompute, cute.make_layout(self.sf_vec_size)) acc_frg = tTR_rAcc_frg.load() acc_frg = epilogue_op(acc_frg) # Apply element-wise absolute value using math.absf (supports vectors) - abs_acc_frg_ir = cutlass._mlir.dialects.math.absf( - acc_frg.ir_value() - ) - abs_acc_frg = type(acc_frg)( - abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype - ) + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) if cutlass.const_expr(self.vector_f32): - tCrSFC_pvscale_subtile = tCrSFC_pvscale[ - None, None, sfc_subtile_idx - ] + tCrSFC_pvscale_subtile = tCrSFC_pvscale[None, None, sfc_subtile_idx] for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]): - tCrSFC_pvscale_subtile[vi] = abs_acc_frg[ - None, vi - ].reduce( + tCrSFC_pvscale_subtile[vi] = abs_acc_frg[None, vi].reduce( cute.ReductionOp.MAX, cutlass.Float32(0.0), 0, # Use 0.0 as init for abs values ) - for vi in cutlass.range_constexpr( - 0, abs_acc_frg.shape[1], 2 - ): + for vi in cutlass.range_constexpr(0, abs_acc_frg.shape[1], 2): ( tCrSFC_pvscale_subtile[vi], tCrSFC_pvscale_subtile[vi + 1], @@ -1761,9 +1583,7 @@ def kernel( ) # TODO: need to investigate f32x2 -> f8x2 conversion - if ( - sfc_subtile_idx == 3 - ): # 3 is the last subtile to composite SFC to a reg + if sfc_subtile_idx == 3: # 3 is the last subtile to composite SFC to a reg # # convert SFC from fp32 to sf_dtype # @@ -1784,21 +1604,15 @@ def kernel( ## tCrSFC_qpvscale_up = tCrSFC_pvscale[None, None, sfc_subtile_idx] tCrSFC.store(tCrSFC_pvscale.load().to(self.sf_dtype)) tCrSFC_qpvscale_up_fp32.store(tCrSFC.load().to(cutlass.Float32)) - tCrSFC_qpvscale_up = tCrSFC_qpvscale_up_fp32[ - None, None, sfc_subtile_idx - ] + tCrSFC_qpvscale_up = tCrSFC_qpvscale_up_fp32[None, None, sfc_subtile_idx] fp32_max = cutlass.Float32(3.40282346638528859812e38) if cutlass.const_expr(self.vector_f32): - for vi in cutlass.range_constexpr( - 0, cute.size(tCrSFC_qpvscale_up), 2 - ): + for vi in cutlass.range_constexpr(0, cute.size(tCrSFC_qpvscale_up), 2): acc_scale = cute.arch.mul_packed_f32x2( ( cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi]), - cute.arch.rcp_approx( - tCrSFC_qpvscale_up[vi + 1] - ), + cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi + 1]), ), (norm_const, norm_const), ) @@ -1813,12 +1627,8 @@ def kernel( (acc_scale_min0, acc_scale_min1), ) else: - for vi in cutlass.range_constexpr( - cute.size(tCrSFC_qpvscale_up) - ): - acc_scale = norm_const * cute.arch.rcp_approx( - tCrSFC_qpvscale_up[vi] - ) + for vi in cutlass.range_constexpr(cute.size(tCrSFC_qpvscale_up)): + acc_scale = norm_const * cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi]) acc_scale = fmin(acc_scale, fp32_max, nan=True) vec = tTR_rAcc_frg[None, vi] @@ -1878,9 +1688,7 @@ def kernel( # Block-level reduction: only first epilogue warp's lane 0 handles this if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: - block_amax = cutlass.Float32( - 0.0 - ) # Initial value for absolute maximum + block_amax = cutlass.Float32(0.0) # Initial value for absolute maximum for i in cutlass.range(self.num_epilog_warps): warp_amax_val = sAmax[i] block_amax = cute.arch.fmax(block_amax, warp_amax_val) @@ -1970,9 +1778,7 @@ def mainloop_s2t_copy_and_partition( # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( - tiled_copy_s2t, tCsSF_compact_s2t_ - ) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) @@ -2022,28 +1828,20 @@ def epilog_tmem_copy_and_partition( epi_tile, ) # (EPI_TILE_M, EPI_TILE_N) - tiled_copy_t2r = tcgen05.make_tmem_copy( - copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] - ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_mnl_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc_up = cute.make_rmem_tensor( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype - ) + tTR_rAcc_up = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) # (T2R, T2R_M, T2R_N) - tTR_rAcc_gate = cute.make_rmem_tensor( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype - ) + tTR_rAcc_gate = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc_up, tTR_rAcc_gate def epilog_smem_copy_and_partition( @@ -2071,9 +1869,7 @@ def epilog_smem_copy_and_partition( - tRS_sC: The partitioned tensor C (smem destination) :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ - copy_atom_r2s = sm100_utils.get_smem_store_op( - self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r - ) + copy_atom_r2s = sm100_utils.get_smem_store_op(self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r) tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) @@ -2111,9 +1907,7 @@ def epilog_gmem_copy_and_partition( :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) # ((ATOM_V, REST_V), EPI_M, EPI_N) # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( @@ -2237,23 +2031,15 @@ def _compute_stages( ) mbar_helpers_bytes = 1024 c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) - ab12_bytes_per_stage = cute.size_in_bytes( - ab12_dtype, ab12_smem_layout_staged_one - ) + ab12_bytes_per_stage = cute.size_in_bytes(ab12_dtype, ab12_smem_layout_staged_one) amax_bytes = Sm100BlockScaledPersistentDenseGemmKernel.get_amax_smem_size() - epi_bytes = ( - c_bytes_per_stage * num_c_stage - + ab12_bytes_per_stage * num_ab12_stage - + amax_bytes - ) + epi_bytes = c_bytes_per_stage * num_c_stage + ab12_bytes_per_stage * num_ab12_stage + amax_bytes # Calculate A/B/SFA/SFB stages: # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B/SFA/SFB stage - num_ab_stage = ( - smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes) - ) // ab_bytes_per_stage + num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes)) // ab_bytes_per_stage return num_acc_stage, num_ab_stage, num_c_stage, num_ab12_stage @@ -2285,12 +2071,8 @@ def _compute_grid( num_ctas_mnl = gc[(0, (None, None, None))].shape cluster_shape_mnl = (*cluster_shape_mn, 1) - tile_sched_params = utils.PersistentTileSchedulerParams( - num_ctas_mnl, cluster_shape_mnl - ) - grid = utils.StaticPersistentTileScheduler.get_grid_shape( - tile_sched_params, max_active_clusters - ) + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) return tile_sched_params, grid @@ -2378,44 +2160,26 @@ def __call__( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): - a_cute = cute.make_tensor( - a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order) - ) - b_cute = cute.make_tensor( - b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order) - ) - sfa_cute = cute.make_tensor( - sfa_ptr, layout=cute.make_ordered_layout(sfa_shape, order=sfa_order) - ) - sfb_cute = cute.make_tensor( - sfb_ptr, layout=cute.make_ordered_layout(sfb_shape, order=sfb_order) - ) - c_cute = cute.make_tensor( - c_ptr, layout=cute.make_ordered_layout(c_shape, order=c_order) - ) - ab12_cute = cute.make_tensor( - ab12_ptr, layout=cute.make_ordered_layout(ab12_shape, order=ab12_order) - ) + a_cute = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order)) + b_cute = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order)) + sfa_cute = cute.make_tensor(sfa_ptr, layout=cute.make_ordered_layout(sfa_shape, order=sfa_order)) + sfb_cute = cute.make_tensor(sfb_ptr, layout=cute.make_ordered_layout(sfb_shape, order=sfb_order)) + c_cute = cute.make_tensor(c_ptr, layout=cute.make_ordered_layout(c_shape, order=c_order)) + ab12_cute = cute.make_tensor(ab12_ptr, layout=cute.make_ordered_layout(ab12_shape, order=ab12_order)) amax_cute = None if cutlass.const_expr(amax_ptr is not None): - amax_cute = cute.make_tensor( - amax_ptr, layout=cute.make_ordered_layout(amax_shape, order=amax_order) - ) + amax_cute = cute.make_tensor(amax_ptr, layout=cute.make_ordered_layout(amax_shape, order=amax_order)) sfc_cute = None if cutlass.const_expr(sfc_ptr is not None): - sfc_cute = cute.make_tensor( - sfc_ptr, layout=cute.make_ordered_layout(sfc_shape, order=sfc_order) - ) + sfc_cute = cute.make_tensor(sfc_ptr, layout=cute.make_ordered_layout(sfc_shape, order=sfc_order)) norm_const_cute = None if cutlass.const_expr(norm_const_ptr is not None): norm_const_cute = cute.make_tensor( norm_const_ptr, - layout=cute.make_ordered_layout( - norm_const_shape, order=norm_const_order - ), + layout=cute.make_ordered_layout(norm_const_shape, order=norm_const_order), ) self.kernel( diff --git a/python/cudnn/gemm_swiglu/dense_gemm_persistent_swiglu.py b/python/cudnn/gemm_swiglu/dense_gemm_persistent_swiglu.py index 8636efae..ae9b7f72 100644 --- a/python/cudnn/gemm_swiglu/dense_gemm_persistent_swiglu.py +++ b/python/cudnn/gemm_swiglu/dense_gemm_persistent_swiglu.py @@ -181,9 +181,7 @@ def __init__( # K dimension is deferred in _setup_attributes self.mma_tiler = (*mma_tiler_mn, 1) - self.cta_group = ( - tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE - ) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE self.occupancy = 1 # Set specialized warp ids @@ -195,9 +193,7 @@ def __init__( ) self.mma_warp_id = 4 self.tma_warp_id = 5 - self.threads_per_cta = 32 * len( - (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) - ) + self.threads_per_cta = 32 * len((self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)) # Set barrier id for cta sync, epilogue sync and tmem ptr sync self.cta_sync_bar_id = 0 self.epilog_sync_bar_id = 1 @@ -238,8 +234,7 @@ def _setup_attributes(self): ) self.mma_tiler_c = ( self.mma_tiler[0], - self.mma_tiler[1] - // 2, # divide by 2 because Glu advnces by half on N dimension + self.mma_tiler[1] // 2, # divide by 2 because Glu advnces by half on N dimension self.mma_tiler[2], ) self.cta_tile_shape_mnk = ( @@ -280,21 +275,19 @@ def _setup_attributes(self): ) # Setup A/B/AB12 stage count in shared memory and ACC stage count in tensor memory - self.num_acc_stage, self.num_ab_stage, self.num_ab12_stage, self.num_c_stage = ( - self._compute_stages( - tiled_mma, - self.mma_tiler, - self.a_dtype, - self.b_dtype, - self.epi_tile, - self.epi_tile_c, - self.ab12_dtype, - self.ab12_layout, - self.c_dtype, - self.c_layout, - self.smem_capacity, - self.occupancy, - ) + self.num_acc_stage, self.num_ab_stage, self.num_ab12_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.epi_tile_c, + self.ab12_dtype, + self.ab12_layout, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, ) # Compute A/B/AB12 shared memory layout @@ -324,9 +317,7 @@ def _setup_attributes(self): ) # Compute the number of tensor memory allocation columns - self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( - tiled_mma, self.mma_tiler, self.num_acc_stage - ) + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols(tiled_mma, self.mma_tiler, self.num_acc_stage) @cute.jit def __call__( @@ -393,9 +384,7 @@ def __call__( atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A - a_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id - ) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, @@ -404,15 +393,11 @@ def __call__( self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, - internal_type=( - cutlass.TFloat32 if a.element_type is cutlass.Float32 else None - ), + internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), ) # Setup TMA load for B - b_op = sm100_utils.cluster_shape_to_tma_atom_B( - self.cluster_shape_mn, tiled_mma.thr_id - ) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, @@ -421,9 +406,7 @@ def __call__( self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, - internal_type=( - cutlass.TFloat32 if b.element_type is cutlass.Float32 else None - ), + internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), ) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) @@ -431,12 +414,8 @@ def __call__( self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size # Setup TMA store for AB12 and C - ab12_cta_v_layout = cute.composition( - cute.make_identity_layout(ab12.shape), self.epi_tile - ) - c_cta_v_layout = cute.composition( - cute.make_identity_layout(c.shape), self.epi_tile_c - ) + ab12_cta_v_layout = cute.composition(cute.make_identity_layout(ab12.shape), self.epi_tile) + c_cta_v_layout = cute.composition(cute.make_identity_layout(c.shape), self.epi_tile_c) epi_smem_layout = cute.slice_(self.ab12_smem_layout_staged, (None, None, 0)) epi_smem_layout_c = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) tma_atom_ab12, tma_tensor_ab12 = cpasync.make_tiled_tma_atom( @@ -453,9 +432,7 @@ def __call__( ) # Compute grid size - self.tile_sched_params, grid = self._compute_grid( - ab12, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters - ) + self.tile_sched_params, grid = self._compute_grid(ab12, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters) self.buffer_align_bytes = 1024 @@ -493,16 +470,12 @@ class SharedStorage: # c_smem_size: S<1,4,3> o 0 o ((8,16),(32,1),(1,8)):((32,256),(1,0),(0,4096)) # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ - cute.struct.MemRange[ - self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ - cute.struct.MemRange[ - self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) - ], + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], self.buffer_align_bytes, ] @@ -586,12 +559,8 @@ def kernel( bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( - cta_rank_in_cluster - ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) # Coord inside cta tidx, _, _ = cute.arch.thread_idx() @@ -607,9 +576,7 @@ def kernel( # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - ab_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_tma_producer - ) + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, @@ -621,12 +588,8 @@ def kernel( # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - num_acc_consumer_threads = len(self.epilog_warp_id) * ( - 2 if use_2cta_instrs else 1 - ) - acc_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, num_acc_consumer_threads - ) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, @@ -640,9 +603,7 @@ def kernel( if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads - ) + cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) cute.arch.mbarrier_init_fence() # Cluster arrive after barrier init @@ -653,21 +614,13 @@ def kernel( # Setup smem tensor A/B/AB12/C # # (EPI_TILE_M, EPI_TILE_N, STAGE) - sAB12 = storage.sAB12.get_tensor( - ab12_smem_layout_staged.outer, swizzle=ab12_smem_layout_staged.inner - ) + sAB12 = storage.sAB12.get_tensor(ab12_smem_layout_staged.outer, swizzle=ab12_smem_layout_staged.inner) # (EPI_TILE_M, EPI_TILE_N, STAGE) - sC = storage.sC.get_tensor( - c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner - ) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) # (MMA, MMA_M, MMA_K, STAGE) - sA = storage.sA.get_tensor( - a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner - ) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) # (MMA, MMA_N, MMA_K, STAGE) - sB = storage.sB.get_tensor( - b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner - ) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) # # Compute multicast mask for A/B buffer full @@ -675,20 +628,14 @@ def kernel( a_full_mcast_mask = None b_full_mcast_mask = None if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): - a_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 - ) - b_full_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 - ) + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) # # Local_tile partition global tensors # # (bM, bK, RestM, RestK, RestL) - gA_mkl = cute.local_tile( - mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) - ) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, @@ -696,9 +643,7 @@ def kernel( (None, None, None), # Half of the tile ) # (bM, bN, RestM, RestN, RestL) - gAB12_mnl = cute.local_tile( - mAB12_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) - ) + gAB12_mnl = cute.local_tile(mAB12_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), @@ -723,9 +668,7 @@ def kernel( # Partition global/shared tensor for TMA load A/B # # TMA load A partition_S/D - a_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape - ) + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( @@ -736,9 +679,7 @@ def kernel( cute.group_modes(tCgA, 0, 3), ) # TMA load B partition_S/D - b_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape - ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( @@ -759,9 +700,7 @@ def kernel( # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) - tCtAcc_fake = tiled_mma.make_fragment_C( - cute.append(acc_shape, self.num_acc_stage) - ) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) # # Cluster wait before tensor memory alloc @@ -769,9 +708,7 @@ def kernel( if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_wait() else: - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta - ) + cute.arch.barrier(barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta) # # Specialized TMA load warp @@ -781,14 +718,10 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_ab_stage - ) + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -803,29 +736,21 @@ def kernel( # Slice to per mma tile index # # ((atom_v, rest_v), RestK) - tAgA_slice = tAgA[ - (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) - ] + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] # ((atom_v, rest_v), RestK) - tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_block_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Tma load loop # for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty - ab_pipeline.producer_acquire( - ab_producer_state, peek_ab_empty_status - ) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) # TMA load A/B cute.copy( @@ -847,9 +772,7 @@ def kernel( ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < k_block_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire( - ab_producer_state - ) + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # # Advance to next tile @@ -889,17 +812,11 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - ab_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_ab_stage - ) - acc_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_acc_stage - ) + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) while work_tile.is_valid_tile: # Get tile coord from tile scheduler @@ -918,9 +835,7 @@ def kernel( ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_block_cnt and is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Wait for accumulator buffer empty @@ -939,9 +854,7 @@ def kernel( for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): if is_leader_cta: # Conditionally wait for AB buffer full - ab_pipeline.consumer_wait( - ab_consumer_state, peek_ab_full_status - ) + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) # tCtAcc += tCrA * tCrB num_kphases = cute.size(tCrA, mode=[2]) @@ -971,9 +884,7 @@ def kernel( peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < k_block_cnt: if is_leader_cta: - peek_ab_full_status = ab_pipeline.consumer_try_wait( - ab_consumer_state - ) + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) # # Async arrive accumulator buffer full @@ -1059,10 +970,8 @@ def kernel( tTR_rAB12 = cute.make_rmem_tensor(tTR_rAcc.shape, self.ab12_dtype) tTR_rAB12_1 = cute.make_rmem_tensor(tTR_rAcc.shape, self.ab12_dtype) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rAB12, tRS_rAB12_1, tRS_rC, tRS_sAB12, tRS_sC = ( - self.epilog_smem_copy_and_partition( - tiled_copy_t2r, tTR_rAB12, tTR_rAB12_1, tTR_rC, epi_tidx, sAB12, sC - ) + tiled_copy_r2s, tRS_rAB12, tRS_rAB12_1, tRS_rC, tRS_sAB12, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rAB12, tTR_rAB12_1, tTR_rC, epi_tidx, sAB12, sC ) ( @@ -1087,14 +996,10 @@ def kernel( # # Persistent tile scheduling loop # - tile_sched = utils.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( @@ -1138,9 +1043,7 @@ def kernel( ] # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] # # Wait for accumulator buffer full # @@ -1156,33 +1059,19 @@ def kernel( for subtile_idx in cutlass.range(0, subtile_cnt, 2): # # Load accumulator from tensor memory buffer to register - tTR_tAcc_mn = tTR_tAcc[ - (None, None, None, subtile_idx) - ] # input tile0 - tTR_tAcc_mn1 = tTR_tAcc[ - (None, None, None, subtile_idx + 1) - ] # input tile 1 - cute.copy( - tiled_copy_t2r, tTR_tAcc_mn1, tTR_rAcc1 - ) # copy input tile 1 - cute.copy( - tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc - ) # copy input tile 0 + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] # input tile0 + tTR_tAcc_mn1 = tTR_tAcc[(None, None, None, subtile_idx + 1)] # input tile 1 + cute.copy(tiled_copy_t2r, tTR_tAcc_mn1, tTR_rAcc1) # copy input tile 1 + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) # copy input tile 0 # Convert to C type - acc_vec0 = tiled_copy_r2s.retile( - tTR_rAcc - ).load() # copy input tile 0 - acc_vec1 = tiled_copy_r2s.retile( - tTR_rAcc1 - ).load() # copy input tile 1 + acc_vec0 = tiled_copy_r2s.retile(tTR_rAcc).load() # copy input tile 0 + acc_vec1 = tiled_copy_r2s.retile(tTR_rAcc1).load() # copy input tile 1 acc_vec0 = acc_vec0 * alpha acc_vec1 = acc_vec1 * alpha # Use exp2 with log2(e) conversion since cute.math.exp is not available # exp(x) = 2^(x * log2(e)) - gate_rcp = (1 + cute.math.exp2(-1 * acc_vec1 * LOG2_E, True)).to( - self.acc_dtype - ) + gate_rcp = (1 + cute.math.exp2(-1 * acc_vec1 * LOG2_E, True)).to(self.acc_dtype) res = cute.make_rmem_tensor(gate_rcp.shape, cutlass.Float32) res.store(gate_rcp) @@ -1202,12 +1091,8 @@ def kernel( tRS_rC.store(acc_vec_c) # Store AB12 and C to shared memory - ab12_buffer0 = ( - num_prev_subtiles + subtile_idx - ) % self.num_ab12_stage - ab12_buffer1 = ( - num_prev_subtiles + subtile_idx + 1 - ) % self.num_ab12_stage + ab12_buffer0 = (num_prev_subtiles + subtile_idx) % self.num_ab12_stage + ab12_buffer1 = (num_prev_subtiles + subtile_idx + 1) % self.num_ab12_stage c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage cute.copy( @@ -1284,18 +1169,12 @@ def kernel( if warp_idx == self.epilog_warp_id[0]: cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) epilog_threads = 32 * len(self.epilog_warp_id) - cute.arch.barrier( - barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads - ) + cute.arch.barrier(barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads) if warp_idx == self.epilog_warp_id[0]: if use_2cta_instrs: - cute.arch.mbarrier_arrive( - tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 - ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - cute.arch.dealloc_tmem( - tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs - ) + cute.arch.dealloc_tmem(tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs) # # Wait for C store complete # @@ -1350,27 +1229,19 @@ def epilog_tmem_copy_and_partition( epi_tile, ) # (EPI_TILE_M, EPI_TILE_N) - tiled_copy_t2r = tcgen05.make_tmem_copy( - copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] - ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gAB12_mnl_epi = cute.flat_divide( - gAB12_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) + gAB12_mnl_epi = cute.flat_divide(gAB12_mnl[((None, None), 0, 0, None, None, None)], epi_tile) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gAB12 = thr_copy_t2r.partition_D(gAB12_mnl_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_rmem_tensor( - tTR_gAB12[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype - ) - tTR_rAcc1 = cute.make_rmem_tensor( - tTR_gAB12[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype - ) + tTR_rAcc = cute.make_rmem_tensor(tTR_gAB12[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + tTR_rAcc1 = cute.make_rmem_tensor(tTR_gAB12[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc, tTR_rAcc1 def epilog_smem_copy_and_partition( @@ -1382,9 +1253,7 @@ def epilog_smem_copy_and_partition( tidx: cutlass.Int32, sAB12: cute.Tensor, sC: cute.Tensor, - ) -> Tuple[ - cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor - ]: + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]: """ Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). @@ -1411,9 +1280,7 @@ def epilog_smem_copy_and_partition( - tRS_sC: The partitioned tensor C (smem destination) :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] """ - copy_atom_r2s = sm100_utils.get_smem_store_op( - self.ab12_layout, self.ab12_dtype, self.acc_dtype, tiled_copy_t2r - ) + copy_atom_r2s = sm100_utils.get_smem_store_op(self.ab12_layout, self.ab12_dtype, self.acc_dtype, tiled_copy_t2r) tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) @@ -1436,9 +1303,7 @@ def epilog_gmem_copy_and_partition( epi_tile_c: cute.Tile, sAB12: cute.Tensor, sC: cute.Tensor, - ) -> Tuple[ - cute.CopyAtom, cute.CopyAtom, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor - ]: + ) -> Tuple[cute.CopyAtom, cute.CopyAtom, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]: """Make tiledCopy for global memory store, then use it to: - partition register array (source) and global memory (destination) for none TMA store version; - partition shared memory (source) and global memory (destination) for TMA store version. @@ -1472,12 +1337,8 @@ def epilog_gmem_copy_and_partition( :rtype: Tuple[cute.CopyAtom, cute.CopyAtom, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] """ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - gAB12_epi = cute.flat_divide( - gAB12_mnl[((None, None), 0, 0, None, None, None)], epi_tile - ) - gC_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile_c - ) + gAB12_epi = cute.flat_divide(gAB12_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile_c) tma_atom_ab12 = atom1 tma_atom_c = atom2 sAB12_for_tma_partition = cute.group_modes(sAB12, 0, 2) @@ -1580,13 +1441,9 @@ def _compute_stages( epi_tile_c, 1, ) - ab_bytes_per_stage = cute.size_in_bytes( - a_dtype, a_smem_layout_stage_one - ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + ab_bytes_per_stage = cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) mbar_helpers_bytes = 1024 - ab12_bytes_per_stage = cute.size_in_bytes( - ab12_dtype, ab12_smem_layout_staged_one - ) + ab12_bytes_per_stage = cute.size_in_bytes(ab12_dtype, ab12_smem_layout_staged_one) ab12_bytes = ab12_bytes_per_stage * num_ab12_stage c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) c_bytes = c_bytes_per_stage * num_c_stage @@ -1595,9 +1452,7 @@ def _compute_stages( # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial AB12/C stages bytes # Divide remaining by bytes needed per A/B stage - num_ab_stage = ( - smem_capacity // occupancy - (mbar_helpers_bytes + ab12_bytes + c_bytes) - ) // ab_bytes_per_stage + num_ab_stage = (smem_capacity // occupancy - (mbar_helpers_bytes + ab12_bytes + c_bytes)) // ab_bytes_per_stage # Refine epilogue stages: # Calculate remaining smem after allocating for A/B stages and reserved bytes @@ -1610,12 +1465,8 @@ def _compute_stages( # Assert: Check total shared memory usage doesn't exceed capacity total_ab_smem = occupancy * ab_bytes_per_stage * num_ab_stage - total_output_smem = occupancy * ( - ab12_bytes_per_stage * num_ab12_stage + c_bytes_per_stage * num_c_stage - ) - total_smem_used = ( - total_ab_smem + total_output_smem + occupancy * mbar_helpers_bytes - ) + total_output_smem = occupancy * (ab12_bytes_per_stage * num_ab12_stage + c_bytes_per_stage * num_c_stage) + total_smem_used = total_ab_smem + total_output_smem + occupancy * mbar_helpers_bytes return num_acc_stage, num_ab_stage, num_ab12_stage, num_c_stage @@ -1647,12 +1498,8 @@ def _compute_grid( num_ctas_mnl = gab12[(0, (None, None, None))].shape cluster_shape_mnl = (*cluster_shape_mn, 1) - tile_sched_params = utils.PersistentTileSchedulerParams( - num_ctas_mnl, cluster_shape_mnl - ) - grid = utils.StaticPersistentTileScheduler.get_grid_shape( - tile_sched_params, max_active_clusters - ) + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) return tile_sched_params, grid @@ -1721,15 +1568,9 @@ def __call__( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x / (1 + math.exp(-x, True)), ): - a_cute = cute.make_tensor( - a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order) - ) - b_cute = cute.make_tensor( - b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order) - ) - ab12_cute = cute.make_tensor( - ab12_ptr, layout=cute.make_ordered_layout(ab12_shape, order=ab12_order) - ) + a_cute = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order)) + b_cute = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order)) + ab12_cute = cute.make_tensor(ab12_ptr, layout=cute.make_ordered_layout(ab12_shape, order=ab12_order)) self.kernel( a_cute, b_cute, diff --git a/python/cudnn/grouped_gemm/__init__.py b/python/cudnn/grouped_gemm/__init__.py new file mode 100644 index 00000000..bba4367e --- /dev/null +++ b/python/cudnn/grouped_gemm/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from .grouped_gemm_swiglu.api import ( + GroupedGemmSwigluSm100, + grouped_gemm_swiglu_wrapper_sm100, +) + +__all__ = [ + "GroupedGemmSwigluSm100", + "grouped_gemm_swiglu_wrapper_sm100", +] diff --git a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/__init__.py b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/__init__.py new file mode 100644 index 00000000..36c52656 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +""" +Grouped GEMM SwiGLU Kernel Module + +This module provides the forward grouped GEMM with SwiGLU activation +for MoE (Mixture of Experts) workloads on SM100+ GPUs. +""" + +from .api import ( + GroupedGemmSwigluSm100, + grouped_gemm_swiglu_wrapper_sm100, +) + +__all__ = [ + "GroupedGemmSwigluSm100", + "grouped_gemm_swiglu_wrapper_sm100", +] diff --git a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py new file mode 100644 index 00000000..eb4aec79 --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py @@ -0,0 +1,999 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +API for Grouped GEMM SwiGLU Forward Kernel (SM100+) + +This module provides the API class for contiguous grouped block-scaled GEMM +with SwiGLU activation for MoE (Mixture of Experts) workloads. +""" + +from .grouped_gemm_swiglu_quant import ( + BlockScaledContiguousGroupedGemmKernel, + BlockScaledContiguousGroupedGemmKernelNoDlpack, +) +from cuda.bindings import driver as cuda +import torch +from typing import Tuple, Optional + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack, make_ptr +from packaging import version + +from cudnn.datatypes import _convert_to_cutlass_data_type +from cudnn.api_base import APIBase, TupleDict, ceil_div, is_power_of_2 + + +class GroupedGemmSwigluSm100(APIBase): + """API class for Grouped GEMM SwiGLU forward operation on SM100+ GPUs. + + This kernel performs contiguous grouped block-scaled GEMM with SwiGLU activation, + designed for MoE (Mixture of Experts) workloads. + + Key features: + - Supports variable M per group (aligned to cta_tile_m) + - Contiguous memory layout for A and D tensors + - Block-scaled quantization support (MXF8, MXF4, NVF4) + + Example: + >>> api = GroupedGemmSwigluSm100( + ... sample_a=a_tensor, + ... ... + ... ) + >>> api.check_support() + >>> api.compile() + >>> api.execute(..., stream) + """ + + def __init__( + self, + sample_a: torch.Tensor, + sample_b: torch.Tensor, + sample_c: torch.Tensor, + sample_d: torch.Tensor, + sample_sfa: torch.Tensor, + sample_sfb: torch.Tensor, + sample_tile_idx_to_expert_idx: torch.Tensor, + sample_num_non_exiting_tiles: torch.Tensor, + sample_alpha: torch.Tensor, + # Required quantization output (column-quantized D tensor) + sample_d_col: torch.Tensor, + # Optional quantization output arguments + sample_sfd_row: Optional[torch.Tensor] = None, + sample_sfd_col: Optional[torch.Tensor] = None, + sample_amax: Optional[torch.Tensor] = None, + sample_norm_const: Optional[torch.Tensor] = None, + sample_prob: Optional[torch.Tensor] = None, + sample_m_split_cumsum: Optional[torch.Tensor] = None, + # Configuration + acc_dtype: torch.dtype = torch.float32, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + ): + """Initialize the GroupedGemmSwigluSm100 API. + + :param sample_a: Sample A tensor (valid_m, k, 1) + :param sample_b: Sample B tensor (n, k, l) where l = num_groups + :param sample_c: Sample C tensor for intermediate storage + :param sample_d: Sample D output tensor (valid_m, n/2, 1) after SwiGLU + :param sample_sfa: Sample scale factor A tensor + :param sample_sfb: Sample scale factor B tensor + :param sample_tile_idx_to_expert_idx: Mapping from tile index to expert/group index + :param sample_num_non_exiting_tiles: Number of valid tiles + :param sample_alpha: Per-group alpha scaling factors + :param sample_d_col: Column-quantized D tensor (required for quant kernel) + :param sample_sfd_row: Optional row scale factor for D + :param sample_sfd_col: Optional column scale factor for D + :param sample_amax: Optional amax tensor for quantization + :param sample_norm_const: Optional normalization constant + :param sample_prob: Optional probability tensor for gating + :param sample_m_split_cumsum: Optional m split cumulative sum tensor. Required when discrete_col_sfd is True. + :param acc_dtype: Accumulator data type + :param mma_tiler_mn: MMA tiler shape (M, N) + :param cluster_shape_mn: Cluster shape (M, N) + :param sf_vec_size: Scale factor vector size + :param vector_f32: Use vectorized f32 operations + :param m_aligned: Alignment for group M dimension + :param discrete_col_sfd: Boolean, True to generate discrete col-major scale factor tensor. Only applies when already output scale factor tensors are provided. + """ + super().__init__() + + self._logger.warning("GroupedGemmSwigluSm100 is an experimental API") + self._logger.debug("Entering __init__") + + # Store sample tensors + self.sample_a = sample_a + self.sample_b = sample_b + self.sample_c = sample_c + self.sample_d = sample_d + self.sample_sfa = sample_sfa + self.sample_sfb = sample_sfb + self.sample_tile_idx_to_expert_idx = sample_tile_idx_to_expert_idx + self.sample_num_non_exiting_tiles = sample_num_non_exiting_tiles + self.sample_alpha = sample_alpha + + # Optional quantization outputs + self.sample_d_col = sample_d_col + self.sample_sfd_row = sample_sfd_row + self.sample_sfd_col = sample_sfd_col + self.sample_amax = sample_amax + self.sample_norm_const = self._unpad_tensor_to_ndim(sample_norm_const, 1, "norm_const") + self.sample_prob = sample_prob + self.sample_m_split_cumsum = sample_m_split_cumsum + + # Configuration + self.acc_dtype = acc_dtype + self.mma_tiler_mn = mma_tiler_mn + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + if cluster_shape_mn is None: + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + else: + self.cluster_shape_mn = cluster_shape_mn + self.sf_vec_size = sf_vec_size + self.vector_f32 = vector_f32 + self.m_aligned = m_aligned + self.discrete_col_sfd = discrete_col_sfd + + # Determine kernel variant based on sample tensor dtypes + # NoDlpack kernels are required for: + # - FP4 dtypes (any of ab_dtype, c_dtype, d_dtype) + # - FP8 dtypes on PyTorch < 2.10.0 + ab_dtype = self.sample_a.dtype + c_dtype = self.sample_c.dtype + d_dtype = self.sample_d.dtype + torch_version = version.parse(torch.__version__) + is_ab_fp4 = self._is_fp4x2(ab_dtype) + is_c_fp4 = self._is_fp4x2(c_dtype) + is_d_fp4 = self._is_fp4x2(d_dtype) + is_ab_fp8 = self._is_fp8(ab_dtype) + is_c_fp8 = self._is_fp8(c_dtype) + is_d_fp8 = self._is_fp8(d_dtype) + _fp8_dlpack_supported = version.parse(torch_version.base_version) >= version.parse("2.10.0") + use_no_dlpack_kernel = is_ab_fp4 or is_c_fp4 or is_d_fp4 or ((is_ab_fp8 or is_c_fp8 or is_d_fp8) and not _fp8_dlpack_supported) + + if use_no_dlpack_kernel: + self._logger.debug("Using NoDlpack kernel due to FP4 dtype or FP8 dtype on incompatible torch version") + self._kernel = BlockScaledContiguousGroupedGemmKernelNoDlpack + else: + self._kernel = BlockScaledContiguousGroupedGemmKernel + + self._interpret_uint8_as_fp4x2 = True + self._logger.debug(f"__init__ completed") + + def check_support(self) -> bool: + """Check if the kernel configuration is supported. + + :return: True if supported, raises exception otherwise + """ + self._logger.debug("Entering check_support") + + all_none = all(x is None for x in [self.sample_sfd_row, self.sample_sfd_col, self.sample_norm_const]) + none_none = all(x is not None for x in [self.sample_sfd_row, self.sample_sfd_col, self.sample_norm_const]) + self._value_error_if( + not (all_none or none_none), + "sample_sfd_row, sample_sfd_col, and norm_const must be all None or all not None", + ) + self.generate_sfd = none_none + if self.discrete_col_sfd and not self.generate_sfd: + self._logger.warning("discrete_col_sfd is True but generate_sfd is False, discrete_col_sfd will be ignored") + self.discrete_col_sfd = False + self._value_error_if(self.discrete_col_sfd and self.sample_m_split_cumsum is None, "sample_m_split_cumsum is required when discrete_col_sfd is True") + + self._logger.debug("Checking tensor shapes and strides") + tensor_m, k, _one = self._tensor_shape(self.sample_a, name="sample_a") + n, _, l = self._tensor_shape(self.sample_b, name="sample_b") + _, _, _one = self._tensor_shape(self.sample_c, name="sample_c") + _, n_2, _one = self._tensor_shape(self.sample_d, name="sample_d") + + self._check_tensor_shape(self.sample_a, (tensor_m, k, 1), "A") + self._check_tensor_shape(self.sample_b, (n, k, l), "B") + self._check_tensor_shape(self.sample_c, (tensor_m, n, 1), "C") + self._check_tensor_shape(self.sample_d, (tensor_m, n // 2, 1), "D") + + self._check_tensor_shape(self.sample_d_col, (tensor_m, n // 2, 1), "D_col") + + rest_k = ceil_div(ceil_div(k, self.sf_vec_size), 4) + self._check_tensor_shape(self.sample_sfa, (32, 4, ceil_div(tensor_m, 128), 4, rest_k, 1), "SFA") + self._check_tensor_shape(self.sample_sfb, (32, 4, ceil_div(n, 128), 4, rest_k, l), "SFB") + rest_n2 = ceil_div(ceil_div(n // 2, self.sf_vec_size), 4) + self._check_tensor_shape(self.sample_sfd_row, (32, 4, ceil_div(tensor_m, 128), 4, rest_n2, 1), "SFD_row") + rest_m = ceil_div(ceil_div(tensor_m, self.sf_vec_size), 4) + self._check_tensor_shape(self.sample_sfd_col, (32, 4, ceil_div(n // 2, 128), 4, rest_m, 1), "SFD_col") + + self._check_tensor_shape(self.sample_alpha, (l,), "alpha") + self._check_tensor_shape(self.sample_prob, (tensor_m, 1, 1), "prob") + self._check_tensor_shape(self.sample_amax, (l, 1), "amax") + self._check_tensor_shape(self.sample_num_non_exiting_tiles, (1,), "num_non_exiting_tiles") + self._check_tensor_shape(self.sample_norm_const, (1,), "norm_const") + self._check_tensor_shape(self.sample_m_split_cumsum, (l + 1,), "m_split_cumsum") + + _, self.a_stride_order = self._check_tensor_stride(self.sample_a, stride=[(k, 1, tensor_m * k)], extra_error_msg="A must have k-major layout") + _, self.b_stride_order = self._check_tensor_stride(self.sample_b, stride=[(k, 1, n * k)], extra_error_msg="B must have k-major layout") + _, self.c_stride_order = self._check_tensor_stride(self.sample_c, stride=[(n, 1, tensor_m * n)], extra_error_msg="C must have n-major layout") + _, self.d_stride_order = self._check_tensor_stride(self.sample_d, stride=[(n_2, 1, tensor_m * n_2)], extra_error_msg="D must have n-major layout") + _, self.d_col_stride_order = self._check_tensor_stride( + self.sample_d_col, stride=[(n_2, 1, tensor_m * n_2)], extra_error_msg="D_col must have n-major layout" + ) + self.cd_stride_order = self.c_stride_order + + self._logger.debug("Checking data types") + self.ab_dtype = self._check_dtype( + self.sample_a, + dtype=[ + torch.float4_e2m1fn_x2, + torch.uint8, + torch.float8_e5m2, + torch.float8_e4m3fn, + ], + name="A/B", + ) + self._check_dtype(self.sample_b, dtype=self.ab_dtype, name="B", extra_error_msg="B must have the same dtype as A") + + self.sf_dtype = self._check_dtype( + self.sample_sfa, + dtype=[torch.float8_e8m0fnu, torch.float8_e4m3fn], + name="SFA/SFB/SFD_row/SFD_col", + ) + self._check_dtype(self.sample_sfb, dtype=self.sf_dtype, name="SFB", extra_error_msg="SFB must have the same dtype as SFA") + self._check_dtype(self.sample_sfd_row, dtype=self.sf_dtype, name="SFD_row", extra_error_msg="SFD_row must have the same dtype as SFA") + self._check_dtype(self.sample_sfd_col, dtype=self.sf_dtype, name="SFD_col", extra_error_msg="SFD_col must have the same dtype as SFA") + + self._value_error_if(self.sf_vec_size not in [16, 32], f"sf_vec_size must be 16 or 32, got {self.sf_vec_size}") + self._value_error_if( + self.sf_dtype in [torch.float8_e4m3fn] and self.sf_vec_size == 32, + f"sf_dtype {self.sf_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported", + ) + self._value_error_if( + self._is_fp8(self.ab_dtype) and self.sf_vec_size == 16, f"ab_dtype {self.ab_dtype} and sf_vec_size {self.sf_vec_size} combination is not supported" + ) + + self._check_dtype(self.acc_dtype, dtype=torch.float32, name="Accumulator", extra_error_msg="Accumulator must be float32") + self.c_dtype = self._check_dtype( + self.sample_c, + dtype=[torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2, torch.float4_e2m1fn_x2], + name="C", + extra_error_msg="C must have the same dtype as A", + ) + + if self._is_fp4x2(self.ab_dtype): + self.d_dtype = self._check_dtype( + self.sample_d, dtype=[torch.bfloat16, torch.float32], name="D", extra_error_msg="D must be bf16 or float32 when ab_dtype is fp4" + ) + else: + self.d_dtype = self._check_dtype( + self.sample_d, + dtype=[ + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float4_e2m1fn_x2, + ], # torch.float32 fails non-deterministicly + name="D", + ) + self._check_dtype(self.sample_d_col, dtype=self.d_dtype, name="D_col", extra_error_msg="D_col must have the same dtype as D") + + self._not_implemented_error_if( + self._is_fp4x2(self.ab_dtype) and self.sf_vec_size == 16 and self.d_dtype == torch.float32, # Fails to compile + f"Invalid configuration: fp4 ab_dtype, sf_vec_size 16, d_dtype float32 is not supported. Please use sf_vec_size 32 or d_dtype bf16 instead", + ) + + self._logger.debug("Checking MMA tile shape and cluster shape") + self._value_error_if( + not self.use_2cta_instrs and self.mma_tiler_mn[0] not in [64, 128], + f"MMA tiler M must be 64 or 128 when use_2cta_instrs=False, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if( + self.use_2cta_instrs and self.mma_tiler_mn[0] not in [128, 256], + f"MMA tiler M must be 128 or 256 when use_2cta_instrs=True, got {self.mma_tiler_mn[0]}", + ) + self._value_error_if(self.mma_tiler_mn[1] not in [128, 256], f"MMA tiler N must be 128 or 256, got {self.mma_tiler_mn[1]}") + self._value_error_if( + self.cluster_shape_mn[0] % (2 if self.use_2cta_instrs else 1) != 0, + f"cluster_shape_mn[0] must be divisible by 2 when use_2cta_instrs=True, got {self.cluster_shape_mn[0]}", + ) + self._value_error_if( + not ( + self.cluster_shape_mn[0] * self.cluster_shape_mn[1] <= 16 + and self.cluster_shape_mn[0] > 0 + and self.cluster_shape_mn[1] > 0 + and self.cluster_shape_mn[0] <= 4 + and self.cluster_shape_mn[1] <= 4 + and is_power_of_2(self.cluster_shape_mn[0]) + and is_power_of_2(self.cluster_shape_mn[1]) + ), + f"Invalid cluster shape: expected values to be powers of 2 and cluster_shape_mn[0] * cluster_shape_mn[1] <= 16, got {self.cluster_shape_mn[0]},{self.cluster_shape_mn[1]}", + ) + cluster_tiler_m = (self.cluster_shape_mn[0] // (2 if self.use_2cta_instrs else 1)) * self.mma_tiler_mn[0] + # Skip invalid cluster tiler shape since contiguous layout can't handle oob access + # The contiguous layout means the aligned data is stored in a contiguous manner. + # It can't handle runtime oob when alignment is not align with the tile_M, + # since the problem shape of TMA store can't be changed at runtime. + self._value_error_if(cluster_tiler_m not in [128, 256], f"Invalid cluster tiler shape: expected cluster_tiler_m in {{128, 256}}, got {cluster_tiler_m}") + # Check if m_aligned is a multiple of cluster_tiler_m + # This ensures that each group's M dimension (which is a multiple of m_aligned) + # won't be split across tiles, preventing a single tile from loading data + # from multiple groups (which would access wrong B matrix data) + self._value_error_if( + self.m_aligned % self.mma_tiler_mn[0] != 0, + f"Invalid m_aligned: expected m_aligned to be divisible by mma_tiler_mn[0], got {self.m_aligned} % {self.mma_tiler_mn[0]} != 0", + ) + + self._logger.debug("Checking tensor alignment") + + def check_contigous_16B_alignment(dtype, stride_order, tensor_shape): + is_mode0_major = stride_order == (0, 1, 2) + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // (_convert_to_cutlass_data_type(dtype, interpret_uint8_as_fp4x2=self._interpret_uint8_as_fp4x2).width) + return num_major_elements % num_contiguous_elements == 0 + + self._value_error_if( + not ( + check_contigous_16B_alignment(self.ab_dtype, self.a_stride_order, (tensor_m, k, l)) + and check_contigous_16B_alignment(self.ab_dtype, self.b_stride_order, (n, k, l)) + and check_contigous_16B_alignment(self.d_dtype, self.cd_stride_order, (tensor_m, n, l)) + ), + "Invalid tensor alignment: tensors must be 16B aligned", + ) + + # Disabled configurations + self._not_implemented_error_if( + (self._is_fp8(self.ab_dtype)) and (self.mma_tiler_mn[1] == 128) and (self._is_fp8(self.d_dtype)), + f"Invalid configuration: fp8 ab_dtype and sf_vec_size 32 with mma_tiler_mn[1] == 128 and fp8 d_dtype is not supported" + + f"Please use mma_tiler_mn[1] == 256 instead", + ) + self._not_implemented_error_if( + self._is_fp4x2(self.ab_dtype) and (self.c_dtype not in [torch.float16, torch.bfloat16]), + f"Invalid configuration: for fp4 ab_dtype, c_dtype must be float16 or bfloat16, got {self.c_dtype}", + ) + + # Check environment + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + compute_capability = major * 10 + minor + if compute_capability < 100: + raise RuntimeError(f"GroupedGemmSwiglu requires SM100+ compute capability, " f"but found SM{compute_capability} on device {device}") + + self._is_supported = True + self._logger.debug("check_support completed successfully") + return True + + def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: + """Compile the kernel. + + :param current_stream: CUDA stream to use + """ + self._logger.debug("Entering compile") + current_stream = self._get_default_stream(current_stream) + self._ensure_support_checked() + + gemm_swiglu = self._kernel( + sf_vec_size=self.sf_vec_size, + acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), + use_2cta_instrs=self.use_2cta_instrs, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vector_f32=self.vector_f32, + generate_sfd=self.generate_sfd, + discrete_col_sfd=self.discrete_col_sfd, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + + if self._kernel is BlockScaledContiguousGroupedGemmKernel: + self._logger.debug("Compiling grouped_gemm_swiglu kernel (dlpack)") + self._compiled_kernel = cute.compile( + gemm_swiglu, + a=from_dlpack(self.sample_a, assumed_align=16), + b=from_dlpack(self.sample_b, assumed_align=16), + c=from_dlpack(self.sample_c, assumed_align=16), + d=from_dlpack(self.sample_d, assumed_align=16), + d_col=from_dlpack(self.sample_d_col, assumed_align=16) if self.sample_d_col is not None else None, + sfa=from_dlpack(self.sample_sfa, assumed_align=16), + sfb=from_dlpack(self.sample_sfb, assumed_align=16), + sfd_row_tensor=from_dlpack(self.sample_sfd_row, assumed_align=16) if self.sample_sfd_row is not None else None, + sfd_col_tensor=from_dlpack(self.sample_sfd_col, assumed_align=16) if self.sample_sfd_col is not None else None, + amax_tensor=from_dlpack(self.sample_amax, assumed_align=16) if self.sample_amax is not None else None, + norm_const_tensor=from_dlpack(self.sample_norm_const) if self.sample_norm_const is not None else None, + tile_idx_to_expert_idx=from_dlpack(self.sample_tile_idx_to_expert_idx, assumed_align=16), + num_non_exiting_tiles=from_dlpack(self.sample_num_non_exiting_tiles, assumed_align=16), + m_split_cumsum=from_dlpack(self.sample_m_split_cumsum, assumed_align=16) if self.sample_m_split_cumsum is not None else None, + alpha=from_dlpack(self.sample_alpha, assumed_align=16), + prob=from_dlpack(self.sample_prob, assumed_align=16) if self.sample_prob is not None else None, + max_active_clusters=max_active_clusters, + stream=current_stream, + ) + elif self._kernel is BlockScaledContiguousGroupedGemmKernelNoDlpack: + self._logger.debug("Compiling grouped_gemm_swiglu kernel (no_dlpack)") + # Create cute pointers/tensors manually to avoid DLPack requirements + a_ptr, a_shape, a_order = self._make_cute_tensor_descriptor(self.sample_a, name="A") + b_ptr, b_shape, b_order = self._make_cute_tensor_descriptor(self.sample_b, name="B") + c_ptr, c_shape, c_order = self._make_cute_tensor_descriptor(self.sample_c, name="C") + d_ptr, d_shape, d_order = self._make_cute_tensor_descriptor(self.sample_d, name="D") + d_col_ptr, d_col_shape, d_col_order = self._make_cute_tensor_descriptor(self.sample_d_col, name="D_col") + sfa_ptr, sfa_shape, sfa_order = self._make_cute_tensor_descriptor(self.sample_sfa, name="SFA") + sfb_ptr, sfb_shape, sfb_order = self._make_cute_tensor_descriptor(self.sample_sfb, name="SFB") + sfd_row_ptr, sfd_row_shape, sfd_row_order = self._make_cute_tensor_descriptor(self.sample_sfd_row, name="SFD_row") + sfd_col_ptr, sfd_col_shape, sfd_col_order = self._make_cute_tensor_descriptor(self.sample_sfd_col, name="SFD_col") + amax_ptr, amax_shape, amax_order = self._make_cute_tensor_descriptor(self.sample_amax, name="amax") + norm_const_ptr, norm_const_shape, norm_const_order = self._make_cute_tensor_descriptor(self.sample_norm_const, name="norm_const") + tile_idx_ptr, tile_idx_shape, tile_idx_order = self._make_cute_tensor_descriptor(self.sample_tile_idx_to_expert_idx, name="tile_idx") + num_tiles_ptr, num_tiles_shape, num_tiles_order = self._make_cute_tensor_descriptor(self.sample_num_non_exiting_tiles, name="num_tiles") + m_split_cumsum_ptr, m_split_cumsum_shape, m_split_cumsum_order = self._make_cute_tensor_descriptor( + self.sample_m_split_cumsum, name="m_split_cumsum" + ) + alpha_ptr, alpha_shape, alpha_order = self._make_cute_tensor_descriptor(self.sample_alpha, name="alpha") + prob_ptr, prob_shape, prob_order = self._make_cute_tensor_descriptor(self.sample_prob, name="prob") + + self._compiled_kernel = cute.compile( + gemm_swiglu, + a_ptr=a_ptr, + a_shape=a_shape, + a_order=a_order, + b_ptr=b_ptr, + b_shape=b_shape, + b_order=b_order, + c_ptr=c_ptr, + c_shape=c_shape, + c_order=c_order, + d_ptr=d_ptr, + d_shape=d_shape, + d_order=d_order, + d_col_ptr=d_col_ptr, + d_col_shape=d_col_shape, + d_col_order=d_col_order, + sfa_ptr=sfa_ptr, + sfa_shape=sfa_shape, + sfa_order=sfa_order, + sfb_ptr=sfb_ptr, + sfb_shape=sfb_shape, + sfb_order=sfb_order, + sfd_row_ptr=sfd_row_ptr, + sfd_row_shape=sfd_row_shape, + sfd_row_order=sfd_row_order, + sfd_col_ptr=sfd_col_ptr, + sfd_col_shape=sfd_col_shape, + sfd_col_order=sfd_col_order, + amax_ptr=amax_ptr, + amax_shape=amax_shape, + amax_order=amax_order, + norm_const_ptr=norm_const_ptr, + norm_const_shape=norm_const_shape, + norm_const_order=norm_const_order, + tile_idx_to_expert_idx_ptr=tile_idx_ptr, + tile_idx_to_expert_idx_shape=tile_idx_shape, + tile_idx_to_expert_idx_order=tile_idx_order, + num_non_exiting_tiles_ptr=num_tiles_ptr, + num_non_exiting_tiles_shape=num_tiles_shape, + num_non_exiting_tiles_order=num_tiles_order, + m_split_cumsum_ptr=m_split_cumsum_ptr, + m_split_cumsum_shape=m_split_cumsum_shape, + m_split_cumsum_order=m_split_cumsum_order, + alpha_ptr=alpha_ptr, + alpha_shape=alpha_shape, + alpha_order=alpha_order, + prob_ptr=prob_ptr, + prob_shape=prob_shape, + prob_order=prob_order, + max_active_clusters=max_active_clusters, + stream=current_stream, + ) + else: + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") + + self._logger.debug("Kernel compiled successfully") + + def execute( + self, + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + c_tensor: torch.Tensor, + d_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + tile_idx_to_expert_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + alpha_tensor: torch.Tensor, + d_col_tensor: Optional[torch.Tensor] = None, + sfd_row_tensor: Optional[torch.Tensor] = None, + sfd_col_tensor: Optional[torch.Tensor] = None, + amax_tensor: Optional[torch.Tensor] = None, + norm_const_tensor: Optional[torch.Tensor] = None, + prob_tensor: Optional[torch.Tensor] = None, + m_split_cumsum: Optional[torch.Tensor] = None, + current_stream: Optional[cuda.CUstream] = None, + skip_compile: bool = False, + ) -> None: + """Execute the compiled kernel. + + :param a_tensor: Input A tensor + :param b_tensor: Input B tensor (weights) + :param c_tensor: Intermediate C tensor + :param d_tensor: Output D tensor + :param sfa_tensor: Scale factor A + :param sfb_tensor: Scale factor B + :param tile_idx_to_expert_idx: Tile to expert mapping + :param num_non_exiting_tiles: Number of valid tiles + :param alpha_tensor: Per-group scaling factors + :param d_col_tensor: Optional column-quantized output + :param sfd_row_tensor: Optional row scale factor D + :param sfd_col_tensor: Optional column scale factor D + :param amax_tensor: Optional amax tensor + :param norm_const_tensor: Optional normalization constant + :param prob_tensor: Optional probability tensor + :param m_split_cumsum: Optional m split cumulative sum tensor + :param current_stream: CUDA stream + :param skip_compile: If True, use JIT execution without prior compilation + """ + self._logger.debug("Entering execute") + current_stream = self._get_default_stream(current_stream) + + norm_const_tensor = self._unpad_tensor_to_ndim(norm_const_tensor, 1, "norm_const") + + if not skip_compile: + self._runtime_error_if( + self._compiled_kernel is None, + "Kernel not compiled; call compile() first or use skip_compile=True", + ) + + if self._kernel is BlockScaledContiguousGroupedGemmKernel: + self._logger.debug("Executing grouped_gemm_swiglu kernel (dlpack)") + self._compiled_kernel( + a=from_dlpack(a_tensor, assumed_align=16), + b=from_dlpack(b_tensor, assumed_align=16), + c=from_dlpack(c_tensor, assumed_align=16), + d=from_dlpack(d_tensor, assumed_align=16), + d_col=from_dlpack(d_col_tensor, assumed_align=16) if d_col_tensor is not None else None, + sfa=from_dlpack(sfa_tensor, assumed_align=16), + sfb=from_dlpack(sfb_tensor, assumed_align=16), + sfd_row_tensor=from_dlpack(sfd_row_tensor, assumed_align=16) if sfd_row_tensor is not None else None, + sfd_col_tensor=from_dlpack(sfd_col_tensor, assumed_align=16) if sfd_col_tensor is not None else None, + amax_tensor=from_dlpack(amax_tensor, assumed_align=16) if amax_tensor is not None else None, + norm_const_tensor=from_dlpack(norm_const_tensor, assumed_align=16) if norm_const_tensor is not None else None, + tile_idx_to_expert_idx=from_dlpack(tile_idx_to_expert_idx, assumed_align=16), + num_non_exiting_tiles=from_dlpack(num_non_exiting_tiles, assumed_align=16), + m_split_cumsum=from_dlpack(m_split_cumsum, assumed_align=16) if m_split_cumsum is not None else None, + alpha=from_dlpack(alpha_tensor, assumed_align=16), + prob=from_dlpack(prob_tensor, assumed_align=16) if prob_tensor is not None else None, + stream=current_stream, + ) + elif self._kernel is BlockScaledContiguousGroupedGemmKernelNoDlpack: + self._logger.debug("Executing grouped_gemm_swiglu kernel (no_dlpack)") + # Create cute pointers manually to avoid DLPack requirements + a_ptr = self._make_cute_pointer(a_tensor, assumed_align=16) + b_ptr = self._make_cute_pointer(b_tensor, assumed_align=16) + c_ptr = self._make_cute_pointer(c_tensor, assumed_align=16) + d_ptr = self._make_cute_pointer(d_tensor, assumed_align=16) + d_col_ptr = self._make_cute_pointer(d_col_tensor, assumed_align=16) + sfa_ptr = self._make_cute_pointer(sfa_tensor, assumed_align=16) + sfb_ptr = self._make_cute_pointer(sfb_tensor, assumed_align=16) + sfd_row_ptr = self._make_cute_pointer(sfd_row_tensor, assumed_align=16) + sfd_col_ptr = self._make_cute_pointer(sfd_col_tensor, assumed_align=16) + amax_ptr = self._make_cute_pointer(amax_tensor, assumed_align=16) + norm_const_ptr = self._make_cute_pointer(norm_const_tensor, assumed_align=16) + tile_idx_ptr = self._make_cute_pointer(tile_idx_to_expert_idx, assumed_align=16) + num_tiles_ptr = self._make_cute_pointer(num_non_exiting_tiles, assumed_align=16) + m_split_cumsum_ptr = self._make_cute_pointer(m_split_cumsum, assumed_align=16) + alpha_ptr = self._make_cute_pointer(alpha_tensor, assumed_align=16) + prob_ptr = self._make_cute_pointer(prob_tensor, assumed_align=16) + + self._compiled_kernel( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + d_ptr=d_ptr, + d_col_ptr=d_col_ptr, + sfa_ptr=sfa_ptr, + sfb_ptr=sfb_ptr, + sfd_row_ptr=sfd_row_ptr, + sfd_col_ptr=sfd_col_ptr, + amax_ptr=amax_ptr, + norm_const_ptr=norm_const_ptr, + tile_idx_to_expert_idx_ptr=tile_idx_ptr, + num_non_exiting_tiles_ptr=num_tiles_ptr, + m_split_cumsum_ptr=m_split_cumsum_ptr, + alpha_ptr=alpha_ptr, + prob_ptr=prob_ptr, + stream=current_stream, + ) + else: + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") + else: + self._logger.debug("Executing without compiled kernel (JIT)") + generate_sfd = sfd_row_tensor is not None and sfd_col_tensor is not None and norm_const_tensor is not None + discrete_col_sfd = self.discrete_col_sfd and generate_sfd + + gemm_swiglu = self._kernel( + sf_vec_size=self.sf_vec_size, + acc_dtype=_convert_to_cutlass_data_type(self.acc_dtype), + use_2cta_instrs=self.use_2cta_instrs, + mma_tiler_mn=self.mma_tiler_mn, + cluster_shape_mn=self.cluster_shape_mn, + vector_f32=self.vector_f32, + generate_sfd=generate_sfd, + discrete_col_sfd=discrete_col_sfd, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(self.cluster_shape_mn[0] * self.cluster_shape_mn[1]) + + if self._kernel is BlockScaledContiguousGroupedGemmKernel: + self._logger.debug("JIT executing grouped_gemm_swiglu kernel (dlpack)") + gemm_swiglu( + a=from_dlpack(a_tensor, assumed_align=16), + b=from_dlpack(b_tensor, assumed_align=16), + c=from_dlpack(c_tensor, assumed_align=16), + d=from_dlpack(d_tensor, assumed_align=16), + d_col=from_dlpack(d_col_tensor, assumed_align=16) if d_col_tensor is not None else None, + sfa=from_dlpack(sfa_tensor, assumed_align=16), + sfb=from_dlpack(sfb_tensor, assumed_align=16), + sfd_row_tensor=from_dlpack(sfd_row_tensor, assumed_align=16) if sfd_row_tensor is not None else None, + sfd_col_tensor=from_dlpack(sfd_col_tensor, assumed_align=16) if sfd_col_tensor is not None else None, + amax_tensor=from_dlpack(amax_tensor, assumed_align=16) if amax_tensor is not None else None, + norm_const_tensor=from_dlpack(norm_const_tensor) if norm_const_tensor is not None else None, + tile_idx_to_expert_idx=from_dlpack(tile_idx_to_expert_idx, assumed_align=16), + num_non_exiting_tiles=from_dlpack(num_non_exiting_tiles, assumed_align=16), + m_split_cumsum=from_dlpack(m_split_cumsum, assumed_align=16) if self.m_split_cumsum is not None else None, + alpha=from_dlpack(alpha_tensor, assumed_align=16), + prob=from_dlpack(prob_tensor, assumed_align=16) if prob_tensor is not None else None, + max_active_clusters=max_active_clusters, + stream=current_stream, + ) + elif self._kernel is BlockScaledContiguousGroupedGemmKernelNoDlpack: + self._logger.debug("JIT executing grouped_gemm_swiglu kernel (no_dlpack)") + # Create cute tensor descriptors manually to avoid DLPack requirements + a_ptr, a_shape, a_order = self._make_cute_tensor_descriptor(a_tensor, name="A") + b_ptr, b_shape, b_order = self._make_cute_tensor_descriptor(b_tensor, name="B") + c_ptr, c_shape, c_order = self._make_cute_tensor_descriptor(c_tensor, name="C") + d_ptr, d_shape, d_order = self._make_cute_tensor_descriptor(d_tensor, name="D") + d_col_ptr, d_col_shape, d_col_order = self._make_cute_tensor_descriptor(d_col_tensor, name="D_col") + sfa_ptr, sfa_shape, sfa_order = self._make_cute_tensor_descriptor(sfa_tensor, name="SFA") + sfb_ptr, sfb_shape, sfb_order = self._make_cute_tensor_descriptor(sfb_tensor, name="SFB") + sfd_row_ptr, sfd_row_shape, sfd_row_order = self._make_cute_tensor_descriptor(sfd_row_tensor, name="SFD_row") + sfd_col_ptr, sfd_col_shape, sfd_col_order = self._make_cute_tensor_descriptor(sfd_col_tensor, name="SFD_col") + amax_ptr, amax_shape, amax_order = self._make_cute_tensor_descriptor(amax_tensor, name="amax") + norm_const_ptr, norm_const_shape, norm_const_order = self._make_cute_tensor_descriptor(norm_const_tensor, name="norm_const") + tile_idx_ptr, tile_idx_shape, tile_idx_order = self._make_cute_tensor_descriptor(tile_idx_to_expert_idx, name="tile_idx") + num_tiles_ptr, num_tiles_shape, num_tiles_order = self._make_cute_tensor_descriptor(num_non_exiting_tiles, name="num_tiles") + alpha_ptr, alpha_shape, alpha_order = self._make_cute_tensor_descriptor(alpha_tensor, name="alpha") + prob_ptr, prob_shape, prob_order = self._make_cute_tensor_descriptor(prob_tensor, name="prob") + m_split_cumsum_ptr, m_split_cumsum_shape, m_split_cumsum_order = self._make_cute_tensor_descriptor(m_split_cumsum, name="m_split_cumsum") + + gemm_swiglu( + a_ptr=a_ptr, + a_shape=a_shape, + a_order=a_order, + b_ptr=b_ptr, + b_shape=b_shape, + b_order=b_order, + c_ptr=c_ptr, + c_shape=c_shape, + c_order=c_order, + d_ptr=d_ptr, + d_shape=d_shape, + d_order=d_order, + d_col_ptr=d_col_ptr, + d_col_shape=d_col_shape, + d_col_order=d_col_order, + sfa_ptr=sfa_ptr, + sfa_shape=sfa_shape, + sfa_order=sfa_order, + sfb_ptr=sfb_ptr, + sfb_shape=sfb_shape, + sfb_order=sfb_order, + sfd_row_ptr=sfd_row_ptr, + sfd_row_shape=sfd_row_shape, + sfd_row_order=sfd_row_order, + sfd_col_ptr=sfd_col_ptr, + sfd_col_shape=sfd_col_shape, + sfd_col_order=sfd_col_order, + amax_ptr=amax_ptr, + amax_shape=amax_shape, + amax_order=amax_order, + norm_const_ptr=norm_const_ptr, + norm_const_shape=norm_const_shape, + norm_const_order=norm_const_order, + tile_idx_to_expert_idx_ptr=tile_idx_ptr, + tile_idx_to_expert_idx_shape=tile_idx_shape, + tile_idx_to_expert_idx_order=tile_idx_order, + num_non_exiting_tiles_ptr=num_tiles_ptr, + num_non_exiting_tiles_shape=num_tiles_shape, + num_non_exiting_tiles_order=num_tiles_order, + m_split_cumsum_ptr=m_split_cumsum_ptr, + m_split_cumsum_shape=m_split_cumsum_shape, + m_split_cumsum_order=m_split_cumsum_order, + alpha_ptr=alpha_ptr, + alpha_shape=alpha_shape, + alpha_order=alpha_order, + prob_ptr=prob_ptr, + prob_shape=prob_shape, + prob_order=prob_order, + max_active_clusters=max_active_clusters, + stream=current_stream, + ) + else: + raise NotImplementedError(f"Unreachable: invalid kernel type {self._kernel}") + + self._logger.debug("Execute completed") + + +import logging + +_logger = logging.getLogger(__name__) +_cache_of_GroupedGemmSwigluSm100Objects = {} + + +def grouped_gemm_swiglu_wrapper_sm100( + a_tensor: torch.Tensor, + b_tensor: torch.Tensor, + sfa_tensor: torch.Tensor, + sfb_tensor: torch.Tensor, + tile_idx_to_expert_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + alpha_tensor: torch.Tensor, + norm_const_tensor: Optional[torch.Tensor] = None, + prob_tensor: Optional[torch.Tensor] = None, + m_split_cumsum: Optional[torch.Tensor] = None, + acc_dtype: torch.dtype = torch.float32, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.bfloat16, + cd_major: str = "n", + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Optional[Tuple[int, int]] = None, + sf_vec_size: int = 16, + vector_f32: bool = False, + m_aligned: int = 256, + discrete_col_sfd: bool = False, + current_stream: Optional[cuda.CUstream] = None, +) -> TupleDict: + """Convenience wrapper for grouped GEMM SwiGLU forward operation. + + This function creates the API, compiles, and executes in one call. + Compiled kernels are cached for reuse when called with the same configuration. + + Args: + a_tensor: Input A tensor (valid_m, k, 1) + b_tensor: Weight B tensor (n, k, l) + sfa_tensor: Scale factor A + sfb_tensor: Scale factor B + tile_idx_to_expert_idx: Tile to expert mapping + num_non_exiting_tiles: Number of valid tiles + alpha_tensor: Per-group scaling + norm_const_tensor: Optional normalization constant. Required when using FP8 + input configurations (i.e., when a_tensor.dtype is FP8 and sfa_tensor.dtype is FP8). + Should be None for FP4/BF16 input configurations. + prob_tensor: Optional probability tensor for gating + m_split_cumsum: Optional m split cumulative sum tensor. Required when discrete_col_sfd is True. + acc_dtype: Accumulator data type + c_dtype: Intermediate C tensor data type (always bfloat16) + d_dtype: Output D tensor data type (fp8 when ab is fp8, bf16 when ab is fp4) + cd_major: CD major dimension (note: only "n"-major layout is supported) + mma_tiler_mn: MMA tiler shape + cluster_shape_mn: Cluster shape + sf_vec_size: Scale factor vector size + vector_f32: Use vectorized f32 + m_aligned: M alignment + discrete_col_sfd: Boolean, True to generate discrete col-major scale factor tensor. Only applies when already output scale factor tensors are provided. + current_stream: CUDA stream + + Returns: + TupleDict: A dictionary-like object containing output tensors that can also be unpacked as a tuple. + Dictionary keys (also the unpacking order): + - **c_tensor** (torch.Tensor): Intermediate result tensor + - **d_tensor** (torch.Tensor): Final output tensor after SwiGLU + - **d_col_tensor** (torch.Tensor): Column-wise output tensor + - **amax_tensor** (torch.Tensor or None): Absolute maximum values (for quantization) + - **sfd_row_tensor** (torch.Tensor or None): Row-wise scale factors for D (FP8 only) + - **sfd_col_tensor** (torch.Tensor or None): Column-wise scale factors for D (FP8 only) + + Example usage:: + + # Dictionary-style access + result = grouped_gemm_swiglu_wrapper_sm100(...) + c = result["c_tensor"] + d = result["d_tensor"] + + # Tuple unpacking + c, d, d_col, amax, sfd_row, sfd_col = grouped_gemm_swiglu_wrapper_sm100(...) + + # Integer indexing + c = result[0] # c_tensor + """ + valid_m, k, _ = a_tensor.shape + n, _, l = b_tensor.shape + n_out = n // 2 # After SwiGLU + + _logger.debug(f"grouped_gemm_swiglu_wrapper_sm100: Creating output tensors c_tensor, d_tensor, d_col_tensor") + + if cd_major == "n": + # 1, m, n, permute (1, 2, 0) -> (m, n, 1) + c_tensor = torch.empty_strided((valid_m, n, 1), (n, 1, valid_m * n), dtype=c_dtype, device=a_tensor.device) + d_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + d_col_tensor = torch.empty_strided((valid_m, n_out, 1), (n_out, 1, valid_m * n_out), dtype=d_dtype, device=a_tensor.device) + else: + raise ValueError(f"cd_major must be 'n', got {cd_major}") + + sfd_row_tensor = None + sfd_col_tensor = None + amax_tensor = None + + if a_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sfa_tensor.dtype in [torch.float8_e8m0fnu, torch.float8_e4m3fn]: + _logger.debug("grouped_gemm_swiglu_wrapper_sm100: Detected fp8 a_dtype and sfa_dtype, constructing sfd_row_tensor and sfd_col_tensor") + + sf_dtype = sfa_tensor.dtype + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # sfd_row: l=1, mn=valid_m, k=n_out + sf_k_row = ceil_div(n_out, sf_vec_size) + mma_shape_row = ( + 1, + ceil_div(valid_m, 128), + ceil_div(sf_k_row, 4), + 32, + 4, + 4, + ) + sfd_row_tensor = torch.empty(mma_shape_row, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + # sfd_col: l=1, mn=n_out, k=valid_m + sf_k_col = ceil_div(valid_m, sf_vec_size) + mma_shape_col = ( + 1, + ceil_div(n_out, 128), + ceil_div(sf_k_col, 4), + 32, + 4, + 4, + ) + sfd_col_tensor = torch.empty(mma_shape_col, dtype=sf_dtype, device=a_tensor.device).permute(mma_permute_order) + + if d_dtype in [torch.bfloat16, torch.float16]: + _logger.debug("grouped_gemm_swiglu_wrapper_sm100: Detected bf16/float16 d_dtype, constructing amax_tensor") + amax_tensor = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=a_tensor.device) + + cache_key = ( + a_tensor.shape, + b_tensor.shape, + a_tensor.dtype, + b_tensor.dtype, + a_tensor.stride(), + b_tensor.stride(), + sfa_tensor.shape, + sfb_tensor.shape, + sfa_tensor.stride(), + sfb_tensor.stride(), + sfa_tensor.dtype, + sfb_tensor.dtype, + norm_const_tensor.shape if norm_const_tensor is not None else None, + norm_const_tensor.stride() if norm_const_tensor is not None else None, + norm_const_tensor.dtype if norm_const_tensor is not None else None, + m_split_cumsum.shape if m_split_cumsum is not None else None, + m_split_cumsum.stride() if m_split_cumsum is not None else None, + m_split_cumsum.dtype if m_split_cumsum is not None else None, + acc_dtype, + c_dtype, + d_dtype, + cd_major, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + vector_f32, + m_aligned, + discrete_col_sfd, + ) + + if cache_key in _cache_of_GroupedGemmSwigluSm100Objects: + _logger.debug("group_gemm_swiglu_wrapper_sm100: Using previously cached GroupedGemmSwigluSm100 object") + grouped_gemm_swiglu = _cache_of_GroupedGemmSwigluSm100Objects[cache_key] + grouped_gemm_swiglu.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + tile_idx_to_expert_idx=tile_idx_to_expert_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + alpha_tensor=alpha_tensor, + d_col_tensor=d_col_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + prob_tensor=prob_tensor, + m_split_cumsum=m_split_cumsum, + current_stream=current_stream, + ) + else: + _logger.debug("group_gemm_swiglu_wrapper_sm100: No previously cached GroupedGemmSwigluSm100 object found, creating new GroupedGemmSwigluSm100 object") + grouped_gemm_swiglu = GroupedGemmSwigluSm100( + sample_a=a_tensor, + sample_b=b_tensor, + sample_c=c_tensor, + sample_d=d_tensor, + sample_sfa=sfa_tensor, + sample_sfb=sfb_tensor, + sample_tile_idx_to_expert_idx=tile_idx_to_expert_idx, + sample_num_non_exiting_tiles=num_non_exiting_tiles, + sample_alpha=alpha_tensor, + sample_amax=amax_tensor, + sample_d_col=d_col_tensor, + sample_sfd_row=sfd_row_tensor, + sample_sfd_col=sfd_col_tensor, + sample_norm_const=norm_const_tensor, + sample_prob=prob_tensor, + sample_m_split_cumsum=m_split_cumsum, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + vector_f32=vector_f32, + m_aligned=m_aligned, + discrete_col_sfd=discrete_col_sfd, + ) + + assert grouped_gemm_swiglu.check_support(), "Unsupported configuration" + grouped_gemm_swiglu.compile(current_stream=current_stream) + grouped_gemm_swiglu.execute( + a_tensor=a_tensor, + b_tensor=b_tensor, + c_tensor=c_tensor, + d_tensor=d_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + tile_idx_to_expert_idx=tile_idx_to_expert_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + alpha_tensor=alpha_tensor, + d_col_tensor=d_col_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + amax_tensor=amax_tensor, + norm_const_tensor=norm_const_tensor, + prob_tensor=prob_tensor, + m_split_cumsum=m_split_cumsum, + current_stream=current_stream, + ) + _cache_of_GroupedGemmSwigluSm100Objects[cache_key] = grouped_gemm_swiglu + + return TupleDict( + c_tensor=c_tensor, + d_tensor=d_tensor, + d_col_tensor=d_col_tensor, + amax_tensor=amax_tensor, + sfd_row_tensor=sfd_row_tensor, + sfd_col_tensor=sfd_col_tensor, + ) diff --git a/python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py new file mode 100644 index 00000000..b758fe9c --- /dev/null +++ b/python/cudnn/grouped_gemm/grouped_gemm_swiglu/grouped_gemm_swiglu_quant.py @@ -0,0 +1,3070 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Grouped GEMM SwiGLU Forward Kernel (SM100+) + +High-performance contiguous grouped block-scaled GEMM with SwiGLU activation +for MoE (Mixture of Experts) workloads on NVIDIA Blackwell GPUs. +""" + +from typing import Type, Tuple, Union, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cutlass_dsl import T +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack +from cutlass._mlir.dialects.nvvm import ReduxKind +from cutlass.cute.typing import Float32, Int32, Numeric +from cutlass.cutlass_dsl import T, dsl_user_op, if_generate +from cutlass._mlir.dialects import math, nvvm, llvm, scf +from cutlass._mlir.dialects.nvvm import FPRoundingMode + +from cudnn.api_base import ceil_div +from ..utils import ( + PersistentTileSchedulerParams, + fmin, + warp_redux_sync, + atomic_max_float32, + sigmoid_f32, + silu_f32, +) + +""" +High-performance persistent blockscaled contiguous grouped dense GEMM (D = alpha * (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKx1, A can be row-major("K"), ValidM is composed of valid m in different groups +- Matrix B is NxKxL, B can be column-major("K"), L is grouped dimension +- Matrix D is MxNx1, D can be row-major("N"), ValidM is composed of valid m in different groups +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively + +Matrix A/D Memory Layout Diagrams: + + ``` + Group 0 Group 1 Group 2 + -+---------+---------+---------+ + | | | | + K| ValidM0 | ValidM1 | ValidM2 | + | | | | + -+---------+---------+---------+ + |<- ValidM ->| + ``` + Note: the Group(L) dimension will be flatted into M dimension, and the rest Group(L) size is 1. + each ValidM will be aligned to 256 or 128. The alignment is determined by the mma_tiler_mn parameter. + For NVFP4, 2CTA, the alignment is 256. For NVFP4, 1CTA, the alignment is 128. + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA operations. +2. MMA warp: + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply alpha and update the final accumulator Final = alpha * acc + - Type convert Final matrix to output type. + - Store D matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \ + --ab_dtype Float4E2M1FN --d_dtype BFloat16 --acc_dtype Float32 \ + --sf_dtype Float8E4M3FN --sf_vec_size 16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 256,4096,7168,1 --use_2cta_instrs --m_aligned 256 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/contiguous_blockscaled_grouped_gemm.py \ + --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --c_dtype BFloat16 \ + --d_dtype Float8E4M3FN --sf_vec_size 32 --mma_tiler_mn 256,256 \ + --cluster_shape_mn 2,1 --nkl 4096,7168,8 --use_2cta_instrs \ + --m_aligned 256 --fixed_m 4096 + +Constraints: +* Supported input data types: mxf8, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 64/128/192/256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/D tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. + +CUDA Graph Support: +* For CUDA graph support, the tile_idx_to_expert_idx, A/D matrices, and scale factor A can be padded to a larger size + (e.g., permuted_m = m*topK + num_local_experts*(256-1), example: 4096*8 + (256/32)*255 = 34808) +* Use create_tensors() with permuted_m parameter to automatically pad: + - tile_idx_to_expert_idx: padded for invalid tiles + - A matrix: padded to permuted_m rows (padding rows contain dummy data) + - D matrix: padded to permuted_m rows (output buffer for cuda_graph) + - Scale factor A: padded to match A matrix dimensions +* Kernel handling of padding (similar to masked_grouped_gemm.py): + - Scheduler warp checks if tile_idx >= num_non_exiting_tiles to exit + - Only valid tiles (tile_idx < num_non_exiting_tiles) are written to tile_info pipeline + - When no more valid tiles exist, outer loop exits and calls producer_tail() + - Consumer warps process only valid tiles from pipeline + - No deadlock or synchronization issues +* Consumer warps check initial tile against num_non_exiting_tiles and set is_valid_tile=False if tile_idx >= num_non_exiting_tiles +* Only rows within (aligned_groupm[0]+aligned_groupm[1]+...) contain valid data +* Padding rows in D matrix will not be written by the kernel +""" + + +class BlockScaledContiguousGroupedGemmKernel: + """This class implements batched matrix multiplication (D = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported D data types: + - BFloat16 + - Float8E4M3FN/Float8E5M2 + + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 64/128/192/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + """ + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vector_f32: bool, + generate_sfd: bool, + discrete_col_sfd: bool, + ): + """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param vector_f32: Boolean, True to use vectorized f32 operations. + :type vector_f32: bool + :param generate_sfd: Boolean, True to generate output scale factor tensor + :type generate_sfd: bool + :param discrete_col_sfd: Boolean, True to generate discrete col-major scale factor tensor + :type discrete_col_sfd: bool + """ + + self.sf_vec_size = sf_vec_size + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.epilog_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + ) + ) + # TODO: Do we need to reallocate register? + # self.num_regs_uniform_warps = 64 + # self.num_regs_sched_warps = 64 + # self.num_regs_epilogue_warps = 216 + + # Set barrier for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.vector_f32 = vector_f32 + + self.generate_sfd = generate_sfd + self.discrete_col_sfd = discrete_col_sfd + + # Amax reduction configuration + self.num_epilog_warps = len(self.epilog_warp_id) + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/D stage counts in shared memory + - Computing A/B/D shared memory layout + - Computing tensor memory allocation columns + """ + + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + # Configure tiled mma + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + self.mma_tiler_d = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1] // 2, + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk_d = ( + self.mma_tiler_d[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_d[1], + self.mma_tiler_d[2], + ) + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Set epilogue subtile + self.epi_tile = (128, 32) + self.epi_tile_cnt = ( + self.cta_tile_shape_mnk_d[0] // self.epi_tile[0], + self.cta_tile_shape_mnk_d[1] // self.epi_tile[1], + ) + self.epi_tile_c = (128, 64) + + # Setup A/B/D/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_d_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.epi_tile_c, + self.c_dtype, + self.c_layout, + self.d_dtype, + self.d_layout, + self.sf_dtype, + self.sf_vec_size, + self.num_smem_capacity, + self.occupancy, + self.generate_sfd, + ) + + # Compute A/B/D/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile_c, + self.num_c_stage, + ) + + self.d_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, + self.d_layout, + self.epi_tile, + self.num_d_stage, + ) + + # Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case + self.overlapping_accum = self.num_acc_stage == 1 and self.mma_tiler[1] == 256 + + # Compute number of TMEM columns for SFA/SFB/Accumulator + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + self.epi_tile_n_required = 2 * cute.size(self.epi_tile[1]) + # Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ((self.num_sf_tmem_cols + self.epi_tile_n_required - 1) // self.epi_tile_n_required - 1) * 2 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + d: cute.Tensor, + d_col: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + sfd_row_tensor: Optional[cute.Tensor], + sfd_col_tensor: Optional[cute.Tensor], + amax_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + tile_idx_to_expert_idx: cute.Tensor, + num_non_exiting_tiles: cute.Tensor, + m_split_cumsum: Optional[cute.Tensor], + alpha: cute.Tensor, + prob: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param d: Output tensor D + :type d: cute.Tensor + :param d_col: Output tensor D column quantized + :type d_col: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param sfd_row_tensor: Scale factor tensor D + :type sfd_row_tensor: Optional[cute.Tensor] + :param sfd_col_tensor: Scale factor tensor D + :type sfd_col_tensor: Optional[cute.Tensor] + :param amax_tensor: Absolute maximum value tensor + :type amax_tensor: Optional[cute.Tensor] + :param norm_const_tensor: Norm constant tensor + :type norm_const_tensor: Optional[cute.Tensor] + :param tile_idx_to_expert_idx: Mapping from tile index to expert ID, shape (permuted_m/cta_tile_m,) where cta_tile_m is the CTA tile M size + :type tile_idx_to_expert_idx: cute.Tensor + :param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,) + :type num_non_exiting_tiles: cute.Tensor + :param alpha: Alpha tensor for each group + :type alpha: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.d_dtype: Type[cutlass.Numeric] = d.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + self.d_layout = utils.LayoutEnum.from_tensor(d) + + # Compute grid size + m, n, l = cute.shape(d) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size) + sfa = cute.make_tensor(sfa.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) + sfb = cute.make_tensor(sfb.iterator, sfb_layout) + + # Setup sfd tensor by filling D tensor to scale factor atom layout + self.generate_sfd = sfd_row_tensor is not None and norm_const_tensor is not None + self.discrete_col_sfd = m_split_cumsum != None and self.discrete_col_sfd + if cutlass.const_expr(self.generate_sfd == False): + self.discrete_col_sfd = False + if cutlass.const_expr(self.generate_sfd): + sfd_row_layout = blockscaled_utils.tile_atom_to_shape_SF(d.shape, self.sf_vec_size) + sfd_row_tensor = cute.make_tensor(sfd_row_tensor.iterator, sfd_row_layout) + sfd_col_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, OperandMajorMode.MN).layout, + d.shape, + (1, 2, 3), + ) + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_layout = sfd_row_layout + sfd_col_tensor = cute.make_tensor(sfd_col_tensor.iterator, sfd_col_layout) + + self.generate_amax = amax_tensor is not None + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release} + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # {$nv-internal-release begin} + # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) + # logical blocks for SFB when cta_tile_shape_n=192. + # {$nv-internal-release end} + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size) * atom_thr_size + + # Setup TMA store for C + c_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + c_smem_layout, + self.epi_tile_c, + ) + + # Setup TMA store for D + d_smem_layout = cute.slice_(self.d_smem_layout_staged, (None, None, 0)) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d, + d_smem_layout, + self.epi_tile, + ) + tma_atom_d_col, tma_tensor_d_col = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + d_col, + d_smem_layout, + self.epi_tile, + ) + + # Compute grid size + output_shape = (m, n, l) + self.tile_sched_params, grid = self._compute_grid( + output_shape, + self.cta_tile_shape_mnk_d, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorageFP8: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + cute.cosize(self.d_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + sD_col: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + cute.cosize(self.d_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + # Amax reduction shared memory (one FP32 per epilogue warp) + # Use smaller alignment for amax since it's only 16 bytes + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + 1, # byte alignment + ] + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 6 * self.num_tile_stage], + 1, # byte alignment + ] + + @cute.struct + class SharedStorageFP4: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype, + cute.cosize(self.d_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + # Amax reduction shared memory (one FP32 per epilogue warp) + # Use smaller alignment for amax since it's only 16 bytes + sAmax: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.num_epilog_warps], + 1, # byte alignment + ] + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 6 * self.num_tile_stage], + 1, # byte alignment + ] + + if cutlass.const_expr(self.generate_sfd): + self.shared_storage = SharedStorageFP8 + else: + self.shared_storage = SharedStorageFP4 + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + tma_atom_d, + tma_tensor_d, + tma_atom_d_col, + tma_tensor_d_col, + sfd_row_tensor, + sfd_col_tensor, + norm_const_tensor, + amax_tensor, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + m_split_cumsum, + alpha, + prob, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.d_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + max_number_threads=[self.threads_per_cta, 1, 1], + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + @cute.jit + def amax_reduction_per_thread(self, vec_fp32, amax_fp32) -> None: + vec_fp32_ssa = vec_fp32.load() + abs_acc_values_ir = cutlass._mlir.dialects.math.absf(vec_fp32_ssa.ir_value()) + abs_acc_values = type(vec_fp32_ssa)(abs_acc_values_ir, vec_fp32_ssa.shape, vec_fp32_ssa.dtype) + subtile_amax = abs_acc_values.reduce(cute.ReductionOp.MAX, cutlass.Float32(0.0), 0) + return cute.arch.fmax(amax_fp32, subtile_amax) + + @cute.jit + def amax_reduction_per_warp_and_cta(self, amax_fp32, warp_idx, amax_smem, amax_gmem) -> None: + # Warp-level reduction using wrapper function + warp_amax = warp_redux_sync( + value=amax_fp32, + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + # Each epilogue warp's lane 0 writes warp amax to shared memory + if cute.arch.lane_idx() == 0: + amax_smem[warp_idx] = cutlass.Float32(warp_amax) + + # Ensure all epilogue warps complete their writes before block reduction + self.epilog_sync_barrier.arrive_and_wait() + + # Block-level reduction: only first epilogue warp's lane 0 handles this + if warp_idx == self.epilog_warp_id[0] and cute.arch.lane_idx() == 0: + block_amax = cutlass.Float32(0.0) + for i in cutlass.range(self.num_epilog_warps): + warp_amax_val = amax_smem[i] + block_amax = cute.arch.fmax(block_amax, warp_amax_val) + + # Global atomic max (accumulates across all tiles for final tensor amax) + _ = atomic_max_float32(ptr=amax_gmem, value=block_amax) + + @cute.jit + def store_c( + self, + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc, + tTR_rAcc_gate, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + prev_subtile_idx, + real_subtile_idx, + ) -> None: + c_buffer = prev_subtile_idx % self.num_c_stage + tRS_rC.store(tTR_rAcc.load().to(self.c_dtype)) + cute.copy( + tiled_copy_r2s, + tRS_rC[(None, None, 0)], + tRS_sC[(None, None, 0, c_buffer)], + ) + tRS_rC.store(tTR_rAcc_gate.load().to(self.c_dtype)) + cute.copy( + tiled_copy_r2s, + tRS_rC[(None, None, 0)], + tRS_sC[(None, None, 1, c_buffer)], # ((1, 32), 1, 2, (1, 1)), ((0, 1), 0, 32, (0, 0)) + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store smem to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], # ((8192, 1), (1, 4)), ((1, 0), (0, 8192)) + bSG_gC[(None, real_subtile_idx)], # (((64, 128), 1), (1, 4)) : (((1@0,1@1),0),(0,64@0)) + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + @cute.jit + def quant_sfd_row( + self, + tile_idx, + tiled_copy_r2s, + src, + pvscale, + norm_const, + rcp_limit, + tRSrD, + tile_info, + ) -> None: + # Get absolute max across a vector and Compute SFD + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + + pvscale_f32x4 = cute.make_rmem_tensor(4, cutlass.Float32) + sfd_f8x4 = cute.make_rmem_tensor(4, self.sf_dtype) + tmp_f32 = ( + abs_acc_frg[None, 0].reduce( + cute.ReductionOp.MAX, + cutlass.Float32(0.0), + 0, # Use 0.0 as init for abs values + ) + * rcp_limit + * norm_const + ) + # + # Manually store pvscale to avoid spilling + # + if tile_idx == 0: + pvscale[0] = tmp_f32 + elif tile_idx == 1: + pvscale[1] = tmp_f32 + elif tile_idx == 2: + pvscale[2] = tmp_f32 + elif tile_idx == 3: + pvscale[3] = tmp_f32 + + # + # Compute quantized output values and convert to D type + # + pvscale_f32x4[0] = tmp_f32 + sfd_f8x4.store(pvscale_f32x4.load().to(self.sf_dtype)) + pvscale_f32x4.store(sfd_f8x4.load().to(cutlass.Float32)) + qpvscale_up = pvscale_f32x4[0] + + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale = norm_const * cute.arch.rcp_approx(qpvscale_up) + acc_scale = fmin(acc_scale, fp32_max, nan=True) + if cutlass.const_expr(self.vector_f32): + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(0, self.sf_vec_size, 2): + ( + vec[ei], + vec[ei + 1], + ) = cute.arch.mul_packed_f32x2( + (vec[ei], vec[ei + 1]), + (acc_scale, acc_scale), + rnd=FPRoundingMode.RN, + ftz=False, + ) + else: + vec = tTR_rAcc_frg[None, 0] + for ei in cutlass.range_constexpr(self.sf_vec_size): + vec[ei] = vec[ei] * acc_scale + + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def quant_sfd_col( + self, + tile_idx, + tiled_copy_r2s, + src, + pvscale, + norm_const, + rcp_limit, + tRSrD, + tile_info, + ) -> None: + # Get absolute max across a vector and Compute SFD + tTR_rAcc_frg = cute.logical_divide(src, cute.make_layout(self.sf_vec_size)) + acc_frg = tTR_rAcc_frg.load() + abs_acc_frg_ir = cutlass._mlir.dialects.math.absf(acc_frg.ir_value()) + acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + + tmp_f32 = cutlass.Float32(0.0) + for vi in cutlass.range_constexpr(acc_frg.shape[0]): + max_value_original = ( + cutlass.Float32( + warp_redux_sync( + value=acc_frg[vi, 0], + kind=ReduxKind.MAX, + mask_and_clamp=0xFFFFFFFF, + nan=True, + ) + ) + * rcp_limit + * norm_const + ) + max_value_vec = cute.full(4, max_value_original, dtype=cutlass.Float32) + max_value_vec_f8 = max_value_vec.to(cutlass.Float8E8M0FNU) + max_value_vec_f32_chunked = max_value_vec_f8.to(cutlass.Float32) + max_value = max_value_vec_f32_chunked[0] + tidx = cute.arch.thread_idx()[0] + if tidx % 32 == vi: + tmp_f32 = max_value + + acc_scale_col = cutlass.Float32(0.0) + if max_value_vec_f32_chunked[0] == 0.000000: + acc_scale_col = cutlass.Float32(0.0) + else: + acc_scale_col = norm_const * cute.arch.rcp_approx(max_value_vec_f32_chunked[0]) + fp32_max = cutlass.Float32(3.40282346638528859812e38) + acc_scale_col = fmin(acc_scale_col, fp32_max) + tTR_rAcc_frg[vi] = tTR_rAcc_frg[vi] * acc_scale_col + pvscale[None, None, tile_idx][0] = tmp_f32 + + acc_vec = tiled_copy_r2s.retile(src).load() + tRSrD.store(acc_vec.to(self.d_dtype)) + + @cute.jit + def tile_info_to_mn_idx( + self, + tile_info: cute.Tensor, + ): + m_idx = tile_info[0] * cute.size(self.cta_tile_shape_mnk[0]) + n_idx = tile_info[1] * cute.size(self.cta_tile_shape_mnk[1]) + return m_idx, n_idx + + @cute.jit + def create_and_partition_new_SFDCol( + self, + tile_info: cute.Tensor, + mSFDCol_mnl: cute.Tensor, + ): + m_idx, n_idx = self.tile_info_to_mn_idx(tile_info) + cumsum_tokens = tile_info[5] + tokens_this_group = tile_info[4] + n_total = cute.size(mSFDCol_mnl.shape[1]) + + sf_tile_idx_begin = cumsum_tokens // cute.size(mSFDCol_mnl.shape[0][0]) + mSFDCol_mnl_new_ptr = mSFDCol_mnl[(None, sf_tile_idx_begin), None, 0].iterator + + sfd_col_quant_layout = cute.tile_to_shape( + blockscaled_utils.BlockScaledBasicChunk(self.sf_vec_size, cute.nvgpu.tcgen05.OperandMajorMode.MN).layout, + (tokens_this_group, n_total, mSFDCol_mnl.shape[2]), + (1, 2, 3), + ) + regPerSubtile = 4 + sfd_tile = ( + cute.make_layout(128), + cute.make_layout(32 * regPerSubtile), + ) + mSFDCol_mnl_new = cute.make_tensor(mSFDCol_mnl_new_ptr, sfd_col_quant_layout) + gSFDCol_mnl_new = cute.local_tile(mSFDCol_mnl_new, sfd_tile, (None, None, None)) + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col_quant = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gSFDCol_mnl_new.element_type, + num_bits_per_copy=8, + ) + tiled_copy_sfd_col_quant = cute.make_tiled_copy_tv(copy_atom_sfd_col_quant, thr_layout, val_layout) + tidx = cute.arch.thread_idx()[0] + thr_copy_sfd_col_quant = tiled_copy_sfd_col_quant.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col_quant.partition_D(cute.filter_zeros(gSFDCol_mnl_new)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + return tCgSFDCol_mnl + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tma_atom_d: cute.CopyAtom, + mD_mnl: cute.Tensor, + tma_atom_d_col: cute.CopyAtom, + mD_col_mnl: cute.Tensor, + mSFDRow_mnl: Optional[cute.Tensor], + mSFDCol_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + mAmax_tensor: Optional[cute.Tensor], + tile_idx_to_expert_idx: cute.Tensor, + num_non_exiting_tiles: cute.Tensor, + m_split_cumsum: Optional[cute.Tensor], + alpha: cute.Tensor, + prob: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + d_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + cpasync.prefetch_descriptor(tma_atom_d) + if cutlass.const_expr(self.generate_sfd): + cpasync.prefetch_descriptor(tma_atom_d_col) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) + + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_consumer_threads) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/D/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + sD = storage.sD.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + # (EPI_TILE_M, EPI_TILE_N, STAGE) + # placeholder again + sD_col = sD + if cutlass.const_expr(self.generate_sfd): + sD_col = storage.sD_col.get_tensor(d_smem_layout_staged.outer, swizzle=d_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # Shared memory for amax reduction (one FP32 per epilogue warp) + # Simple 1D layout. The allocation always here if no amax is generated, + # as the overhead is minimal and we want to keep the code simple. + amax_layout = cute.make_layout((self.num_epilog_warps,)) + sAmax = storage.sAmax.get_tensor(amax_layout) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + if cutlass.const_expr(self.discrete_col_sfd): + info_layout = cute.make_layout((6, self.num_tile_stage), stride=(1, 6)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + b_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None)) + gD_mnl = cute.local_tile(mD_mnl, cute.slice_(self.mma_tiler_d, (None, None, 0)), (None, None, None)) + # placeholder, it will be eventually removed by compiler as we won't do any store to it in FP4 mode + gD_col_mnl = gD_mnl + if cutlass.const_expr(self.generate_sfd): + gD_col_mnl = cute.local_tile( + mD_col_mnl, + cute.slice_(self.mma_tiler_d, (None, None, 0)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/D + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + tCgD = thr_mma.partition_C(gD_mnl) + # placeholder, same as above + tCgD_col = tCgD + if cutlass.const_expr(self.generate_sfd): + tCgD_col = thr_mma.partition_C(gD_col_mnl) + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/D + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage_overlapped)) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_tile_stage) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape) + if mma_tile_coord_m < num_non_exiting_tiles[0]: + tile_info_pipeline.producer_acquire(tile_info_producer_state) + cur_tile_coord = work_tile.tile_idx + expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m] + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = expert_idx + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(work_tile.is_valid_tile) + if cutlass.const_expr(self.discrete_col_sfd): + tokens_presum_this_group = m_split_cumsum[expert_idx] + tokens_presum_next_group = m_split_cumsum[expert_idx + 1] + # number of tokens in this group + sInfo[(4, tile_info_producer_state.index)] = tokens_presum_next_group - tokens_presum_this_group + # token prefix sum of this group + sInfo[(5, tile_info_producer_state.index)] = tokens_presum_this_group + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = work_tile.tile_idx[0] + sInfo[(1, tile_info_producer_state.index)] = work_tile.tile_idx[1] + sInfo[(2, tile_info_producer_state.index)] = -1 + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)] + + # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release} + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] + tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + cute.copy( + tma_atom_sfa, + tAgSFA_k, + tAsSFA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_k, + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acd_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + # Get the first tile info from pipeline (scheduler has filtered out tiles >= num_non_exiting_tiles) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acd_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acd_producer_state.phase ^ 1 + else: + acc_stage_index = acd_producer_state.index + + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or cta_tile_shape_n=64 # {$nv-internal-release} + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acd_producer_state, peek_acc_empty_status) + # + # Mma mainloop + # + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acd_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acd_producer_state.advance() + if acd_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acd_producer_state) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acd_producer_state) + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc_up, + tTR_rAcc_gate, + ) = self.epilog_tmem_copy_and_partition(epi_tidx, tCtAcc_base, tCgD, epi_tile, use_2cta_instrs) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rC, epi_tidx, sC) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, self.epi_tile_c, sC) + + tTR_rD = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD, epi_tidx, sD) + ( + tma_atom_d, + bSG_sD, + bSG_gD_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d, tCgD, epi_tile, sD) + + tTR_rD_col = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.d_dtype) + tiled_copy_r2s, tRS_rD_col, tRS_sD_col = self.epilog_smem_copy_and_partition(tiled_copy_t2r, tTR_rD_col, epi_tidx, sD_col) + ( + tma_atom_d_col, + bSG_sD_col, + bSG_gD_col_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_d_col, tCgD_col, epi_tile, sD_col) + + if cutlass.const_expr(self.generate_sfd): + norm_const = norm_const_tensor[0] + regPerSubtile = 4 + sfd_row_tile = ( + cute.make_layout(128), + cute.make_layout(32 * regPerSubtile), + ) + # (EPI_TILE_M, EPI_TILE_N, RestM, RestN, RestL) + gSFDRow_mnl = cute.local_tile(mSFDRow_mnl, sfd_row_tile, (None, None, None)) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, RestM, RestN, RestL) + tCgSFDRow_mnl = thr_copy_t2r.partition_D(gSFDRow_mnl) + tCgSFDRow_mnl = cute.filter_zeros(tCgSFDRow_mnl) + # (T2R, T2R_M, T2R_N) + tCrSFDRow = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype) + tCrSFDRow_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + d_rcp_limits = self.get_dtype_rcp_limits(self.d_dtype) + + # both SFDs are stored in row major mode. + sfd_col_tile = sfd_row_tile + gSFDCol_mnl = cute.local_tile(mSFDCol_mnl, sfd_col_tile, (None, None, None)) + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((1,), order=(0,)) + copy_atom_sfd_col = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gSFDCol_mnl.element_type, + num_bits_per_copy=8, + ) + tiled_copy_sfd_col = cute.make_tiled_copy_tv(copy_atom_sfd_col, thr_layout, val_layout) + thr_copy_sfd_col = tiled_copy_sfd_col.get_slice(tidx) + tCgSFDCol_mnl = thr_copy_sfd_col.partition_D(cute.filter_zeros(gSFDCol_mnl)) + tCgSFDCol_mnl = cute.filter_zeros(tCgSFDCol_mnl) + tCrSFDCol = cute.make_rmem_tensor(tCgSFDRow_mnl[(None, None, None, 0, 0, 0)].shape, self.sf_dtype) + tCrSFDCol_pvscale = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + tCrSFDCol_qpvscale_up_fp32 = cute.make_rmem_tensor_like(tCrSFDRow, cutlass.Float32) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + d_pipeline = None + # Threads/warps participating in tma store pipeline + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + d_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + d_col_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_d_stage, + producer_group=d_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_tile_stage) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + if cutlass.const_expr(self.discrete_col_sfd): + tile_info = cute.make_rmem_tensor((6,), cutlass.Int32) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(tile_info.shape[0], unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Get alpha for current group + # + + expert_idx = mma_tile_coord_mnl[2] + alpha_val = alpha[expert_idx] + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + bSG_gD = bSG_gD_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + bSG_gD_col = bSG_gD_col_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + bSG_gD = cute.group_modes(bSG_gD, 1, cute.rank(bSG_gD)) + bSG_gD_col = cute.group_modes(bSG_gD_col, 1, cute.rank(bSG_gD_col)) + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = cutlass.Boolean(True) if acc_stage_index == 0 else cutlass.Boolean(False) + else: + acc_stage_index = acc_consumer_state.index + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + + if cutlass.const_expr(self.generate_sfd): + # (T2R, T2R_M, T2R_N, RestM, RestN) + tCgSFDRow_mn = tCgSFDRow_mnl[ + ( + None, + None, + None, + None, + None, + 0, + ) + ] + tCgSFDCol_mnl_new = tCgSFDCol_mnl + if cutlass.const_expr(self.discrete_col_sfd): + tCgSFDCol_mnl_new = self.create_and_partition_new_SFDCol(tile_info, mSFDCol_mnl) + tCgSFDCol_mn = tCgSFDCol_mnl_new[ + ( + None, + None, + None, + None, + None, + 0, + ) + ] + + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = cutlass.Float32(0.0) + + # + # Get PROB + # Note, it always assumes T2R_M/EPI_M is 1, otherwise it will break the result. + # + mPosition = tile_info[0] * self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape) + tidx + mProb = prob[mPosition, 0, 0] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(0, subtile_cnt, 2, unroll=1): + real_subtile_idx = subtile_idx // 2 + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + # Subtile always iterates on N dimension as we only have 4x1DP tmem load pattern for cta_tile_m = 128 cases. # {$nv-internal-release} + real_subtile_idx = self.cta_tile_shape_mnk[1] // self.epi_tile_n_required - 1 - subtile_idx // 2 + + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, real_subtile_idx * 2)] + tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, real_subtile_idx * 2 + 1)] + + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up) + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate) + + # + # Async arrive accumulator buffer empty ealier when overlapping_accum is enabled + # + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Apply alpha + # + if cutlass.const_expr(self.vector_f32): + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): + tTR_rAcc_up[i], tTR_rAcc_up[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc_up[i], tTR_rAcc_up[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + rnd=FPRoundingMode.RN, + ftz=False, + ) + tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1] = cute.arch.mul_packed_f32x2( + (tTR_rAcc_gate[i], tTR_rAcc_gate[i + 1]), + ( + cutlass.Float32(alpha_val), + cutlass.Float32(alpha_val), + ), + rnd=FPRoundingMode.RN, + ftz=False, + ) + else: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + tTR_rAcc_up[i] = tTR_rAcc_up[i] * cutlass.Float32(alpha_val) + tTR_rAcc_gate[i] = tTR_rAcc_gate[i] * cutlass.Float32(alpha_val) + + # + # Store to C tensor + # + self.store_c( + tiled_copy_r2s, + tma_atom_c, + warp_idx, + tTR_rAcc_up, + tTR_rAcc_gate, + tRS_rC, + tRS_sC, + bSG_gC, + bSG_sC, + c_pipeline, + num_prev_subtiles, + real_subtile_idx, + ) + + acc_vec_up = tTR_rAcc_up.load() + acc_vec_gate = tTR_rAcc_gate.load() + + # SwiGlu + tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype) + if cutlass.const_expr(self.vector_f32): + # SwiGlu Packed Version + LOG2_E = cutlass.Float32(1.4426950408889634) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): + tCompute_log2e = cute.arch.mul_packed_f32x2( + (acc_vec_gate[i], acc_vec_gate[i + 1]), + (-LOG2_E, -LOG2_E), + rnd=FPRoundingMode.RN, + ftz=False, + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.add_packed_f32x2( + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), + ), + (1.0, 1.0), + ) + tCompute[i] = cute.arch.rcp_approx(tCompute[i]) + tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_gate[i], acc_vec_gate[i + 1]), + rnd=FPRoundingMode.RN, + ftz=False, + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_up[i], acc_vec_up[i + 1]), + rnd=FPRoundingMode.RN, + ftz=False, + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (mProb, mProb), + rnd=FPRoundingMode.RN, + ftz=False, + ) + else: + # SwiGlu Unpacked Version + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + tCompute[i] = acc_vec_up[i] * silu_f32(acc_vec_gate[i], fastmath=True) + tCompute[i] = tCompute[i] * mProb + + # + # Generate amax + # + if cutlass.const_expr(self.generate_amax): + thread_tile_amax = self.amax_reduction_per_thread(tCompute, thread_tile_amax) + + if cutlass.const_expr(self.generate_sfd): + tCompute_col = cute.make_rmem_tensor(tCompute.layout, tCompute.element_type) + tCompute_col.store(tCompute.load()) + # + # Generate row major SFD + # + self.quant_sfd_row( + real_subtile_idx, + tiled_copy_r2s, + tCompute, + tCrSFDRow_pvscale, + norm_const, + d_rcp_limits, + tRS_rD, + tile_info, + ) + # + # Generate col major SFD + # + self.quant_sfd_col( + real_subtile_idx, + tiled_copy_r2s, + tCompute_col, + tCrSFDCol_pvscale, + norm_const, + d_rcp_limits, + tRS_rD_col, + tile_info, + ) + + # Assume subtile partitioned always happens on n dimension + sfd_row_idx_mn = ( + tile_info[0], + tile_info[1], + ) + sfd_col_idx_mn = sfd_row_idx_mn + if cutlass.const_expr(self.discrete_col_sfd): + sfd_col_idx_mn = ( + tile_info[0] - tile_info[5] // 128, + tile_info[1], + ) + tCgSFDRow = tCgSFDRow_mn[ + ( + None, + None, + None, + *sfd_row_idx_mn, + ) + ] + tCgSFDCol = tCgSFDCol_mn[ + ( + None, + None, + None, + *sfd_col_idx_mn, + ) + ] + + if subtile_idx == 6: + tCrSFDRow.store(tCrSFDRow_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDRow, tCgSFDRow) + tCrSFDCol.store(tCrSFDCol_pvscale.load().to(self.sf_dtype)) + cute.autovec_copy(tCrSFDCol, tCgSFDCol) + else: + # + # Convert to D type + # + acc_vec = tiled_copy_r2s.retile(tCompute).load() + tRS_rD.store(acc_vec.to(self.d_dtype)) + + # + # Store D to shared memory + # + d_buffer = num_prev_subtiles % self.num_d_stage + num_prev_subtiles = num_prev_subtiles + 1 + cute.copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[(None, None, None, d_buffer)], + ) + if cutlass.const_expr(self.generate_sfd): + cute.copy( + tiled_copy_r2s, + tRS_rD_col, + tRS_sD_col[(None, None, None, d_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store D to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_d, + bSG_sD[(None, d_buffer)], + bSG_gD[(None, real_subtile_idx)], + ) + if cutlass.const_expr(self.generate_sfd): + cute.copy( + tma_atom_d_col, + bSG_sD_col[(None, d_buffer)], + bSG_gD_col[(None, real_subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + d_pipeline.producer_commit() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(tile_info.shape[0], unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # Perform amax reduction after all subtiles are processed + if cutlass.const_expr(self.generate_amax): + gAmax = mAmax_tensor[(expert_idx, None)].iterator.llvm_ptr # First element + self.amax_reduction_per_warp_and_cta(thread_tile_amax, warp_idx, sAmax, gAmax) + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C/D store complete + # + c_pipeline.producer_tail() + d_pipeline.producer_tail() + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gD_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gD_mnl: The global tensor D + :type gD_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc_up, tTR_rAcc_gate) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc_up: The partitioned accumulator tensor for acc up + - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.d_layout, + self.d_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gD_mnl_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gD_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc_up = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + # (T2R, T2R_M, T2R_N) + tTR_rAcc_gate = cute.make_rmem_tensor(tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc_up, tTR_rAcc_gate + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sD: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sD: The shared memory tensor to be copied and partitioned + :type sD: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rD, tRS_sD) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rD: The partitioned tensor D (register source) + - tRS_sD: The partitioned tensor D (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op(self.d_layout, self.d_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) + # (R2S, R2S_M, R2S_N) + tRS_rD = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rD, tRS_sD + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gD_mnl: cute.Tensor, + epi_tile: cute.Tile, + sD: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gD_mnl: The global tensor D + :type gD_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sD: The shared memory tensor to be copied and partitioned + :type sD: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_d, bSG_sD, bSG_gD) where: + - tma_atom_d: The TMA copy atom + - bSG_sD: The partitioned shared memory tensor D + - bSG_gD: The partitioned global tensor D + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gD_epi = cute.flat_divide(gD_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_d = atom + sD_for_tma_partition = cute.group_modes(sD, 0, 2) + gD_for_tma_partition = cute.group_modes(gD_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sD, bSG_gD = cpasync.tma_partition( + tma_atom_d, + 0, + cute.make_layout(1), + sD_for_tma_partition, + gD_for_tma_partition, + ) + return tma_atom_d, bSG_sD, bSG_gD + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + epi_tile_c: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + d_dtype: Type[cutlass.Numeric], + d_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + num_smem_capacity: int, + occupancy: int, + generate_sfd: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/D operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param d_layout: Layout of operand D. + :type d_layout: utils.LayoutEnum + :param sf_dtype: Data type of scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Vector size of scale factor. + :type sf_vec_size: int + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, D stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C/D stages + num_c_stage = 2 if generate_sfd else 1 + num_d_stage = 2 if generate_sfd else 1 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and D + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile_c, + 1, + ) + + d_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + d_dtype, + d_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + # Mbar bytes + mbar_helpers_bytes = 1024 + # Sinfo bytes + sinfo_bytes = 4 * 4 * num_tile_stage + # C/D bytes + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + d_bytes_per_stage = cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + d_bytes = d_bytes_per_stage * num_d_stage * (2 if generate_sfd else 1) + # AMAX bytes + amax_bytes = BlockScaledContiguousGroupedGemmKernel.get_amax_smem_size() if d_dtype == cutlass.BFloat16 else 0 + # Epilogue bytes + epi_bytes = c_bytes + d_bytes + amax_bytes + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial D stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes + sinfo_bytes)) // ab_bytes_per_stage + + # Refine epilogue stages: + ##num_d_stage += ( + ## num_smem_capacity + ## - occupancy * ab_bytes_per_stage * num_ab_stage + ## - occupancy * (mbar_helpers_bytes + epi_bytes) + ##) // (occupancy * d_bytes_per_stage) + + total_bytes = occupancy * (ab_bytes_per_stage * num_ab_stage + epi_bytes + sinfo_bytes + mbar_helpers_bytes) + + ## Display stage information + ## cute.printf( + ## f"generate_sfd: {generate_sfd}, num_acc_stage: {num_acc_stage}, num_ab_stage: {num_ab_stage}, num_c_stage: {num_c_stage}, num_d_stage: {num_d_stage}, num_tile_stage: {num_tile_stage}, total_bytes: {total_bytes}" + ## ) + return num_acc_stage, num_ab_stage, num_c_stage, num_d_stage, num_tile_stage + + @staticmethod + def _compute_grid( + output_shape: Tuple[int, int, int], + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor D. + + :param d: The output tensor D + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[PersistentTileSchedulerParams, tuple[int, int, int]] + """ + m, n, g = output_shape + d_layout = cute.make_layout(cute.shape(output_shape)) + d_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gd = cute.zipped_divide(d_layout, tiler=d_shape) + num_ctas_mnl = cute.slice_(gd.shape, (0, (None, None, None))) + cluster_shape_mnl = (*cluster_shape_mn, 1) + + swizzle_n_blackwell = 2048 + tile_sched_params = PersistentTileSchedulerParams( + num_ctas_mnl, + cluster_shape_mnl, + raster_along_m=False, + swizzle_size=swizzle_n_blackwell // cta_tile_shape_mnk[1], + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float: + """ + Calculates the reciprocal of the maximum absolute value for a given data type. + + :param dtype: Data type + :type dtype: Type[cutlass.Numeric] + + :return: An float representing the reciprocal of the maximum absolute value + :rtype: float + """ + if dtype == cutlass.Float4E2M1FN: + return 1 / 6.0 + if dtype == cutlass.Float8E4M3FN: + return 1 / 448.0 + if dtype == cutlass.Float8E5M2: + return 1 / 128.0 + return 1.0 + + @staticmethod + def get_amax_smem_size(): + # Note: 4 is hardcoded for num_epilog_warps + return 4 * cute.size_in_bytes(cutlass.Float32, cute.make_layout((1,))) + + +class BlockScaledContiguousGroupedGemmKernelNoDlpack: + """Wrapper around BlockScaledContiguousGroupedGemmKernel that avoids DLPack. + + This wrapper constructs cute.Tensors directly from cute.Pointer, shapes, and + explicit layout orders for all operands. Useful when tensor dtypes (FP4, FP8) + aren't supported by DLPack on older PyTorch versions. + """ + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vector_f32: bool, + generate_sfd: bool, + discrete_col_sfd: bool, + ): + self.kernel = BlockScaledContiguousGroupedGemmKernel( + sf_vec_size=sf_vec_size, + acc_dtype=acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + vector_f32=vector_f32, + generate_sfd=generate_sfd, + discrete_col_sfd=discrete_col_sfd, + ) + + @cute.jit + def __call__( + self, + a_ptr: cute.Pointer, + a_shape: cutlass.Constexpr[Tuple[int, int, int]], + a_order: cutlass.Constexpr[Tuple[int, int, int]], + b_ptr: cute.Pointer, + b_shape: cutlass.Constexpr[Tuple[int, int, int]], + b_order: cutlass.Constexpr[Tuple[int, int, int]], + c_ptr: cute.Pointer, + c_shape: cutlass.Constexpr[Tuple[int, int, int]], + c_order: cutlass.Constexpr[Tuple[int, int, int]], + d_ptr: cute.Pointer, + d_shape: cutlass.Constexpr[Tuple[int, int, int]], + d_order: cutlass.Constexpr[Tuple[int, int, int]], + d_col_ptr: cute.Pointer, + d_col_shape: cutlass.Constexpr[Tuple[int, int, int]], + d_col_order: cutlass.Constexpr[Tuple[int, int, int]], + sfa_ptr: cute.Pointer, + sfa_shape: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfa_order: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfb_ptr: cute.Pointer, + sfb_shape: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfb_order: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfd_row_ptr: Optional[cute.Pointer], + sfd_row_shape: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfd_row_order: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfd_col_ptr: Optional[cute.Pointer], + sfd_col_shape: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + sfd_col_order: cutlass.Constexpr[Tuple[int, int, int, int, int, int]], + amax_ptr: Optional[cute.Pointer], + amax_shape: cutlass.Constexpr[Tuple[int, int]], + amax_order: cutlass.Constexpr[Tuple[int, int]], + norm_const_ptr: Optional[cute.Pointer], + norm_const_shape: cutlass.Constexpr[Tuple[int]], + norm_const_order: cutlass.Constexpr[Tuple[int]], + tile_idx_to_expert_idx_ptr: cute.Pointer, + tile_idx_to_expert_idx_shape: cutlass.Constexpr[Tuple[int]], + tile_idx_to_expert_idx_order: cutlass.Constexpr[Tuple[int]], + num_non_exiting_tiles_ptr: cute.Pointer, + num_non_exiting_tiles_shape: cutlass.Constexpr[Tuple[int]], + num_non_exiting_tiles_order: cutlass.Constexpr[Tuple[int]], + m_split_cumsum_ptr: Optional[cute.Pointer], + m_split_cumsum_shape: cutlass.Constexpr[Tuple[int]], + m_split_cumsum_order: cutlass.Constexpr[Tuple[int]], + alpha_ptr: cute.Pointer, + alpha_shape: cutlass.Constexpr[Tuple[int]], + alpha_order: cutlass.Constexpr[Tuple[int]], + prob_ptr: Optional[cute.Pointer], + prob_shape: cutlass.Constexpr[Tuple[int, int, int]], + prob_order: cutlass.Constexpr[Tuple[int, int, int]], + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation using raw pointers and shapes. + + See BlockScaledContiguousGroupedGemmKernel.__call__ for parameter descriptions. + """ + # Construct cute.Tensors from pointers and shapes + a_cute = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout(a_shape, order=a_order)) + b_cute = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout(b_shape, order=b_order)) + c_cute = cute.make_tensor(c_ptr, layout=cute.make_ordered_layout(c_shape, order=c_order)) + d_cute = cute.make_tensor(d_ptr, layout=cute.make_ordered_layout(d_shape, order=d_order)) + d_col_cute = cute.make_tensor(d_col_ptr, layout=cute.make_ordered_layout(d_col_shape, order=d_col_order)) + sfa_cute = cute.make_tensor(sfa_ptr, layout=cute.make_ordered_layout(sfa_shape, order=sfa_order)) + sfb_cute = cute.make_tensor(sfb_ptr, layout=cute.make_ordered_layout(sfb_shape, order=sfb_order)) + tile_idx_cute = cute.make_tensor( + tile_idx_to_expert_idx_ptr, + layout=cute.make_ordered_layout(tile_idx_to_expert_idx_shape, order=tile_idx_to_expert_idx_order), + ) + num_tiles_cute = cute.make_tensor( + num_non_exiting_tiles_ptr, + layout=cute.make_ordered_layout(num_non_exiting_tiles_shape, order=num_non_exiting_tiles_order), + ) + alpha_cute = cute.make_tensor(alpha_ptr, layout=cute.make_ordered_layout(alpha_shape, order=alpha_order)) + + # Optional tensors + sfd_row_cute = None + if cutlass.const_expr(sfd_row_ptr is not None): + sfd_row_cute = cute.make_tensor( + sfd_row_ptr, + layout=cute.make_ordered_layout(sfd_row_shape, order=sfd_row_order), + ) + + sfd_col_cute = None + if cutlass.const_expr(sfd_col_ptr is not None): + sfd_col_cute = cute.make_tensor( + sfd_col_ptr, + layout=cute.make_ordered_layout(sfd_col_shape, order=sfd_col_order), + ) + + amax_cute = None + if cutlass.const_expr(amax_ptr is not None): + amax_cute = cute.make_tensor(amax_ptr, layout=cute.make_ordered_layout(amax_shape, order=amax_order)) + + norm_const_cute = None + if cutlass.const_expr(norm_const_ptr is not None): + norm_const_cute = cute.make_tensor( + norm_const_ptr, + layout=cute.make_ordered_layout(norm_const_shape, order=norm_const_order), + ) + + m_split_cumsum_cute = None + if cutlass.const_expr(m_split_cumsum_ptr is not None): + m_split_cumsum_cute = cute.make_tensor( + m_split_cumsum_ptr, + layout=cute.make_ordered_layout(m_split_cumsum_shape, order=m_split_cumsum_order), + ) + + prob_cute = None + if cutlass.const_expr(prob_ptr is not None): + prob_cute = cute.make_tensor(prob_ptr, layout=cute.make_ordered_layout(prob_shape, order=prob_order)) + + self.kernel( + a=a_cute, + b=b_cute, + c=c_cute, + d=d_cute, + d_col=d_col_cute, + sfa=sfa_cute, + sfb=sfb_cute, + sfd_row_tensor=sfd_row_cute, + sfd_col_tensor=sfd_col_cute, + amax_tensor=amax_cute, + norm_const_tensor=norm_const_cute, + tile_idx_to_expert_idx=tile_idx_cute, + num_non_exiting_tiles=num_tiles_cute, + m_split_cumsum=m_split_cumsum_cute, + alpha=alpha_cute, + prob=prob_cute, + max_active_clusters=max_active_clusters, + stream=stream, + epilogue_op=epilogue_op, + ) + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] diff --git a/python/cudnn/grouped_gemm/utils.py b/python/cudnn/grouped_gemm/utils.py new file mode 100644 index 00000000..522141c9 --- /dev/null +++ b/python/cudnn/grouped_gemm/utils.py @@ -0,0 +1,851 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Shared utilities for grouped GEMM kernels. + +This module contains the tile scheduler classes and helper functions used by both +the forward (grouped_gemm_swiglu) and backward (grouped_gemm_dswiglu) kernels. +""" + +from typing import Tuple, Union + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, + const_expr, +) +from cutlass._mlir import ir +from cutlass._mlir.dialects import scf, llvm, nvvm +from cutlass.cutlass_dsl import T +from cutlass.cute.typing import Float32, Int32 as CuteInt32 +import cutlass.cute as cute +import cutlass + +############################################################################## +# Helper functions +############################################################################## + + +def fmin( + a: Union[float, Float32], + b: Union[float, Float32], + *, + nan: bool = True, + loc=None, + ip=None, +) -> Float32: + """Compute the minimum of two float32 values with NaN handling. + + :param a: First operand + :param b: Second operand + :param nan: If True, propagate NaN values + :return: Minimum value + """ + if nan: + ptx_instr = "min.NaN.f32 $0, $1, $2;" + else: + ptx_instr = "min.f32 $0, $1, $2;" + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + ptx_instr, + "=f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +def warp_redux_sync( + value, + kind=None, + mask_and_clamp: int = 0xFFFFFFFF, + abs: bool = False, + nan: bool = None, + *, + loc=None, + ip=None, +): + """Perform a warp-level reduction synchronization for max with abs and NaN. + + :param value: Value to reduce + :param kind: Reduction kind (unused, kept for API compatibility) + :param mask_and_clamp: Warp mask and clamp value + :param abs: Whether to use absolute value + :param nan: Whether to handle NaN values + :return: Reduced value across warp + """ + value_type = type(value) + value_ir = value.ir_value(loc=loc, ip=ip) + mask_ir = Int32(mask_and_clamp).ir_value(loc=loc, ip=ip) + ptx_instr = "redux.sync.max.abs.NaN.f32 $0, $1, $2;" + + return value_type( + llvm.inline_asm( + T.f32(), + [value_ir, mask_ir], + ptx_instr, + "=f,f,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +def atomic_max_float32( + ptr, + value: Float32, + *, + positive_only: bool = True, + loc=None, + ip=None, +) -> Float32: + """Perform atomic max operation on a float32 value in global memory. + + This implementation works correctly for non-negative values (>= 0) using direct bitcast. + + :param ptr: Pointer to the memory location + :param value: The float32 value to compare and potentially store (should be >= 0) + :return: The old value at the memory location + """ + value_int = llvm.bitcast(T.i32(), value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + + old_value_int = nvvm.atomicrmw( + res=T.i32(), + op=cutlass._mlir.dialects.nvvm.AtomicOpKind.MAX, + ptr=ptr, + a=value_int, + loc=loc, + ip=ip, + ) + + return Float32(llvm.bitcast(T.f32(), old_value_int, loc=loc, ip=ip)) + + +def atomic_add_float32( + ptr, + value: Float32, + *, + loc=None, + ip=None, +) -> Float32: + """Perform atomic add operation on a float32 value in global memory. + + :param ptr: Pointer to the memory location + :param value: The float32 value to add + :return: The old value at the memory location + """ + old_value = nvvm.atomicrmw( + res=T.f32(), + op=cutlass._mlir.dialects.nvvm.AtomicOpKind.FADD, + ptr=ptr, + a=value.ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + return Float32(old_value) + + +def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: + """Compute the sigmoid function: 1 / (1 + exp(-a)). + + :param a: Input value + :param fastmath: Whether to use fast math approximations + :return: Sigmoid of input + """ + return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath)) + + +def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: + """Compute the SiLU (Swish) activation: a * sigmoid(a). + + :param a: Input value + :param fastmath: Whether to use fast math approximations + :return: SiLU of input + """ + return a * sigmoid_f32(a, fastmath=fastmath) + + +############################################################################## +# Static persistent tile scheduler +############################################################################## + + +class WorkTileInfo: + """A class to represent information about a work tile. + + :ivar tile_idx: The index of the tile. + :type tile_idx: cute.Coord + :ivar is_valid_tile: Whether the tile is valid. + :type is_valid_tile: Boolean + """ + + def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean): + self._tile_idx = tile_idx + self._is_valid_tile = Boolean(is_valid_tile) + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.tile_idx) + values.extend(extract_mlir_values(self.is_valid_tile)) + return values + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 4 + new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @property + def is_valid_tile(self) -> Boolean: + """Check latest tile returned by the scheduler is valid or not. + + Any scheduling requests after all tasks completed will return an invalid tile. + + :return: The validity of the tile. + :rtype: Boolean + """ + return self._is_valid_tile + + @property + def tile_idx(self) -> cute.Coord: + """Get the index of the tile. + + :return: The index of the tile. + :rtype: cute.Coord + """ + return self._tile_idx + + +class PersistentTileSchedulerParams: + """A class to represent parameters for a persistent tile scheduler. + + This class is designed to manage and compute the layout of clusters and tiles + in a batched gemm problem. + + :ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1). + :type cluster_shape_mn: tuple + :ivar problem_layout_ncluster_mnl: Layout of the problem in terms of + number of clusters in (m, n, l) dimensions. + :type problem_layout_ncluster_mnl: cute.Layout + """ + + @dsl_user_op + def __init__( + self, + problem_shape_ntile_mnl: cute.Shape, + cluster_shape_mnk: cute.Shape, + raster_along_m: bool = True, + swizzle_size: int = 1, + *, + loc=None, + ip=None, + ): + """Initializes the PersistentTileSchedulerParams with the given parameters. + + :param problem_shape_ntile_mnl: The shape of the problem in terms of + number of CTA (Cooperative Thread Array) in (m, n, l) dimensions. + :type problem_shape_ntile_mnl: cute.Shape + :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. + :type cluster_shape_mnk: cute.Shape + :param swizzle_size: Swizzling size in the unit of cluster. 1 means no swizzle + :type swizzle_size: int + :param raster_along_m: Rasterization order of clusters. Only used when swizzle_size > 1. + True means along M, false means along N. + :type raster_along_m: bool + + :raises ValueError: If cluster_shape_k is not 1. + """ + + if cluster_shape_mnk[2] != 1: + raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + if swizzle_size < 1: + raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") + + self.problem_shape_ntile_mnl = problem_shape_ntile_mnl + # cluster_shape_mnk is kept for reconstruction + self._cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mnk[:2] + self.swizzle_size = swizzle_size + self._raster_along_m = raster_along_m + self._loc = loc + + # By default, we follow m major (col-major) raster order, so make a col-major layout + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div(self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + # Apply swizzle if swizzle_size > 1 + if swizzle_size > 1: + problem_shape_ncluster_mnl = cute.round_up( + self.problem_layout_ncluster_mnl.shape, + (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), + ) + + if raster_along_m: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + problem_shape_ncluster_mnl[0], + (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), + problem_shape_ncluster_mnl[2], + ), + stride=( + swizzle_size, + (1, swizzle_size * problem_shape_ncluster_mnl[0]), + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + ), + loc=loc, + ip=ip, + ) + else: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), + problem_shape_ncluster_mnl[1], + problem_shape_ncluster_mnl[2], + ), + stride=( + (1, swizzle_size * problem_shape_ncluster_mnl[1]), + swizzle_size, + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + ), + loc=loc, + ip=ip, + ) + + # Create FastDivmod divisors (only when swizzle_size == 1 for correctness) + # FastDivmod assumes simple col-major layout, incompatible with swizzled layouts + if swizzle_size == 1: + problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip) + cluster_count_m = self.problem_layout_ncluster_mnl.shape[0] + cluster_count_n = self.problem_layout_ncluster_mnl.shape[1] + + # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling) + self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip) + + # cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates + self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip) + + # cluster_shape_n_fdd: Used for the second level decomposition + self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip) + else: + # FastDivmod not applicable with swizzling, set to None + self.batch_fdd = None + self.cluster_shape_m_fdd = None + self.cluster_shape_n_fdd = None + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self._raster_along_m, + self.swizzle_size, + ]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + + # Add FastDivmod divisors to MLIR values for Host->Device transfer + # Only add non-None values to avoid MLIR type errors + fastdivmod_values = [] + fastdivmod_indices = [] # Track which FastDivmod objects are present + + for i, (fdd_name, fdd_obj) in enumerate( + [ + ("batch_fdd", self.batch_fdd), + ("cluster_shape_m_fdd", self.cluster_shape_m_fdd), + ("cluster_shape_n_fdd", self.cluster_shape_n_fdd), + ] + ): + if fdd_obj is not None: + # Extract MLIR values from FastDivmodDivisor objects + fdd_values = extract_mlir_values(fdd_obj) + fastdivmod_values.extend(fdd_values) + fastdivmod_indices.append(i) + + values += fastdivmod_values + self._values_pos.append(len(fastdivmod_indices)) # Store count of FastDivmod objects, not values + self._fastdivmod_indices = fastdivmod_indices # Store for reconstruction + + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + values_copy = list(values) # Make a copy to avoid modifying original + + # Reconstruct original objects from MLIR values + for obj, n_items in zip( + [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self._raster_along_m, + self.swizzle_size, + ], + self._values_pos[:-1], # Exclude FastDivmod count + ): + obj_list.append(new_from_mlir_values(obj, values_copy[:n_items])) + values_copy = values_copy[n_items:] + + # Create new params object by calling __init__ with reconstructed values + # This properly recreates layouts and other derived attributes in the device context + new_params = PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + # Restore FastDivmod divisors from remaining values + fdd_names = ["batch_fdd", "cluster_shape_m_fdd", "cluster_shape_n_fdd"] + + if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0: + # Override the FastDivmod divisors created by __init__ with reconstructed ones + for j, original_index in enumerate(self._fastdivmod_indices): + fdd_name = fdd_names[original_index] + # Get the original FastDivmodDivisor object + original_fdd = getattr(self, fdd_name) + if original_fdd is not None and j < len(values_copy): + # Each FastDivmodDivisor has 1 MLIR value + reconstructed_fdd = new_from_mlir_values(original_fdd, [values_copy[j]]) + setattr(new_params, fdd_name, reconstructed_fdd) + + return new_params + + @dsl_user_op + def get_grid_shape(self, max_active_clusters: Int32, *, loc=None, ip=None) -> Tuple[Integer, Integer, Integer]: + """Computes the grid shape based on the maximum active clusters allowed. + + :param max_active_clusters: The maximum number of active clusters that + can run in one wave. + :type max_active_clusters: Int32 + + :return: A tuple containing the grid shape in (m, n, persistent_clusters). + - m: self.cluster_shape_m. + - n: self.cluster_shape_n. + - persistent_clusters: Number of persistent clusters that can run. + """ + + # Total ctas in problem size + num_ctas_mnl = tuple(cute.size(x) * y for x, y in zip(self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn)) + ( + self.problem_layout_ncluster_mnl.shape[2], + ) + + num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + + num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip) + # Total ctas that can run in one wave + num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster + + num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave) + num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster + + return (*self.cluster_shape_mn, num_persistent_clusters) + + +class StaticPersistentTileScheduler: + """A scheduler for static persistent tile execution in CUTLASS/CuTe kernels. + + :ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl + :type params: PersistentTileSchedulerParams + :ivar num_persistent_clusters: Number of persistent clusters that can be launched + :type num_persistent_clusters: Int32 + :ivar cta_id_in_cluster: ID of the CTA within its cluster + :type cta_id_in_cluster: cute.Coord + :ivar _num_tiles_executed: Counter for executed tiles + :type _num_tiles_executed: Int32 + :ivar _current_work_linear_idx: Current cluster index + :type _current_work_linear_idx: Int32 + """ + + def __init__( + self, + params: PersistentTileSchedulerParams, + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + ): + """Initializes the StaticPersistentTileScheduler with the given parameters. + + :param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl. + :type params: PersistentTileSchedulerParams + :param num_persistent_clusters: Number of persistent clusters that can be launched. + :type num_persistent_clusters: Int32 + :param current_work_linear_idx: Current cluster index. + :type current_work_linear_idx: Int32 + :param cta_id_in_cluster: ID of the CTA within its cluster. + :type cta_id_in_cluster: cute.Coord + :param num_tiles_executed: Counter for executed tiles. + :type num_tiles_executed: Int32 + """ + self.params = params + self.num_persistent_clusters = num_persistent_clusters + self._current_work_linear_idx = current_work_linear_idx + self.cta_id_in_cluster = cta_id_in_cluster + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_clusters) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self.cta_id_in_cluster)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + + # CRITICAL: Also extract FastDivmod divisors from params + values.extend(extract_mlir_values(self.params)) + + return values + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "StaticPersistentTileScheduler": + assert len(values) >= 6 + new_num_persistent_clusters = new_from_mlir_values(self.num_persistent_clusters, [values[0]]) + new_current_work_linear_idx = new_from_mlir_values(self._current_work_linear_idx, [values[1]]) + new_cta_id_in_cluster = new_from_mlir_values(self.cta_id_in_cluster, values[2:5]) + new_num_tiles_executed = new_from_mlir_values(self._num_tiles_executed, [values[5]]) + + # Reconstruct params with FastDivmod divisors + params_values = values[6:] # Remaining values are from params + new_params = new_from_mlir_values(self.params, params_values) + + return StaticPersistentTileScheduler( + new_params, # Use reconstructed params with FastDivmod divisors + new_num_persistent_clusters, + new_current_work_linear_idx, + new_cta_id_in_cluster, + new_num_tiles_executed, + ) + + @staticmethod + @dsl_user_op + def create( + params: PersistentTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + """Initialize the static persistent tile scheduler. + + :param params: Parameters for the persistent tile scheduler. + :type params: PersistentTileSchedulerParams + :param block_idx: The 3d block index in the format (bidx, bidy, bidz). + :type block_idx: Tuple[Integer, Integer, Integer] + :param grid_dim: The 3d grid dimensions for kernel launch. + :type grid_dim: Tuple[Integer, Integer, Integer] + + :return: A StaticPersistentTileScheduler object. + :rtype: StaticPersistentTileScheduler + """ + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(params.cluster_shape_mn, loc=loc, ip=ip) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidz) + + # CTA id in the cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return StaticPersistentTileScheduler( + params, + num_persistent_clusters, + current_work_linear_idx, + cta_id_in_cluster, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: PersistentTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + """Calculates the grid shape to be launched on GPU using problem shape, + threadblock shape, and active cluster size. + + :param params: Parameters for grid shape calculation. + :type params: PersistentTileSchedulerParams + :param max_active_clusters: Maximum active clusters allowed. + :type max_active_clusters: Int32 + + :return: The calculated 3d grid shape. + :rtype: Tuple[Integer, Integer, Integer] + """ + + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx(self, current_work_linear_idx: Int32, *, loc=None, ip=None) -> WorkTileInfo: + """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. + + :param current_work_linear_idx: The linear index of the current work. + :type current_work_linear_idx: Int32 + + :return: An object containing information about the current tile coordinates + and validity status. + :rtype: WorkTileInfo + """ + + is_valid = current_work_linear_idx < cute.size(self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip) + + # Choose coordinate calculation method based on swizzle configuration + if self.params.swizzle_size == 1: + # Use FastDivmod optimization for non-swizzled layouts + cur_cluster_coord = self._get_cluster_work_idx_with_fastdivmod(current_work_linear_idx, loc=loc, ip=ip) + else: + # Use get_flat_coord for swizzled layouts (FastDivmod doesn't support them) + cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_flat_coord(current_work_linear_idx, loc=loc, ip=ip) + + # cur_tile_coord is a tuple of i32 values + cur_tile_coord = tuple( + Int32(x) * Int32(z) + Int32(y) + for x, y, z in zip( + cur_cluster_coord, + self.cta_id_in_cluster, + (*self.params.cluster_shape_mn, Int32(1)), + ) + ) + + return WorkTileInfo(cur_tile_coord, is_valid) + + def _get_cluster_work_idx_with_fastdivmod(self, current_work_linear_idx: Int32, *, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """FastDivmod optimized CLUSTER coordinate calculation. + + CRITICAL: This should mimic problem_layout_ncluster_mnl.get_hier_coord() + which returns CLUSTER coordinates, not tile coordinates! + + :param current_work_linear_idx: Linear index in the work space + :type current_work_linear_idx: Int32 + :return: Cluster coordinates (m, n, l) or None if FastDivmod not available + :rtype: Tuple[Int32, Int32, Int32] or None + """ + + # Step 1: Handle persistent scheduling - map linear_idx to work_unit_id + work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd) + + # Step 2: Decode work_unit_id using FastDivmod objects + # The layout structure is: problem_layout_ncluster_mnl has shape (cluster_count_m, cluster_count_n, batch_count) + # work_unit_id needs to be decomposed into (batch_l, cluster_n, cluster_m) in little-endian order + + # First, get cluster_m using cluster_shape_m_fdd + cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd) + + # Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod + batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd) + + return (cluster_m, cluster_n, batch_l) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx(self._current_work_linear_idx, loc=loc, ip=ip) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32(self.num_persistent_clusters) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + +class StaticPersistentRuntimeTileScheduler(StaticPersistentTileScheduler): + """A scheduler for static persistent runtime tile execution in CUTLASS/CuTe kernels. + + This scheduler will always launch all the SMs and the scheduler will generate + the real tile info for each SM. + + :ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl + :type params: PersistentTileSchedulerParams + :ivar num_persistent_clusters: Number of persistent clusters that can be launched + :type num_persistent_clusters: Int32 + :ivar cta_id_in_cluster: ID of the CTA within its cluster + :type cta_id_in_cluster: cute.Coord + :ivar _num_tiles_executed: Counter for executed tiles + :type _num_tiles_executed: Int32 + :ivar _current_work_linear_idx: Current cluster index + :type _current_work_linear_idx: Int32 + """ + + def __init__( + self, + params: PersistentTileSchedulerParams, + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + inner_mode: int = 1, + ): + """Initializes the StaticPersistentRuntimeTileScheduler with the given parameters. + + :param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl. + :type params: PersistentTileSchedulerParams + :param num_persistent_clusters: Number of persistent clusters that can be launched. + :type num_persistent_clusters: Int32 + :param current_work_linear_idx: Current cluster index. + :type current_work_linear_idx: Int32 + :param cta_id_in_cluster: ID of the CTA within its cluster. + :type cta_id_in_cluster: cute.Coord + :param num_tiles_executed: Counter for executed tiles. + :type num_tiles_executed: Int32 + :param inner_mode: The inner mode along which the linear index will be decomposed first. + :type inner_mode: int + """ + super().__init__( + params, + num_persistent_clusters, + current_work_linear_idx, + cta_id_in_cluster, + num_tiles_executed, + ) + if inner_mode not in [0, 1]: + raise ValueError(f"inner_mode must be 0(for M mode) or 1(for N mode), but got {inner_mode}") + self.inner_mode = inner_mode + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "StaticPersistentRuntimeTileScheduler": + assert len(values) >= 6 + new_num_persistent_clusters = new_from_mlir_values(self.num_persistent_clusters, [values[0]]) + new_current_work_linear_idx = new_from_mlir_values(self._current_work_linear_idx, [values[1]]) + new_cta_id_in_cluster = new_from_mlir_values(self.cta_id_in_cluster, values[2:5]) + new_num_tiles_executed = new_from_mlir_values(self._num_tiles_executed, [values[5]]) + + # Reconstruct params with FastDivmod divisors (same as parent class) + params_values = values[6:] # Remaining values are from params + new_params = new_from_mlir_values(self.params, params_values) + + return StaticPersistentRuntimeTileScheduler( + new_params, # Use reconstructed params with FastDivmod divisors + new_num_persistent_clusters, + new_current_work_linear_idx, + new_cta_id_in_cluster, + new_num_tiles_executed, + self.inner_mode, + ) + + @staticmethod + @dsl_user_op + def create( + params: PersistentTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + inner_mode: int = 1, + *, + loc=None, + ip=None, + ): + """Initialize the static persistent tile scheduler. + + :param params: Parameters for the persistent tile scheduler. + :type params: PersistentTileSchedulerParams + :param block_idx: The 3d block index in the format (bidx, bidy, bidz). + :type block_idx: Tuple[Integer, Integer, Integer] + :param grid_dim: The 3d grid dimensions for kernel launch. + :type grid_dim: Tuple[Integer, Integer, Integer] + :param inner_mode: The inner mode along which the linear index will be decomposed first. + :type inner_mode: int + + :return: A StaticPersistentRuntimeTileScheduler object. + :rtype: StaticPersistentRuntimeTileScheduler + """ + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(params.cluster_shape_mn, loc=loc, ip=ip) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidz) + + # CTA id in the cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return StaticPersistentRuntimeTileScheduler( + params, + num_persistent_clusters, + current_work_linear_idx, + cta_id_in_cluster, + num_tiles_executed, + inner_mode, + ) + + # private method + def _get_current_work_for_linear_idx(self, current_work_linear_idx: Int32, *, loc=None, ip=None) -> WorkTileInfo: + """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. + + :param current_work_linear_idx: The linear index of the current work. + :type current_work_linear_idx: Int32 + + :return: An object containing information about the current tile coordinates + and validity status. + :rtype: WorkTileInfo + """ + ntile_shape = self.params.problem_layout_ncluster_mnl.shape + int_max = 2147483647 + if const_expr(self.inner_mode == 1): + ntile_layout = cute.make_layout((int_max, ntile_shape[1]), stride=(ntile_shape[1], 1)) + else: + ntile_layout = cute.make_layout((ntile_shape[0], int_max), stride=(1, ntile_shape[0])) + cluster_tile_coord_mn = ntile_layout.get_hier_coord(current_work_linear_idx) + cur_tile_coord = ( + cluster_tile_coord_mn[0], + cluster_tile_coord_mn[1], + Int32(0), + ) + + # it is determined by kernel implementation + is_valid = True + + return WorkTileInfo(cur_tile_coord, is_valid) diff --git a/python/cudnn/native_sparse_attention/compression/api.py b/python/cudnn/native_sparse_attention/compression/api.py index 3931038c..f38f4251 100644 --- a/python/cudnn/native_sparse_attention/compression/api.py +++ b/python/cudnn/native_sparse_attention/compression/api.py @@ -94,38 +94,21 @@ def check_support(self) -> bool: b, h_q, s_qo, d_v = self.sample_o.shape if self.sample_q.shape != (b, h_qo, s_qo, d_qk): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {b, h_qo, s_qo, d_qk}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {b, h_qo, s_qo, d_qk}, got {self.sample_q.shape}") if self.sample_k.shape != (b, h_kv, s_kv, d_qk): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {b, h_kv, s_kv, d_qk}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {b, h_kv, s_kv, d_qk}, got {self.sample_k.shape}") if self.sample_v.shape != (b, h_kv, s_kv, d_v): - raise ValueError( - f"Input shape mismatch: expected V tensor shape {b, h_kv, s_kv, d_v}, got {self.sample_v.shape}" - ) + raise ValueError(f"Input shape mismatch: expected V tensor shape {b, h_kv, s_kv, d_v}, got {self.sample_v.shape}") if self.sample_o.shape != (b, h_q, s_qo, d_v): - raise ValueError( - f"Output shape mismatch: expected O tensor shape {b, h_q, s_qo, d_v}, got {self.sample_o.shape}" - ) + raise ValueError(f"Output shape mismatch: expected O tensor shape {b, h_q, s_qo, d_v}, got {self.sample_o.shape}") if self.enable_lse: - self.sample_lse = self._unpad_tensor_to_ndim( - self.sample_lse, 3, "sample_lse" - ) + self.sample_lse = self._unpad_tensor_to_ndim(self.sample_lse, 3, "sample_lse") if self.sample_lse.shape != (b, h_q, s_qo): - raise ValueError( - f"Output shape mismatch: expected LSE tensor shape {b, h_q, s_qo}, got {self.sample_lse.shape}" - ) + raise ValueError(f"Output shape mismatch: expected LSE tensor shape {b, h_q, s_qo}, got {self.sample_lse.shape}") if not self.sample_lse.is_contiguous(): raise ValueError("LSE tensor must be contiguous") - if ( - self.sample_cum_seqlen_q is not None - or self.sample_cum_seqlen_k is not None - ): - self._logger.warning( - "sample_cum_seqlen_q and sample_cum_seqlen_k are ignored for B,H,S,D layout" - ) + if self.sample_cum_seqlen_q is not None or self.sample_cum_seqlen_k is not None: + self._logger.warning("sample_cum_seqlen_q and sample_cum_seqlen_k are ignored for B,H,S,D layout") # Shapes self.batch_size = b @@ -144,40 +127,24 @@ def check_support(self) -> bool: t, h_q, d_v = self.sample_o.shape if self.sample_q.shape != (t, h_q, d_qk): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}") if self.sample_k.shape != (t_kv, h_kv, d_qk): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {t_kv, h_kv, d_qk}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {t_kv, h_kv, d_qk}, got {self.sample_k.shape}") if self.sample_v.shape != (t_kv, h_kv, d_v): - raise ValueError( - f"Input shape mismatch: expected V tensor shape {t_kv, h_kv, d_v}, got {self.sample_v.shape}" - ) + raise ValueError(f"Input shape mismatch: expected V tensor shape {t_kv, h_kv, d_v}, got {self.sample_v.shape}") if self.sample_o.shape != (t, h_q, d_v): - raise ValueError( - f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}" - ) + raise ValueError(f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}") if self.enable_lse: - self.sample_lse = self._unpad_tensor_to_ndim( - self.sample_lse, 2, "sample_lse" - ) + self.sample_lse = self._unpad_tensor_to_ndim(self.sample_lse, 2, "sample_lse") if self.sample_lse.shape != (t, h_q): - raise ValueError( - f"Output shape mismatch: expected LSE tensor shape {t, h_q}, got {self.sample_lse.shape}" - ) + raise ValueError(f"Output shape mismatch: expected LSE tensor shape {t, h_q}, got {self.sample_lse.shape}") if self.sample_cum_seqlen_q is None or self.sample_cum_seqlen_k is None: raise ValueError( f"sample_cum_seqlen_q and sample_cum_seqlen_k must be provided for T,H,D layout, got {self.sample_cum_seqlen_q} and {self.sample_cum_seqlen_k}" ) - self.sample_cum_seqlen_q = self._unpad_tensor_to_ndim( - self.sample_cum_seqlen_q, 1, "sample_cum_seqlen_q" - ) - self.sample_cum_seqlen_k = self._unpad_tensor_to_ndim( - self.sample_cum_seqlen_k, 1, "sample_cum_seqlen_k" - ) + self.sample_cum_seqlen_q = self._unpad_tensor_to_ndim(self.sample_cum_seqlen_q, 1, "sample_cum_seqlen_q") + self.sample_cum_seqlen_k = self._unpad_tensor_to_ndim(self.sample_cum_seqlen_k, 1, "sample_cum_seqlen_k") if self.sample_cum_seqlen_q.ndim != 1 or self.sample_cum_seqlen_k.ndim != 1: raise ValueError( f"sample_cum_seqlen_q and sample_cum_seqlen_k must be 1D tensors, got {self.sample_cum_seqlen_q.ndim}D and {self.sample_cum_seqlen_k.ndim}D" @@ -203,9 +170,7 @@ def check_support(self) -> bool: self.head_dim = d_qk else: - raise ValueError( - f"Invalid input layout: sample_q must be rank-3 (T,H,D) or rank-4 (B,H,S,D), got {self.sample_q.ndim}" - ) + raise ValueError(f"Invalid input layout: sample_q must be rank-3 (T,H,D) or rank-4 (B,H,S,D), got {self.sample_q.ndim}") if d_qk != d_v: raise ValueError("D_qk must match D_v") if d_qk not in {32, 64, 128}: @@ -217,25 +182,15 @@ def check_support(self) -> bool: in_dtype = self.sample_q.dtype out_dtype = self.sample_o.dtype if self.sample_k.dtype != in_dtype or self.sample_v.dtype != in_dtype: - raise ValueError( - f"Inputs must have the same dtype, got K {self.sample_k.dtype}, V {self.sample_v.dtype} for Q {in_dtype}" - ) + raise ValueError(f"Inputs must have the same dtype, got K {self.sample_k.dtype}, V {self.sample_v.dtype} for Q {in_dtype}") if in_dtype not in {torch.float16, torch.bfloat16, torch.float8_e4m3fn}: - raise ValueError( - f"Inputs must be Float16, BFloat16, or Float8E4M3FN, got {in_dtype}" - ) + raise ValueError(f"Inputs must be Float16, BFloat16, or Float8E4M3FN, got {in_dtype}") if out_dtype not in {torch.float16, torch.bfloat16, torch.float8_e4m3fn}: - raise ValueError( - f"Outputs must be Float16, BFloat16, or Float8E4M3FN, got {out_dtype}" - ) + raise ValueError(f"Outputs must be Float16, BFloat16, or Float8E4M3FN, got {out_dtype}") if self.qk_acc_dtype_torch not in {torch.float32}: - raise ValueError( - f"qk_acc_dtype must be Float32, got {self.qk_acc_dtype_torch}" - ) + raise ValueError(f"qk_acc_dtype must be Float32, got {self.qk_acc_dtype_torch}") if self.pv_acc_dtype_torch not in {torch.float32}: - raise ValueError( - f"pv_acc_dtype must be Float32, got {self.pv_acc_dtype_torch}" - ) + raise ValueError(f"pv_acc_dtype must be Float32, got {self.pv_acc_dtype_torch}") # Scale defaults if self.scale_softmax is None: @@ -250,9 +205,7 @@ def check_support(self) -> bool: major, minor = torch.cuda.get_device_capability(device) compute_capability = major * 10 + minor if compute_capability < 100: - raise RuntimeError( - f"CompressionAttention requires SM100+ compute capability, but found SM{compute_capability} on device {device}" - ) + raise RuntimeError(f"CompressionAttention requires SM100+ compute capability, but found SM{compute_capability} on device {device}") if compute_capability == 103: raise RuntimeError("cuteDSL is not supported on SM103") @@ -279,16 +232,8 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: scale_softmax_log2 = scale_softmax * log2_e scale_output = self.scale_v * self.inv_scale_o - s_q = ( - self.s_q - if self.input_layout == "B,H,S,D" - else max(self.sample_cum_seqlen_q).item() - ) - s_kv = ( - self.s_kv - if self.input_layout == "B,H,S,D" - else max(self.sample_cum_seqlen_k).item() - ) + s_q = self.s_q if self.input_layout == "B,H,S,D" else max(self.sample_cum_seqlen_q).item() + s_kv = self.s_kv if self.input_layout == "B,H,S,D" else max(self.sample_cum_seqlen_k).item() self.problem_size = ( self.batch_size, s_q, @@ -303,50 +248,18 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: self._compiled_kernel = cute.compile( fmha_kernel, q_iter=from_dlpack(self.sample_q, assumed_align=16).iterator, - q_stride=( - self.sample_q.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (self.sample_q.stride()[0], *self.sample_q.stride()) - ), + q_stride=(self.sample_q.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (self.sample_q.stride()[0], *self.sample_q.stride())), k_iter=from_dlpack(self.sample_k, assumed_align=16).iterator, - k_stride=( - self.sample_k.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (self.sample_k.stride()[0], *self.sample_k.stride()) - ), + k_stride=(self.sample_k.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (self.sample_k.stride()[0], *self.sample_k.stride())), v_iter=from_dlpack(self.sample_v, assumed_align=16).iterator, - v_stride=( - self.sample_v.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (self.sample_v.stride()[0], *self.sample_v.stride()) - ), + v_stride=(self.sample_v.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (self.sample_v.stride()[0], *self.sample_v.stride())), o_iter=from_dlpack(self.sample_o, assumed_align=16).iterator, - o_stride=( - self.sample_o.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (self.sample_o.stride()[0], *self.sample_o.stride()) - ), + o_stride=(self.sample_o.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (self.sample_o.stride()[0], *self.sample_o.stride())), problem_size=self.problem_size, - cum_seqlen_q=( - from_dlpack(self.sample_cum_seqlen_q, assumed_align=16) - if self.input_layout == "T,H,D" - else None - ), - cum_seqlen_k=( - from_dlpack(self.sample_cum_seqlen_k, assumed_align=16) - if self.input_layout == "T,H,D" - else None - ), - lse_iter=( - from_dlpack(self.sample_lse, assumed_align=16).iterator - if self.enable_lse - else None - ), - lse_stride=( - self.sample_lse.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (0, *self.sample_lse.stride()) - ), + cum_seqlen_q=(from_dlpack(self.sample_cum_seqlen_q, assumed_align=16) if self.input_layout == "T,H,D" else None), + cum_seqlen_k=(from_dlpack(self.sample_cum_seqlen_k, assumed_align=16) if self.input_layout == "T,H,D" else None), + lse_iter=(from_dlpack(self.sample_lse, assumed_align=16).iterator if self.enable_lse else None), + lse_stride=(self.sample_lse.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (0, *self.sample_lse.stride())), scale_softmax_log2=scale_softmax_log2, scale_softmax=scale_softmax, scale_output=scale_output, @@ -378,23 +291,15 @@ def execute( if self.enable_lse: if lse_tensor is None: - raise ValueError( - "kernel was compiled with lse_tensor provided, but lse_tensor was not provided during execute" - ) - lse_tensor = self._unpad_tensor_to_ndim( - lse_tensor, o_tensor.ndim - 1, "lse_tensor" - ) + raise ValueError("kernel was compiled with lse_tensor provided, but lse_tensor was not provided during execute") + lse_tensor = self._unpad_tensor_to_ndim(lse_tensor, o_tensor.ndim - 1, "lse_tensor") if self.input_layout == "T,H,D": if cum_seqlen_q_tensor is None or cum_seqlen_k_tensor is None: raise ValueError( f"cum_seqlen_q_tensor and cum_seqlen_k_tensor must be provided for T,H,D layout, got {cum_seqlen_q_tensor} and {cum_seqlen_k_tensor}" ) - cum_seqlen_q_tensor = self._unpad_tensor_to_ndim( - cum_seqlen_q_tensor, 1, "cum_seqlen_q_tensor" - ) - cum_seqlen_k_tensor = self._unpad_tensor_to_ndim( - cum_seqlen_k_tensor, 1, "cum_seqlen_k_tensor" - ) + cum_seqlen_q_tensor = self._unpad_tensor_to_ndim(cum_seqlen_q_tensor, 1, "cum_seqlen_q_tensor") + cum_seqlen_k_tensor = self._unpad_tensor_to_ndim(cum_seqlen_k_tensor, 1, "cum_seqlen_k_tensor") # Scale values scale_q = self.scale_q if scale_q is None else scale_q @@ -413,53 +318,25 @@ def execute( self._logger.debug("Executing with compiled kernel") self._compiled_kernel( q_iter=from_dlpack( - ( - q_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else q_tensor - ), + (q_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else q_tensor), assumed_align=16, ).iterator, k_iter=from_dlpack( - ( - k_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else k_tensor - ), + (k_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else k_tensor), assumed_align=16, ).iterator, v_iter=from_dlpack( - ( - v_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else v_tensor - ), + (v_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else v_tensor), assumed_align=16, ).iterator, o_iter=from_dlpack( - ( - o_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else o_tensor - ), + (o_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else o_tensor), assumed_align=16, ).iterator, problem_size=self.problem_size, - cum_seqlen_q=( - from_dlpack(cum_seqlen_q_tensor, assumed_align=16).iterator - if self.input_layout == "T,H,D" - else None - ), - cum_seqlen_k=( - from_dlpack(cum_seqlen_k_tensor, assumed_align=16).iterator - if self.input_layout == "T,H,D" - else None - ), - lse_iter=( - from_dlpack(lse_tensor, assumed_align=16).iterator - if self.enable_lse - else None - ), + cum_seqlen_q=(from_dlpack(cum_seqlen_q_tensor, assumed_align=16).iterator if self.input_layout == "T,H,D" else None), + cum_seqlen_k=(from_dlpack(cum_seqlen_k_tensor, assumed_align=16).iterator if self.input_layout == "T,H,D" else None), + lse_iter=(from_dlpack(lse_tensor, assumed_align=16).iterator if self.enable_lse else None), scale_softmax_log2=scale_softmax_log2_val, scale_softmax=scale_softmax_val, scale_output=scale_output_val, @@ -479,78 +356,30 @@ def execute( ) fmha_kernel( q_iter=from_dlpack( - ( - q_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else q_tensor - ), + (q_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else q_tensor), assumed_align=16, ).iterator, - q_stride=( - q_tensor.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (q_tensor.stride()[0], *q_tensor.stride()) - ), + q_stride=(q_tensor.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (q_tensor.stride()[0], *q_tensor.stride())), k_iter=from_dlpack( - ( - k_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else k_tensor - ), + (k_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else k_tensor), assumed_align=16, ).iterator, - k_stride=( - k_tensor.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (k_tensor.stride()[0], *k_tensor.stride()) - ), + k_stride=(k_tensor.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (k_tensor.stride()[0], *k_tensor.stride())), v_iter=from_dlpack( - ( - v_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else v_tensor - ), + (v_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else v_tensor), assumed_align=16, ).iterator, - v_stride=( - v_tensor.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (v_tensor.stride()[0], *v_tensor.stride()) - ), + v_stride=(v_tensor.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (v_tensor.stride()[0], *v_tensor.stride())), o_iter=from_dlpack( - ( - o_tensor.transpose(1, 2) - if self.input_layout == "B,H,S,D" - else o_tensor - ), + (o_tensor.transpose(1, 2) if self.input_layout == "B,H,S,D" else o_tensor), assumed_align=16, ).iterator, - o_stride=( - o_tensor.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (o_tensor.stride()[0], *o_tensor.stride()) - ), + o_stride=(o_tensor.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (o_tensor.stride()[0], *o_tensor.stride())), problem_size=self.problem_size, - cum_seqlen_q=( - from_dlpack(cum_seqlen_q_tensor, assumed_align=16) - if self.input_layout == "T,H,D" - else None - ), - cum_seqlen_k=( - from_dlpack(cum_seqlen_k_tensor, assumed_align=16) - if self.input_layout == "T,H,D" - else None - ), - lse_iter=( - from_dlpack(lse_tensor, assumed_align=16).iterator - if self.enable_lse - else None - ), - lse_stride=( - lse_tensor.transpose(1, 2).stride() - if self.input_layout == "B,H,S,D" - else (0, *lse_tensor.stride()) - ), + cum_seqlen_q=(from_dlpack(cum_seqlen_q_tensor, assumed_align=16) if self.input_layout == "T,H,D" else None), + cum_seqlen_k=(from_dlpack(cum_seqlen_k_tensor, assumed_align=16) if self.input_layout == "T,H,D" else None), + lse_iter=(from_dlpack(lse_tensor, assumed_align=16).iterator if self.enable_lse else None), + lse_stride=(lse_tensor.transpose(1, 2).stride() if self.input_layout == "B,H,S,D" else (0, *lse_tensor.stride())), scale_softmax_log2=scale_softmax_log2_val, scale_softmax=scale_softmax_val, scale_output=scale_output_val, @@ -592,9 +421,7 @@ def compression_attention_wrapper( Returns: tuple: (o_tensor, lse_tensor | None) """ - _logger.debug( - "compression_attention_wrapper: Creating empty output tensor o and optional lse" - ) + _logger.debug("compression_attention_wrapper: Creating empty output tensor o and optional lse") o_tensor, lse_tensor = None, None o_dtype = o_dtype if o_dtype is not None else q_tensor.dtype @@ -602,30 +429,18 @@ def compression_attention_wrapper( b, h_q, s_q, d = q_tensor.shape _, h_k, s_k, d_v = v_tensor.shape - o_tensor = make_tensor_strided_like( - q_tensor, (b, h_q, s_q, d_v), dtype=o_dtype, device=q_tensor.device - ) + o_tensor = make_tensor_strided_like(q_tensor, (b, h_q, s_q, d_v), dtype=o_dtype, device=q_tensor.device) if enable_lse: - lse_tensor = torch.empty( - b, h_q, s_q, dtype=torch.float32, device=q_tensor.device - ).contiguous() + lse_tensor = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q_tensor.device).contiguous() elif q_tensor.ndim == 3: # thd t, h_q, d = q_tensor.shape _, h_k, d_v = v_tensor.shape - o_tensor = make_tensor_strided_like( - q_tensor, (t, h_q, d_v), dtype=o_dtype, device=q_tensor.device - ) + o_tensor = make_tensor_strided_like(q_tensor, (t, h_q, d_v), dtype=o_dtype, device=q_tensor.device) if enable_lse: - lse_tensor = ( - torch.empty(1, h_q, t, dtype=torch.float32, device=q_tensor.device) - .contiguous() - .permute(2, 1, 0) - ) + lse_tensor = torch.empty(1, h_q, t, dtype=torch.float32, device=q_tensor.device).contiguous().permute(2, 1, 0) else: - raise ValueError( - f"Invalid input layout: q_tensor must be rank-4 (B,H,S,D) or rank-3 (T,H,D), got {q_tensor.ndim}" - ) + raise ValueError(f"Invalid input layout: q_tensor must be rank-4 (B,H,S,D) or rank-3 (T,H,D), got {q_tensor.ndim}") cache_key = ( q_tensor.shape, @@ -656,9 +471,7 @@ def compression_attention_wrapper( scale_softmax, ) if cache_key in _cache_of_CompressionAttentionObjects: - _logger.debug( - "compression_attention_wrapper: Using previously cached CompressionAttention object" - ) + _logger.debug("compression_attention_wrapper: Using previously cached CompressionAttention object") comp_attn = _cache_of_CompressionAttentionObjects[cache_key] comp_attn.execute( q_tensor=q_tensor, @@ -671,9 +484,7 @@ def compression_attention_wrapper( current_stream=stream, ) else: - _logger.debug( - "compression_attention_wrapper: No cached object found, creating new CompressionAttention object" - ) + _logger.debug("compression_attention_wrapper: No cached object found, creating new CompressionAttention object") comp_attn = CompressionAttention( sample_q=q_tensor, sample_k=k_tensor, diff --git a/python/cudnn/native_sparse_attention/compression/fmha.py b/python/cudnn/native_sparse_attention/compression/fmha.py index 0cd789d1..8eb72709 100644 --- a/python/cudnn/native_sparse_attention/compression/fmha.py +++ b/python/cudnn/native_sparse_attention/compression/fmha.py @@ -208,10 +208,7 @@ def __init__( self.buffer_align_bytes = 1024 num_warps_per_warpgroup = 4 - self.softmax_warpgroup_count = ( - len((*self.softmax0_warp_ids, *self.softmax1_warp_ids)) - // num_warps_per_warpgroup - ) + self.softmax_warpgroup_count = len((*self.softmax0_warp_ids, *self.softmax1_warp_ids)) // num_warps_per_warpgroup def _setup_attributes(self): """Set up configurations and parameters for the FMHA kernel operation. @@ -511,9 +508,7 @@ class SharedStorage: mma_s1_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] s0_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] s1_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] - s0_s1_sequence_mbar_ptr: cute.struct.MemRange[ - Int64, self.softmax_warpgroup_count - ] + s0_s1_sequence_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_warpgroup_count] corr_epi_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] mma_corr_mbar_ptr: cute.struct.MemRange[Int64, self.mma_corr_stage * 2] tmem_dealloc_mbar_ptr: cute.struct.MemRange[Int64, 1] @@ -686,69 +681,45 @@ def kernel( mma_s0_producer, mma_s0_consumer = pipeline.PipelineUmmaAsync.create( num_stages=self.mma_softmax_stage, producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax0_warp_ids) - ), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax0_warp_ids)), barrier_storage=storage.mma_s0_mbar_ptr.data_ptr(), ).make_participants() mma_s1_producer, mma_s1_consumer = pipeline.PipelineUmmaAsync.create( num_stages=self.mma_softmax_stage, producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax1_warp_ids) - ), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax1_warp_ids)), barrier_storage=storage.mma_s1_mbar_ptr.data_ptr(), ).make_participants() s0_corr_producer, s0_corr_consumer = pipeline.PipelineAsync.create( num_stages=self.softmax_corr_stage, - producer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax0_warp_ids) - ), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.correction_warp_ids) - ), + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax0_warp_ids)), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.correction_warp_ids)), barrier_storage=storage.s0_corr_mbar_ptr.data_ptr(), ).make_participants() s1_corr_producer, s1_corr_consumer = pipeline.PipelineAsync.create( num_stages=self.softmax_corr_stage, - producer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax1_warp_ids) - ), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.correction_warp_ids) - ), + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax1_warp_ids)), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.correction_warp_ids)), barrier_storage=storage.s1_corr_mbar_ptr.data_ptr(), ).make_participants() corr_epi_producer, corr_epi_consumer = pipeline.PipelineAsync.create( num_stages=self.epi_stage, - producer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.correction_warp_ids) - ), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len([self.epilogue_warp_id]) - ), + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.correction_warp_ids)), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len([self.epilogue_warp_id])), barrier_storage=storage.corr_epi_mbar_ptr.data_ptr(), ).make_participants() mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( num_stages=self.mma_corr_stage, producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.correction_warp_ids) - ), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.correction_warp_ids)), barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), ).make_participants() - s0_s1_sequence_producer, s0_s1_sequence_consumer = ( - pipeline.PipelineAsync.create( - num_stages=1, - producer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax0_warp_ids) - ), - consumer_group=make_thread_cooperative_group( - self.threads_per_warp * len(self.softmax1_warp_ids) - ), - barrier_storage=storage.s0_s1_sequence_mbar_ptr.data_ptr(), - ).make_participants() - ) + s0_s1_sequence_producer, s0_s1_sequence_consumer = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax0_warp_ids)), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.softmax1_warp_ids)), + barrier_storage=storage.s0_s1_sequence_mbar_ptr.data_ptr(), + ).make_participants() tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() # Correction & Epilogue & tmem barrier init @@ -768,32 +739,22 @@ def kernel( # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) - sQ = storage.sQ.get_tensor( - q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner - ) + sQ = storage.sQ.get_tensor(q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner) # (MMA, MMA_K, MMA_D, PIPE) - sK = storage.sK.get_tensor( - k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner - ) + sK = storage.sK.get_tensor(k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner) # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) - sO = storage.sO.get_tensor( - o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner - ) + sO = storage.sO.get_tensor(o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner) qk_thr_mma = qk_tiled_mma.get_slice(0) # default 1sm pv_thr_mma = pv_tiled_mma.get_slice(0) # default 1sm tSrQ = qk_thr_mma.make_fragment_A(sQ) tSrK = qk_thr_mma.make_fragment_B(sK) tOrV = pv_thr_mma.make_fragment_B(sV) - qk_acc_shape = qk_thr_mma.partition_shape_C( - (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) - ) + qk_acc_shape = qk_thr_mma.partition_shape_C((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) tStS = qk_thr_mma.make_fragment_C(qk_acc_shape) - pv_acc_shape = pv_thr_mma.partition_shape_C( - (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) - ) + pv_acc_shape = pv_thr_mma.partition_shape_C((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) tOtO = pv_thr_mma.make_fragment_C(pv_acc_shape) tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) @@ -804,13 +765,11 @@ def kernel( tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] tOrP0 = cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout, ) tOrP1 = cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, tOrP.layout, ) cute.arch.barrier( @@ -829,9 +788,7 @@ def kernel( if warp_idx == self.load_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - tile_sched = fmha_utils.create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = fmha_utils.create_fmha_static_tile_scheduler(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: @@ -892,9 +849,7 @@ def kernel( # Local tile partition global tensors # (bM, bK, loopM, loopK, loopL) - gQ_qdl = cute.flat_divide( - mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2]) - ) + gQ_qdl = cute.flat_divide(mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2])) tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( tma_atom_q, @@ -905,9 +860,7 @@ def kernel( ) tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord_q[2]] - gK_kdl = cute.flat_divide( - mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) - ) + gK_kdl = cute.flat_divide(mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2])) tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( tma_atom_k, @@ -918,9 +871,7 @@ def kernel( ) tKgK = tKgK_kdl[None, None, 0, curr_block_coord_kv[2]] - gV_dkl = cute.flat_divide( - mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) - ) + gV_dkl = cute.flat_divide(mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2])) tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( tma_atom_v, @@ -1026,9 +977,7 @@ def kernel( barrier_id=self.tmem_alloc_sync_bar_id, number_of_threads=self.threads_per_warp, ) - tile_sched = fmha_utils.create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = fmha_utils.create_fmha_static_tile_scheduler(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: @@ -1153,9 +1102,7 @@ def kernel( tSrKi = tSrK[None, None, None, k_handle.index] # 2. gemm inner_num_kphases = cute.size(tSrQ0, mode=[2]) - for kphase_idx in cutlass.range( - inner_num_kphases, unroll_full=True - ): + for kphase_idx in cutlass.range(inner_num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( @@ -1176,9 +1123,7 @@ def kernel( s1_handle = mma_s1_producer.acquire_and_advance() # 3. gemm inner_num_kphases = cute.size(tOrP0, mode=[2]) - for kphase_idx in cutlass.range( - inner_num_kphases, unroll_full=True - ): + for kphase_idx in cutlass.range(inner_num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx) pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) cute.gemm( @@ -1198,9 +1143,7 @@ def kernel( # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i # 1. gemm inner_num_kphases = cute.size(tSrQ1, mode=[2]) - for kphase_idx in cutlass.range( - inner_num_kphases, unroll_full=True - ): + for kphase_idx in cutlass.range(inner_num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx) qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) cute.gemm( @@ -1225,9 +1168,7 @@ def kernel( s0_handle = mma_s0_producer.acquire_and_advance() # 4. gemm inner_num_kphases = cute.size(tOrP0, mode=[2]) - for kphase_idx in cutlass.range( - inner_num_kphases, unroll_full=True - ): + for kphase_idx in cutlass.range(inner_num_kphases, unroll_full=True): kphase_coord = (None, None, kphase_idx) pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) cute.gemm( @@ -1296,9 +1237,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.epilogue_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - tile_sched = fmha_utils.create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = fmha_utils.create_fmha_static_tile_scheduler(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: @@ -1334,9 +1273,7 @@ def kernel( o0_coord = 2 * curr_block_coord_o[0] o1_coord = o0_coord + 1 - gO_qdl = cute.flat_divide( - mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1]) - ) + gO_qdl = cute.flat_divide(mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1])) gO = gO_qdl[None, None, None, 0, curr_block_coord_o[2]] tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( tma_atom_o, @@ -1403,10 +1340,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Softmax1 # /////////////////////////////////////////////////////////////////////////////// - if ( - warp_idx < self.correction_warp_ids[0] - and warp_idx >= self.softmax1_warp_ids[0] - ): + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: # increase register after decreasing cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) @@ -1441,12 +1375,8 @@ def kernel( tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) - tStS_vec0 = cute.make_tensor( - tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout - ) - tStS_vec1 = cute.make_tensor( - tStS.iterator + self.tmem_vec1_offset, tStS_vec_layout - ) + tStS_vec0 = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout) + tStS_vec1 = cute.make_tensor(tStS.iterator + self.tmem_vec1_offset, tStS_vec_layout) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) @@ -1464,9 +1394,7 @@ def kernel( tTMEM_LOAD_VECtS1 = thr_tmem_load_vec.partition_S(tStS_vec1) tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) - tile_sched = fmha_utils.create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = fmha_utils.create_fmha_static_tile_scheduler(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: @@ -1494,9 +1422,7 @@ def kernel( ) if not continue_cond: - row_idx = ( - curr_block_coord[0] * self.cta_tiler[0] + tTMEM_LOAD_VECcS[0][0] - ) + row_idx = curr_block_coord[0] * self.cta_tiler[0] + tTMEM_LOAD_VECcS[0][0] if cutlass.const_expr(cum_seqlen_k is not None): cuseqlen_k = cum_seqlen_k[batch_coord] seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k @@ -1520,15 +1446,9 @@ def kernel( for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): # wait for vec0 (row_wise current max & previous max) vec0_handle = s0_corr_consumer.wait_and_advance() - tTMEM_LOAD_VECrS = cute.make_rmem_tensor( - tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype - ) - cute.copy( - tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS - ) - scale_ = scale_softmax_log2 * ( - tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] - ) + tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + scale_ = scale_softmax_log2 * (tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1]) scale = cute.math.exp2(scale_, fastmath=True) # wait for o0 @@ -1541,12 +1461,8 @@ def kernel( # wait for vec1 (row_wise current max & previous max) vec1_handle = s1_corr_consumer.wait_and_advance() - cute.copy( - tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS - ) - scale_ = scale_softmax_log2 * ( - tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] - ) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + scale_ = scale_softmax_log2 * (tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1]) scale = cute.math.exp2(scale_, fastmath=True) o1_handle = mma_corr_consumer.wait_and_advance() @@ -1559,9 +1475,7 @@ def kernel( # wait for vec0 (row_wise global sum) vec0_handle = s0_corr_consumer.wait_and_advance() - tTMEM_LOAD_VECrS = cute.make_rmem_tensor( - tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype - ) + tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype) cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) cute.arch.fence_view_async_tmem_load() vec0_handle.release() @@ -1670,9 +1584,7 @@ def softmax_step( :rtype: tuple """ cS, row_max, row_sum, vec_i_handle = iter_args - seqlen_k, seqlen_q, scale_softmax_log2, window_size_left, window_size_right = ( - value_args - ) + seqlen_k, seqlen_q, scale_softmax_log2, window_size_left, window_size_right = value_args ( mma_si_consumer, si_corr_producer, @@ -1699,9 +1611,7 @@ def softmax_step( tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) - tScS_P_layout = cute.composition( - tScS.layout, cute.make_layout((128, tilePlikeFP32)) - ) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) @@ -1736,15 +1646,12 @@ def softmax_step( smallestUnmaskedKInWarp = min( seqlen_k - 1, - (index_q - (cute.arch.thread_idx()[0] % 32)) // compression_factor - - 1, + (index_q - (cute.arch.thread_idx()[0] % 32)) // compression_factor - 1, ) largestUnmaskedKInWarp = min( seqlen_k - 1, - (index_q + 32 - (cute.arch.thread_idx()[0] % 32)) - // compression_factor - - 1, + (index_q + 32 - (cute.arch.thread_idx()[0] % 32)) // compression_factor - 1, ) if smallestUnmaskedKInWarp - tScS[0][1] < 64: @@ -1762,9 +1669,7 @@ def softmax_step( if row_max == -cutlass.Float32.inf: row_max_safe = 0.0 - tTMEM_STORE_VECrS = cute.make_rmem_tensor( - tTMEM_STORE_VECcS.shape, self.qk_acc_dtype - ) + tTMEM_STORE_VECrS = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype) tTMEM_STORE_VECrS[0] = old_row_max tTMEM_STORE_VECrS[1] = row_max_safe @@ -1790,27 +1695,19 @@ def softmax_step( frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - tTMEM_STORErS_x4_e_frg = cute.logical_divide( - tTMEM_STORErS_x4_e, cute.make_layout(frg_tile) - ) + tTMEM_STORErS_x4_e_frg = cute.logical_divide(tTMEM_STORErS_x4_e, cute.make_layout(frg_tile)) for j in cutlass.range(frg_cnt): for k in cutlass.range(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): - tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j] = ( - cute.arch.fma_packed_f32x2( - (tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j]), - (scale, scale), - (minus_row_max_scale, minus_row_max_scale), - ) + tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j] = cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), ) - tTMEM_LOADrS_frg[k, j] = cute.math.exp2( - tTMEM_LOADrS_frg[k, j], fastmath=True - ) + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - tTMEM_LOADrS_frg[k + 1, j] = cute.math.exp2( - tTMEM_LOADrS_frg[k + 1, j], fastmath=True - ) + tTMEM_LOADrS_frg[k + 1, j] = cute.math.exp2(tTMEM_LOADrS_frg[k + 1, j], fastmath=True) s_vec = tTMEM_LOADrS_frg[None, j].load() tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) @@ -1840,18 +1737,10 @@ def softmax_step( tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): - local_row_sum_0 = cute.arch.add_packed_f32x2( - local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) - ) - local_row_sum_1 = cute.arch.add_packed_f32x2( - local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1]) - ) - local_row_sum_2 = cute.arch.add_packed_f32x2( - local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2]) - ) - local_row_sum_3 = cute.arch.add_packed_f32x2( - local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3]) - ) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0])) + local_row_sum_1 = cute.arch.add_packed_f32x2(local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1])) + local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2])) + local_row_sum_3 = cute.arch.add_packed_f32x2(local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3])) local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1) local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3) @@ -1934,18 +1823,9 @@ def softmax( :type fused_mask: fmha_utils.FusedMask """ tidx, _, _ = cute.arch.thread_idx() - thread_idx = tidx % ( - self.threads_per_warp - * ( - len(self.softmax0_warp_ids) - if stage == 0 - else len(self.softmax1_warp_ids) - ) - ) + thread_idx = tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))) - cS_base = cute.make_identity_tensor( - (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) - ) + cS_base = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width tScS = qk_thr_mma.partition_C(cS_base) tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) @@ -1953,9 +1833,7 @@ def softmax( tStS_vec = cute.make_tensor(tStS.iterator + tmem_vec_offset, tStS_vec_layout) tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) - tStS_P_layout = cute.composition( - tStS.layout, cute.make_layout((128, tilePlikeFP32)) - ) + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) tmem_p_offset = self.tmem_p0_offset if stage == 0 else self.tmem_p1_offset tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout) tmem_load_atom = cute.make_copy_atom( @@ -1963,14 +1841,7 @@ def softmax( self.qk_acc_dtype, ) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) - thread_idx = tidx % ( - self.threads_per_warp - * ( - len(self.softmax0_warp_ids) - if stage == 0 - else len(self.softmax1_warp_ids) - ) - ) + thread_idx = tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))) thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) tTMEM_LOADtS = thr_tmem_load.partition_S(tStSi) tmem_store_vec_atom = cute.make_copy_atom( @@ -1989,9 +1860,7 @@ def softmax( thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) - tile_sched = fmha_utils.create_fmha_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) + tile_sched = fmha_utils.create_fmha_static_tile_scheduler(tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()) work_tile = tile_sched.initial_work_tile_info() while work_tile.is_valid_tile: @@ -2039,8 +1908,7 @@ def softmax( ) logical_offset = ( - curr_block_coord[0] * self.cta_tiler[0] - + stage * self.qk_mma_tiler[0], + curr_block_coord[0] * self.cta_tiler[0] + stage * self.qk_mma_tiler[0], 0, ) cS = cute.domain_offset(logical_offset, cS_base) @@ -2065,9 +1933,7 @@ def softmax( window_size_right, ) - for i in cutlass.range( - start_count, start_count + leading_mask_count, 1, unroll=1 - ): + for i in cutlass.range(start_count, start_count + leading_mask_count, 1, unroll=1): cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) iter_args = (cS_iter, row_max, row_sum, vec_i_handle) pipeline_args = ( @@ -2147,10 +2013,7 @@ def softmax( for i in cutlass.range( start_count + leading_mask_count + unmask_count, - start_count - + leading_mask_count - + unmask_count - + trailing_mask_count, + start_count + leading_mask_count + unmask_count + trailing_mask_count, 1, unroll=1, ): @@ -2180,9 +2043,7 @@ def softmax( tensor_args, ) si_handle = mma_si_consumer.wait_and_advance() - tTMEM_STORE_VECrS = cute.make_rmem_tensor( - tTMEM_STORE_VECcS.shape, self.qk_acc_dtype - ) + tTMEM_STORE_VECrS = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype) if row_sum == 0.0: row_sum = 1.0 @@ -2247,12 +2108,8 @@ def correction_rescale( self.pv_acc_dtype, ) - tOtO_i_layout = cute.composition( - tOtO.layout, cute.make_layout((128, corr_tile_size)) - ) - tOcO_i_layout = cute.composition( - tOcO.layout, cute.make_layout((128, corr_tile_size)) - ) + tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) @@ -2269,21 +2126,13 @@ def correction_rescale( tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i) - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype - ) + tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype) for i in cutlass.range(self.cta_tiler[2] // corr_tile_size): tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) + tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])) tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout - ) + tTMEM_LOADtO_i = cute.make_tensor(tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout) + tTMEM_STOREtO_i = cute.make_tensor(tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout) cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) for j in cutlass.range(0, cute.size(tTMrO_i), 2): @@ -2371,14 +2220,10 @@ def correction_epilog( use_2cta_instrs=False, ) - tiled_tmem_load = tcgen05.make_tmem_copy( - tmem_copy_atom, tOtO_i[(None, None), 0] - ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) - smem_copy_atom = sm100_utils.get_smem_store_op( - self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load - ) + smem_copy_atom = sm100_utils.get_smem_store_op(self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load) tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) @@ -2388,9 +2233,7 @@ def correction_epilog( for i in cutlass.range(self.cta_tiler[2] // corr_tile_size): tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] - tTMrO = cute.make_rmem_tensor( - tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype - ) + tTMrO = cute.make_rmem_tensor(tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype) cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) for j in cutlass.range(0, cute.size(tTMrO), 2): @@ -2405,10 +2248,7 @@ def correction_epilog( cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) if cutlass.const_expr(mLSE is not None): - lse = ( - cute.math.log(tTMEM_LOAD_VECrS[0], fastmath=True) - + scale_softmax * tTMEM_LOAD_VECrS[1] - ) + lse = cute.math.log(tTMEM_LOAD_VECrS[0], fastmath=True) + scale_softmax * tTMEM_LOAD_VECrS[1] if row_idx < seqlen_q: mLSE[row_idx + cuseqlen_q, blk_coord[2]] = lse diff --git a/python/cudnn/native_sparse_attention/compression/fmha_helpers.py b/python/cudnn/native_sparse_attention/compression/fmha_helpers.py index 70041e36..d0e89630 100644 --- a/python/cudnn/native_sparse_attention/compression/fmha_helpers.py +++ b/python/cudnn/native_sparse_attention/compression/fmha_helpers.py @@ -76,9 +76,7 @@ def __new_from_mlir_values__(self, values): for obj, n_items in zip([self.problem_shape_mbh], self._values_pos): obj_list.append(new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return FmhaStaticTileSchedulerParams( - self.is_persistent, *(tuple(obj_list)), loc=self._loc - ) + return FmhaStaticTileSchedulerParams(self.is_persistent, *(tuple(obj_list)), loc=self._loc) class FmhaStaticTileScheduler: @@ -135,9 +133,7 @@ def __init__( self._grid_shape = grid_shape self._is_persistent = params.is_persistent self._current_work_linear_idx = current_work_linear_idx - self._problem_shape_mbh = cute.make_layout( - params.problem_shape_mbh, loc=loc, ip=ip - ) + self._problem_shape_mbh = cute.make_layout(params.problem_shape_mbh, loc=loc, ip=ip) self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) self._is_first_block = True self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) @@ -210,17 +206,11 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: :return: WorkTileInfo containing tile coordinates and validity flag. :rtype: WorkTileInfo """ - is_valid = ( - self._current_work_linear_idx < self._num_blocks - if self._is_persistent - else self._is_first_block - ) + is_valid = self._current_work_linear_idx < self._num_blocks if self._is_persistent else self._is_first_block blk_coord = (0, 0, 0) if self._is_persistent: - blk_coord = self._problem_shape_mbh.get_hier_coord( - self._current_work_linear_idx, loc=loc, ip=ip - ) + blk_coord = self._problem_shape_mbh.get_hier_coord(self._current_work_linear_idx, loc=loc, ip=ip) else: blk_coord = self._blk_coord @@ -266,14 +256,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): assert len(values) == 10 new_params = new_from_mlir_values(self._params, values[0:3]) - new_current_work_linear_idx = new_from_mlir_values( - self._current_work_linear_idx, [values[3]] - ) + new_current_work_linear_idx = new_from_mlir_values(self._current_work_linear_idx, [values[3]]) new_blk_coord = new_from_mlir_values(self._blk_coord, values[4:7]) new_grid_shape = new_from_mlir_values(self._grid_shape, values[7:]) - return FmhaStaticTileScheduler( - new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape - ) + return FmhaStaticTileScheduler(new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape) def create_fmha_static_tile_scheduler( @@ -425,17 +411,10 @@ def get_trip_count( """ result = 0 - offset = ( - 0 - if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) - else seqlen_k - seqlen_q - ) + offset = 0 if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) else seqlen_k - seqlen_q if cutlass.const_expr(mask_type == MaskType.RESIDUAL_MASK): result = cute.ceil_div(seqlen_k, tile_shape[1]) - if cutlass.const_expr( - mask_type == MaskType.WINDOW_MASK - or mask_type == MaskType.WINDOW_MASK_INFERENCE - ): + if cutlass.const_expr(mask_type == MaskType.WINDOW_MASK or mask_type == MaskType.WINDOW_MASK_INFERENCE): if cutlass.const_expr(window_size_right is None): result = cute.ceil_div(seqlen_k, tile_shape[1]) else: @@ -447,20 +426,14 @@ def get_trip_count( elif cutlass.const_expr(mask_type == MaskType.COMPRESSED_CAUSAL_MASK): compression_factor = seqlen_q // seqlen_k - block_end = ( - (blk_coord[0] + 1) * tile_shape[0] - 1 + offset + window_size_right - ) + block_end = (blk_coord[0] + 1) * tile_shape[0] - 1 + offset + window_size_right - tmp_blocks_k = cute.ceil_div( - ((block_end + 1) // compression_factor), tile_shape[1] - ) + tmp_blocks_k = cute.ceil_div(((block_end + 1) // compression_factor), tile_shape[1]) max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) result = max(1, min(max_blocks_k, tmp_blocks_k)) - start_block = FusedMask.get_trip_start( - mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left - ) + start_block = FusedMask.get_trip_start(mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left) result = result - start_block return result @@ -493,11 +466,7 @@ def get_trip_start( :type window_size_right: Optional[Int32] """ result = 0 - offset = ( - 0 - if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) - else seqlen_k - seqlen_q - ) + offset = 0 if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) else seqlen_k - seqlen_q if cutlass.const_expr(window_size_left is not None): min_idx_q = blk_coord[0] * tile_shape[0] idx_k = min_idx_q + offset - window_size_left @@ -536,14 +505,8 @@ def get_leading_mask_id( :return: Tuple of (begin, end) tile idx for the leading mask. :rtype: Tuple[Int32, Int32] """ - offset = ( - 0 - if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) - else seqlen_k - seqlen_q - ) - leading_mask_begin = FusedMask.get_trip_start( - mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left - ) + offset = 0 if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) else seqlen_k - seqlen_q + leading_mask_begin = FusedMask.get_trip_start(mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left) trip_count = FusedMask.get_trip_count( mask_type, blk_coord, @@ -554,9 +517,7 @@ def get_leading_mask_id( window_size_right, ) min_idx_q = (blk_coord[0] + 1) * tile_shape[0] + offset - window_size_left - leading_mask_end = min( - max(min_idx_q // tile_shape[1], 0), trip_count + leading_mask_begin - 1 - ) + leading_mask_end = min(max(min_idx_q // tile_shape[1], 0), trip_count + leading_mask_begin - 1) return leading_mask_begin, leading_mask_end @cute.jit @@ -590,14 +551,8 @@ def get_trailing_mask_id( :return: Tuple of (begin, end) tile idx for the trailing mask. :rtype: Tuple[Int32, Int32] """ - offset = ( - 0 - if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) - else seqlen_k - seqlen_q - ) - trip_start = FusedMask.get_trip_start( - mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left - ) + offset = 0 if cutlass.const_expr(mask_type is not MaskType.WINDOW_MASK_INFERENCE) else seqlen_k - seqlen_q + trip_start = FusedMask.get_trip_start(mask_type, blk_coord, tile_shape, seqlen_q, seqlen_k, window_size_left) trip_count = FusedMask.get_trip_count( mask_type, blk_coord, @@ -608,9 +563,7 @@ def get_trailing_mask_id( window_size_right, ) min_idx_q = blk_coord[0] * tile_shape[0] + offset + window_size_right - trailing_mask_begin = max( - min(min_idx_q // tile_shape[1], trip_count + trip_start - 1), 0 - ) + trailing_mask_begin = max(min(min_idx_q // tile_shape[1], trip_count + trip_start - 1), 0) trailing_mask_end = trip_count + trip_start - 1 return trailing_mask_begin, trailing_mask_end @@ -700,10 +653,7 @@ def get_masked_trailing_count( """ result = 0 - if cutlass.const_expr( - mask_type == MaskType.WINDOW_MASK - or mask_type == MaskType.WINDOW_MASK_INFERENCE - ): + if cutlass.const_expr(mask_type == MaskType.WINDOW_MASK or mask_type == MaskType.WINDOW_MASK_INFERENCE): if cutlass.const_expr(window_size_right is not None): trailing_mask_begin, trailing_mask_end = FusedMask.get_trailing_mask_id( mask_type, @@ -715,16 +665,14 @@ def get_masked_trailing_count( window_size_right, ) if cutlass.const_expr(window_size_left is not None): - leading_mask_begin, leading_mask_end = ( - FusedMask.get_leading_mask_id( - mask_type, - blk_coord, - tile_shape, - seqlen_q, - seqlen_k, - window_size_left, - window_size_right, - ) + leading_mask_begin, leading_mask_end = FusedMask.get_leading_mask_id( + mask_type, + blk_coord, + tile_shape, + seqlen_q, + seqlen_k, + window_size_left, + window_size_right, ) if trailing_mask_begin <= leading_mask_end: result = trailing_mask_end - leading_mask_end @@ -865,14 +813,10 @@ def apply_mask( offset = seqlen_k - seqlen_q for i in cutlass.range(cute.size(acc_qk)): index_q, index_k = index_qk[i] - if cutlass.const_expr( - window_size_left is not None or window_size_right is not None - ): + if cutlass.const_expr(window_size_left is not None or window_size_right is not None): if cutlass.const_expr(mask_type == MaskType.COMPRESSED_CAUSAL_MASK): compression_factor = seqlen_q // seqlen_k - if ( - index_q + 1 - ) // compression_factor - 1 < index_k or index_k >= seqlen_k: + if (index_q + 1) // compression_factor - 1 < index_k or index_k >= seqlen_k: acc_qk[i] = -Float32.inf if index_k >= seqlen_k or index_q >= seqlen_q: # residual mask acc_qk[i] = -Float32.inf diff --git a/python/cudnn/native_sparse_attention/selection/NSA_select_attn_fwd_hmma.py b/python/cudnn/native_sparse_attention/selection/NSA_select_attn_fwd_hmma.py index 67f8d845..3b9db9dd 100644 --- a/python/cudnn/native_sparse_attention/selection/NSA_select_attn_fwd_hmma.py +++ b/python/cudnn/native_sparse_attention/selection/NSA_select_attn_fwd_hmma.py @@ -42,11 +42,7 @@ def __init__( assert self.dtype in [cutlass.Float16, cutlass.BFloat16] assert self.acc_dtype in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32] - assert ( - self.block_size % 8 == 0 - and self.block_size >= 16 - and self.block_size <= 128 - ), "block_size should be a multiple of 8 and >= 16 and <= 128" + assert self.block_size % 8 == 0 and self.block_size >= 16 and self.block_size <= 128, "block_size should be a multiple of 8 and >= 16 and <= 128" self.K_stage = 1 self.V_stage = 1 @@ -62,9 +58,7 @@ def __init__( self.GQA_group_size = GQA_group_size self.log2_e = 1.4426950408889634074 - assert ( - self.GQA_group_size <= 16 - ), "GQA_group_size should be less than or equal to 16" + assert self.GQA_group_size <= 16, "GQA_group_size should be less than or equal to 16" @cute.jit def __call__( @@ -117,9 +111,7 @@ def __call__( self.O_dtype = O.element_type if cutlass.const_expr(self.Q_dtype.width != self.K_dtype.width): - raise TypeError( - f"Type width mismatch: {self.Q_dtype.width} != {self.K_dtype.width}" - ) + raise TypeError(f"Type width mismatch: {self.Q_dtype.width} != {self.K_dtype.width}") mma_n_itr = self.block_size // 8 tiled_mma_QK = cute.make_tiled_mma( @@ -143,10 +135,7 @@ def __call__( cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER, self.Q_dtype, ) - assert ( - self.Q_layout.sm90_mma_major_mode() - == cute.nvgpu.warpgroup.OperandMajorMode.K - ), "Q_layout should be K-major" + assert self.Q_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K, "Q_layout should be K-major" Q_smem_layout_staged = cute.tile_to_shape( Q_smem_layout_atom, cute.append(Q_smem_shape, 1), @@ -155,16 +144,11 @@ def __call__( K_smem_shape = (self.block_size, self.tile_shape_mnk_QK[2]) K_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - self.K_layout, self.K_dtype, K_smem_shape[1] - ), # K-major by default + sm90_utils.get_smem_layout_atom(self.K_layout, self.K_dtype, K_smem_shape[1]), # K-major by default # cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER, self.K_dtype, ) - assert ( - self.K_layout.sm90_mma_major_mode() - == cute.nvgpu.warpgroup.OperandMajorMode.K - ), "K_layout should be K-major" + assert self.K_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K, "K_layout should be K-major" K_smem_layout_staged = cute.tile_to_shape( K_smem_layout_atom, cute.append(K_smem_shape, self.K_stage), @@ -173,17 +157,12 @@ def __call__( V_smem_shape = (self.tile_shape_mnk_PV[2], self.tile_shape_mnk_PV[1]) V_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - self.V_layout, self.V_dtype, V_smem_shape[1] - ), # K-major by default + sm90_utils.get_smem_layout_atom(self.V_layout, self.V_dtype, V_smem_shape[1]), # K-major by default # cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER, self.V_dtype, ) - assert ( - self.V_layout.sm90_mma_major_mode() - == cute.nvgpu.warpgroup.OperandMajorMode.K - ), "V_layout should be K-major" + assert self.V_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K, "V_layout should be K-major" V_smem_layout_staged = cute.tile_to_shape( V_smem_layout_atom, cute.append(V_smem_shape, self.V_stage), @@ -191,25 +170,15 @@ def __call__( ) # import pdb; pdb.set_trace() - V_layout_atom = sm90_utils.get_smem_layout_atom( - self.V_layout, self.V_dtype, V_smem_shape[1] - ) + V_layout_atom = sm90_utils.get_smem_layout_atom(self.V_layout, self.V_dtype, V_smem_shape[1]) - if cutlass.const_expr( - V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128 - ): + if cutlass.const_expr(V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128): Vt_layout_atom = cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128 - elif cutlass.const_expr( - V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64 - ): + elif cutlass.const_expr(V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64): Vt_layout_atom = cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64 - elif cutlass.const_expr( - V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32 - ): + elif cutlass.const_expr(V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32): Vt_layout_atom = cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32 - elif cutlass.const_expr( - V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER - ): + elif cutlass.const_expr(V_layout_atom == cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER): Vt_layout_atom = cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER else: raise ValueError(f"Unsupported V_layout_atom: {V_layout_atom}") @@ -228,9 +197,7 @@ def __call__( O_smem_shape = self.epi_tile O_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( - sm90_utils.get_smem_layout_atom( - self.O_layout, self.O_dtype, O_smem_shape[1] - ), # K-major by default + sm90_utils.get_smem_layout_atom(self.O_layout, self.O_dtype, O_smem_shape[1]), # K-major by default self.O_dtype, ) O_smem_layout_staged = cute.tile_to_shape( @@ -265,9 +232,7 @@ def __call__( ) smem_layout_O = cute.slice_(O_smem_layout_staged, (None, None, 0)) - O_cta_v_layout = cute.composition( - cute.make_identity_layout(O.shape), self.epi_tile - ) + O_cta_v_layout = cute.composition(cute.make_identity_layout(O.shape), self.epi_tile) tma_atom_O, tma_tensor_O = cute.nvgpu.cpasync.make_tiled_tma_atom( cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), O, @@ -276,24 +241,16 @@ def __call__( ) L_smem_shape = (self.tile_shape_mnk_QK[0], 1) - L_smem_layout = cute.make_layout( - shape=L_smem_shape, stride=(1, self.tile_shape_mnk_QK[0]) - ) + L_smem_layout = cute.make_layout(shape=L_smem_shape, stride=(1, self.tile_shape_mnk_QK[0])) M_smem_shape = (self.tile_shape_mnk_QK[0], 1) - M_smem_layout = cute.make_layout( - shape=M_smem_shape, stride=(1, self.tile_shape_mnk_QK[0]) - ) + M_smem_layout = cute.make_layout(shape=M_smem_shape, stride=(1, self.tile_shape_mnk_QK[0])) BUFFER_ALIGN_BYTES = 128 @cute.struct class SharedStorageShare: - mainloop_pipeline_array_ptr: cute.struct.MemRange[ - cutlass.Int64, self.K_stage * 2 - ] - mainloop_pipeline_v_array_ptr: cute.struct.MemRange[ - cutlass.Int64, self.V_stage * 2 - ] + mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.K_stage * 2] + mainloop_pipeline_v_array_ptr: cute.struct.MemRange[cutlass.Int64, self.V_stage * 2] prefetchQ_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] sQ: cute.struct.Align[ cute.struct.MemRange[self.Q_dtype, cute.cosize(Q_smem_layout_staged)], @@ -312,9 +269,7 @@ class SharedStorageShare: BUFFER_ALIGN_BYTES, ] - assert cute.cosize(Q_smem_layout_staged) + cute.cosize( - K_smem_layout_staged - ) + cute.cosize(V_smem_layout_staged) >= cute.cosize( + assert cute.cosize(Q_smem_layout_staged) + cute.cosize(K_smem_layout_staged) + cute.cosize(V_smem_layout_staged) >= cute.cosize( O_smem_layout_staged ), "shared storage size is not enough for so" self.shared_storage = SharedStorageShare @@ -360,9 +315,7 @@ class SharedStorageShare: stream=stream, ) - def _threadquad_reduce( - self, val: cutlass.Float32, op: Callable, mask: int - ) -> cutlass.Float32: + def _threadquad_reduce(self, val: cutlass.Float32, op: Callable, mask: int) -> cutlass.Float32: """thread quad reduction :param val: register value @@ -382,9 +335,7 @@ def _threadquad_reduce( ) return val - def _threadquad_reduce_max( - self, val: cutlass.Float32, mask: int - ) -> cutlass.Float32: + def _threadquad_reduce_max(self, val: cutlass.Float32, mask: int) -> cutlass.Float32: """thread quad reduction max :param val: register value @@ -394,9 +345,7 @@ def _threadquad_reduce_max( """ return self._threadquad_reduce(val, lambda x, y: cute.arch.fmax(x, y), mask) - def _threadquad_reduce_sum( - self, val: cutlass.Float32, mask: int - ) -> cutlass.Float32: + def _threadquad_reduce_sum(self, val: cutlass.Float32, mask: int) -> cutlass.Float32: """thread quad reduction sum :param val: register value @@ -441,9 +390,7 @@ def _make_acc_tensor_mn_view(self, acc: cute.Tensor) -> cute.Tensor: return cute.make_tensor(acc.iterator, acc_layout_mn) @cute.jit - def _exp2f( - self, x: Union[cute.TensorSSA, cutlass.Float32] - ) -> Union[cute.TensorSSA, cutlass.Float32]: + def _exp2f(self, x: Union[cute.TensorSSA, cutlass.Float32]) -> Union[cute.TensorSSA, cutlass.Float32]: """exp2f calculation for both vector and scalar. :param x: input value @@ -503,9 +450,7 @@ def kernel( t, KV_head_idx, offset_idx = cute.arch.block_idx() tidx, _, _ = cute.arch.thread_idx() - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) smem = cutlass.utils.SmemAllocator() @@ -518,12 +463,8 @@ def kernel( K_tma_copy_bytes = cute.size_in_bytes(self.K_dtype, K_smem_layout) # one consumer consumer_arrive_cnt = self.threads_per_block // 32 - mainloop_pipeline_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread - ) - mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, consumer_arrive_cnt - ) + mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() mainloop_pipeline = pipeline.PipelineTmaAsync.create( barrier_storage=mainloop_pipeline_array_ptr, @@ -534,12 +475,8 @@ def kernel( ) Q_tma_copy_bytes = cute.size_in_bytes(self.Q_dtype, Q_smem_layout) - prefetchQ_pipeline_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread - ) - prefetchQ_pipeline_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, consumer_arrive_cnt - ) + prefetchQ_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + prefetchQ_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) prefetchQ_pipeline_array_ptr = storage.prefetchQ_pipeline_array_ptr.data_ptr() prefetchQ_pipeline = pipeline.PipelineTmaAsync.create( barrier_storage=prefetchQ_pipeline_array_ptr, @@ -550,12 +487,8 @@ def kernel( ) V_tma_copy_bytes = cute.size_in_bytes(self.V_dtype, V_smem_layout) - mainloop_pipeline_producer_group_v = pipeline.CooperativeGroup( - pipeline.Agent.Thread - ) - mainloop_pipeline_consumer_group_v = pipeline.CooperativeGroup( - pipeline.Agent.Thread, consumer_arrive_cnt - ) + mainloop_pipeline_producer_group_v = pipeline.CooperativeGroup(pipeline.Agent.Thread) + mainloop_pipeline_consumer_group_v = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) mainloop_pipeline_array_ptr_v = storage.mainloop_pipeline_v_array_ptr.data_ptr() mainloop_pipeline_V = pipeline.PipelineTmaAsync.create( barrier_storage=mainloop_pipeline_array_ptr_v, @@ -565,21 +498,11 @@ def kernel( tx_count=V_tma_copy_bytes, ) - sQ = storage.sQ.get_tensor( - Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner - ) - sK = storage.sK.get_tensor( - K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner - ) - sV = storage.sV.get_tensor( - V_smem_layout_staged.outer, swizzle=V_smem_layout_staged.inner - ) - sVt = storage.sV.get_tensor( - Vt_smem_layout_staged.outer, swizzle=Vt_smem_layout_staged.inner - ) - sO = storage.sQ.get_tensor( - O_smem_layout_staged.outer, swizzle=O_smem_layout_staged.inner - ) # sO shared with sK + sQ = storage.sQ.get_tensor(Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner) + sK = storage.sK.get_tensor(K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner) + sV = storage.sV.get_tensor(V_smem_layout_staged.outer, swizzle=V_smem_layout_staged.inner) + sVt = storage.sV.get_tensor(Vt_smem_layout_staged.outer, swizzle=Vt_smem_layout_staged.inner) + sO = storage.sQ.get_tensor(O_smem_layout_staged.outer, swizzle=O_smem_layout_staged.inner) # sO shared with sK sIDX = storage.sIDX.get_tensor(block_indices.shape[2]) smem_copy_atom_Q = cute.make_copy_atom( @@ -605,21 +528,12 @@ def kernel( seq_len = seq_offsets[offset_idx + 1] - seq_offsets[offset_idx] offset = seq_offsets[offset_idx] - for i in cutlass.range( - (block_indices.shape[2] + self.threads_per_block - 1) - // self.threads_per_block - ): + for i in cutlass.range((block_indices.shape[2] + self.threads_per_block - 1) // self.threads_per_block): if i * self.threads_per_block + tidx < block_indices.shape[2]: - sIDX[i * self.threads_per_block + tidx] = block_indices[ - offset + t, KV_head_idx, i * self.threads_per_block + tidx - ] + sIDX[i * self.threads_per_block + tidx] = block_indices[offset + t, KV_head_idx, i * self.threads_per_block + tidx] cute.arch.sync_threads() - seq_len_aligned = ( - (seq_len + self.tile_shape_mnk_QK[1] - 1) - // self.tile_shape_mnk_QK[1] - * self.tile_shape_mnk_QK[1] - ) + seq_len_aligned = (seq_len + self.tile_shape_mnk_QK[1] - 1) // self.tile_shape_mnk_QK[1] * self.tile_shape_mnk_QK[1] mQ_offset = cute.domain_offset((0, 0, offset, 0), mQ) mQ = cute.make_tensor( @@ -632,16 +546,12 @@ def kernel( mK_offset = cute.domain_offset((offset, 0, 0), mK) mK = cute.make_tensor( mK_offset.iterator, - cute.make_layout( - shape=(seq_len_aligned, mK.shape[1], mK.shape[2]), stride=mK.stride - ), + cute.make_layout(shape=(seq_len_aligned, mK.shape[1], mK.shape[2]), stride=mK.stride), ) mV_offset = cute.domain_offset((offset, 0, 0), mV) # `[K, B*T, H]` mV = cute.make_tensor( mV_offset.iterator, - cute.make_layout( - shape=(seq_len_aligned, mV.shape[1], mV.shape[2]), stride=mV.stride - ), + cute.make_layout(shape=(seq_len_aligned, mV.shape[1], mV.shape[2]), stride=mV.stride), ) mO_offset = cute.domain_offset((0, 0, offset, 0), mO) mO = cute.make_tensor( @@ -654,16 +564,12 @@ def kernel( mL_offset = cute.domain_offset((0, offset, 0), mL) mL = cute.make_tensor( mL_offset.iterator, - cute.make_layout( - shape=(mL.shape[0], seq_len_aligned, mL.shape[2]), stride=mL.stride - ), + cute.make_layout(shape=(mL.shape[0], seq_len_aligned, mL.shape[2]), stride=mL.stride), ) mM_offset = cute.domain_offset((0, offset, 0), mM) mM = cute.make_tensor( mM_offset.iterator, - cute.make_layout( - shape=(mM.shape[0], seq_len_aligned, mM.shape[2]), stride=mM.stride - ), + cute.make_layout(shape=(mM.shape[0], seq_len_aligned, mM.shape[2]), stride=mM.stride), ) if t < seq_len: @@ -707,9 +613,7 @@ def kernel( thr_mma_QK = tiled_mma_QK.get_slice(tidx) - q_cta_layout = cute.make_layout( - cute.slice_(cta_layout_mnk, (0, None, 0)).shape - ) + q_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) q_cta_crd = cluster_coord_mnk[1] sQ_for_tma_partition = cute.group_modes(sQ, 0, 2) gQ_for_tma_partition = cute.group_modes(gQ, 0, 2) @@ -722,9 +626,7 @@ def kernel( gQ_for_tma_partition, ) - K_cta_layout = cute.make_layout( - cute.slice_(cta_layout_mnk, (None, 0, 0)).shape - ) + K_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) k_cta_crd = cluster_coord_mnk[0] sK_for_tma_partition = cute.group_modes(sK, 0, 2) gK_for_tma_partition = cute.group_modes(gK, 0, 2) @@ -736,9 +638,7 @@ def kernel( gK_for_tma_partition, ) - v_cta_layout = cute.make_layout( - cute.slice_(cta_layout_mnk, (None, 0, 0)).shape - ) + v_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) v_cta_crd = cluster_coord_mnk[0] sV_for_tma_partition = cute.group_modes(sV, 0, 2) gV_for_tma_partition = cute.group_modes(gV, 0, 2) @@ -758,56 +658,32 @@ def kernel( tSsK = smem_thr_copy_K.partition_S(sK) # import ipdb; ipdb.set_trace() - acc_shape_QK = thr_mma_QK.partition_shape_C( - (self.tile_shape_mnk_QK[0], self.tile_shape_mnk_QK[1]) - ) + acc_shape_QK = thr_mma_QK.partition_shape_C((self.tile_shape_mnk_QK[0], self.tile_shape_mnk_QK[1])) acc_QK = cute.make_rmem_tensor(acc_shape_QK, self.acc_dtype) acc_QK.fill(0) - mainloop_producer_state_K = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.K_stage - ) + mainloop_producer_state_K = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.K_stage) - prefetchQ_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, 1 - ) + prefetchQ_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) - mainloop_producer_state_V = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.V_stage - ) - mainloop_consumer_read_state_V = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.V_stage - ) - mainloop_consumer_release_state_V = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.V_stage - ) + mainloop_producer_state_V = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.V_stage) + mainloop_consumer_read_state_V = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.V_stage) + mainloop_consumer_release_state_V = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.V_stage) - mainloop_consumer_read_state_K = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.K_stage - ) + mainloop_consumer_read_state_K = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.K_stage) - mainloop_consumer_release_state_K = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.K_stage - ) + mainloop_consumer_release_state_K = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.K_stage) - prefetchQ_consumer_read_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, 1 - ) - prefetchQ_consumer_release_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, 1 - ) + prefetchQ_consumer_read_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) + prefetchQ_consumer_release_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) # ******************** # softmax intermediate result # ******************** # shape:(mmaSahpeM * mma_m) - row_max = cute.make_rmem_tensor( - (acc_shape_QK[0][0] * acc_shape_QK[1]), cutlass.Float32 - ) - row_sum = cute.make_rmem_tensor( - (acc_shape_QK[0][0] * acc_shape_QK[1]), cutlass.Float32 - ) + row_max = cute.make_rmem_tensor((acc_shape_QK[0][0] * acc_shape_QK[1]), cutlass.Float32) + row_sum = cute.make_rmem_tensor((acc_shape_QK[0][0] * acc_shape_QK[1]), cutlass.Float32) row_max.fill(-cutlass.Float32.inf) row_sum.fill(0.0) @@ -816,9 +692,7 @@ def kernel( # ******************** K_tile_cnt = block_counts[offset + t, KV_head_idx] prefetch_K_tile_cnt = cutlass.max(cutlass.min(self.K_stage, K_tile_cnt), 0) - prefetch_V_tile_cnt = cutlass.max( - cutlass.min(self.V_stage - 1, K_tile_cnt), 0 - ) + prefetch_V_tile_cnt = cutlass.max(cutlass.min(self.V_stage - 1, K_tile_cnt), 0) if warp_idx == 0: prefetchQ_pipeline.producer_acquire(prefetchQ_producer_state) @@ -828,9 +702,7 @@ def kernel( tma_atom_Q, tAgQ_k, tAsQ_pipe, - tma_bar_ptr=prefetchQ_pipeline.producer_get_barrier( - prefetchQ_producer_state - ), + tma_bar_ptr=prefetchQ_pipeline.producer_get_barrier(prefetchQ_producer_state), ) prefetchQ_pipeline.producer_commit(prefetchQ_producer_state) prefetchQ_producer_state.advance() @@ -847,9 +719,7 @@ def kernel( tma_atom_K, tKgK_k, tKsK_pipe, - tma_bar_ptr=mainloop_pipeline.producer_get_barrier( - mainloop_producer_state_K - ), + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state_K), ) mainloop_pipeline.producer_commit(mainloop_producer_state_K) @@ -866,9 +736,7 @@ def kernel( tma_atom_V, tVgV_k, tVsV_pipe, - tma_bar_ptr=mainloop_pipeline_V.producer_get_barrier( - mainloop_producer_state_V - ), + tma_bar_ptr=mainloop_pipeline_V.producer_get_barrier(mainloop_producer_state_V), ) mainloop_pipeline_V.producer_commit(mainloop_producer_state_V) mainloop_producer_state_V.advance() @@ -876,9 +744,7 @@ def kernel( # cute.printf("prefetch_idx: %d, block_idx: %d, mainloop_producer_state_V.index: %d\n", prefetch_idx, block_idx, mainloop_producer_state_V.index) peek_q_full_status = cutlass.Boolean(1) - peek_q_full_status = prefetchQ_pipeline.consumer_try_wait( - prefetchQ_consumer_read_state - ) + peek_q_full_status = prefetchQ_pipeline.consumer_try_wait(prefetchQ_consumer_read_state) # tiled_mma_QK.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) num_K_blocks = cute.size(tSrQ, mode=[2]) @@ -888,9 +754,7 @@ def kernel( # ******************** thr_mma_PV = tiled_mma_PV.get_slice(tidx) - acc_shape_PV = thr_mma_PV.partition_shape_C( - (self.tile_shape_mnk_PV[0], self.tile_shape_mnk_PV[1]) - ) + acc_shape_PV = thr_mma_PV.partition_shape_C((self.tile_shape_mnk_PV[0], self.tile_shape_mnk_PV[1])) acc_PV = cute.make_rmem_tensor(acc_shape_PV, self.acc_dtype) acc_PV.fill(0) @@ -903,9 +767,7 @@ def kernel( gL_thr = tiled_mma_QK.get_slice(tidx).partition_C(gL) gM_thr = tiled_mma_QK.get_slice(tidx).partition_C(gM) - prefetchQ_pipeline.consumer_wait( - prefetchQ_consumer_read_state, peek_q_full_status - ) + prefetchQ_pipeline.consumer_wait(prefetchQ_consumer_read_state, peek_q_full_status) for k in cutlass.range_constexpr(0, cute.size(tSrQ, mode=[2])): cute.copy( smem_tiled_copy_Q, @@ -915,14 +777,10 @@ def kernel( peak_k_full_status = cutlass.Boolean(1) if prefetch_K_tile_cnt > 0: - peak_k_full_status = mainloop_pipeline.consumer_try_wait( - mainloop_consumer_read_state_K - ) + peak_k_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state_K) for K_tile in cutlass.range(0, K_tile_cnt, 1, unroll=1): - mainloop_pipeline.consumer_wait( - mainloop_consumer_read_state_K, peak_k_full_status - ) + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state_K, peak_k_full_status) acc_QK.fill(0) # cute.nvgpu.warpgroup.fence() # if tidx == 0: @@ -940,9 +798,7 @@ def kernel( tma_atom_V, tVgV_k, tVsV_pipe, - tma_bar_ptr=mainloop_pipeline_V.producer_get_barrier( - mainloop_producer_state_V - ), + tma_bar_ptr=mainloop_pipeline_V.producer_get_barrier(mainloop_producer_state_V), ) mainloop_pipeline_V.producer_commit(mainloop_producer_state_V) mainloop_producer_state_V.advance() @@ -953,18 +809,12 @@ def kernel( tSrK_copy_view[None, None, 0, mainloop_consumer_read_state_K.index], ) - for k in cutlass.range_constexpr( - 0, cute.size(tSrQ, mode=[2]), unroll=True - ): + for k in cutlass.range_constexpr(0, cute.size(tSrQ, mode=[2]), unroll=True): if k < cute.size(tSrK, mode=[2]) - 1: cute.copy( smem_tiled_copy_K, - tSsK[ - None, None, k + 1, mainloop_consumer_read_state_K.index - ], - tSrK_copy_view[ - None, None, k + 1, mainloop_consumer_read_state_K.index - ], + tSsK[None, None, k + 1, mainloop_consumer_read_state_K.index], + tSrK_copy_view[None, None, k + 1, mainloop_consumer_read_state_K.index], ) cute.gemm( @@ -992,9 +842,7 @@ def kernel( tma_atom_K, tKgK_k, tKsK_pipe, - tma_bar_ptr=mainloop_pipeline.producer_get_barrier( - mainloop_producer_state_K - ), + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state_K), ) mainloop_pipeline.producer_commit(mainloop_producer_state_K) @@ -1010,42 +858,26 @@ def kernel( for r in cutlass.range_constexpr(cute.size(gL_thr.shape[0][1])): if cute.elem_less(cLM_thr[(0, r), 0, 0][0], self.GQA_group_size): acc_QK_row = acc_QK_mn[r, None].load() * softmax_scale - row_max_cur_row = acc_QK_row.reduce( - cute.ReductionOp.MAX, -cutlass.Float32.inf, 0 - ) - row_max_cur_row = self._threadquad_reduce_max( - row_max_cur_row, mask=(1 << self.GQA_group_size) - 1 - ) + row_max_cur_row = acc_QK_row.reduce(cute.ReductionOp.MAX, -cutlass.Float32.inf, 0) + row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row, mask=(1 << self.GQA_group_size) - 1) row_max_prev_row = row_max_prev[r] if is_not_first_n_block: - row_max_cur_row = cute.arch.fmax( - row_max_prev_row, row_max_cur_row - ) + row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row) acc_QK_row_exp = cute.TensorSSA( # e^{Sn-mn} self._exp2f((acc_QK_row - row_max_cur_row) * self.log2_e), tuple(acc_QK_row.shape), cutlass.Float32, ) - acc_QK_row_sum = acc_QK_row_exp.reduce( - cute.ReductionOp.ADD, cutlass.Float32.zero, 0 - ) - acc_QK_row_sum = self._threadquad_reduce_sum( - acc_QK_row_sum, mask=(1 << self.GQA_group_size) - 1 - ) # rowsum(e^{Sn-mn}) + acc_QK_row_sum = acc_QK_row_exp.reduce(cute.ReductionOp.ADD, cutlass.Float32.zero, 0) + acc_QK_row_sum = self._threadquad_reduce_sum(acc_QK_row_sum, mask=(1 << self.GQA_group_size) - 1) # rowsum(e^{Sn-mn}) if is_not_first_n_block: - prev_minus_cur_exp = self._exp2f( # e^{M^{(n-1)} - M^{(n)}} - (row_max_prev_row - row_max_cur_row) * self.log2_e - ) + prev_minus_cur_exp = self._exp2f((row_max_prev_row - row_max_cur_row) * self.log2_e) # e^{M^{(n-1)} - M^{(n)}} # L^{(n)} = rowsum(e^{Sn-mn}) + L^{(n-1)} * e^{M^{(n-1)} - M^{(n)}} - acc_QK_row_sum = ( - acc_QK_row_sum + row_sum[r] * prev_minus_cur_exp - ) + acc_QK_row_sum = acc_QK_row_sum + row_sum[r] * prev_minus_cur_exp # O^{(n-1)}' = O^{(n-1)} * e^{M^{(n-1)} - M^{(n)}} - acc_PV_mn[r, None] = ( - acc_PV_mn[r, None].load() * prev_minus_cur_exp - ) + acc_PV_mn[r, None] = acc_PV_mn[r, None].load() * prev_minus_cur_exp row_max[r] = row_max_cur_row row_sum[r] = acc_QK_row_sum @@ -1055,9 +887,7 @@ def kernel( # p@V gemm calculation # /////////////////////////////////////////////////////////////////////////////// peak_v_full_status = cutlass.Boolean(1) - peak_v_full_status = mainloop_pipeline_V.consumer_try_wait( - mainloop_consumer_read_state_V - ) + peak_v_full_status = mainloop_pipeline_V.consumer_try_wait(mainloop_consumer_read_state_V) rP = cute.make_fragment_like(acc_QK, self.dtype) # rP.store(acc_QK.load().to(self.dtype)) @@ -1070,13 +900,9 @@ def kernel( # convert rP from ((2, 2, 2*num_k_blocks_pv), 1, 1) to ((2, 2, 2), 1, 1, num_k_blocks_pv) num_k_blocks_pv = cute.size(tOrVt, mode=[2]) - rP_divided_dim3 = thr_mma_PV.partition_shape_A( - (self.tile_shape_mnk_PV[0], self.tile_shape_mnk_PV[2]) - )[0][2] + rP_divided_dim3 = thr_mma_PV.partition_shape_A((self.tile_shape_mnk_PV[0], self.tile_shape_mnk_PV[2]))[0][2] - rP_layout_divided = cute.logical_divide( - rP.layout, (None, None, rP_divided_dim3) - ) + rP_layout_divided = cute.logical_divide(rP.layout, (None, None, rP_divided_dim3)) rP_mma_view = cute.make_layout( ( ( @@ -1100,26 +926,18 @@ def kernel( rP = cute.make_tensor(rP.iterator, rP_mma_view) - mainloop_pipeline_V.consumer_wait( - mainloop_consumer_read_state_V, peak_v_full_status - ) + mainloop_pipeline_V.consumer_wait(mainloop_consumer_read_state_V, peak_v_full_status) cute.copy( smem_tiled_copy_V, tOsVt[None, None, 0, mainloop_consumer_read_state_V.index], - tOrVt_copy_view[ - None, None, 0, mainloop_consumer_read_state_V.index - ], + tOrVt_copy_view[None, None, 0, mainloop_consumer_read_state_V.index], ) for k in cutlass.range_constexpr(0, cute.size(tOrVt, mode=[2])): if k < cute.size(tOrVt, mode=[2]) - 1: cute.copy( smem_tiled_copy_V, - tOsVt[ - None, None, k + 1, mainloop_consumer_read_state_V.index - ], - tOrVt_copy_view[ - None, None, k + 1, mainloop_consumer_read_state_V.index - ], + tOsVt[None, None, k + 1, mainloop_consumer_read_state_V.index], + tOrVt_copy_view[None, None, k + 1, mainloop_consumer_read_state_V.index], ) cute.gemm( tiled_mma_PV, @@ -1131,9 +949,7 @@ def kernel( peak_k_full_status = cutlass.Boolean(1) if K_tile < K_tile_cnt - 1: - peak_k_full_status = mainloop_pipeline.consumer_try_wait( - mainloop_consumer_read_state_K - ) + peak_k_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state_K) mainloop_pipeline_V.consumer_release(mainloop_consumer_release_state_V) mainloop_consumer_read_state_V.advance() @@ -1152,14 +968,8 @@ def kernel( # softmax normalization: O^{(n)} = O^{(n)} / L^{(n)} for row_idx in cutlass.range(gL_thr.shape[0][1]): if cute.elem_less(cLM_thr[(0, row_idx), 0, 0][0], self.GQA_group_size): - acc_pv_mn_is_zero_or_nan = ( - row_sum[row_idx] == 0.0 or row_sum[row_idx] != row_sum[row_idx] - ) - scale = ( - 1.0 - if acc_pv_mn_is_zero_or_nan - else cute.arch.rcp_approx(row_sum[row_idx]) - ) + acc_pv_mn_is_zero_or_nan = row_sum[row_idx] == 0.0 or row_sum[row_idx] != row_sum[row_idx] + scale = 1.0 if acc_pv_mn_is_zero_or_nan else cute.arch.rcp_approx(row_sum[row_idx]) acc_PV_mn[row_idx, None] = acc_PV_mn[row_idx, None].load() * scale tOgO_for_tma_partition = cute.zipped_divide( @@ -1251,9 +1061,7 @@ def copy_reg_to_gmem( # Copy from D registers to shared memory epi_buffer = epi_idx % cute.size(tRS_dv_sD, mode=[3]) - cute.copy( - tiled_copy_r2s, tRS_rD_out, tRS_dv_sD[(None, None, None, epi_buffer)] - ) + cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_dv_sD[(None, None, None, epi_buffer)]) cute.arch.fence_proxy( cute.arch.ProxyKind.async_shared, @@ -1262,9 +1070,7 @@ def copy_reg_to_gmem( # barrier for sync cute.arch.barrier() - epi_tile_layout = cute.make_layout( - epi_tile_shape, stride=(epi_tile_shape[1], 1) - ) + epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) # Copy from shared memory to global memory if warp_idx == 0: diff --git a/python/cudnn/native_sparse_attention/selection/api.py b/python/cudnn/native_sparse_attention/selection/api.py index 225a2e44..b4331098 100644 --- a/python/cudnn/native_sparse_attention/selection/api.py +++ b/python/cudnn/native_sparse_attention/selection/api.py @@ -89,83 +89,47 @@ def check_support(self) -> bool: t, h_q, d_v = self.sample_o.shape if self.sample_q.shape != (t, h_q, d_qk): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}") if self.sample_k.shape != (t, h_kv, d_qk): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {t, h_kv, d_qk}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {t, h_kv, d_qk}, got {self.sample_k.shape}") if self.sample_v.shape != (t, h_kv, d_v): - raise ValueError( - f"Input shape mismatch: expected V tensor shape {t, h_kv, d_v}, got {self.sample_v.shape}" - ) + raise ValueError(f"Input shape mismatch: expected V tensor shape {t, h_kv, d_v}, got {self.sample_v.shape}") if self.sample_o.shape != (t, h_q, d_v): - raise ValueError( - f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}" - ) + raise ValueError(f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}") self.sample_l = self._unpad_tensor_to_ndim(self.sample_l, 2, "sample_l") if self.sample_l.shape != (t, h_q): - raise ValueError( - f"Output shape mismatch: expected L tensor shape {t, h_q}, got {self.sample_l.shape}" - ) + raise ValueError(f"Output shape mismatch: expected L tensor shape {t, h_q}, got {self.sample_l.shape}") self.sample_m = self._unpad_tensor_to_ndim(self.sample_m, 2, "sample_m") if self.sample_m.shape != (t, h_q): - raise ValueError( - f"Output shape mismatch: expected M tensor shape {t, h_q}, got {self.sample_m.shape}" - ) + raise ValueError(f"Output shape mismatch: expected M tensor shape {t, h_q}, got {self.sample_m.shape}") if self.sample_cum_seqlen_q is None: - raise ValueError( - f"sample_cum_seqlen_q must be provided for T,H,D format, got {self.sample_cum_seqlen_q}" - ) - if self.sample_cum_seqlen_k is not None and not torch.equal( - self.sample_cum_seqlen_q, self.sample_cum_seqlen_k - ): + raise ValueError(f"sample_cum_seqlen_q must be provided for T,H,D format, got {self.sample_cum_seqlen_q}") + if self.sample_cum_seqlen_k is not None and not torch.equal(self.sample_cum_seqlen_q, self.sample_cum_seqlen_k): raise NotImplementedError( f"SelectionAttention requires sample_cum_seqlen_q and sample_cum_seqlen_k to be identical, but got {self.sample_cum_seqlen_q} and {self.sample_cum_seqlen_k}" ) if self.max_s_q is None: - raise ValueError( - f"max_s_q must be provided for T,H,D format, got {self.max_s_q}" - ) + raise ValueError(f"max_s_q must be provided for T,H,D format, got {self.max_s_q}") if self.max_s_k is not None and self.max_s_q != self.max_s_k: - raise NotImplementedError( - f"SelectionAttention requires max_s_q and max_s_k to be identical, but got {self.max_s_q} and {self.max_s_k}" - ) + raise NotImplementedError(f"SelectionAttention requires max_s_q and max_s_k to be identical, but got {self.max_s_q} and {self.max_s_k}") self.batch_size = len(self.sample_cum_seqlen_q) - 1 if self.batch_size <= 0: - raise ValueError( - f"batch_size (len(sample_cum_seqlen_q) - 1) must be greater than 0, got {self.batch_size}" - ) + raise ValueError(f"batch_size (len(sample_cum_seqlen_q) - 1) must be greater than 0, got {self.batch_size}") if self.sample_cum_seqlen_q.dtype not in (torch.int32, torch.int64): - raise ValueError( - f"sample_cum_seqlen_q must be int32 or int64, got {self.sample_cum_seqlen_q.dtype}" - ) + raise ValueError(f"sample_cum_seqlen_q must be int32 or int64, got {self.sample_cum_seqlen_q.dtype}") - if ( - self.sample_block_indices.shape[:2] != (t, h_kv) - and self.sample_block_indices.ndim != 3 - ): - raise ValueError( - f"sample_block_indices shape mismatch: expected {(t, h_kv, 'K')}, got {tuple(self.sample_block_indices.shape)}" - ) + if self.sample_block_indices.shape[:2] != (t, h_kv) and self.sample_block_indices.ndim != 3: + raise ValueError(f"sample_block_indices shape mismatch: expected {(t, h_kv, 'K')}, got {tuple(self.sample_block_indices.shape)}") if self.sample_block_counts.shape != (t, h_kv): - raise ValueError( - f"sample_block_counts shape mismatch: expected {(t, h_kv)}, got {tuple(self.sample_block_counts.shape)}" - ) - if ( - self.sample_block_indices.dtype != torch.int32 - or self.sample_block_counts.dtype != torch.int32 - ): + raise ValueError(f"sample_block_counts shape mismatch: expected {(t, h_kv)}, got {tuple(self.sample_block_counts.shape)}") + if self.sample_block_indices.dtype != torch.int32 or self.sample_block_counts.dtype != torch.int32: raise ValueError( f"sample_block_indices and sample_block_counts must be int32, got {self.sample_block_indices.dtype} and {self.sample_block_counts.dtype}" ) else: - raise ValueError( - f"sample_q must be rank-3 (T,H,D) or rank-4 (B,H,S,D), got {self.sample_q.ndim}" - ) + raise ValueError(f"sample_q must be rank-3 (T,H,D) or rank-4 (B,H,S,D), got {self.sample_q.ndim}") # Shared derived attributes if h_q % h_kv != 0: @@ -179,12 +143,7 @@ def check_support(self) -> bool: # Validate dtypes and config self._logger.debug("Checking dtypes and config") self.dtype = self.sample_q.dtype - if not ( - self.dtype - == self.sample_k.dtype - == self.sample_v.dtype - == self.sample_o.dtype - ): + if not (self.dtype == self.sample_k.dtype == self.sample_v.dtype == self.sample_o.dtype): raise ValueError("All input/output tensors must have the same dtype") if self.dtype not in {torch.float16, torch.bfloat16}: raise ValueError("dtype must be Float16 or BFloat16") @@ -206,12 +165,8 @@ def check_support(self) -> bool: major, minor = torch.cuda.get_device_capability(device) compute_capability = major * 10 + minor if compute_capability < 90: - self._logger.error( - f"Requires SM90+ compute capability, but found SM{compute_capability} on device {device}" - ) - raise RuntimeError( - f"Requires SM90+ compute capability, but found SM{compute_capability} on device {device}" - ) + self._logger.error(f"Requires SM90+ compute capability, but found SM{compute_capability} on device {device}") + raise RuntimeError(f"Requires SM90+ compute capability, but found SM{compute_capability} on device {device}") if compute_capability == 103: raise RuntimeError("cuteDSL SelectionAttention is not supported on SM103") @@ -265,29 +220,17 @@ def shares_memory(original, reshaped): return original.data_ptr() == reshaped.data_ptr() if not shares_memory(q, q_reshaped): - raise ValueError( - "Q tensor memory changed during reshape - expected view operation" - ) + raise ValueError("Q tensor memory changed during reshape - expected view operation") if not shares_memory(k, k_reshaped): - raise ValueError( - "K tensor memory changed during reshape - expected view operation" - ) + raise ValueError("K tensor memory changed during reshape - expected view operation") if not shares_memory(v, v_reshaped): - raise ValueError( - "V tensor memory changed during reshape - expected view operation" - ) + raise ValueError("V tensor memory changed during reshape - expected view operation") if not shares_memory(o, o_reshaped): - raise ValueError( - "O tensor memory changed during reshape - expected view operation" - ) + raise ValueError("O tensor memory changed during reshape - expected view operation") if not shares_memory(l, l_reshaped): - raise ValueError( - "L tensor memory changed during reshape - expected view operation" - ) + raise ValueError("L tensor memory changed during reshape - expected view operation") if not shares_memory(m, m_reshaped): - raise ValueError( - "M tensor memory changed during reshape - expected view operation" - ) + raise ValueError("M tensor memory changed during reshape - expected view operation") return q_reshaped, k_reshaped, v_reshaped, o_reshaped, l_reshaped, m_reshaped @@ -306,15 +249,13 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: ) self._logger.debug("Reshaping tensors to kernel expected format") - q_reshaped, k_reshaped, v_reshaped, o_reshaped, l_reshaped, m_reshaped = ( - self._reshape_tensors( - self.sample_q, - self.sample_k, - self.sample_v, - self.sample_o, - self.sample_l, - self.sample_m, - ) + q_reshaped, k_reshaped, v_reshaped, o_reshaped, l_reshaped, m_reshaped = self._reshape_tensors( + self.sample_q, + self.sample_k, + self.sample_v, + self.sample_o, + self.sample_l, + self.sample_m, ) mQ = from_dlpack(q_reshaped, assumed_align=128) @@ -368,10 +309,8 @@ def execute( self._logger.debug("Reshaping tensors to kernel expected format") l_tensor = self._unpad_tensor_to_ndim(l_tensor, 2, "l_tensor") m_tensor = self._unpad_tensor_to_ndim(m_tensor, 2, "m_tensor") - q_reshaped, k_reshaped, v_reshaped, o_reshaped, l_reshaped, m_reshaped = ( - self._reshape_tensors( - q_tensor, k_tensor, v_tensor, o_tensor, l_tensor, m_tensor - ) + q_reshaped, k_reshaped, v_reshaped, o_reshaped, l_reshaped, m_reshaped = self._reshape_tensors( + q_tensor, k_tensor, v_tensor, o_tensor, l_tensor, m_tensor ) mQ = from_dlpack(q_reshaped, assumed_align=128) @@ -461,20 +400,10 @@ def selection_attention_wrapper( Returns: tuple: (o_tensor, l_tensor, m_tensor) - Output, logsumexp, and max tensors """ - _logger.debug( - "selection_attention_wrapper: Creating empty output tensors o, l, and m" - ) + _logger.debug("selection_attention_wrapper: Creating empty output tensors o, l, and m") - max_s_q = ( - max(cum_seqlen_q_tensor[1:] - cum_seqlen_q_tensor[:-1]).item() - if max_s_q is None - else max_s_q - ) - max_s_k = ( - max(cum_seqlen_k_tensor[1:] - cum_seqlen_k_tensor[:-1]).item() - if max_s_k is None - else max_s_k - ) + max_s_q = max(cum_seqlen_q_tensor[1:] - cum_seqlen_q_tensor[:-1]).item() if max_s_q is None else max_s_q + max_s_k = max(cum_seqlen_k_tensor[1:] - cum_seqlen_k_tensor[:-1]).item() if max_s_k is None else max_s_k t, h_q, d = q_tensor.shape _, h_kv, d_v = v_tensor.shape @@ -509,9 +438,7 @@ def selection_attention_wrapper( max_s_k, ) if cache_key in _cache_of_SelectionAttentionObjects: - _logger.debug( - "selection_attention_wrapper: Using previously cached SelectionAttention object" - ) + _logger.debug("selection_attention_wrapper: Using previously cached SelectionAttention object") selection_attention = _cache_of_SelectionAttentionObjects[cache_key] selection_attention.execute( q_tensor=q_tensor, @@ -528,9 +455,7 @@ def selection_attention_wrapper( current_stream=stream, ) else: - _logger.debug( - "selection_attention_wrapper: No previously cached SelectionAttention object found, creating new SelectionAttention object" - ) + _logger.debug("selection_attention_wrapper: No previously cached SelectionAttention object found, creating new SelectionAttention object") selection_attention = SelectionAttention( sample_q=q_tensor, sample_k=k_tensor, diff --git a/python/cudnn/native_sparse_attention/sliding_window_attention/api.py b/python/cudnn/native_sparse_attention/sliding_window_attention/api.py index 7463565a..16b4d652 100644 --- a/python/cudnn/native_sparse_attention/sliding_window_attention/api.py +++ b/python/cudnn/native_sparse_attention/sliding_window_attention/api.py @@ -43,47 +43,23 @@ def __init__( self.sample_v = sample_v self.sample_o = sample_o self.is_infer = sample_stats is None - self.sample_stats = ( - self._pad_tensor_to_ndim(sample_stats, self.sample_o.ndim, "sample_stats") - if sample_stats is not None - else None - ) + self.sample_stats = self._pad_tensor_to_ndim(sample_stats, self.sample_o.ndim, "sample_stats") if sample_stats is not None else None self.left_bound = left_bound self.right_bound = right_bound - self.sample_seq_len_q = self._pad_tensor_to_ndim( - sample_seq_len_q, 4, "sample_seq_len_q" - ) - self.sample_seq_len_kv = self._pad_tensor_to_ndim( - sample_seq_len_kv, 4, "sample_seq_len_kv" - ) + self.sample_seq_len_q = self._pad_tensor_to_ndim(sample_seq_len_q, 4, "sample_seq_len_q") + self.sample_seq_len_kv = self._pad_tensor_to_ndim(sample_seq_len_kv, 4, "sample_seq_len_kv") self.max_seq_len_q = max_seq_len_q self.max_seq_len_kv = max_seq_len_kv - self.sample_q_ragged_offset = self._pad_tensor_to_ndim( - sample_q_ragged_offset, 4, "sample_q_ragged_offset" - ) - self.sample_k_ragged_offset = self._pad_tensor_to_ndim( - sample_k_ragged_offset, 4, "sample_k_ragged_offset" - ) - self.sample_v_ragged_offset = self._pad_tensor_to_ndim( - sample_v_ragged_offset, 4, "sample_v_ragged_offset" - ) - self.sample_o_ragged_offset = self._pad_tensor_to_ndim( - sample_o_ragged_offset, 4, "sample_o_ragged_offset" - ) + self.sample_q_ragged_offset = self._pad_tensor_to_ndim(sample_q_ragged_offset, 4, "sample_q_ragged_offset") + self.sample_k_ragged_offset = self._pad_tensor_to_ndim(sample_k_ragged_offset, 4, "sample_k_ragged_offset") + self.sample_v_ragged_offset = self._pad_tensor_to_ndim(sample_v_ragged_offset, 4, "sample_v_ragged_offset") + self.sample_o_ragged_offset = self._pad_tensor_to_ndim(sample_o_ragged_offset, 4, "sample_o_ragged_offset") self.sample_stats_ragged_offset = ( - self._pad_tensor_to_ndim( - sample_stats_ragged_offset, 4, "sample_stats_ragged_offset" - ) - if sample_stats_ragged_offset is not None - else None + self._pad_tensor_to_ndim(sample_stats_ragged_offset, 4, "sample_stats_ragged_offset") if sample_stats_ragged_offset is not None else None ) - self.attn_scale = ( - attn_scale - if attn_scale is not None - else 1.0 / math.sqrt(self.sample_q.shape[-1]) - ) + self.attn_scale = attn_scale if attn_scale is not None else 1.0 / math.sqrt(self.sample_q.shape[-1]) self.intermediate_data_type = intermediate_data_type self.compute_data_type = compute_data_type @@ -95,9 +71,7 @@ def __init__( self._logger.critical( "cudnn_handle not provided, creating new handle. This is not recommended as this is significant overhead and will occur for each SlidingWindowAttention object created." ) - self._cudnn_handle = ( - cudnn_handle if cudnn_handle is not None else cudnn.create_handle() - ) + self._cudnn_handle = cudnn_handle if cudnn_handle is not None else cudnn.create_handle() self._cudnn_swa_graph = None self._cudnn_compiled = False self._logger.debug( @@ -130,25 +104,11 @@ def compute_exclusive_prefix_sum(tensor): ) # Calculate ragged offsets - q_ragged_offset = ( - compute_exclusive_prefix_sum(seq_len_q) * self.sample_q.stride()[0] - ).to(dtype=torch.int64) - k_ragged_offset = ( - compute_exclusive_prefix_sum(seq_len_kv) * self.sample_k.stride()[0] - ).to(dtype=torch.int64) - v_ragged_offset = ( - compute_exclusive_prefix_sum(seq_len_kv) * self.sample_v.stride()[0] - ).to(dtype=torch.int64) - o_ragged_offset = ( - compute_exclusive_prefix_sum(seq_len_q) * self.sample_o.stride()[0] - ).to(dtype=torch.int64) - stats_ragged_offset = ( - ( - compute_exclusive_prefix_sum(seq_len_q) * self.sample_stats.stride()[0] - ).to(dtype=torch.int64) - if not self.is_infer - else None - ) + q_ragged_offset = (compute_exclusive_prefix_sum(seq_len_q) * self.sample_q.stride()[0]).to(dtype=torch.int64) + k_ragged_offset = (compute_exclusive_prefix_sum(seq_len_kv) * self.sample_k.stride()[0]).to(dtype=torch.int64) + v_ragged_offset = (compute_exclusive_prefix_sum(seq_len_kv) * self.sample_v.stride()[0]).to(dtype=torch.int64) + o_ragged_offset = (compute_exclusive_prefix_sum(seq_len_q) * self.sample_o.stride()[0]).to(dtype=torch.int64) + stats_ragged_offset = (compute_exclusive_prefix_sum(seq_len_q) * self.sample_stats.stride()[0]).to(dtype=torch.int64) if not self.is_infer else None return ( q_ragged_offset, @@ -164,10 +124,7 @@ def check_support(self) -> bool: if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") self.dtype = self.sample_q.dtype - self.sm_version = ( - torch.cuda.get_device_capability()[0] * 10 - + torch.cuda.get_device_capability()[1] - ) + self.sm_version = torch.cuda.get_device_capability()[0] * 10 + torch.cuda.get_device_capability()[1] if self.sample_q.ndim == 4: self._logger.debug("Inferred bshd layout") self.input_layout = "bshd" @@ -179,9 +136,7 @@ def check_support(self) -> bool: swa_graph = cudnn.pygraph( io_data_type=_torch_to_cudnn_data_type(self.dtype), - intermediate_data_type=_torch_to_cudnn_data_type( - self.intermediate_data_type - ), + intermediate_data_type=_torch_to_cudnn_data_type(self.intermediate_data_type), compute_data_type=_torch_to_cudnn_data_type(self.compute_data_type), handle=self._cudnn_handle, sm_version=self.sm_version, @@ -206,37 +161,23 @@ def check_support(self) -> bool: b, h_q, s_q, d_v = self.sample_o.shape if self.sample_q.shape != (b, h_q, s_q, d_qk): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {b, h_q, s_q, d_qk}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {b, h_q, s_q, d_qk}, got {self.sample_q.shape}") if self.sample_k.shape != (b, h_kv, s_kv, d_qk): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {b, h_kv, s_kv, d_qk}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {b, h_kv, s_kv, d_qk}, got {self.sample_k.shape}") if self.sample_v.shape != (b, h_kv, s_kv, d_v): - raise ValueError( - f"Input shape mismatch: expected V tensor shape {b, h_kv, s_kv, d_v}, got {self.sample_v.shape}" - ) + raise ValueError(f"Input shape mismatch: expected V tensor shape {b, h_kv, s_kv, d_v}, got {self.sample_v.shape}") if self.sample_o.shape != (b, h_q, s_q, d_v): - raise ValueError( - f"Output shape mismatch: expected O tensor shape {b, h_q, s_q, d_v}, got {self.sample_o.shape}" - ) + raise ValueError(f"Output shape mismatch: expected O tensor shape {b, h_q, s_q, d_v}, got {self.sample_o.shape}") if not self.is_infer: - self.sample_stats = self._pad_tensor_to_ndim( - self.sample_stats, 4, "sample_stats" - ) + self.sample_stats = self._pad_tensor_to_ndim(self.sample_stats, 4, "sample_stats") if self.sample_stats.shape != (b, h_q, s_q, 1): - raise ValueError( - f"Output shape mismatch: expected Stats tensor shape {b, h_q, s_q, 1}, got {self.sample_stats.shape}" - ) + raise ValueError(f"Output shape mismatch: expected Stats tensor shape {b, h_q, s_q, 1}, got {self.sample_stats.shape}") if self.sample_seq_len_q is not None or self.sample_seq_len_kv is not None: raise ValueError( f"sample_seq_len_q and sample_seq_len_kv should be None for bshd layout, got {self.sample_seq_len_q} and {self.sample_seq_len_kv}" ) if self.max_seq_len_q is not None or self.max_seq_len_kv is not None: - raise ValueError( - f"max_seq_len_q and max_seq_len_kv should be None for bshd layout, got {self.max_seq_len_q} and {self.max_seq_len_kv}" - ) + raise ValueError(f"max_seq_len_q and max_seq_len_kv should be None for bshd layout, got {self.max_seq_len_q} and {self.max_seq_len_kv}") if ( self.sample_q_ragged_offset is not None or self.sample_k_ragged_offset is not None @@ -258,38 +199,24 @@ def check_support(self) -> bool: t, h_q, d_v = self.sample_o.shape if self.sample_q.shape != (t, h_q, d_qk): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {t, h_q, d_qk}, got {self.sample_q.shape}") if self.sample_k.shape != (t, h_kv, d_qk): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {t, h_kv, d_qk}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {t, h_kv, d_qk}, got {self.sample_k.shape}") if self.sample_v.shape != (t, h_kv, d_v): - raise ValueError( - f"Input shape mismatch: expected V tensor shape {t, h_kv, d_v}, got {self.sample_v.shape}" - ) + raise ValueError(f"Input shape mismatch: expected V tensor shape {t, h_kv, d_v}, got {self.sample_v.shape}") if self.sample_o.shape != (t, h_q, d_v): - raise ValueError( - f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}" - ) + raise ValueError(f"Output shape mismatch: expected O tensor shape {t, h_q, d_v}, got {self.sample_o.shape}") if not self.is_infer: - self.sample_stats = self._pad_tensor_to_ndim( - self.sample_stats, 3, "sample_stats" - ) + self.sample_stats = self._pad_tensor_to_ndim(self.sample_stats, 3, "sample_stats") if self.sample_stats.shape != (t, h_q, 1): - raise ValueError( - f"Output shape mismatch: expected Stats tensor shape {t, h_q, 1}, got {self.sample_stats.shape}" - ) + raise ValueError(f"Output shape mismatch: expected Stats tensor shape {t, h_q, 1}, got {self.sample_stats.shape}") if self.sample_seq_len_q is None or self.sample_seq_len_kv is None: raise ValueError( f"sample_seq_len_q and sample_seq_len_kv must be provided for thd layout, got {self.sample_seq_len_q} and {self.sample_seq_len_kv}" ) if self.max_seq_len_q is None or self.max_seq_len_kv is None: - raise ValueError( - f"max_seq_len_q and max_seq_len_kv must be provided for thd layout, got {self.max_seq_len_q} and {self.max_seq_len_kv}" - ) + raise ValueError(f"max_seq_len_q and max_seq_len_kv must be provided for thd layout, got {self.max_seq_len_q} and {self.max_seq_len_kv}") if ( self.sample_q_ragged_offset is None @@ -303,17 +230,12 @@ def check_support(self) -> bool: or self.sample_k_ragged_offset is not None or self.sample_v_ragged_offset is not None or self.sample_o_ragged_offset is not None - or ( - not self.is_infer - and self.sample_stats_ragged_offset is not None - ) + or (not self.is_infer and self.sample_stats_ragged_offset is not None) ): raise ValueError( f"sample_q_ragged_offset, sample_k_ragged_offset, sample_v_ragged_offset, sample_o_ragged_offset, and sample_stats_ragged_offset must be all provided or all None, got {self.sample_q_ragged_offset}, {self.sample_k_ragged_offset}, {self.sample_v_ragged_offset}, {self.sample_o_ragged_offset}, and {self.sample_stats_ragged_offset}" ) - self._logger.info( - "Calculating ragged offsets internally assuming fully packed THD layout" - ) + self._logger.info("Calculating ragged offsets internally assuming fully packed THD layout") ( self.sample_q_ragged_offset, self.sample_k_ragged_offset, @@ -360,22 +282,12 @@ def check_support(self) -> bool: ) self.seq_len_q_cudnn = swa_graph.tensor_like(self.sample_seq_len_q) self.seq_len_kv_cudnn = swa_graph.tensor_like(self.sample_seq_len_kv) - self.q_ragged_offset_cudnn = swa_graph.tensor_like( - self.sample_q_ragged_offset - ) - self.k_ragged_offset_cudnn = swa_graph.tensor_like( - self.sample_k_ragged_offset - ) - self.v_ragged_offset_cudnn = swa_graph.tensor_like( - self.sample_v_ragged_offset - ) - self.o_ragged_offset_cudnn = swa_graph.tensor_like( - self.sample_o_ragged_offset - ) + self.q_ragged_offset_cudnn = swa_graph.tensor_like(self.sample_q_ragged_offset) + self.k_ragged_offset_cudnn = swa_graph.tensor_like(self.sample_k_ragged_offset) + self.v_ragged_offset_cudnn = swa_graph.tensor_like(self.sample_v_ragged_offset) + self.o_ragged_offset_cudnn = swa_graph.tensor_like(self.sample_o_ragged_offset) if not self.is_infer: - self.stats_ragged_offset_cudnn = swa_graph.tensor_like( - self.sample_stats_ragged_offset - ) + self.stats_ragged_offset_cudnn = swa_graph.tensor_like(self.sample_stats_ragged_offset) self.q_cudnn.set_ragged_offset(self.q_ragged_offset_cudnn) self.k_cudnn.set_ragged_offset(self.k_ragged_offset_cudnn) @@ -408,9 +320,7 @@ def check_support(self) -> bool: self.o_cudnn.set_dim(self.sample_o.shape).set_stride(self.sample_o.stride()) if not self.is_infer: self.stats_cudnn.set_output(True).set_data_type(cudnn.data_type.FLOAT) - self.stats_cudnn.set_dim(self.sample_stats.shape).set_stride( - self.sample_stats.stride() - ) + self.stats_cudnn.set_dim(self.sample_stats.shape).set_stride(self.sample_stats.stride()) elif self.input_layout == "thd": self.o_cudnn.set_dim((b, h_q, self.max_seq_len_q, d_v)) self.o_cudnn.set_stride( @@ -439,9 +349,7 @@ def check_support(self) -> bool: try: swa_graph.validate() except cudnn.cudnnGraphNotSupportedError as e: - self._logger.error( - f"Graph not supported (cudnnGraphNotSupportedError): {e}" - ) + self._logger.error(f"Graph not supported (cudnnGraphNotSupportedError): {e}") return False except Exception as e: self._logger.error(f"Graph not supported: {e}") @@ -454,16 +362,12 @@ def check_support(self) -> bool: def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: if current_stream is not None: - self._logger.warning( - "Overwriting cudnn_handle stream with provided cuda stream. Do not pass in current_stream if this is not intended." - ) + self._logger.warning("Overwriting cudnn_handle stream with provided cuda stream. Do not pass in current_stream if this is not intended.") cudnn.set_stream(self._cudnn_handle, current_stream) self._ensure_support_checked() self._cudnn_swa_graph.build_operation_graph() - self._cudnn_swa_graph.create_execution_plans( - [cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK] - ) + self._cudnn_swa_graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) self._cudnn_swa_graph.check_support() self._cudnn_swa_graph.build_plans() @@ -491,57 +395,31 @@ def execute( self._logger.debug("Entering execute") cudnn_handle = self._cudnn_handle if cudnn_handle is None else cudnn_handle if current_stream is not None: - self._logger.info( - "Overwriting cudnn_handle stream with provided cuda stream. Do not pass in current_stream if this is not intended." - ) + self._logger.info("Overwriting cudnn_handle stream with provided cuda stream. Do not pass in current_stream if this is not intended.") cudnn.set_stream(cudnn_handle, current_stream) if skip_compile: - raise NotImplementedError( - "cudnn sliding window attention kernel does not support skip_compile" - ) + raise NotImplementedError("cudnn sliding window attention kernel does not support skip_compile") if self._cudnn_swa_graph is None or not self._cudnn_compiled: raise ValueError("SlidingWindowAttention kernel not compiled") self._logger.debug("Executing with compiled kernel") self._logger.debug("Reshaping tensors to kernel expected format") - stats_tensor = ( - self._pad_tensor_to_ndim(stats_tensor, self.sample_o.ndim, "stats_tensor") - if stats_tensor is not None - else None - ) - seq_len_q_tensor = self._pad_tensor_to_ndim( - seq_len_q_tensor, 4, "seq_len_q_tensor" - ) - seq_len_kv_tensor = self._pad_tensor_to_ndim( - seq_len_kv_tensor, 4, "seq_len_kv_tensor" - ) - q_ragged_offset_tensor = self._pad_tensor_to_ndim( - q_ragged_offset_tensor, 4, "q_ragged_offset_tensor" - ) - k_ragged_offset_tensor = self._pad_tensor_to_ndim( - k_ragged_offset_tensor, 4, "k_ragged_offset_tensor" - ) - v_ragged_offset_tensor = self._pad_tensor_to_ndim( - v_ragged_offset_tensor, 4, "v_ragged_offset_tensor" - ) - o_ragged_offset_tensor = self._pad_tensor_to_ndim( - o_ragged_offset_tensor, 4, "o_ragged_offset_tensor" - ) - stats_ragged_offset_tensor = self._pad_tensor_to_ndim( - stats_ragged_offset_tensor, 4, "stats_ragged_offset_tensor" - ) + stats_tensor = self._pad_tensor_to_ndim(stats_tensor, self.sample_o.ndim, "stats_tensor") if stats_tensor is not None else None + seq_len_q_tensor = self._pad_tensor_to_ndim(seq_len_q_tensor, 4, "seq_len_q_tensor") + seq_len_kv_tensor = self._pad_tensor_to_ndim(seq_len_kv_tensor, 4, "seq_len_kv_tensor") + q_ragged_offset_tensor = self._pad_tensor_to_ndim(q_ragged_offset_tensor, 4, "q_ragged_offset_tensor") + k_ragged_offset_tensor = self._pad_tensor_to_ndim(k_ragged_offset_tensor, 4, "k_ragged_offset_tensor") + v_ragged_offset_tensor = self._pad_tensor_to_ndim(v_ragged_offset_tensor, 4, "v_ragged_offset_tensor") + o_ragged_offset_tensor = self._pad_tensor_to_ndim(o_ragged_offset_tensor, 4, "o_ragged_offset_tensor") + stats_ragged_offset_tensor = self._pad_tensor_to_ndim(stats_ragged_offset_tensor, 4, "stats_ragged_offset_tensor") if not self.is_infer and stats_tensor is None: - raise ValueError( - f"stats_tensor must be provided when compiled in non-inference mode, got {stats_tensor}" - ) + raise ValueError(f"stats_tensor must be provided when compiled in non-inference mode, got {stats_tensor}") if self.input_layout == "thd": if seq_len_q_tensor is None or seq_len_kv_tensor is None: - raise ValueError( - f"seq_len_q_tensor and seq_len_kv_tensor must be provided for thd layout, got {seq_len_q_tensor} and {seq_len_kv_tensor}" - ) + raise ValueError(f"seq_len_q_tensor and seq_len_kv_tensor must be provided for thd layout, got {seq_len_q_tensor} and {seq_len_kv_tensor}") if ( q_ragged_offset_tensor is None or k_ragged_offset_tensor is None @@ -559,9 +437,7 @@ def execute( raise ValueError( f"q_ragged_offset_tensor, k_ragged_offset_tensor, v_ragged_offset_tensor, o_ragged_offset_tensor, and stats_ragged_offset_tensor must be all provided or all None, got {q_ragged_offset_tensor}, {k_ragged_offset_tensor}, {v_ragged_offset_tensor}, {o_ragged_offset_tensor}, and {stats_ragged_offset_tensor}" ) - self._logger.info( - "Calculating ragged offsets internally assuming fully packed THD layout" - ) + self._logger.info("Calculating ragged offsets internally assuming fully packed THD layout") ( q_ragged_offset_tensor, k_ragged_offset_tensor, @@ -637,37 +513,21 @@ def sliding_window_attention_wrapper( o_tensor, stats_tensor = None, None o_dtype = o_dtype if o_dtype is not None else q_tensor.dtype if q_tensor.ndim == 3: # thd - _logger.debug( - "sliding_window_attention_wrapper: Creating empty output tensor o for thd layout" - ) + _logger.debug("sliding_window_attention_wrapper: Creating empty output tensor o for thd layout") t, h_q, d = q_tensor.shape _, h_k, d_v = v_tensor.shape - o_tensor = make_tensor_strided_like( - q_tensor, (t, h_q, d_v), dtype=o_dtype, device=q_tensor.device - ) + o_tensor = make_tensor_strided_like(q_tensor, (t, h_q, d_v), dtype=o_dtype, device=q_tensor.device) if not is_infer: - _logger.debug( - "sliding_window_attention_wrapper: Creating empty output tensor stats for thd layout" - ) - stats_tensor = make_tensor_strided_like( - q_tensor, (t, h_q, 1), dtype=torch.float32, device=q_tensor.device - ) + _logger.debug("sliding_window_attention_wrapper: Creating empty output tensor stats for thd layout") + stats_tensor = make_tensor_strided_like(q_tensor, (t, h_q, 1), dtype=torch.float32, device=q_tensor.device) else: # bshd - _logger.debug( - "sliding_window_attention_wrapper: Creating empty output tensor o for bshd layout" - ) + _logger.debug("sliding_window_attention_wrapper: Creating empty output tensor o for bshd layout") b, h_q, s_q, d = q_tensor.shape _, h_k, s_k, d_v = v_tensor.shape - o_tensor = make_tensor_strided_like( - q_tensor, (b, h_q, s_q, d_v), dtype=o_dtype, device=q_tensor.device - ) + o_tensor = make_tensor_strided_like(q_tensor, (b, h_q, s_q, d_v), dtype=o_dtype, device=q_tensor.device) if not is_infer: - _logger.debug( - "sliding_window_attention_wrapper: Creating empty output tensor stats for bshd layout" - ) - stats_tensor = make_tensor_strided_like( - q_tensor, (b, h_q, s_q, 1), dtype=torch.float32, device=q_tensor.device - ) + _logger.debug("sliding_window_attention_wrapper: Creating empty output tensor stats for bshd layout") + stats_tensor = make_tensor_strided_like(q_tensor, (b, h_q, s_q, 1), dtype=torch.float32, device=q_tensor.device) cache_key = ( q_tensor.shape, @@ -679,11 +539,7 @@ def sliding_window_attention_wrapper( k_ragged_offset_tensor.shape if k_ragged_offset_tensor is not None else None, v_ragged_offset_tensor.shape if v_ragged_offset_tensor is not None else None, o_ragged_offset_tensor.shape if o_ragged_offset_tensor is not None else None, - ( - stats_ragged_offset_tensor.shape - if stats_ragged_offset_tensor is not None - else None - ), + (stats_ragged_offset_tensor.shape if stats_ragged_offset_tensor is not None else None), q_tensor.stride(), k_tensor.stride(), v_tensor.stride(), @@ -693,11 +549,7 @@ def sliding_window_attention_wrapper( k_ragged_offset_tensor.stride() if k_ragged_offset_tensor is not None else None, v_ragged_offset_tensor.stride() if v_ragged_offset_tensor is not None else None, o_ragged_offset_tensor.stride() if o_ragged_offset_tensor is not None else None, - ( - stats_ragged_offset_tensor.stride() - if stats_ragged_offset_tensor is not None - else None - ), + (stats_ragged_offset_tensor.stride() if stats_ragged_offset_tensor is not None else None), q_tensor.dtype, k_tensor.dtype, v_tensor.dtype, @@ -710,12 +562,8 @@ def sliding_window_attention_wrapper( ) sliding_window_attention_object = None if cache_key in _cache_of_SlidingWindowAttentionObjects: - _logger.debug( - "sliding_window_attention_wrapper: Using previously cached SlidingWindowAttention object" - ) - sliding_window_attention_object = _cache_of_SlidingWindowAttentionObjects[ - cache_key - ] + _logger.debug("sliding_window_attention_wrapper: Using previously cached SlidingWindowAttention object") + sliding_window_attention_object = _cache_of_SlidingWindowAttentionObjects[cache_key] sliding_window_attention_object.execute( q_tensor=q_tensor, @@ -734,9 +582,7 @@ def sliding_window_attention_wrapper( cudnn_handle=cudnn_handle, ) else: - _logger.debug( - "sliding_window_attention_wrapper: No previously cached SlidingWindowAttention object found, creating new SlidingWindowAttention object" - ) + _logger.debug("sliding_window_attention_wrapper: No previously cached SlidingWindowAttention object found, creating new SlidingWindowAttention object") sliding_window_attention_object = SlidingWindowAttention( sample_q=q_tensor, sample_k=k_tensor, @@ -750,12 +596,8 @@ def sliding_window_attention_wrapper( sample_v_ragged_offset=v_ragged_offset_tensor, sample_o_ragged_offset=o_ragged_offset_tensor, sample_stats_ragged_offset=stats_ragged_offset_tensor, - max_seq_len_q=( - max(seq_len_q_tensor).item() if seq_len_q_tensor is not None else None - ), - max_seq_len_kv=( - max(seq_len_kv_tensor).item() if seq_len_kv_tensor is not None else None - ), + max_seq_len_q=(max(seq_len_q_tensor).item() if seq_len_q_tensor is not None else None), + max_seq_len_kv=(max(seq_len_kv_tensor).item() if seq_len_kv_tensor is not None else None), left_bound=left_bound, right_bound=right_bound, attn_scale=attn_scale, @@ -781,8 +623,6 @@ def sliding_window_attention_wrapper( stats_ragged_offset_tensor=stats_ragged_offset_tensor, current_stream=stream, ) - _cache_of_SlidingWindowAttentionObjects[cache_key] = ( - sliding_window_attention_object - ) + _cache_of_SlidingWindowAttentionObjects[cache_key] = sliding_window_attention_object return o_tensor, stats_tensor diff --git a/python/cudnn/native_sparse_attention/top_k/api.py b/python/cudnn/native_sparse_attention/top_k/api.py index 34b9e17b..728aed1c 100644 --- a/python/cudnn/native_sparse_attention/top_k/api.py +++ b/python/cudnn/native_sparse_attention/top_k/api.py @@ -78,10 +78,7 @@ def check_support(self) -> bool: if self.sample_cum_seqlen_q is None and self.sample_cum_seqlen_k is None: self.input_layout = "B,H,S,D" - elif ( - self.sample_cum_seqlen_q is not None - and self.sample_cum_seqlen_k is not None - ): + elif self.sample_cum_seqlen_q is not None and self.sample_cum_seqlen_k is not None: self.input_layout = "T,H,D" if self.sample_q.ndim == 3: @@ -95,129 +92,70 @@ def check_support(self) -> bool: self.sample_lse = self.sample_lse.unsqueeze(0).transpose(1, 2) elif self.sample_lse.ndim == 3: self._logger.info("reshaping lse_tensor from T,H,1 to 1,H,T") - self.sample_lse = ( - self._unpad_tensor_to_ndim(self.sample_lse, 2, "sample_lse") - .unsqueeze(0) - .transpose(1, 2) - ) + self.sample_lse = self._unpad_tensor_to_ndim(self.sample_lse, 2, "sample_lse").unsqueeze(0).transpose(1, 2) if self.sample_topk_scores.ndim == 3: self._logger.info("reshaping topk_scores_tensor from T,H,D to 1,H,T,D") - self.sample_topk_scores = self.sample_topk_scores.unsqueeze( - 0 - ).transpose(1, 2) + self.sample_topk_scores = self.sample_topk_scores.unsqueeze(0).transpose(1, 2) if self.sample_topk_indices.ndim == 3: self._logger.info("reshaping topk_indices_tensor from T,H,D to 1,H,T,D") - self.sample_topk_indices = self.sample_topk_indices.unsqueeze( - 0 - ).transpose(1, 2) + self.sample_topk_indices = self.sample_topk_indices.unsqueeze(0).transpose(1, 2) if self.sample_cum_seqlen_q.ndim != 1: - self._logger.info( - "cum_seqlen_q must be 1D tensor. Attempting to squeeze last dimension(s)" - ) + self._logger.info("cum_seqlen_q must be 1D tensor. Attempting to squeeze last dimension(s)") for _ in range(self.sample_cum_seqlen_q.ndim - 1): self.sample_cum_seqlen_q = self.sample_cum_seqlen_q.squeeze(-1) if self.sample_cum_seqlen_q.ndim != 1: - raise ValueError( - f"cum_seqlen_q must be 1D tensor, got {self.sample_cum_seqlen_q.ndim}D" - ) + raise ValueError(f"cum_seqlen_q must be 1D tensor, got {self.sample_cum_seqlen_q.ndim}D") if self.sample_cum_seqlen_k.ndim != 1: - self._logger.info( - "cum_seqlen_k must be 1D tensor. Attempting to squeeze last dimension(s)" - ) + self._logger.info("cum_seqlen_k must be 1D tensor. Attempting to squeeze last dimension(s)") for _ in range(self.sample_cum_seqlen_k.ndim - 1): self.sample_cum_seqlen_k = self.sample_cum_seqlen_k.squeeze(-1) if self.sample_cum_seqlen_k.ndim != 1: - raise ValueError( - f"cum_seqlen_k must be 1D tensor, got {self.sample_cum_seqlen_k.ndim}D" - ) + raise ValueError(f"cum_seqlen_k must be 1D tensor, got {self.sample_cum_seqlen_k.ndim}D") if self.max_s_q is None: - self._logger.warning( - "max_s_q not provided, inferring from cum_seqlen_q" - ) - self.max_s_q = ( - (self.sample_cum_seqlen_q[1:] - self.sample_cum_seqlen_q[:-1]) - .max() - .item() - ) + self._logger.warning("max_s_q not provided, inferring from cum_seqlen_q") + self.max_s_q = (self.sample_cum_seqlen_q[1:] - self.sample_cum_seqlen_q[:-1]).max().item() if self.max_s_k is None: - self._logger.warning( - "max_s_k not provided, inferring from cum_seqlen_k" - ) - self.max_s_k = ( - (self.sample_cum_seqlen_k[1:] - self.sample_cum_seqlen_k[:-1]) - .max() - .item() - ) + self._logger.warning("max_s_k not provided, inferring from cum_seqlen_k") + self.max_s_k = (self.sample_cum_seqlen_k[1:] - self.sample_cum_seqlen_k[:-1]).max().item() else: - raise ValueError( - f"cum_seqlen_q and cum_seqlen_k must be None or both not None, got {self.sample_cum_seqlen_q} and {self.sample_cum_seqlen_k}" - ) + raise ValueError(f"cum_seqlen_q and cum_seqlen_k must be None or both not None, got {self.sample_cum_seqlen_q} and {self.sample_cum_seqlen_k}") b, h_q, s_q, d = self.sample_q.shape b, h_k, s_k, d = self.sample_k.shape if self.sample_q.shape != (b, h_q, s_q, d): - raise ValueError( - f"Input shape mismatch: expected Q tensor shape {b, h_q, s_q, d}, got {self.sample_q.shape}" - ) + raise ValueError(f"Input shape mismatch: expected Q tensor shape {b, h_q, s_q, d}, got {self.sample_q.shape}") if self.sample_k.shape != (b, h_k, s_k, d): - raise ValueError( - f"Input shape mismatch: expected K tensor shape {b, h_k, s_k, d}, got {self.sample_k.shape}" - ) + raise ValueError(f"Input shape mismatch: expected K tensor shape {b, h_k, s_k, d}, got {self.sample_k.shape}") if self.sample_lse.shape == (b, h_q, s_q, 1): - self._logger.info( - "reshaping lse_tensor from (b, h_q, s_q, 1) to (b, h_q, s_q)" - ) + self._logger.info("reshaping lse_tensor from (b, h_q, s_q, 1) to (b, h_q, s_q)") self.sample_lse = self.sample_lse.squeeze(-1) if self.sample_lse.shape != (b, h_q, s_q): - raise ValueError( - f"Input shape mismatch: expected LSE tensor shape {b, h_q, s_q}, got {self.sample_lse.shape}" - ) + raise ValueError(f"Input shape mismatch: expected LSE tensor shape {b, h_q, s_q}, got {self.sample_lse.shape}") if self.sample_lse.stride(-1) != 1: - self._logger.warning( - "lse_tensor is expected to have leading stride in last dimension of shape (b, h_q, s_q), copying lse_tensor to contiguous" - ) + self._logger.warning("lse_tensor is expected to have leading stride in last dimension of shape (b, h_q, s_q), copying lse_tensor to contiguous") self.sample_lse = self.sample_lse.contiguous() if self.sample_topk_scores.shape != (b, h_k, s_q, self.k_value): - raise ValueError( - f"Input shape mismatch: expected TopK Scores tensor shape {b, h_k, s_q, self.k_value}, got {self.sample_topk_scores.shape}" - ) + raise ValueError(f"Input shape mismatch: expected TopK Scores tensor shape {b, h_k, s_q, self.k_value}, got {self.sample_topk_scores.shape}") if self.sample_topk_indices.shape != (b, h_k, s_q, self.k_value): - raise ValueError( - f"Input shape mismatch: expected TopK Indices tensor shape {b, h_k, s_q, self.k_value}, got {self.sample_topk_indices.shape}" - ) + raise ValueError(f"Input shape mismatch: expected TopK Indices tensor shape {b, h_k, s_q, self.k_value}, got {self.sample_topk_indices.shape}") - self.batch_size = ( - b - if (self.input_layout == "B,H,S,D") - else (len(self.sample_cum_seqlen_q) - 1) - ) + self.batch_size = b if (self.input_layout == "B,H,S,D") else (len(self.sample_cum_seqlen_q) - 1) self.h_q, self.h_k, self.head_dim = h_q, h_k, d if self.input_layout == "B,H,S,D": self.max_s_q, self.max_s_k = s_q, s_k self._logger.debug("Checking dtypes") if self.sample_q.dtype != self.sample_k.dtype: - raise ValueError( - f"Q and K must have the same dtype, got {self.sample_q.dtype} and {self.sample_k.dtype}" - ) + raise ValueError(f"Q and K must have the same dtype, got {self.sample_q.dtype} and {self.sample_k.dtype}") self.dtype = self.sample_q.dtype if self.sample_lse.dtype != self.acc_dtype: - raise ValueError( - f"LSE and Accumulator must have the same dtype, got {self.sample_lse.dtype} and {self.acc_dtype}" - ) + raise ValueError(f"LSE and Accumulator must have the same dtype, got {self.sample_lse.dtype} and {self.acc_dtype}") if self.sample_topk_scores.dtype != self.acc_dtype: - raise ValueError( - f"TopK Scores and Accumulator must have the same dtype, got {self.sample_topk_scores.dtype} and {self.acc_dtype}" - ) + raise ValueError(f"TopK Scores and Accumulator must have the same dtype, got {self.sample_topk_scores.dtype} and {self.acc_dtype}") if self.sample_topk_indices.dtype != torch.int32: - raise ValueError( - f"TopK Indices must be int32, got {self.sample_topk_indices.dtype}" - ) + raise ValueError(f"TopK Indices must be int32, got {self.sample_topk_indices.dtype}") if self.input_layout == "T,H,D": - if ( - self.sample_cum_seqlen_q.dtype != torch.int32 - or self.sample_cum_seqlen_k.dtype != torch.int32 - ): + if self.sample_cum_seqlen_q.dtype != torch.int32 or self.sample_cum_seqlen_k.dtype != torch.int32: raise ValueError( f"cum_seqlen_q and cum_seqlen_k tensors must be int32, got {self.sample_cum_seqlen_q.dtype} and {self.sample_cum_seqlen_k.dtype}" ) @@ -230,9 +168,7 @@ def check_support(self) -> bool: major, minor = torch.cuda.get_device_capability(device) compute_capability = major * 10 + minor if compute_capability < 100: - raise RuntimeError( - f"TopKReduction requires SM100+ compute capability, but found SM{compute_capability} on device {device}" - ) + raise RuntimeError(f"TopKReduction requires SM100+ compute capability, but found SM{compute_capability} on device {device}") if compute_capability == 103: raise RuntimeError("cuteDSL TopKReduction is not supported on SM103") @@ -257,11 +193,7 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: is_causal=self.is_causal, ) - scale_softmax = ( - 1.0 / math.sqrt(self.head_dim) - if self.scale_softmax is None - else self.scale_softmax - ) + scale_softmax = 1.0 / math.sqrt(self.head_dim) if self.scale_softmax is None else self.scale_softmax log2_e = math.log2(math.e) softmax_scale_log2_e = scale_softmax * log2_e problem_size = ( @@ -273,31 +205,13 @@ def compile(self, current_stream: Optional[cuda.CUstream] = None) -> None: self.head_dim, ) - sample_q_cute = from_dlpack( - self.sample_q, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - sample_k_cute = from_dlpack( - self.sample_k, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - sample_lse_cute = from_dlpack( - self.sample_lse, assumed_align=16 - ).mark_layout_dynamic(leading_dim=2) - sample_topk_scores_cute = from_dlpack( - self.sample_topk_scores, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - sample_topk_indices_cute = from_dlpack( - self.sample_topk_indices, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - sample_cum_seqlen_q_cute = ( - from_dlpack(self.sample_cum_seqlen_q).mark_layout_dynamic() - if self.input_layout == "T,H,D" - else None - ) - sample_cum_seqlen_k_cute = ( - from_dlpack(self.sample_cum_seqlen_k).mark_layout_dynamic() - if self.input_layout == "T,H,D" - else None - ) + sample_q_cute = from_dlpack(self.sample_q, assumed_align=16).mark_layout_dynamic(leading_dim=3) + sample_k_cute = from_dlpack(self.sample_k, assumed_align=16).mark_layout_dynamic(leading_dim=3) + sample_lse_cute = from_dlpack(self.sample_lse, assumed_align=16).mark_layout_dynamic(leading_dim=2) + sample_topk_scores_cute = from_dlpack(self.sample_topk_scores, assumed_align=16).mark_layout_dynamic(leading_dim=3) + sample_topk_indices_cute = from_dlpack(self.sample_topk_indices, assumed_align=16).mark_layout_dynamic(leading_dim=3) + sample_cum_seqlen_q_cute = from_dlpack(self.sample_cum_seqlen_q).mark_layout_dynamic() if self.input_layout == "T,H,D" else None + sample_cum_seqlen_k_cute = from_dlpack(self.sample_cum_seqlen_k).mark_layout_dynamic() if self.input_layout == "T,H,D" else None self._compiled_kernel = cute.compile( topk_reduction, @@ -331,9 +245,7 @@ def execute( if self.input_layout == "T,H,D": if cumulative_s_q_tensor is None or cumulative_s_k_tensor is None: - raise ValueError( - "cumulative_s_q_tensor and cumulative_s_k_tensor are required when using T,H,D layout" - ) + raise ValueError("cumulative_s_q_tensor and cumulative_s_k_tensor are required when using T,H,D layout") if q_tensor.ndim == 3: self._logger.info("reshaping q_tensor from T,H,D to 1,H,T,D") q_tensor = q_tensor.unsqueeze(0).transpose(1, 2) @@ -345,11 +257,7 @@ def execute( lse_tensor = lse_tensor.unsqueeze(0).transpose(1, 2) elif lse_tensor.ndim == 3: self._logger.info("reshaping lse_tensor from T,H,1 to 1,H,T") - lse_tensor = ( - self._unpad_tensor_to_ndim(lse_tensor, 2, "lse_tensor") - .unsqueeze(0) - .transpose(1, 2) - ) + lse_tensor = self._unpad_tensor_to_ndim(lse_tensor, 2, "lse_tensor").unsqueeze(0).transpose(1, 2) if topk_scores_tensor.ndim == 3: self._logger.info("reshaping topk_scores_tensor from T,H,D to 1,H,T,D") topk_scores_tensor = topk_scores_tensor.unsqueeze(0).transpose(1, 2) @@ -361,41 +269,17 @@ def execute( self._logger.info("reshaping lse_tensor to remove trailing dimension") lse_tensor = lse_tensor.squeeze(-1) if lse_tensor.stride(-1) != 1: - self._logger.warning( - "lse_tensor is expected to have leading stride in last dimension of shape (b, h_q, s_q), copying lse_tensor to contiguous" - ) + self._logger.warning("lse_tensor is expected to have leading stride in last dimension of shape (b, h_q, s_q), copying lse_tensor to contiguous") lse_tensor = lse_tensor.contiguous() - q_cute = from_dlpack(q_tensor, assumed_align=16).mark_layout_dynamic( - leading_dim=3 - ) - k_cute = from_dlpack(k_tensor, assumed_align=16).mark_layout_dynamic( - leading_dim=3 - ) - lse_cute = from_dlpack(lse_tensor, assumed_align=16).mark_layout_dynamic( - leading_dim=2 - ) - topk_scores_cute = from_dlpack( - topk_scores_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - topk_indices_cute = from_dlpack( - topk_indices_tensor, assumed_align=16 - ).mark_layout_dynamic(leading_dim=3) - cumulative_s_q_cute = ( - from_dlpack(cumulative_s_q_tensor).mark_layout_dynamic() - if self.input_layout == "T,H,D" - else None - ) - cumulative_s_k_cute = ( - from_dlpack(cumulative_s_k_tensor).mark_layout_dynamic() - if self.input_layout == "T,H,D" - else None - ) - scale_softmax = ( - 1.0 / math.sqrt(self.head_dim) - if self.scale_softmax is None - else self.scale_softmax - ) + q_cute = from_dlpack(q_tensor, assumed_align=16).mark_layout_dynamic(leading_dim=3) + k_cute = from_dlpack(k_tensor, assumed_align=16).mark_layout_dynamic(leading_dim=3) + lse_cute = from_dlpack(lse_tensor, assumed_align=16).mark_layout_dynamic(leading_dim=2) + topk_scores_cute = from_dlpack(topk_scores_tensor, assumed_align=16).mark_layout_dynamic(leading_dim=3) + topk_indices_cute = from_dlpack(topk_indices_tensor, assumed_align=16).mark_layout_dynamic(leading_dim=3) + cumulative_s_q_cute = from_dlpack(cumulative_s_q_tensor).mark_layout_dynamic() if self.input_layout == "T,H,D" else None + cumulative_s_k_cute = from_dlpack(cumulative_s_k_tensor).mark_layout_dynamic() if self.input_layout == "T,H,D" else None + scale_softmax = 1.0 / math.sqrt(self.head_dim) if self.scale_softmax is None else self.scale_softmax log2_e = math.log2(math.e) softmax_scale_log2_e = scale_softmax * log2_e problem_size = ( @@ -479,21 +363,13 @@ def topk_reduction_wrapper( if cum_seqlen_q_tensor is not None and cum_seqlen_k_tensor is not None: # T,H,D total_seq_len_q = cum_seqlen_q_tensor[-1].item() h_k = k_tensor.shape[1] - topk_scores_tensor = torch.empty( - total_seq_len_q, h_k, k_value, dtype=acc_dtype, device=q_tensor.device - ) - topk_indices_tensor = torch.empty( - total_seq_len_q, h_k, k_value, dtype=torch.int32, device=q_tensor.device - ) + topk_scores_tensor = torch.empty(total_seq_len_q, h_k, k_value, dtype=acc_dtype, device=q_tensor.device) + topk_indices_tensor = torch.empty(total_seq_len_q, h_k, k_value, dtype=torch.int32, device=q_tensor.device) elif cum_seqlen_q_tensor is None and cum_seqlen_k_tensor is None: # B,H,S,D b, _, s_q, _ = q_tensor.shape _, h_k, _, _ = k_tensor.shape - topk_scores_tensor = torch.empty( - b, s_q, h_k, k_value, dtype=acc_dtype, device=q_tensor.device - ).transpose(1, 2) - topk_indices_tensor = torch.empty( - b, s_q, h_k, k_value, dtype=torch.int32, device=q_tensor.device - ).transpose(1, 2) + topk_scores_tensor = torch.empty(b, s_q, h_k, k_value, dtype=acc_dtype, device=q_tensor.device).transpose(1, 2) + topk_indices_tensor = torch.empty(b, s_q, h_k, k_value, dtype=torch.int32, device=q_tensor.device).transpose(1, 2) else: raise ValueError( f"cum_seqlen_q_tensor and cum_seqlen_k_tensor must either both be None (B,H,S,D) or both not None (T,H,D), got {cum_seqlen_q_tensor} and {cum_seqlen_k_tensor}" @@ -527,9 +403,7 @@ def topk_reduction_wrapper( ) if cache_key in _cache_of_TopKReductionObjects: - _logger.debug( - "topk_reduction_wrapper: Using previously cached TopKReduction object" - ) + _logger.debug("topk_reduction_wrapper: Using previously cached TopKReduction object") topk_reduction = _cache_of_TopKReductionObjects[cache_key] topk_reduction.execute( q_tensor=q_tensor, diff --git a/python/cudnn/native_sparse_attention/top_k/nsa_top_k_reduction_fwd.py b/python/cudnn/native_sparse_attention/top_k/nsa_top_k_reduction_fwd.py index a636701f..9d0c8c3e 100644 --- a/python/cudnn/native_sparse_attention/top_k/nsa_top_k_reduction_fwd.py +++ b/python/cudnn/native_sparse_attention/top_k/nsa_top_k_reduction_fwd.py @@ -68,9 +68,7 @@ def __init__( self.k_value = k_value self.selection_block_size = selection_block_size self.compress_block_sliding_stride = compress_block_sliding_stride - self.num_elem_for_reduction = ( - selection_block_size // compress_block_sliding_stride - ) + self.num_elem_for_reduction = selection_block_size // compress_block_sliding_stride self.cluster_shape_mn = (1, 1) self.mma_tiler = mma_tiler @@ -121,15 +119,9 @@ def __call__( stride_b_q = s_q * head_dim * h_k * h_r if cumulative_s_q is None else 0 stride_b_k = s_k * head_dim * h_k if cumulative_s_k is None else 0 stride_b_lse = s_q * h_r * h_k if cumulative_s_q is None else 0 - stride_b_out = ( - s_q * (s_k_max // self.num_elem_for_reduction) * h_k - if cumulative_s_q is None - else 0 - ) + stride_b_out = s_q * (s_k_max // self.num_elem_for_reduction) * h_k if cumulative_s_q is None else 0 stride_b_topk_scores = s_q * self.k_value * h_k if cumulative_s_q is None else 0 - stride_b_topk_indices = ( - s_q * self.k_value * h_k if cumulative_s_q is None else 0 - ) + stride_b_topk_indices = s_q * self.k_value * h_k if cumulative_s_q is None else 0 Q = cute.make_tensor( Q.iterator, @@ -253,29 +245,17 @@ def __call__( @cute.struct class SharedStorage: - load_mma_Q_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.load_mma_Q_stage * 2 - ] - load_mma_K_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.load_mma_K_stage * 2 - ] - load_compute_LSE_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.load_compute_LSE_stage * 2 - ] - mma_compute_S_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.mma_compute_S_stage * 2 - ] + load_mma_Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_Q_stage * 2] + load_mma_K_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_mma_K_stage * 2] + load_compute_LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_compute_LSE_stage * 2] + mma_compute_S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_compute_S_stage * 2] sQ: cute.struct.Align[ - cute.struct.MemRange[ - self.element_dtype, cute.cosize(Q_smem_layout_staged) - ], + cute.struct.MemRange[self.element_dtype, cute.cosize(Q_smem_layout_staged)], 1024, ] sK: cute.struct.Align[ - cute.struct.MemRange[ - self.element_dtype, cute.cosize(K_smem_layout_staged) - ], + cute.struct.MemRange[self.element_dtype, cute.cosize(K_smem_layout_staged)], 1024, ] sLSE: cute.struct.Align[ @@ -315,12 +295,8 @@ class SharedStorage: ) def make_and_init_load_mma_Q_pipeline(self, load_mma_Q_mbar_ptr): - load_mma_Q_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_warp_id]) - ) - load_mma_Q_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) + load_mma_Q_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len([self.load_warp_id])) + load_mma_Q_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len([self.mma_warp_id])) return pipeline.PipelineTmaUmma.create( barrier_storage=load_mma_Q_mbar_ptr, num_stages=self.load_mma_Q_stage, @@ -330,12 +306,8 @@ def make_and_init_load_mma_Q_pipeline(self, load_mma_Q_mbar_ptr): ) def make_and_init_load_mma_K_pipeline(self, load_mma_K_mbar_ptr): - load_mma_K_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_warp_id]) - ) - load_mma_K_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) + load_mma_K_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len([self.load_warp_id])) + load_mma_K_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, len([self.mma_warp_id])) return pipeline.PipelineTmaUmma.create( barrier_storage=load_mma_K_mbar_ptr, num_stages=self.load_mma_K_stage, @@ -407,29 +379,15 @@ def kernel( smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - load_mma_Q_pipeline = self.make_and_init_load_mma_Q_pipeline( - storage.load_mma_Q_mbar_ptr.data_ptr() - ) - load_mma_K_pipeline = self.make_and_init_load_mma_K_pipeline( - storage.load_mma_K_mbar_ptr.data_ptr() - ) - load_compute_LSE_pipeline = self.make_and_init_load_compute_LSE_pipeline( - storage.load_compute_LSE_mbar_ptr.data_ptr() - ) - mma_compute_S_pipeline = self.make_and_init_mma_compute_S_pipeline( - storage.mma_compute_S_mbar_ptr.data_ptr() - ) + load_mma_Q_pipeline = self.make_and_init_load_mma_Q_pipeline(storage.load_mma_Q_mbar_ptr.data_ptr()) + load_mma_K_pipeline = self.make_and_init_load_mma_K_pipeline(storage.load_mma_K_mbar_ptr.data_ptr()) + load_compute_LSE_pipeline = self.make_and_init_load_compute_LSE_pipeline(storage.load_compute_LSE_mbar_ptr.data_ptr()) + mma_compute_S_pipeline = self.make_and_init_mma_compute_S_pipeline(storage.mma_compute_S_mbar_ptr.data_ptr()) - cute.arch.barrier( - barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta - ) + cute.arch.barrier(barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta) - sQ = storage.sQ.get_tensor( - Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner - ) - sK = storage.sK.get_tensor( - K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner - ) + sQ = storage.sQ.get_tensor(Q_smem_layout_staged.outer, swizzle=Q_smem_layout_staged.inner) + sK = storage.sK.get_tensor(K_smem_layout_staged.outer, swizzle=K_smem_layout_staged.inner) sLSE = storage.sLSE.get_tensor(LSE_smem_layout) block_offset = (Int32(0), Int32(0), Int32(0), (Int32(0), Int32(0), Int32(0))) @@ -455,24 +413,18 @@ def kernel( mQ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), tma_tensor_Q) mK = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), tma_tensor_K) mTopk_scores = cute.make_tensor( - Topk_scores.iterator - + cute.assume(block_offset[0] * Topk_scores.stride[0], divby=self.k_value), + Topk_scores.iterator + cute.assume(block_offset[0] * Topk_scores.stride[0], divby=self.k_value), Topk_scores.layout, ) mTopk_indices = cute.make_tensor( - Topk_indices.iterator - + cute.assume(block_offset[0] * Topk_indices.stride[0], divby=self.k_value), + Topk_indices.iterator + cute.assume(block_offset[0] * Topk_indices.stride[0], divby=self.k_value), Topk_indices.layout, ) # (MMA_M, MMA_K, REST_M, REST_K, (H_r, H_k, B)) - gQ = cute.local_tile( - mQ, cute.select(self.mma_tiler, mode=[0, 2]), (None, None, None) - ) + gQ = cute.local_tile(mQ, cute.select(self.mma_tiler, mode=[0, 2]), (None, None, None)) # (MMA_N, MMA_K, REST_N, REST_K, (1, H_k, B)) - gK = cute.local_tile( - mK, cute.select(self.mma_tiler, mode=[0, 2]), (None, None, None) - ) + gK = cute.local_tile(mK, cute.select(self.mma_tiler, mode=[0, 2]), (None, None, None)) # (MMA_M, MMA_K, H_r) gQ = gQ[None, None, bidx, 0, (None, bidy, bidz)] @@ -506,9 +458,7 @@ def kernel( tSrQ = QK_tiled_mma.make_fragment_A(sQ) tSrK = QK_tiled_mma.make_fragment_B(sK) - tStS_shape = QK_tiled_mma.partition_shape_C( - cute.select(self.mma_tiler, mode=[0, 1]) - ) + tStS_shape = QK_tiled_mma.partition_shape_C(cute.select(self.mma_tiler, mode=[0, 1])) # ((MMA_M, MMA_N), REST_M, REST_N) tStS = QK_tiled_mma.make_fragment_C(tStS_shape) # another tmem for reduction @@ -528,22 +478,14 @@ def kernel( # TODO: reconfig regs cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - load_mma_Q_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_mma_Q_stage - ) - load_mma_K_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_mma_K_stage - ) - load_compute_LSE_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_compute_LSE_stage - ) + load_mma_Q_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.load_mma_Q_stage) + load_mma_K_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.load_mma_K_stage) + load_compute_LSE_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.load_compute_LSE_stage) while load_iter_count > 0: # Wait for K to be empty load_mma_K_pipeline.producer_acquire(load_mma_K_producer_state) - K_tma_barrier = load_mma_K_pipeline.producer_get_barrier( - load_mma_K_producer_state - ) + K_tma_barrier = load_mma_K_pipeline.producer_get_barrier(load_mma_K_producer_state) # Load K tile cute.copy( @@ -558,9 +500,7 @@ def kernel( # Load Q and LSE for h_r_idx in cutlass.range(cute.size(tQgQ, mode=[1])): - load_compute_LSE_pipeline.producer_acquire( - load_compute_LSE_producer_state - ) + load_compute_LSE_pipeline.producer_acquire(load_compute_LSE_producer_state) # Load LSE thread_idx = tidx % self.threads_per_warp @@ -576,10 +516,7 @@ def kernel( LSE_for_copy = cute.flat_divide(LSE, (1,)) LSE_idx_offset = block_offset[0] * LSE.stride[0] for i in cutlass.range_constexpr(async_copy_num_elts): - LSE_idx = ( - self.mma_tiler[0] * bidx - + thread_idx * async_copy_num_elts - ) + LSE_idx = self.mma_tiler[0] * bidx + thread_idx * async_copy_num_elts if cute.elem_less(LSE_idx + i, cur_s_q): cute.copy( atom_async_copy, @@ -603,16 +540,12 @@ def kernel( load_compute_LSE_producer_state.index, ].fill(0.0) - load_compute_LSE_pipeline.producer_commit( - load_compute_LSE_producer_state - ) + load_compute_LSE_pipeline.producer_commit(load_compute_LSE_producer_state) load_compute_LSE_producer_state.advance() # Wait for Q to be empty load_mma_Q_pipeline.producer_acquire(load_mma_Q_producer_state) - Q_tma_barrier = load_mma_Q_pipeline.producer_get_barrier( - load_mma_Q_producer_state - ) + Q_tma_barrier = load_mma_Q_pipeline.producer_get_barrier(load_mma_Q_producer_state) # Load Q tile cute.copy( @@ -639,15 +572,9 @@ def kernel( number_of_threads=self.threads_per_warp, ) - load_mma_Q_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_mma_Q_stage - ) - load_mma_K_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_mma_K_stage - ) - mma_compute_S_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_compute_S_stage - ) + load_mma_Q_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.load_mma_Q_stage) + load_mma_K_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.load_mma_K_stage) + mma_compute_S_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.mma_compute_S_stage) while mma_iter_count > 0: # Wait for K to be full @@ -656,14 +583,10 @@ def kernel( for h_r_idx in cutlass.range(cute.size(tQgQ, mode=[1])): # Wait for Q to be full load_mma_Q_pipeline.consumer_wait(load_mma_Q_consumer_state) - mma_compute_S_pipeline.producer_acquire( - mma_compute_S_producer_state - ) + mma_compute_S_pipeline.producer_acquire(mma_compute_S_producer_state) QK_tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, False) - for k_block_idx in cutlass.range_constexpr( - cute.size(tSrQ, mode=[2]) - ): + for k_block_idx in cutlass.range_constexpr(cute.size(tSrQ, mode=[2])): cute.gemm( QK_tiled_mma, tStS, @@ -683,9 +606,7 @@ def kernel( ) QK_tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) - mma_compute_S_pipeline.producer_commit( - mma_compute_S_producer_state - ) + mma_compute_S_pipeline.producer_commit(mma_compute_S_producer_state) mma_compute_S_producer_state.advance() load_mma_Q_pipeline.consumer_release(load_mma_Q_consumer_state) @@ -701,12 +622,8 @@ def kernel( if warp_idx in self.compute_warp_id: cute.arch.warpgroup_reg_alloc(self.num_regs_compute) - mma_compute_S_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_compute_S_stage - ) - load_compute_LSE_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_compute_LSE_stage - ) + mma_compute_S_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.mma_compute_S_stage) + load_compute_LSE_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.load_compute_LSE_stage) load_compute_LSE_pipeline.consumer_wait(load_compute_LSE_consumer_state) thread_idx = tidx % (self.threads_per_warp * self.num_compute_warps) @@ -714,12 +631,8 @@ def kernel( heap_size_ref = cute.make_rmem_tensor((1,), Int32) heap_size_ref[0] = 0 # # Create temporary register heaps for computation - scores_heap_rf = cute.make_rmem_tensor( - ((4, self.k_value // 4), 1, 1), Float32 - ) - idx_heap_rf = cute.make_rmem_tensor( - ((4, self.k_value // 4), 1, 1), Int32 - ) + scores_heap_rf = cute.make_rmem_tensor(((4, self.k_value // 4), 1, 1), Float32) + idx_heap_rf = cute.make_rmem_tensor(((4, self.k_value // 4), 1, 1), Int32) tmem_load_atom = cute.make_copy_atom( tcgen05.Ld32x32bOp(tcgen05.copy.Repetition(32)), @@ -733,15 +646,9 @@ def kernel( cS = cute.make_identity_tensor((self.mma_tiler[0], self.mma_tiler[1])) comp_tile_size = 32 - tStS_tiled = cute.logical_divide( - tStS, cute.make_layout((self.mma_tiler[0], comp_tile_size)) - ) - tStS_compute_tiled = cute.logical_divide( - tStS_reduce, cute.make_layout((self.mma_tiler[0], comp_tile_size)) - ) - cS_tiled = cute.logical_divide( - cS, cute.make_layout((self.mma_tiler[0], comp_tile_size)) - ) + tStS_tiled = cute.logical_divide(tStS, cute.make_layout((self.mma_tiler[0], comp_tile_size))) + tStS_compute_tiled = cute.logical_divide(tStS_reduce, cute.make_layout((self.mma_tiler[0], comp_tile_size))) + cS_tiled = cute.logical_divide(cS, cute.make_layout((self.mma_tiler[0], comp_tile_size))) tStS_slice = tStS_tiled[None, 0] # ((128, 16), 8) tStS_compute_slice = tStS_compute_tiled[None, 0] @@ -753,108 +660,61 @@ def kernel( tTR_cS = thr_t2r.partition_D(cS_tiled) tTR_tS = thr_t2r.partition_S(tStS_tiled) tTR_tS_compute = thr_t2r.partition_S(tStS_compute_tiled) - tTR_rS = cute.make_rmem_tensor( - tTR_cS[None, None, 0].shape, self.acc_dtype - ) - tTR_rS_compute = cute.make_rmem_tensor( - tTR_cS[None, None, 0].shape, self.acc_dtype - ) + tTR_rS = cute.make_rmem_tensor(tTR_cS[None, None, 0].shape, self.acc_dtype) + tTR_rS_compute = cute.make_rmem_tensor(tTR_cS[None, None, 0].shape, self.acc_dtype) tiled_r2t = tcgen05.make_tmem_copy(tmem_store_atom, tStS_compute_slice) thr_r2t = tiled_r2t.get_slice(thread_idx) tRT_tS_compute = thr_r2t.partition_D(tStS_compute_tiled) - tiled_t2r_reduce = tcgen05.make_tmem_copy( - tmem_load_atom, tStS[(None, None), 0, 0] - ) + tiled_t2r_reduce = tcgen05.make_tmem_copy(tmem_load_atom, tStS[(None, None), 0, 0]) thr_t2r_reduce = tiled_t2r_reduce.get_slice(thread_idx) - tTR_tS_reduce = thr_t2r_reduce.partition_S( - tStS_reduce[(None, None), 0, 0] - ) + tTR_tS_reduce = thr_t2r_reduce.partition_S(tStS_reduce[(None, None), 0, 0]) tTR_cS_reduce = thr_t2r_reduce.partition_D(cS) - tTR_rS_reduce = cute.make_rmem_tensor( - tTR_cS_reduce.shape, self.acc_dtype - ) + tTR_rS_reduce = cute.make_rmem_tensor(tTR_cS_reduce.shape, self.acc_dtype) - tmp = cute.make_rmem_tensor( - (self.mma_tiler[1] // self.num_elem_for_reduction), self.acc_dtype - ) + tmp = cute.make_rmem_tensor((self.mma_tiler[1] // self.num_elem_for_reduction), self.acc_dtype) while compute_iter_count > 0: for h_r_idx in range(cute.size(tQgQ, mode=[1])): - mma_compute_S_pipeline.consumer_wait( - mma_compute_S_consumer_state - ) + mma_compute_S_pipeline.consumer_wait(mma_compute_S_consumer_state) # TODO: Added this as we should wait for the producer to load - load_compute_LSE_pipeline.consumer_wait( - load_compute_LSE_consumer_state - ) + load_compute_LSE_pipeline.consumer_wait(load_compute_LSE_consumer_state) - for sub_tile in cutlass.range( - self.mma_tiler[1] // comp_tile_size - ): + for sub_tile in cutlass.range(self.mma_tiler[1] // comp_tile_size): tTR_tS_sub_tile = tTR_tS[None, None, sub_tile] - tTR_tS_compute_sub_tile = tTR_tS_compute[ - None, None, sub_tile - ] - tRT_tS_compute_sub_tile = tRT_tS_compute[ - None, None, sub_tile - ] + tTR_tS_compute_sub_tile = tTR_tS_compute[None, None, sub_tile] + tRT_tS_compute_sub_tile = tRT_tS_compute[None, None, sub_tile] tTR_cS_sub_tile = tTR_cS[None, None, sub_tile] # Copy S from tmem to rmem cute.copy(tiled_t2r, tTR_tS_sub_tile, tTR_rS) - is_residual_k = ( - compute_iter_index * self.mma_tiler[1] - + self.mma_tiler[1] - > cur_s_k - ) + is_residual_k = compute_iter_index * self.mma_tiler[1] + self.mma_tiler[1] > cur_s_k leading_causal_masking = cutlass.Boolean(False) if cutlass.const_expr(self.is_causal): leading_causal_masking = ( - ((compute_iter_index + 1) * self.mma_tiler[1] + 1) - * self.compress_block_sliding_stride - - 1 - > bidx * self.mma_tiler[0] - ) - leading_causal_masking = cute.arch.shuffle_sync( - leading_causal_masking, 0 - ) + (compute_iter_index + 1) * self.mma_tiler[1] + 1 + ) * self.compress_block_sliding_stride - 1 > bidx * self.mma_tiler[0] + leading_causal_masking = cute.arch.shuffle_sync(leading_causal_masking, 0) trailing_residual_masking = cutlass.Boolean(False) trailing_residual_masking = is_residual_k - trailing_residual_masking = cute.arch.shuffle_sync( - trailing_residual_masking, 0 - ) + trailing_residual_masking = cute.arch.shuffle_sync(trailing_residual_masking, 0) - is_masked_tile = ( - leading_causal_masking or trailing_residual_masking - ) + is_masked_tile = leading_causal_masking or trailing_residual_masking # Apply mask if is_masked_tile: - for i in cutlass.range( - cute.size(tTR_rS), unroll_full=True - ): - q_idx = ( - cute.get(tTR_cS_sub_tile[i], mode=[0]) - + bidx * self.mma_tiler[0] - ) - k_block_idx = ( - cute.get(tTR_cS_sub_tile[i], mode=[1]) - + compute_iter_index * self.mma_tiler[1] - ) + for i in cutlass.range(cute.size(tTR_rS), unroll_full=True): + q_idx = cute.get(tTR_cS_sub_tile[i], mode=[0]) + bidx * self.mma_tiler[0] + k_block_idx = cute.get(tTR_cS_sub_tile[i], mode=[1]) + compute_iter_index * self.mma_tiler[1] if is_masked_tile: if cutlass.const_expr(self.is_causal): - k_idx = ( - (k_block_idx + 1) - * self.compress_block_sliding_stride - - 1 - ) + k_idx = (k_block_idx + 1) * self.compress_block_sliding_stride - 1 if k_idx > q_idx: tTR_rS[i] = -cutlass.Float32.inf if q_idx > cur_s_q or k_block_idx > cur_s_k: @@ -867,9 +727,7 @@ def kernel( # LSE should be set negative before and has be already multiplied by log2_e # Copy S_reduce from tmem to rmem - cute.copy( - tiled_t2r, tTR_tS_compute_sub_tile, tTR_rS_compute - ) + cute.copy(tiled_t2r, tTR_tS_compute_sub_tile, tTR_rS_compute) for i in cutlass.range(0, cute.size(tTR_rS, mode=[0]), 2): lse = ( @@ -892,47 +750,34 @@ def kernel( lse, ) tTR_rS[i] = cute.math.exp2(tTR_rS[i], fastmath=True) - tTR_rS[i + 1] = cute.math.exp2( - tTR_rS[i + 1], fastmath=True - ) + tTR_rS[i + 1] = cute.math.exp2(tTR_rS[i + 1], fastmath=True) if h_r_idx == 0: - (tTR_rS_compute[i], tTR_rS_compute[i + 1]) = ( - cute.arch.add_packed_f32x2( - (0.0, 0.0), - (tTR_rS[i], tTR_rS[i + 1]), - ) + tTR_rS_compute[i], tTR_rS_compute[i + 1] = cute.arch.add_packed_f32x2( + (0.0, 0.0), + (tTR_rS[i], tTR_rS[i + 1]), ) else: - (tTR_rS_compute[i], tTR_rS_compute[i + 1]) = ( - cute.arch.add_packed_f32x2( - (tTR_rS_compute[i], tTR_rS_compute[i + 1]), - (tTR_rS[i], tTR_rS[i + 1]), - ) + tTR_rS_compute[i], tTR_rS_compute[i + 1] = cute.arch.add_packed_f32x2( + (tTR_rS_compute[i], tTR_rS_compute[i + 1]), + (tTR_rS[i], tTR_rS[i + 1]), ) cute.arch.fence_view_async_tmem_load() cute.arch.barrier( barrier_id=self.compute_sync_bar_id, - number_of_threads=self.num_compute_warps - * self.threads_per_warp, + number_of_threads=self.num_compute_warps * self.threads_per_warp, ) # Copy tS_reduce back to tmem - cute.copy( - tiled_r2t, tTR_rS_compute, tRT_tS_compute_sub_tile - ) + cute.copy(tiled_r2t, tTR_rS_compute, tRT_tS_compute_sub_tile) cute.arch.fence_view_async_tmem_store() - load_compute_LSE_pipeline.consumer_release( - load_compute_LSE_consumer_state - ) + load_compute_LSE_pipeline.consumer_release(load_compute_LSE_consumer_state) load_compute_LSE_consumer_state.advance() - mma_compute_S_pipeline.consumer_release( - mma_compute_S_consumer_state - ) + mma_compute_S_pipeline.consumer_release(mma_compute_S_consumer_state) mma_compute_S_consumer_state.advance() # Reduce @@ -983,12 +828,8 @@ def kernel( compute_iter_index += 1 # (s_q, k_value, (1, h_k, b)) - gTopk_scores = cute.flat_divide( - mTopk_scores, (self.epi_tile[0], self.k_value) - ) - gTopk_indices = cute.flat_divide( - mTopk_indices, (self.epi_tile[0], self.k_value) - ) + gTopk_scores = cute.flat_divide(mTopk_scores, (self.epi_tile[0], self.k_value)) + gTopk_indices = cute.flat_divide(mTopk_indices, (self.epi_tile[0], self.k_value)) gTopk_scores = gTopk_scores[None, None, bidx, 0, (0, bidy, bidz)] gTopk_indices = gTopk_indices[None, None, bidx, 0, (0, bidy, bidz)] cTopk = cute.make_identity_tensor((self.epi_tile[0], self.k_value)) diff --git a/python/cudnn/native_sparse_attention/utils.py b/python/cudnn/native_sparse_attention/utils.py index 69e83807..9ec421ba 100644 --- a/python/cudnn/native_sparse_attention/utils.py +++ b/python/cudnn/native_sparse_attention/utils.py @@ -15,9 +15,7 @@ def make_tensor_strided_like( """ q_strides = q_tensor.stride() rank_out = len(o_shape) - order = tuple( - sorted(range(min(len(q_strides), rank_out)), key=lambda i: q_strides[i]) - ) + order = tuple(sorted(range(min(len(q_strides), rank_out)), key=lambda i: q_strides[i])) strides = [0] * rank_out current = 1 diff --git a/python/cudnn/wrapper.py b/python/cudnn/wrapper.py index 772095c3..6d1cfc2d 100644 --- a/python/cudnn/wrapper.py +++ b/python/cudnn/wrapper.py @@ -116,11 +116,7 @@ def _find_tensor( for tensor_name, tensor_value in tensor_map.items(): if tensor is tensor_value: return tensor_name - elif ( - hasattr(tensor, "__dlpack__") - and isinstance(dlpack_map, dict) - and id(tensor) in dlpack_map - ): + elif hasattr(tensor, "__dlpack__") and isinstance(dlpack_map, dict) and id(tensor) in dlpack_map: tensor = dlpack_map[id(tensor)] for tensor_name, tensor_value in tensor_map.items(): if tensor_value == tensor: @@ -128,9 +124,7 @@ def _find_tensor( raise ValueError("Input not found in tensor map") -def _extract_tensor( - name: str, tensor: cudnn.tensor, arg_dict: dict -) -> Optional["torch.Tensor"]: +def _extract_tensor(name: str, tensor: cudnn.tensor, arg_dict: dict) -> Optional["torch.Tensor"]: """Extract a dlpack tensor from the arg_dict that matches the provided name or cudnn tensor Args: @@ -155,9 +149,7 @@ def _extract_tensor( return None # not found -def _tensor_like( - cudnn_tensor: cudnn.tensor, tensor_type: str = "pyt" -) -> "torch.Tensor": +def _tensor_like(cudnn_tensor: cudnn.tensor, tensor_type: str = "pyt") -> "torch.Tensor": """Create a tensor like the provided cudnn tensor Args: @@ -173,9 +165,7 @@ def _tensor_like( raise RuntimeError("PyTorch is not available") dtype = cudnn.datatypes._cudnn_to_torch_data_type(cudnn_tensor.get_data_type()) if dtype is None: - raise TypeError( - f"cuDNN uses an unsupported data type in PyTorch: {cudnn_tensor.get_data_type()}" - ) + raise TypeError(f"cuDNN uses an unsupported data type in PyTorch: {cudnn_tensor.get_data_type()}") tensor = torch.empty(cudnn_tensor.get_dim(), device="cuda", dtype=dtype) tensor = torch.as_strided(tensor, cudnn_tensor.get_dim(), cudnn_tensor.get_stride()) return tensor @@ -246,25 +236,15 @@ def __init__( self.__kwargs = kwargs self.__graph = None # to hold the cudnn.pygraph object self.__tensor_map = {} # obj id of dlpack tensor -> cudnn tensor - self.__tensor_in = ( - OrderedDict() - ) # canonical node::argname -> cudnn tensors used as the input - self.__tensor_out = ( - OrderedDict() - ) # canonical node::outname -> cudnn tensors produced by the node + self.__tensor_in = OrderedDict() # canonical node::argname -> cudnn tensors used as the input + self.__tensor_out = OrderedDict() # canonical node::outname -> cudnn tensors produced by the node self.__tensor_unknown = [] # list of cuDNN tensors created by user directly self.__node_count = {} # function name of graph node -> number of times used - self.__node_names = ( - set() - ) # set of assigned names of graph nodes, to check name collision + self.__node_names = set() # set of assigned names of graph nodes, to check name collision self.__input_tuples = None # tuple of input tensors, if set by set_io_tuples self.__output_tuples = None # tuple of output tensors, if set by set_io_tuples - self.__inputs = ( - inputs or [] - ) # hold the list of inputs, to be used by set_io_tuples() implicitly - self.__outputs = ( - outputs or [] - ) # hold the list of outputs, to be used by set_io_tuples() implicitly + self.__inputs = inputs or [] # hold the list of inputs, to be used by set_io_tuples() implicitly + self.__outputs = outputs or [] # hold the list of outputs, to be used by set_io_tuples() implicitly self.__heuristics = heuristics or [heur_mode.A, heur_mode.FALLBACK] if not workspace_alloc: self.__workspace = False @@ -273,10 +253,7 @@ def __init__( # silently replace the PyTorch dtype into cuDNN dtype for key in ["io_data_type", "intermediate_data_type", "compute_data_type"]: if key in kwargs: - kwargs[key] = ( - cudnn.datatypes._torch_to_cudnn_data_type(kwargs[key]) - or kwargs[key] - ) + kwargs[key] = cudnn.datatypes._torch_to_cudnn_data_type(kwargs[key]) or kwargs[key] def __del__(self): pass @@ -286,9 +263,7 @@ def __enter__(self): raise RuntimeError("Graph already created") self.__graph = cudnn.pygraph( # Pass handle only if self.__handle is not None - **( - {"handle": self.__handle} if self.__handle not in ["auto", None] else {} - ), + **({"handle": self.__handle} if self.__handle not in ["auto", None] else {}), **self.__kwargs, ) return self @@ -450,25 +425,17 @@ def __call__(self, *args, **kwargs): if self.__graph is None: raise RuntimeError("Graph not created") if not self.__graph.get_execution_plan_count(): - raise RuntimeError( - "You should not invoke the graph before the context exits" - ) + raise RuntimeError("You should not invoke the graph before the context exits") if len(args) == 1 and isinstance(args[0], dict): return self.__call_with_tensor_dict(args[0], **kwargs) else: if len(args) > 0 and not self.__input_tuples: - raise ValueError( - "You should not invoke the graph with positional arguments before running set_io_tuples()" - ) + raise ValueError("You should not invoke the graph with positional arguments before running set_io_tuples()") if len(args) != len(self.__input_tuples): - raise ValueError( - f"Number of arguments ({len(args)}) does not match number of inputs ({len(self.__input_tuples)})" - ) + raise ValueError(f"Number of arguments ({len(args)}) does not match number of inputs ({len(self.__input_tuples)})") return self.__call_with_positional_args(*args, **kwargs) - def __call_with_positional_args( - self, *args, **kwargs - ) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]: + def __call_with_positional_args(self, *args, **kwargs) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]: """Execute the graph with positional arguments. Args: @@ -488,9 +455,7 @@ def __call_with_positional_args( variant_pack = {} for cudnn_tensor, user_tensor in zip(self.__input_tuples, args): variant_pack[cudnn_tensor.get_uid()] = user_tensor - output_tuple = [ - _tensor_like(cudnn_tensor, "pyt") for cudnn_tensor in self.__output_tuples - ] + output_tuple = [_tensor_like(cudnn_tensor, "pyt") for cudnn_tensor in self.__output_tuples] for cudnn_tensor, user_tensor in zip(self.__output_tuples, output_tuple): variant_pack[cudnn_tensor.get_uid()] = user_tensor # execute the graph @@ -543,9 +508,7 @@ def __call_with_tensor_dict( # all non-virtual tensors in __tensor_in and __tensor_out should be filled variant_pack = {} missing_tensors = {} - for name, tensor in itertools.chain( - self.__tensor_in.items(), self.__tensor_out.items() - ): + for name, tensor in itertools.chain(self.__tensor_in.items(), self.__tensor_out.items()): if tensor.get_uid() in variant_pack or tensor.get_is_virtual(): continue # already filled or not needed user_tensor = _extract_tensor(name, tensor, tensor_dict) @@ -563,17 +526,13 @@ def __call_with_tensor_dict( continue # already filled if name in self.__tensor_out: # output tensor not specified, should be created automatically - variant_pack[tensor.get_uid()] = tensor_dict[name] = _tensor_like( - tensor, "pyt" - ) + variant_pack[tensor.get_uid()] = tensor_dict[name] = _tensor_like(tensor, "pyt") missing_outputs.append(name) else: # input tensor not specified, flag it as missing missing_inputs.append(name) if missing_inputs: - raise RuntimeError( - f"Non-virtual input tensors not found in variant pack: {missing_inputs}" - ) + raise RuntimeError(f"Non-virtual input tensors not found in variant pack: {missing_inputs}") if missing_outputs: logger.debug("Added output tensors: %s", missing_outputs) # execute the graph @@ -635,9 +594,7 @@ def set_io_tuples( tensors_found.add(id(tensor)) input_tensors.append(tensor) except ValueError: - raise ValueError( - f"Input at index {i} ({name}) not found in tensor map" - ) from None + raise ValueError(f"Input at index {i} ({name}) not found in tensor map") from None # Convert "outputs" to a list of names that can be looked up in __tensor_out output_tensors = [] for i, name in enumerate(outputs): @@ -652,9 +609,7 @@ def set_io_tuples( tensors_found.add(id(tensor)) output_tensors.append(tensor) except ValueError: - raise ValueError( - f"Output at index {i} ({name}) not found in tensor map" - ) from None + raise ValueError(f"Output at index {i} ({name}) not found in tensor map") from None # Verify that all input tensors are non-virtual for i, tensor in enumerate(input_tensors): if tensor.get_is_virtual(): @@ -662,14 +617,10 @@ def set_io_tuples( # Verify that all non-virtual tensors are covered by input or output for name, tensor in self.__tensor_out.items(): if not tensor.get_is_virtual() and tensor not in output_tensors: - raise ValueError( - f"Node output {name} is a non-virtual tensor but not specified as output" - ) + raise ValueError(f"Node output {name} is a non-virtual tensor but not specified as output") for name, tensor in self.__tensor_in.items(): if not tensor.get_is_virtual() and id(tensor) not in tensors_found: - raise ValueError( - f"Node input {name} is a non-virtual tensor but not specified as input or output" - ) + raise ValueError(f"Node input {name} is a non-virtual tensor but not specified as input or output") # Set the input and output names self.__input_tuples = tuple(input_tensors) self.__output_tuples = tuple(output_tensors) diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index a4078835..558daa31 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -571,20 +571,38 @@ PyGraph::populate_cuda_graph( void PyGraph::execute(std::unordered_map var_pack, std::intptr_t workspace, - std::optional exec_handle) { + std::optional exec_handle, + py::object override_uids, + py::object override_shapes, + py::object override_strides) { std::unordered_map var_pack_; var_pack_.reserve(var_pack.size()); for (auto const& [uid, device_pointer] : var_pack) { var_pack_.emplace(uid, (void*)device_pointer); } + // Convert override_uids to a vector of int64_t (one-liner) + std::vector override_uids_vec = + override_uids.is_none() ? std::vector() : override_uids.cast>(); + std::vector> override_shapes_vec = + override_shapes.is_none() ? std::vector>() + : override_shapes.cast>>(); + std::vector> override_strides_vec = + override_strides.is_none() ? std::vector>() + : override_strides.cast>>(); + auto workspace_ptr = (void*)workspace; cudnnHandle_t handle_ = exec_handle.has_value() ? static_cast((void*)(exec_handle.value())) : handle; - auto status = graph->execute(handle_, var_pack_, workspace_ptr); + cudnn_frontend::error_t status = {error_code_t::OK, ""}; + if (override_uids_vec.empty()) { + status = graph->execute(handle_, var_pack_, workspace_ptr); + } else { + status = graph->execute( + handle_, var_pack_, workspace_ptr, override_uids_vec, override_shapes_vec, override_strides_vec); + } throw_if(status.is_bad(), status.get_code(), status.get_message()); - return; } @@ -592,19 +610,37 @@ void PyGraph::execute_plan_at_index(std::unordered_map var_pack, std::intptr_t workspace, int64_t index, - std::optional exec_handle) { + std::optional exec_handle, + py::object override_uids, + py::object override_shapes, + py::object override_strides) { std::unordered_map var_pack_; for (auto const& [uid, device_pointer] : var_pack) { var_pack_.emplace(uid, (void*)device_pointer); } + // Convert override_uids to a vector of int64_t (one-liner) + std::vector override_uids_vec = + override_uids.is_none() ? std::vector() : override_uids.cast>(); + std::vector> override_shapes_vec = + override_shapes.is_none() ? std::vector>() + : override_shapes.cast>>(); + std::vector> override_strides_vec = + override_strides.is_none() ? std::vector>() + : override_strides.cast>>(); + auto workspace_ptr = (void*)workspace; cudnnHandle_t handle_ = exec_handle.has_value() ? static_cast((void*)(exec_handle.value())) : handle; - auto status = graph->execute_plan_at_index(handle_, var_pack_, workspace_ptr, index); + cudnn_frontend::error_t status = {error_code_t::OK, ""}; + if (override_uids_vec.empty()) { + status = graph->execute_plan_at_index(handle_, var_pack_, workspace_ptr, index); + } else { + status = graph->execute_plan_at_index( + handle_, var_pack_, workspace_ptr, index, override_uids_vec, override_shapes_vec, override_strides_vec); + } throw_if(status.is_bad(), status.get_code(), status.get_message()); - return; } @@ -641,7 +677,8 @@ init_pygraph_submodule(py::module_& m) { py::object, py::object, std::shared_ptr, - std::shared_ptr>(), + std::shared_ptr, + bool>(), py::arg_v("name", "test_graph"), py::arg_v("io_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("intermediate_data_type", cudnn_frontend::DataType_t::NOT_SET), @@ -650,7 +687,8 @@ init_pygraph_submodule(py::module_& m) { py::arg_v("sm_count", py::none()), py::arg_v("sm_version", py::none()), py::arg_v("kernel_cache", nullptr), - py::arg_v("device_property", nullptr)) + py::arg_v("device_property", nullptr), + py::arg_v("is_dynamic_shape_enabled", false)) .def("tensor_like", py::overload_cast const&, std::string const&>( &PyGraph::tensor_like), @@ -1037,7 +1075,14 @@ init_pygraph_submodule(py::module_& m) { Args: index (int): The index of the plan to get workspace from. )pbdoc") - .def("_execute", &PyGraph::execute) + .def("_execute", + &PyGraph::execute, + py::arg("var_pack"), + py::arg("workspace"), + py::arg("handle"), + py::arg("override_uids") = py::none(), + py::arg("override_shapes") = py::none(), + py::arg("override_strides") = py::none()) .def("populate_cuda_graph", &PyGraph::populate_cuda_graph) .def("update_cuda_graph", &PyGraph::update_cuda_graph) .def("serialize", &PyGraph::serialize) @@ -1046,7 +1091,15 @@ init_pygraph_submodule(py::module_& m) { py::arg("handle_"), py::arg("pyobj")) .def("deserialize", (void (PyGraph::*)(py::object const&))&PyGraph::deserialize, py::arg("pyobj")) - .def("_execute_plan_at_index", &PyGraph::execute_plan_at_index) + .def("_execute_plan_at_index", + &PyGraph::execute_plan_at_index, + py::arg("var_pack"), + py::arg("workspace"), + py::arg("index"), + py::arg("handle"), + py::arg("override_uids") = py::none(), + py::arg("override_shapes") = py::none(), + py::arg("override_strides") = py::none()) .def("__repr__", [](PyGraph const& pygraph) { std::stringstream ss; json j = pygraph.graph; diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index a2167ec5..a473ffe4 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -60,7 +60,8 @@ class PyGraph { py::object sm_count, py::object sm_version, std::shared_ptr kernel_cache, - std::shared_ptr device_properties) + std::shared_ptr device_properties, + bool is_dynamic_shape_enabled) : graph(std::make_shared()) { graph->set_compute_data_type(compute_data_type) .set_intermediate_data_type(intermediate_data_type) @@ -84,6 +85,10 @@ class PyGraph { graph->set_sm_version(sm_version.cast()); } + if (is_dynamic_shape_enabled) { + graph->set_dynamic_shape_enabled(true); + } + if (kernel_cache) { graph->set_kernel_cache(kernel_cache); graph->set_dynamic_shape_enabled(true); @@ -475,6 +480,7 @@ class PyGraph { std::shared_ptr& seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, + bool const use_deterministic_algorithm, py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -542,13 +548,21 @@ class PyGraph { std::intptr_t cuda_graph); void - execute(std::unordered_map var_pack, int64_t workspace, std::optional); + execute(std::unordered_map var_pack, + int64_t workspace, + std::optional, + py::object override_uids = py::none(), + py::object override_shapes = py::none(), + py::object override_strides = py::none()); void execute_plan_at_index(std::unordered_map var_pack, int64_t workspace, int64_t index, - std::optional); + std::optional, + py::object override_uids = py::none(), + py::object override_shapes = py::none(), + py::object override_strides = py::none()); std::vector get_behavior_notes(); diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index 22b40e3b..6eb11c57 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -613,6 +613,7 @@ PyGraph::sdpa_fp8_backward(std::shared_ptr& seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, + bool const use_deterministic_algorithm, py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { @@ -622,6 +623,7 @@ PyGraph::sdpa_fp8_backward(std::shared_ptr& m) { py::arg_v("seq_len_kv", nullptr), py::arg_v("use_causal_mask", false), py::arg_v("use_causal_mask_bottom_right", false), + py::arg_v("use_deterministic_algorithm", false), py::arg_v("dropout", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), @@ -973,6 +976,8 @@ init_pygraph_sdpa_submodule(py::class_& m) { seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. + use_deterministic_algorithm (Optional[bool]): Whether to always use deterministic algorithm. Default is False. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. diff --git a/requirements.txt b/requirements.txt index cac00a79..0860f8c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ pybind11[global] pytest pytest-xdist looseversion -black +black==26.1.0 clang-format==21.1.6 diff --git a/samples/cpp/CMakeLists.txt b/samples/cpp/CMakeLists.txt index 0a35b3bb..48d23fbb 100644 --- a/samples/cpp/CMakeLists.txt +++ b/samples/cpp/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable( sdpa/fp16_bwd_with_flexible_graphs.cpp sdpa/fp16_fwd_with_custom_dropout.cpp sdpa/fp16_fwd_with_paged_caches.cpp + sdpa/fp16_dynamic_shapes.cpp sdpa/fp16_fwd_paged_decode_and_prefill.cpp sdpa/fp16_fwd_with_cudagraphs.cpp sdpa/fp16_bwd_with_cudagraphs.cpp @@ -39,6 +40,8 @@ add_executable( matmul/general_block_scale_matmul.cpp matmul/complex_fp32_matmul.cpp + moe_grouped_matmul/moe_grouped_matmul.cpp + norm/batchnorm.cpp norm/layernorm.cpp norm/adaptive_layernorm.cpp diff --git a/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp b/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp index 01dfef84..595920db 100644 --- a/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp +++ b/samples/cpp/matmul/blackwell_nvfp4_mxfp8_block_scale_matmul.cpp @@ -746,4 +746,172 @@ TEST_CASE("Block Scale Matmul Swiglu", "[matmul][graph][FP4]") { REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); } +TEST_CASE("Blackwell Block Scale Matmul dynamic shape overrides", "[matmul][graph][dynamic_shape]") { +#if (CUDNN_VERSION < 91800) + SKIP("Dynamic shape with overrides is not supported in cudnn versions prior to 9.18.0"); +#endif + + if (check_device_arch_newer_than("blackwell") == false) { + SKIP("Hardware accelerated NVFP4/MXFP8 block scale matmul requires Blackwell and up"); + } + + namespace fe = cudnn_frontend; + + constexpr int A_UID = 1; + constexpr int SF_A_UID = 2; + constexpr int B_UID = 3; + constexpr int SF_B_UID = 4; + constexpr int C_UID = 5; + + static constexpr int indestructible_128x4_block_m_n = 128; + static constexpr int indestructible_128x4_block_k = 4; + + int block_size = 16; + + struct matmul_shapes { + int64_t b, m, n, k; + }; + + matmul_shapes matmul_cache_shape = {1, 1024, 1024, 1024}; + matmul_shapes matmul_dynamic_shape[] = { + {2, 1024, 1024, 1024}, + {2, 2048, 2048, 2048}, + }; + + constexpr int matmul_dynamic_shape_count = sizeof(matmul_dynamic_shape) / sizeof(matmul_cache_shape); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + // build graph and execution plan with a fake shape + auto graph = std::make_shared(); + + graph->set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true); // must be set true for dynamic shape + + int64_t block_scale_dim_m_cache = + div_up(matmul_cache_shape.m, indestructible_128x4_block_m_n) * indestructible_128x4_block_m_n; + int64_t block_scale_dim_n_cache = + div_up(matmul_cache_shape.n, indestructible_128x4_block_m_n) * indestructible_128x4_block_m_n; + int64_t block_scale_dim_k_cache = + div_up(div_up(matmul_cache_shape.k, block_size), indestructible_128x4_block_k) * indestructible_128x4_block_k; + + auto A = graph->tensor(fe::graph::Tensor_attributes() + .set_name("A") + .set_uid(A_UID) + .set_dim({matmul_cache_shape.b, matmul_cache_shape.m, matmul_cache_shape.k}) + .set_stride({matmul_cache_shape.m * matmul_cache_shape.k, matmul_cache_shape.k, 1}) + .set_data_type(fe::DataType_t::FP4_E2M1)); + + auto SF_A = + graph->tensor(fe::graph::Tensor_attributes() + .set_name("SF_A") + .set_uid(SF_A_UID) + .set_dim({matmul_cache_shape.b, block_scale_dim_m_cache, block_scale_dim_k_cache}) + .set_stride({block_scale_dim_m_cache * block_scale_dim_k_cache, block_scale_dim_k_cache, 1}) + .set_data_type(fe::DataType_t::FP8_E4M3) + .set_reordering_type(cudnn_frontend::TensorReordering_t::F8_128x4)); + + auto dequantize_attr_a = fe::graph::Block_scale_dequantize_attributes().set_block_size({1, block_size}); + auto dequan_tensor_a = graph->block_scale_dequantize(A, SF_A, dequantize_attr_a); + + auto B = graph->tensor(fe::graph::Tensor_attributes() + .set_name("B") + .set_uid(B_UID) + .set_dim({matmul_cache_shape.b, matmul_cache_shape.k, matmul_cache_shape.n}) + .set_stride({matmul_cache_shape.n * matmul_cache_shape.k, 1, matmul_cache_shape.k}) + .set_data_type(fe::DataType_t::FP4_E2M1)); + + auto SF_B = + graph->tensor(fe::graph::Tensor_attributes() + .set_name("SF_B") + .set_uid(SF_B_UID) + .set_dim({matmul_cache_shape.b, block_scale_dim_k_cache, block_scale_dim_n_cache}) + .set_stride({block_scale_dim_n_cache * block_scale_dim_k_cache, 1, block_scale_dim_k_cache}) + .set_data_type(fe::DataType_t::FP8_E4M3) + .set_reordering_type(cudnn_frontend::TensorReordering_t::F8_128x4)); + + auto dequantize_attr_b = fe::graph::Block_scale_dequantize_attributes().set_block_size({block_size, 1}); + auto dequan_tensor_b = graph->block_scale_dequantize(B, SF_B, dequantize_attr_b); + + auto C = graph->matmul( + dequan_tensor_a, dequan_tensor_b, fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT)); + C->set_uid(C_UID).set_output(true).set_data_type(fe::DataType_t::BFLOAT16); + + // For dynamic shape, recommend to query fallback plan to get a general good performance + // Heuristics Mode A is recommended if the dynamic problem shapes are similar in size + REQUIRE(graph->build(handle, {fe::HeurMode_t::FALLBACK}).is_good()); + + // run graph with dynamic shapes + for (int idx_shape = 0; idx_shape < matmul_dynamic_shape_count; ++idx_shape) { + int64_t block_scale_dim_m = + div_up(matmul_dynamic_shape[idx_shape].m, indestructible_128x4_block_m_n) * indestructible_128x4_block_m_n; + int64_t block_scale_dim_n = + div_up(matmul_dynamic_shape[idx_shape].n, indestructible_128x4_block_m_n) * indestructible_128x4_block_m_n; + int64_t block_scale_dim_k = + div_up(div_up(matmul_dynamic_shape[idx_shape].k, block_size), indestructible_128x4_block_k) * + indestructible_128x4_block_k; + + std::vector override_uids = {A_UID, SF_A_UID, B_UID, SF_B_UID, C_UID}; + std::vector> override_shapes = { + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].m, matmul_dynamic_shape[idx_shape].k}, + {matmul_dynamic_shape[idx_shape].b, block_scale_dim_m, block_scale_dim_k}, + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].k, matmul_dynamic_shape[idx_shape].n}, + {matmul_dynamic_shape[idx_shape].b, block_scale_dim_k, block_scale_dim_n}, + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].m, matmul_dynamic_shape[idx_shape].n}}; + std::vector> override_strides = { + {matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].k, + matmul_dynamic_shape[idx_shape].k, + 1}, + {block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1}, + {matmul_dynamic_shape[idx_shape].n * matmul_dynamic_shape[idx_shape].k, + 1, + matmul_dynamic_shape[idx_shape].k}, + {block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k}, + {matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].n, + matmul_dynamic_shape[idx_shape].n, + 1}}; + + Surface A_gpu(div_up(matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].m * + matmul_dynamic_shape[idx_shape].k * + cudnn_frontend::detail::get_element_size_in_bits(fe::DataType_t::FP4_E2M1), + 8), + false); + Surface SF_A_gpu(div_up(matmul_dynamic_shape[idx_shape].b * block_scale_dim_m * block_scale_dim_k * + cudnn_frontend::detail::get_element_size_in_bits(fe::DataType_t::FP8_E4M3), + 8), + false); + Surface B_gpu(div_up(matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].k * + matmul_dynamic_shape[idx_shape].n * + cudnn_frontend::detail::get_element_size_in_bits(fe::DataType_t::FP4_E2M1), + 8), + false); + Surface SF_B_gpu(div_up(matmul_dynamic_shape[idx_shape].b * block_scale_dim_k * block_scale_dim_n * + cudnn_frontend::detail::get_element_size_in_bits(fe::DataType_t::FP8_E4M3), + 8), + false); + Surface C_gpu(div_up(matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].m * + matmul_dynamic_shape[idx_shape].n * + cudnn_frontend::detail::get_element_size_in_bits(fe::DataType_t::BFLOAT16), + 8), + false); + + std::unordered_map variant_pack = {{A_UID, A_gpu.devPtr}, + {SF_A_UID, SF_A_gpu.devPtr}, + {B_UID, B_gpu.devPtr}, + {SF_B_UID, SF_B_gpu.devPtr}, + {C_UID, C_gpu.devPtr}}; + + int64_t workspace_size = 0; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr, override_uids, override_shapes, override_strides) + .is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + } // namespace BlackwellNVFP4MXFP8BlockScaleMatmul \ No newline at end of file diff --git a/samples/cpp/matmul/matmuls.cpp b/samples/cpp/matmul/matmuls.cpp index ff30ba1d..8020b7a1 100644 --- a/samples/cpp/matmul/matmuls.cpp +++ b/samples/cpp/matmul/matmuls.cpp @@ -612,4 +612,100 @@ TEST_CASE("Matmul with restricted shared memory", "[matmul][graph]") { std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); -} \ No newline at end of file +} + +TEST_CASE("Matmul dynamic shape overrides", "[matmul][graph][dynamic_shape]") { +#if (CUDNN_VERSION < 91800) + SKIP("Dynamic shape with overrides is not supported in cudnn versions prior to 9.18.0"); +#endif + + namespace fe = cudnn_frontend; + + constexpr int A_UID = 1; + constexpr int B_UID = 2; + constexpr int C_UID = 3; + + struct matmul_shapes { + int64_t b, m, n, k; + }; + + matmul_shapes matmul_cache_shape = {1, 1024, 1024, 1024}; + matmul_shapes matmul_dynamic_shape[] = { + {2, 1024, 1024, 1024}, + {2, 2048, 2048, 2048}, + }; + + constexpr int matmul_dynamic_shape_count = sizeof(matmul_dynamic_shape) / sizeof(matmul_cache_shape); + + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + // build graph and execution plan with a fake shape + auto graph = std::make_shared(); + + graph->set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true); // must be set true for dynamic shape + + auto A = graph->tensor(fe::graph::Tensor_attributes() + .set_name("A") + .set_uid(A_UID) + .set_dim({matmul_cache_shape.b, matmul_cache_shape.m, matmul_cache_shape.k}) + .set_stride({matmul_cache_shape.m * matmul_cache_shape.k, matmul_cache_shape.k, 1}) + .set_data_type(fe::DataType_t::BFLOAT16)); + + auto B = graph->tensor(fe::graph::Tensor_attributes() + .set_name("B") + .set_uid(B_UID) + .set_dim({matmul_cache_shape.b, matmul_cache_shape.k, matmul_cache_shape.n}) + .set_stride({matmul_cache_shape.n * matmul_cache_shape.k, 1, matmul_cache_shape.k}) + .set_data_type(fe::DataType_t::BFLOAT16)); + + auto C = graph->matmul(A, B, fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT)); + C->set_uid(C_UID).set_output(true).set_data_type(fe::DataType_t::BFLOAT16); + + // For dynamic shape, recommend to query fallback plan to get a general good performance + // Heuristics Mode A is recommended if the dynamic problem shapes are similar in size + REQUIRE(graph->build(handle, {fe::HeurMode_t::FALLBACK}).is_good()); + + // run graph with dynamic shapes + for (int idx_shape = 0; idx_shape < matmul_dynamic_shape_count; ++idx_shape) { + std::vector override_uids = {A_UID, B_UID, C_UID}; + std::vector> override_shapes = { + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].m, matmul_dynamic_shape[idx_shape].k}, + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].k, matmul_dynamic_shape[idx_shape].n}, + {matmul_dynamic_shape[idx_shape].b, matmul_dynamic_shape[idx_shape].m, matmul_dynamic_shape[idx_shape].n}}; + std::vector> override_strides = { + {matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].k, + matmul_dynamic_shape[idx_shape].k, + 1}, + {matmul_dynamic_shape[idx_shape].n * matmul_dynamic_shape[idx_shape].k, + 1, + matmul_dynamic_shape[idx_shape].k}, + {matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].n, + matmul_dynamic_shape[idx_shape].n, + 1}}; + + Surface A_gpu( + matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].k, + false); + Surface B_gpu( + matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].k * matmul_dynamic_shape[idx_shape].n, + false); + Surface C_gpu( + matmul_dynamic_shape[idx_shape].b * matmul_dynamic_shape[idx_shape].m * matmul_dynamic_shape[idx_shape].n, + false); + + std::unordered_map variant_pack = { + {A_UID, A_gpu.devPtr}, {B_UID, B_gpu.devPtr}, {C_UID, C_gpu.devPtr}}; + + int64_t workspace_size = 0; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr, override_uids, override_shapes, override_strides) + .is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); + } +} diff --git a/samples/cpp/moe_grouped_matmul/moe_grouped_matmul.cpp b/samples/cpp/moe_grouped_matmul/moe_grouped_matmul.cpp new file mode 100644 index 00000000..9c4a75d2 --- /dev/null +++ b/samples/cpp/moe_grouped_matmul/moe_grouped_matmul.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include + +#include + +#include "../utils/helpers.h" + +#include + +TEST_CASE("WoQ MoeGroupedMatmul", "[MoeGroupedMatmul][graph]") { +#if (CUDNN_VERSION < 91800) + SKIP("MoE is not supported in cudnn versions prior to 9.18.0"); +#endif + + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cudnn version"); + } + namespace fe = cudnn_frontend; + + // problem size + int64_t const batch_size = 2; + int64_t const num_experts = 3; + int64_t const top_k = 2; + int64_t const token_num = 512; + int64_t const weight_size = 256; + int64_t const hidden_size = 512; + int64_t const block_size = 128; + + // Initialize input tensors + Surface token_gpu( + div_up(batch_size * token_num * top_k * hidden_size * + cudnn_frontend::detail::get_element_size_in_bits(cudnn_frontend::DataType_t::HALF), + 8), + false); + Surface weight_gpu( + div_up(num_experts * hidden_size * weight_size * + cudnn_frontend::detail::get_element_size_in_bits(cudnn_frontend::DataType_t::INT4), + 8), + false); + Surface block_scale_gpu( + div_up(num_experts * div_up(hidden_size, block_size) * weight_size * + cudnn_frontend::detail::get_element_size_in_bits(cudnn_frontend::DataType_t::HALF), + 8), + false); + Surface first_token_offset_gpu( + div_up(batch_size * num_experts * + cudnn_frontend::detail::get_element_size_in_bits(cudnn_frontend::DataType_t::INT32), + 8), + false); + Surface moe_grouped_matmul_gpu( + div_up(batch_size * token_num * top_k * weight_size * + cudnn_frontend::detail::get_element_size_in_bits(cudnn_frontend::DataType_t::HALF), + 8), + false); + + std::vector first_token_offset_cpu({0, 128, 512, 768, 1152, 1536}); + CUDA_CHECK(cudaMemcpy(first_token_offset_gpu.devPtr, + first_token_offset_cpu.data(), + first_token_offset_cpu.size() * sizeof(int32_t), + cudaMemcpyHostToDevice)); + + // Make cudnn graph + fe::graph::Graph graph{}; + + graph.set_intermediate_data_type(fe::DataType_t::HALF); + graph.set_compute_data_type(fe::DataType_t::HALF); + + auto tensor_token = graph.tensor(fe::graph::Tensor_attributes() + .set_name("token") + .set_dim({1, batch_size * token_num * top_k, hidden_size}) + .set_stride({batch_size * token_num * top_k * hidden_size, hidden_size, 1}) + .set_data_type(fe::DataType_t::HALF)); + + auto tensor_weight = graph.tensor(fe::graph::Tensor_attributes() + .set_name("weight") + .set_dim({num_experts, hidden_size, weight_size}) + .set_stride({hidden_size * weight_size, 1, hidden_size}) + .set_data_type(fe::DataType_t::INT4)); + + auto tensor_block_scale = graph.tensor( + fe::graph::Tensor_attributes() + .set_name("block_scale") + .set_dim({num_experts, div_up(hidden_size, block_size), weight_size}) + .set_stride({div_up(hidden_size, block_size) * weight_size, 1, div_up(hidden_size, block_size)}) + .set_data_type(fe::DataType_t::HALF)); + + auto tensor_first_token_offset = graph.tensor(fe::graph::Tensor_attributes() + .set_name("first_token_offset") + .set_dim({batch_size * num_experts, 1, 1}) + .set_stride({1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + + auto dequantize_weight_attr = fe::graph::Block_scale_dequantize_attributes() + .set_block_size({block_size, 1}) + .set_compute_data_type(fe::DataType_t::HALF); + + auto tensor_dequantized_weight = + graph.block_scale_dequantize(tensor_weight, tensor_block_scale, dequantize_weight_attr); + tensor_dequantized_weight->set_data_type(fe::DataType_t::HALF); + + auto moe_grouped_matmul_attr = fe::graph::Moe_grouped_matmul_attributes() + .set_name("moe_grouped_matmul") + .set_mode(fe::MoeGroupedMatmulMode_t::NONE) + .set_compute_data_type(fe::DataType_t::HALF) + .set_top_k(top_k); + + auto tensor_moe_grouped_matmul = graph.moe_grouped_matmul( + tensor_token, tensor_dequantized_weight, tensor_first_token_offset, nullptr, nullptr, moe_grouped_matmul_attr); + + tensor_moe_grouped_matmul->set_data_type(fe::DataType_t::HALF); + tensor_moe_grouped_matmul->set_output(true); + + std::cout << graph << std::endl; + REQUIRE(graph.validate().is_good()); + + // Create a unique_ptr for the cuDNN handle + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.check_support().is_good()); + + REQUIRE(graph.build_plans(fe::BuildPlanPolicy_t::ALL).is_good()); + + // Run cudnn graph + int64_t workspace_size = 0; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + std::unordered_map, void*> variant_pack = { + {tensor_token, token_gpu.devPtr}, + {tensor_weight, weight_gpu.devPtr}, + {tensor_block_scale, block_scale_gpu.devPtr}, + {tensor_first_token_offset, first_token_offset_gpu.devPtr}, + {tensor_moe_grouped_matmul, moe_grouped_matmul_gpu.devPtr}}; + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); +} diff --git a/samples/cpp/sdpa/fp16_benchmark.cpp b/samples/cpp/sdpa/fp16_benchmark.cpp index e9d7a030..883ad1a0 100644 --- a/samples/cpp/sdpa/fp16_benchmark.cpp +++ b/samples/cpp/sdpa/fp16_benchmark.cpp @@ -49,9 +49,7 @@ create_sdpa_forward_graph(int64_t const b, float const attn_scale = 1.0f, bool const generate_stats = true, bool const causal_mask = false, - bool const alibi_mask = false, - bool const padding_mask = false, - bool has_attn_bias = false); + bool const padding_mask = false); // Directly use the backward graph builder from the toy example std::shared_ptr diff --git a/samples/cpp/sdpa/fp16_cached.cpp b/samples/cpp/sdpa/fp16_cached.cpp index e26bdd56..11511949 100644 --- a/samples/cpp/sdpa/fp16_cached.cpp +++ b/samples/cpp/sdpa/fp16_cached.cpp @@ -48,9 +48,7 @@ create_sdpa_forward_graph(int64_t const b, float const attn_scale = 1.0f, bool const generate_stats = true, bool const causal_mask = false, - bool const alibi_mask = false, - bool const padding_mask = false, - bool has_attn_bias = false); + bool const padding_mask = false); // Directly use the backward graph builder from the toy example std::shared_ptr diff --git a/samples/cpp/sdpa/fp16_dynamic_shapes.cpp b/samples/cpp/sdpa/fp16_dynamic_shapes.cpp new file mode 100644 index 00000000..51d0cb43 --- /dev/null +++ b/samples/cpp/sdpa/fp16_dynamic_shapes.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +/* +Run this example by using command: +bin/samples "Toy sdpa forward" + +This example shows how to construct a sdpa forward graph. +*/ + +// Tensors in forward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define SEQ_LEN_Q_UID 7 +#define SEQ_LEN_KV_UID 8 + +static std::shared_ptr +create_sdpa_forward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + bool const generate_stats = true, + bool const causal_mask = false, + bool const padding_mask = false) { + // Create a graph and set common global properties. + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true); + + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1})); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_generate_stats(generate_stats) + .set_attn_scale(attn_scale); + + if (causal_mask) { + sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + if (padding_mask) { + auto seq_q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_uid(SEQ_LEN_Q_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto seq_kv = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_uid(SEQ_LEN_KV_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(padding_mask).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_dim({b, h_q, s_q, d_v}).set_stride({h_q * d_v, d_v, b * h_q * d_v, 1}).set_uid(O_UID); + + if (generate_stats) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_uid(STATS_UID); + } else { + assert(Stats == nullptr); + } + + return graph; +} + +TEST_CASE("Toy sdpa forward with dynamic shapes", "[graph][sdpa][flash][forward]") { + int64_t b = 2; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + bool generate_stats = true; + float attn_scale = 0.123f; + bool causal_mask = true; + bool padding_mask = true; + +#if (CUDNN_VERSION < 91900) + SKIP("Test is disabled till backend is updated"); +#endif + + std::cout << "Running size: {" << b << ", " << h_q << ", " << h_k << ", " << h_v << ", " << s_q << ", " << s_kv + << ", " << d_qk << ", " << d_v << "}" << std::endl; + + // Create a unique_ptr for the cuDNN handle + auto handle_ptr = create_cudnn_handle(); + auto handle = *handle_ptr; + + auto graph = create_sdpa_forward_graph( + b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, generate_stats, causal_mask, padding_mask); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + //// Build variant pack + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + + Surface o_tensor(b * s_q * h_q * d_qk, false); + + std::unordered_map variant_pack = { + {Q_UID, q_tensor.devPtr}, {K_UID, k_tensor.devPtr}, {V_UID, v_tensor.devPtr}, {O_UID, o_tensor.devPtr}}; + + Surface devActualSeqlenQ(b, false); + Surface devActualSeqlenKV(b, false); + if (padding_mask) { + std::vector hostActualSeqlenQ(b, 20); + std::vector hostActualSeqlenKV(b, 20); + + CUDA_CHECK(cudaMemcpy(devActualSeqlenQ.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; + variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; + } + + Surface statsTensor(b * h_q * s_q * 1, false); + if (generate_stats == true) { + variant_pack[STATS_UID] = statsTensor.devPtr; + } + + int64_t workspace_size = 0; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + workspace_size = 256 * 1024; + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + // Override shapes + + int64_t override_b = 4; + Surface q_tensor_2(override_b * h_q * s_q * d_qk, false); + Surface k_tensor_2(override_b * h_k * d_qk * s_kv, false); + Surface v_tensor_2(override_b * h_v * d_v * s_kv, false); + + Surface o_tensor_2(override_b * s_q * h_q * d_qk, false); + + std::unordered_map variant_pack_2 = { + {Q_UID, q_tensor_2.devPtr}, {K_UID, k_tensor_2.devPtr}, {V_UID, v_tensor_2.devPtr}, {O_UID, o_tensor_2.devPtr}}; + + Surface devActualSeqlenQ_2(override_b, false); + Surface devActualSeqlenKV_2(override_b, false); + if (padding_mask) { + std::vector hostActualSeqlenQ(override_b, 20); + std::vector hostActualSeqlenKV(override_b, 20); + + CUDA_CHECK(cudaMemcpy(devActualSeqlenQ_2.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * override_b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(devActualSeqlenKV_2.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * override_b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + variant_pack_2[SEQ_LEN_Q_UID] = devActualSeqlenQ_2.devPtr; + variant_pack_2[SEQ_LEN_KV_UID] = devActualSeqlenKV_2.devPtr; + } + + Surface statsTensor_2(override_b * h_q * s_q * 1, false); + if (generate_stats == true) { + variant_pack_2[STATS_UID] = statsTensor_2.devPtr; + } + + std::cout << "Running size: {" << override_b << ", " << h_q << ", " << h_k << ", " << h_v << ", " << s_q << ", " + << s_kv << ", " << d_qk << ", " << d_v << "}" << std::endl; + + std::vector override_uids = {Q_UID, K_UID, V_UID, O_UID, SEQ_LEN_Q_UID, SEQ_LEN_KV_UID, STATS_UID}; + std::vector> override_shapes = {{override_b, h_q, s_q, d_qk}, + {override_b, h_k, s_kv, d_qk}, + {override_b, h_v, s_kv, d_v}, + {override_b, s_q, h_q, d_v}, + {override_b, 1, 1, 1}, + {override_b, 1, 1, 1}, + {override_b, h_q * s_q * 1, 1, 1}}; + std::vector> override_strides = {{h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}, + {h_k * d_qk * s_kv, d_qk * s_kv, s_kv, 1}, + {h_v * d_v * s_kv, d_v * s_kv, s_kv, 1}, + {h_q * d_v, d_v, b * h_q * d_v, 1}, + {1, 1, 1, 1}, + {1, 1, 1, 1}, + {h_q * d_v, d_v, override_b * h_q * d_v, 1}}; + REQUIRE(graph->execute(handle, variant_pack_2, workspace.devPtr, override_uids, override_shapes, override_strides) + .is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); +} diff --git a/samples/cpp/sdpa/fp16_fwd.cpp b/samples/cpp/sdpa/fp16_fwd.cpp index 26457d30..76e4e155 100644 --- a/samples/cpp/sdpa/fp16_fwd.cpp +++ b/samples/cpp/sdpa/fp16_fwd.cpp @@ -57,9 +57,7 @@ create_sdpa_forward_graph(int64_t const b, float const attn_scale = 1.0f, bool const generate_stats = true, bool const causal_mask = false, - bool const alibi_mask = false, - bool const padding_mask = false, - bool has_attn_bias = false) { + bool const padding_mask = false) { // Create a graph and set common global properties. auto graph = std::make_shared(); graph->set_io_data_type(fe::DataType_t::BFLOAT16) @@ -87,7 +85,6 @@ create_sdpa_forward_graph(int64_t const b, auto sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_generate_stats(generate_stats) - .set_alibi_mask(alibi_mask) .set_attn_scale(attn_scale); if (causal_mask) { @@ -95,16 +92,6 @@ create_sdpa_forward_graph(int64_t const b, .set_diagonal_band_right_bound(0); } - if (has_attn_bias) { - auto bias = graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_uid(BIAS_UID) - .set_data_type(fe::DataType_t::HALF) - .set_dim({b, 1, s_q, s_kv}) - .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); - sdpa_options.set_bias(bias); - } - if (padding_mask) { auto seq_q = graph->tensor(fe::graph::Tensor_attributes() .set_name("seq_q") @@ -147,8 +134,6 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { float attn_scale = 0.123f; bool causal_mask = true; bool padding_mask = (cudnnGetVersion() >= 8903); - bool alibi_mask = (cudnnGetVersion() >= 8904); - bool has_attn_bias = (cudnnGetVersion() >= 8903); if (cudnnGetVersion() < 8903) { SKIP("Test requires cudnn 8.9.3 or above"); @@ -159,20 +144,8 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { auto handle_ptr = create_cudnn_handle(); auto handle = *handle_ptr; - auto graph = create_sdpa_forward_graph(b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - attn_scale, - generate_stats, - causal_mask, - alibi_mask, - padding_mask, - has_attn_bias); + auto graph = create_sdpa_forward_graph( + b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, generate_stats, causal_mask, padding_mask); REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); @@ -187,9 +160,6 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { {Q_UID, q_tensor.devPtr}, {K_UID, k_tensor.devPtr}, {V_UID, v_tensor.devPtr}, {O_UID, o_tensor.devPtr}}; Surface bias_tensor(b * 1 * s_q * s_kv, false); - if (has_attn_bias) { - variant_pack[BIAS_UID] = bias_tensor.devPtr; - } Surface devActualSeqlenQ(b, false); Surface devActualSeqlenKV(b, false); diff --git a/samples/cpp/sdpa/fp16_fwd_with_cudagraphs.cpp b/samples/cpp/sdpa/fp16_fwd_with_cudagraphs.cpp index c9b57328..f7b09e1a 100644 --- a/samples/cpp/sdpa/fp16_fwd_with_cudagraphs.cpp +++ b/samples/cpp/sdpa/fp16_fwd_with_cudagraphs.cpp @@ -62,9 +62,7 @@ create_sdpa_forward_graph(int64_t const b, float const attn_scale = 1.0f, bool const generate_stats = true, bool const causal_mask = false, - bool const alibi_mask = false, - bool const padding_mask = false, - bool has_attn_bias = false); + bool const padding_mask = false); // Convenience class to encapsulate SDPA test data for this example class SdpaTestData { @@ -80,7 +78,6 @@ class SdpaTestData { int64_t const workspace_size, bool const generate_stats, bool const padding_mask, - bool const has_attn_bias, float const qkv_fill_value) : q_tensor(b * h_q * s_q * d_qk, false, cpu_float2half_rn(qkv_fill_value)), k_tensor(b * h_k * d_qk * s_kv, false, cpu_float2half_rn(qkv_fill_value)), @@ -92,8 +89,7 @@ class SdpaTestData { statsTensor(b * h_q * s_q * 1, false), workspace(workspace_size, false), generate_stats_(generate_stats), - padding_mask_(padding_mask), - has_attn_bias_(has_attn_bias) {} + padding_mask_(padding_mask) {} std::unordered_map build_variant_pack() { @@ -102,9 +98,6 @@ class SdpaTestData { variant_pack[K_UID] = k_tensor.devPtr; variant_pack[V_UID] = v_tensor.devPtr; variant_pack[O_UID] = o_tensor.devPtr; - if (has_attn_bias_) { - variant_pack[BIAS_UID] = bias_tensor.devPtr; - } if (padding_mask_) { variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; @@ -168,7 +161,6 @@ class SdpaTestData { Surface workspace; bool generate_stats_; bool padding_mask_; - bool has_attn_bias_; }; TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudagraph]") { @@ -176,9 +168,6 @@ TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudag // Because the below test depends on some CUDA graph APIs that changed // between CUDA 11.x and 12.0, it wouldn't even compile in <12.0 anyway, // so we just disable the whole test by #if in that case. -#if (CUDART_VERSION < 12000) - SKIP("Test requires cuda toolkit 12.0 or above"); -#else // Also check the CUDA version at runtime, for good measure. if (cudnnGetCudartVersion() < 12000) { SKIP("Test requires cuda toolkit 12.0 or above"); @@ -202,27 +191,13 @@ TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudag float attn_scale = 0.123f; bool causal_mask = true; bool padding_mask = (cudnnGetVersion() >= 8903); - bool alibi_mask = false; // TODO: (cudnnGetVersion() >= 8904) - bool has_attn_bias = (cudnnGetVersion() >= 8903); // Create a unique_ptr for the cuDNN handle auto handle_ptr = create_cudnn_handle(); auto handle = *handle_ptr; - auto graph = create_sdpa_forward_graph(b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - attn_scale, - generate_stats, - causal_mask, - alibi_mask, - padding_mask, - has_attn_bias); + auto graph = create_sdpa_forward_graph( + b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, generate_stats, causal_mask, padding_mask); // Validate the graph and lower the FE graph to BE graph REQUIRE(graph->validate().is_good()); @@ -254,7 +229,6 @@ TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudag workspace_size, generate_stats, padding_mask, - has_attn_bias, /*fillValue_qkv=*/1.1f); auto variant_pack_1 = test_data_1.build_variant_pack(); @@ -284,7 +258,6 @@ TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudag workspace_size, generate_stats, padding_mask, - has_attn_bias, /*fillValue_qkv=*/1.3f); auto variant_pack_3 = test_data_3.build_variant_pack(); REQUIRE( @@ -306,5 +279,4 @@ TEST_CASE("Toy sdpa forward as CUDA graph", "[graph][sdpa][flash][forward][cudag //// Cleanup CUDA_CHECK(cudaGraphExecDestroy(cuda_graph_exec)); CUDA_CHECK(cudaGraphDestroy(cudnn_cuda_graph)); -#endif // CUDART_VERSION < 12000 } diff --git a/setup.py b/setup.py index 3a7057b7..ac9ba60d 100644 --- a/setup.py +++ b/setup.py @@ -66,9 +66,7 @@ def build_extension(self, ext: CMakeExtension) -> None: cmake_args.append(f"-DCUDNN_PATH={os.environ['CUDNN_PATH']}") if "FETCHCONTENT_SOURCE_DIR_DLPACK" in os.environ: - cmake_args.append( - f"-DFETCHCONTENT_SOURCE_DIR_DLPACK={os.environ['FETCHCONTENT_SOURCE_DIR_DLPACK']}" - ) + cmake_args.append(f"-DFETCHCONTENT_SOURCE_DIR_DLPACK={os.environ['FETCHCONTENT_SOURCE_DIR_DLPACK']}") # Using Ninja-build since it a) is available as a wheel and b) # multithreads automatically. MSVC would require all variables be @@ -104,12 +102,8 @@ def build_extension(self, ext: CMakeExtension) -> None: build_temp.mkdir(parents=True) print(" ".join(cmake_args)) - subprocess.run( - ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True - ) - subprocess.run( - ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True - ) + subprocess.run(["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True) + subprocess.run(["cmake", "--build", ".", *build_args], cwd=build_temp, check=True) setup( diff --git a/test/python/conftest.py b/test/python/conftest.py index 672eb3de..b5f21db7 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -26,7 +26,9 @@ def cudnn_handle(): # =================== PyTest Hooks ===================== def pytest_load_initial_conftests(args, early_config, parser): if not any(arg.startswith("--tb=") for arg in args): - args.append("--tb=short") + args.insert(0, "--tb=short") + if "--no-header" not in args: + args.insert(0, "--no-header") def pytest_configure(config): @@ -52,6 +54,8 @@ def pytest_addoption(parser): parser.addoption("--dryrun", action="store", nargs="?", const=1, type=int, default=0, help="show repro commands when 1, 2, or 3 (use with '-s')") parser.addoption("--diffs", action="store", type=int, default=10, help="set number of numerical mismatches to display") parser.addoption("--repro", action="store", type=str, default=None, help="specify config string to run repro function") + parser.addoption("--seed", action="store", type=int, default=None, help="[fuzzer] random seed for reproducibility") + parser.addoption("--num-tests", action="store", type=int, default=100, help="[fuzzer] number of random tests to run") parser.addoption("--perf", action="store_true", help="enable performance profiling") # MHA command line options to overwrite specific test dimensions in test_mhas.py and test_mhas_v2.py. @@ -93,4 +97,10 @@ def pytest_addoption(parser): parser.addoption("--gemm-amax-mma-tiler", action="store", default=None, type=str, help="[test_gemm_amax.py] MMA tiler (M,N) dimensions as comma-separated values (e.g., '128,128')") parser.addoption("--gemm-amax-cluster-shape", action="store", default=None, type=str, help="[test_gemm_amax.py] Cluster shape (M,N) dimensions as comma-separated values (e.g., '1,1')") parser.addoption("--gemm-amax-skip-ref", action="store_true", help="[test_gemm_amax.py] Skip reference computation for performance testing") + + # Grouped GEMM SwiGLU command line options for test_grouped_gemm_swiglu.py + parser.addoption("--grouped-gemm-nkl", action="store", default=None, type=str, help="[test_grouped_gemm_swiglu.py] N,K,L dimensions as comma-separated values (e.g., '512,512,4')") + parser.addoption("--grouped-gemm-group-m", action="store", default=None, type=str, help="[test_grouped_gemm_swiglu.py] M values per group as comma-separated values (e.g., '256,512,256,256')") + parser.addoption("--grouped-gemm-m-aligned", action="store", default=None, type=int, help="[test_grouped_gemm_swiglu.py] M alignment (e.g., 256)") + parser.addoption("--grouped-gemm-skip-ref", action="store_true", help="[test_grouped_gemm_swiglu.py] Skip reference computation for performance testing") # fmt: on diff --git a/test/python/fe_api/nsa/nsa_reference.py b/test/python/fe_api/nsa/nsa_reference.py index 36d6e824..3e658d0d 100644 --- a/test/python/fe_api/nsa/nsa_reference.py +++ b/test/python/fe_api/nsa/nsa_reference.py @@ -19,9 +19,7 @@ def convert_thd_to_bshd(thd_tensor, seq_len: torch.Tensor, s: int): b = seq_len.size(0) seq_len = seq_len.flatten() - bshd_tensor = torch.zeros( - (b, s, h, d), dtype=thd_tensor.dtype, device=thd_tensor.device - ) + bshd_tensor = torch.zeros((b, s, h, d), dtype=thd_tensor.dtype, device=thd_tensor.device) cumulative_seq_len = torch.cumsum(seq_len, dim=0) - seq_len for bi in range(b): @@ -43,9 +41,7 @@ def convert_bshd_to_thd(bshd_tensor, seq_len: torch.Tensor, maxT: int): assert seq_len.size(1) == seq_len.size(2) == seq_len.size(3) == 1 seq_len = seq_len.flatten() - thd_tensor = torch.zeros( - (maxT, h, d), dtype=bshd_tensor.dtype, device=bshd_tensor.device - ) + thd_tensor = torch.zeros((maxT, h, d), dtype=bshd_tensor.dtype, device=bshd_tensor.device) # Interpret input as (b, s, h, d) in memory while keeping the (b, h, s, d) layout bshd_base = bshd_tensor.permute(0, 2, 1, 3) @@ -142,16 +138,10 @@ def run_ref_nsa_selection_attention( dtype=torch.float32, ) seq_block_counts = block_counts[seq_offset:seq_end, h] # [seq_len] - seq_block_indices = block_indices[ - seq_offset:seq_end, h, : - ] # [seq_len, topk_size] + seq_block_indices = block_indices[seq_offset:seq_end, h, :] # [seq_len, topk_size] topk_size = seq_block_indices.size(-1) - block_range = torch.arange(topk_size, device=mask.device).unsqueeze( - 0 - ) # [1, topk_size] - valid_mask = block_range < seq_block_counts.unsqueeze( - 1 - ) # [seq_len, topk_size] + block_range = torch.arange(topk_size, device=mask.device).unsqueeze(0) # [1, topk_size] + valid_mask = block_range < seq_block_counts.unsqueeze(1) # [seq_len, topk_size] query_indices, block_indices_flat = torch.where(valid_mask) if len(query_indices) > 0: @@ -163,27 +153,17 @@ def run_ref_nsa_selection_attention( max_block_size = block_sizes.max().item() if len(block_sizes) > 0 else 0 if max_block_size > 0: - offsets = torch.arange( - max_block_size, device=mask.device - ) # [max_block_size] + offsets = torch.arange(max_block_size, device=mask.device) # [max_block_size] num_blocks = len(block_ids) - offsets_expanded = offsets.unsqueeze(0).expand( - num_blocks, -1 - ) # [num_blocks, max_block_size] + offsets_expanded = offsets.unsqueeze(0).expand(num_blocks, -1) # [num_blocks, max_block_size] block_sizes_expanded = block_sizes.unsqueeze(1) # [num_blocks, 1] token_starts_expanded = token_starts.unsqueeze(1) # [num_blocks, 1] - query_indices_expanded = query_indices.unsqueeze( - 1 - ) # [num_blocks, 1] + query_indices_expanded = query_indices.unsqueeze(1) # [num_blocks, 1] - position_valid = ( - offsets_expanded < block_sizes_expanded - ) # [num_blocks, max_block_size] + position_valid = offsets_expanded < block_sizes_expanded # [num_blocks, max_block_size] - token_positions = ( - token_starts_expanded + offsets_expanded - ) # [num_blocks, max_block_size] + token_positions = token_starts_expanded + offsets_expanded # [num_blocks, max_block_size] valid_positions = torch.where(position_valid) if len(valid_positions[0]) > 0: @@ -196,39 +176,25 @@ def run_ref_nsa_selection_attention( mask[final_query_indices, final_key_indices] = 0.0 # Step 3: Apply mask to attention scores - qk_scores_fp32 = qk_scores.float() + mask.unsqueeze( - 1 - ) # [seq_len, 1, seq_len] -> [seq_len, GQA_group_size, seq_len] + qk_scores_fp32 = qk_scores.float() + mask.unsqueeze(1) # [seq_len, 1, seq_len] -> [seq_len, GQA_group_size, seq_len] # Step 4: Compute softmax - qk_max = torch.max(qk_scores_fp32, dim=-1, keepdim=True)[ - 0 - ] # [seq_len, GQA_group_size, 1] - qk_exp = torch.exp( - qk_scores_fp32 - qk_max - ) # [seq_len, GQA_group_size, seq_len] - qk_sum = torch.sum( - qk_exp, dim=-1, keepdim=True - ) # [seq_len, GQA_group_size, 1] + qk_max = torch.max(qk_scores_fp32, dim=-1, keepdim=True)[0] # [seq_len, GQA_group_size, 1] + qk_exp = torch.exp(qk_scores_fp32 - qk_max) # [seq_len, GQA_group_size, seq_len] + qk_sum = torch.sum(qk_exp, dim=-1, keepdim=True) # [seq_len, GQA_group_size, 1] attn_weights = qk_exp / qk_sum # [seq_len, GQA_group_size, seq_len] # Step 5: Compute output O = attention_weights @ V # attn_weights: [seq_len, GQA_group_size, seq_len] @ v_seq: [seq_len, d_v] -> [seq_len, GQA_group_size, d_v] - output = torch.matmul( - attn_weights, v_seq.float() - ) # [seq_len, GQA_group_size, d_v] + output = torch.matmul(attn_weights, v_seq.float()) # [seq_len, GQA_group_size, d_v] # Store results O[seq_offset:seq_end, h, :, :] = output.to(dtype) # Store L (sum of exp) and M (max) statistics - reusing computed values # L should store the sum of exponentials (row_sum), not logsumexp, to match reference - L[seq_offset:seq_end, h, :] = qk_sum.squeeze( - -1 - ) # [seq_len, GQA_group_size] - M[seq_offset:seq_end, h, :] = qk_max.squeeze( - -1 - ) # [seq_len, GQA_group_size] + L[seq_offset:seq_end, h, :] = qk_sum.squeeze(-1) # [seq_len, GQA_group_size] + M[seq_offset:seq_end, h, :] = qk_max.squeeze(-1) # [seq_len, GQA_group_size] seq_offset = seq_end @@ -287,9 +253,7 @@ def run_ref_nsa_compression_attention( q_coords = torch.arange(0, s_q_i, device=s_i.device).view(-1, 1) num_compress_blocks = s_k_i stride = max(1, s_q_i // max(1, s_k_i)) - k_coords = ( - ((torch.arange(0, num_compress_blocks, device=s_i.device) + 1) * stride) - 1 - ).view(1, -1) + k_coords = (((torch.arange(0, num_compress_blocks, device=s_i.device) + 1) * stride) - 1).view(1, -1) _mask = k_coords > q_coords s_i = s_i.masked_fill(_mask, -torch.inf) @@ -469,19 +433,11 @@ def check_ref_nsa_compression_attention( ) return scale_softmax = ( - scale_softmax - if scale_softmax is not None - else ( - test_config["scale_softmax"] - if test_config is not None - else 1.0 / math.sqrt(test_config["d_qk"]) - ) + scale_softmax if scale_softmax is not None else (test_config["scale_softmax"] if test_config is not None else 1.0 / math.sqrt(test_config["d_qk"])) ) if test_config["layout"] == "thd": - assert ( - "actual_s_q" in test_config - ), "actual_s_q is required when using T,H,D layout" + assert "actual_s_q" in test_config, "actual_s_q is required when using T,H,D layout" seq_len_q = test_config["actual_s_q"].to(device=Q.device) max_seq_len_q = int(seq_len_q.max().item()) @@ -501,9 +457,7 @@ def check_ref_nsa_compression_attention( # Convert O_ref back to THD for comparison total_T = int(seq_len_q.sum().item()) - O_ref_thd = convert_bshd_to_thd(O_ref_bshd, seq_len_q, total_T).to( - dtype=O.dtype - ) + O_ref_thd = convert_bshd_to_thd(O_ref_bshd, seq_len_q, total_T).to(dtype=O.dtype) torch.testing.assert_close(O, O_ref_thd, atol=atol, rtol=rtol) if LSE is not None: diff --git a/test/python/fe_api/nsa/nsa_utils.py b/test/python/fe_api/nsa/nsa_utils.py index e517a23a..957cd5cf 100644 --- a/test/python/fe_api/nsa/nsa_utils.py +++ b/test/python/fe_api/nsa/nsa_utils.py @@ -95,15 +95,9 @@ def nsa_init( ): major, _ = torch.cuda.get_device_capability() if major < 10: - pytest.skip( - f"Environment not supported: requires compute capability >= 10, found {major}" - ) + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") - b = ( - int(request.config.getoption("--nsa-b")) - if request.config.getoption("--nsa-b") is not None - else 2 - ) + b = int(request.config.getoption("--nsa-b")) if request.config.getoption("--nsa-b") is not None else 2 s_q = ( int(request.config.getoption("--nsa-s_q")) if request.config.getoption("--nsa-s_q") is not None @@ -114,43 +108,15 @@ def nsa_init( if request.config.getoption("--nsa-s_kv") is not None else 1024 if s_kv_default_override is None else s_kv_default_override ) - d_qk = ( - int(request.config.getoption("--nsa-d_qk")) - if request.config.getoption("--nsa-d_qk") is not None - else 128 - ) - d_v = ( - int(request.config.getoption("--nsa-d_v")) - if request.config.getoption("--nsa-d_v") is not None - else 128 - ) - h_q = ( - int(request.config.getoption("--nsa-h_q")) - if request.config.getoption("--nsa-h_q") is not None - else 4 - ) - h_k = ( - int(request.config.getoption("--nsa-h_k")) - if request.config.getoption("--nsa-h_k") is not None - else 1 - ) - h_v = ( - int(request.config.getoption("--nsa-h_v")) - if request.config.getoption("--nsa-h_v") is not None - else 1 - ) + d_qk = int(request.config.getoption("--nsa-d_qk")) if request.config.getoption("--nsa-d_qk") is not None else 128 + d_v = int(request.config.getoption("--nsa-d_v")) if request.config.getoption("--nsa-d_v") is not None else 128 + h_q = int(request.config.getoption("--nsa-h_q")) if request.config.getoption("--nsa-h_q") is not None else 4 + h_k = int(request.config.getoption("--nsa-h_k")) if request.config.getoption("--nsa-h_k") is not None else 1 + h_v = int(request.config.getoption("--nsa-h_v")) if request.config.getoption("--nsa-h_v") is not None else 1 - actual_s_q = ( - torch.tensor([s_q] * b, dtype=torch.int32).cuda() if layout == "thd" else None - ) - actual_s_kv = ( - torch.tensor([s_kv] * b, dtype=torch.int32).cuda() if layout == "thd" else None - ) - topk_sizes = ( - torch.tensor([topk_size] * b, dtype=torch.int32).cuda() - if (layout == "thd" and topk_size is not None) - else None - ) + actual_s_q = torch.tensor([s_q] * b, dtype=torch.int32).cuda() if layout == "thd" else None + actual_s_kv = torch.tensor([s_kv] * b, dtype=torch.int32).cuda() if layout == "thd" else None + topk_sizes = torch.tensor([topk_size] * b, dtype=torch.int32).cuda() if (layout == "thd" and topk_size is not None) else None scale_softmax = 1.0 / math.sqrt(d_qk) if scale_softmax is None else scale_softmax @@ -223,26 +189,12 @@ def allocate_input_tensors(cfg): Q = torch.randn(b, s_q, h_q, d_qk, dtype=dtype).transpose(1, 2).cuda() K = torch.randn(b, s_kv, h_k, d_qk, dtype=dtype).transpose(1, 2).cuda() V = torch.randn(b, s_kv, h_k, d_v, dtype=dtype).transpose(1, 2).cuda() - LSE = ( - -1.0 - * torch.randn(b, s_q, h_q, dtype=torch.float32) - .transpose(1, 2) - .contiguous() - .cuda() - ) + LSE = -1.0 * torch.randn(b, s_q, h_q, dtype=torch.float32).transpose(1, 2).contiguous().cuda() block_counts, block_indices = None, None # TODO elif layout == "thd": - cum_seqlen_q = ( - torch.cat([torch.tensor([0]).cuda(), torch.cumsum(actual_s_q, dim=0)]) - .to(torch.int32) - .cuda() - ) - cum_seqlen_kv = ( - torch.cat([torch.tensor([0]).cuda(), torch.cumsum(actual_s_kv, dim=0)]) - .to(torch.int32) - .cuda() - ) + cum_seqlen_q = torch.cat([torch.tensor([0]).cuda(), torch.cumsum(actual_s_q, dim=0)]).to(torch.int32).cuda() + cum_seqlen_kv = torch.cat([torch.tensor([0]).cuda(), torch.cumsum(actual_s_kv, dim=0)]).to(torch.int32).cuda() max_s_q = max(actual_s_q).item() max_s_kv = max(actual_s_kv).item() @@ -255,12 +207,7 @@ def allocate_input_tensors(cfg): # V: (T, H_kv, D_v) V = torch.randn((total_seq_len_kv, h_k, d_v), dtype=dtype).cuda() # LSE: (T, H_q, 1) - LSE = ( - -1.0 - * torch.randn((1, h_q, total_seq_len_q), dtype=torch.float32) - .transpose(0, 2) - .cuda() - ) + LSE = -1.0 * torch.randn((1, h_q, total_seq_len_q), dtype=torch.float32).transpose(0, 2).cuda() # block_counts: (T, H_kv), block_indices: (T, H_kv, max(topk_sizes)) block_counts, block_indices = None, None # TODO @@ -299,16 +246,8 @@ def allocate_output_tensors(cfg): M = torch.empty(b, s_q, h_q, 1, dtype=torch.float32).transpose(1, 2).cuda() if k_value is not None: - topk_scores = ( - torch.empty(b, s_q, h_k, k_value, dtype=acc_dtype) - .transpose(1, 2) - .cuda() - ) - topk_indices = ( - torch.empty(b, s_q, h_k, k_value, dtype=torch.int32) - .transpose(1, 2) - .cuda() - ) + topk_scores = torch.empty(b, s_q, h_k, k_value, dtype=acc_dtype).transpose(1, 2).cuda() + topk_indices = torch.empty(b, s_q, h_k, k_value, dtype=torch.int32).transpose(1, 2).cuda() elif layout == "thd": total_seq_len = actual_s_q.sum().item() @@ -317,12 +256,8 @@ def allocate_output_tensors(cfg): M = torch.empty(total_seq_len, h_q, 1, dtype=torch.float32).cuda() if k_value is not None: - topk_scores = torch.empty( - total_seq_len, h_k, k_value, dtype=acc_dtype - ).cuda() - topk_indices = torch.empty( - total_seq_len, h_k, k_value, dtype=torch.int32 - ).cuda() + topk_scores = torch.empty(total_seq_len, h_k, k_value, dtype=acc_dtype).cuda() + topk_indices = torch.empty(total_seq_len, h_k, k_value, dtype=torch.int32).cuda() return ( O, @@ -384,9 +319,7 @@ def generate_ragged_offset(cfg): ) -def generate_block_indices( - seq_lens: list[int], num_kv_heads: int, topk_sizes: list[int], block_size: int -): +def generate_block_indices(seq_lens: list[int], num_kv_heads: int, topk_sizes: list[int], block_size: int): """ Generate block indices and counts for sparse attention. @@ -402,23 +335,17 @@ def generate_block_indices( total_seq_len = sum(seq_lens) max_topk_size = max(topk_sizes) block_counts = torch.zeros(total_seq_len, num_kv_heads, dtype=torch.int32) - block_indices = torch.zeros( - total_seq_len, num_kv_heads, max_topk_size, dtype=torch.int32 - ) + block_indices = torch.zeros(total_seq_len, num_kv_heads, max_topk_size, dtype=torch.int32) seq_len_offset = 0 for i in range(len(seq_lens)): seq_len = seq_lens[i] topk_size = topk_sizes[i] max_index = seq_len // block_size - assert ( - topk_size <= max_index - ), "topk_size must be less than or equal to the number of blocks" + assert topk_size <= max_index, "topk_size must be less than or equal to the number of blocks" for t in range(seq_len): for h in range(num_kv_heads): - block_indices[seq_len_offset + t, h, :topk_size] = ( - torch.randperm(max_index)[:topk_size].sort().values - ) + block_indices[seq_len_offset + t, h, :topk_size] = torch.randperm(max_index)[:topk_size].sort().values block_counts[seq_len_offset + t, h] = topk_size seq_len_offset += seq_len diff --git a/test/python/fe_api/nsa/test_NSA_compression_attention.py b/test/python/fe_api/nsa/test_NSA_compression_attention.py index f55ba672..fff4997e 100644 --- a/test/python/fe_api/nsa/test_NSA_compression_attention.py +++ b/test/python/fe_api/nsa/test_NSA_compression_attention.py @@ -32,9 +32,7 @@ def test_nsa_compression_compile_execute( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, @@ -50,9 +48,7 @@ def test_nsa_compression_compile_execute( scale_softmax=scale_softmax, ) - Q, K, V, _, _, _, cum_seqlen_q, cum_seqlen_k, max_s_q, max_s_k = ( - allocate_input_tensors(cfg) - ) + Q, K, V, _, _, _, cum_seqlen_q, cum_seqlen_k, max_s_q, max_s_k = allocate_input_tensors(cfg) O, LSE, _, _, _ = allocate_output_tensors(cfg) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -128,9 +124,7 @@ def test_nsa_compression_wrapper( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, @@ -146,9 +140,7 @@ def test_nsa_compression_wrapper( scale_softmax=scale_softmax, ) - Q, K, V, _, _, _, cum_seqlen_q, cum_seqlen_k, max_s_q, max_s_k = ( - allocate_input_tensors(cfg) - ) + Q, K, V, _, _, _, cum_seqlen_q, cum_seqlen_k, max_s_q, max_s_k = allocate_input_tensors(cfg) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) O, LSE = NSA.compression_attention_wrapper( diff --git a/test/python/fe_api/nsa/test_NSA_selection_attention.py b/test/python/fe_api/nsa/test_NSA_selection_attention.py index 68c5f792..105bfd15 100644 --- a/test/python/fe_api/nsa/test_NSA_selection_attention.py +++ b/test/python/fe_api/nsa/test_NSA_selection_attention.py @@ -33,13 +33,9 @@ def test_nsa_selection_compile_execute( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") if layout != "thd": - pytest.skip( - "Only THD layout supported for selection attention, bshd layout not yet implemented" - ) + pytest.skip("Only THD layout supported for selection attention, bshd layout not yet implemented") cfg = nsa_init( request=request, @@ -50,9 +46,7 @@ def test_nsa_selection_compile_execute( block_size=block_size, ) - Q, K, V, _, actual_s_q, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = ( - allocate_input_tensors(cfg) - ) + Q, K, V, _, actual_s_q, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = allocate_input_tensors(cfg) block_counts, block_indices = generate_block_indices( cfg["actual_s_q"], cfg["h_k"], @@ -129,14 +123,10 @@ def test_nsa_selection_wrapper( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") if layout != "thd": - pytest.skip( - "Only THD layout supported for selection attention, bshd layout not yet implemented" - ) + pytest.skip("Only THD layout supported for selection attention, bshd layout not yet implemented") cfg = nsa_init( request=request, @@ -147,9 +137,7 @@ def test_nsa_selection_wrapper( block_size=block_size, ) - Q, K, V, _, actual_s_q, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = ( - allocate_input_tensors(cfg) - ) + Q, K, V, _, actual_s_q, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = allocate_input_tensors(cfg) block_counts, block_indices = generate_block_indices( cfg["actual_s_q"], cfg["h_k"], diff --git a/test/python/fe_api/nsa/test_NSA_swa.py b/test/python/fe_api/nsa/test_NSA_swa.py index 1804a772..d9522321 100644 --- a/test/python/fe_api/nsa/test_NSA_swa.py +++ b/test/python/fe_api/nsa/test_NSA_swa.py @@ -29,9 +29,7 @@ def test_nsa_swa_compile_execute( try: from cudnn import NSA except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, layout=layout, @@ -115,9 +113,7 @@ def test_nsa_swa_wrapper( try: from cudnn import NSA except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, layout=layout, diff --git a/test/python/fe_api/nsa/test_NSA_topk_reduction.py b/test/python/fe_api/nsa/test_NSA_topk_reduction.py index a022b424..e17d747d 100644 --- a/test/python/fe_api/nsa/test_NSA_topk_reduction.py +++ b/test/python/fe_api/nsa/test_NSA_topk_reduction.py @@ -30,9 +30,7 @@ def test_nsa_topk_reduction_compile_execute( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, @@ -48,9 +46,7 @@ def test_nsa_topk_reduction_compile_execute( s_kv_default_override=128, ) - Q, K, _, LSE, _, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = ( - allocate_input_tensors(cfg) - ) + Q, K, _, LSE, _, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = allocate_input_tensors(cfg) _, _, _, topk_scores, topk_indices = allocate_output_tensors(cfg) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -106,9 +102,7 @@ def test_nsa_topk_reduction_wrapper( from cudnn import NSA from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = nsa_init( request=request, @@ -124,9 +118,7 @@ def test_nsa_topk_reduction_wrapper( s_kv_default_override=128, ) - Q, K, _, LSE, _, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = ( - allocate_input_tensors(cfg) - ) + Q, K, _, LSE, _, _, cum_seqlen_q, cum_seqlen_kv, max_s_q, max_s_kv = allocate_input_tensors(cfg) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) topk_scores, topk_indices = NSA.topk_reduction_wrapper( diff --git a/test/python/fe_api/test_fe_api_utils.py b/test/python/fe_api/test_fe_api_utils.py index 8459ac07..ed212792 100644 --- a/test/python/fe_api/test_fe_api_utils.py +++ b/test/python/fe_api/test_fe_api_utils.py @@ -29,6 +29,11 @@ def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( cvt_sf_MKL_to_M32x4xrm_K4xrk_L = None +def ceil_div(a: int, b: int) -> int: + """Compute ceiling division of a by b.""" + return (a + b - 1) // b + + def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) # else: (l, mode0, mode1) -> (mode0, mode1, l) @@ -47,38 +52,44 @@ def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): # Generate random values according to dtype support if dtype in {torch.int8, torch.int16, torch.int32, torch.int64}: - ref_tensor = torch.randint( - int(min_val), int(max_val), shape, dtype=torch.int32, device="cuda" - ).permute(permute_order) + ref_tensor = torch.randint(int(min_val), int(max_val), shape, dtype=torch.int32, device="cuda").permute(permute_order) dtype_tensor = ref_tensor.to(dtype) if dtype not in {torch.float4_e2m1fn_x2, torch.uint8}: - dtype_tensor = ( - torch.empty(shape, dtype=torch.float32, device="cuda") - .uniform_(float(min_val), float(max_val)) - .permute(permute_order) - .to(dtype) - ) + dtype_tensor = torch.empty(shape, dtype=torch.float32, device="cuda").uniform_(float(min_val), float(max_val)).permute(permute_order).to(dtype) ref_tensor = dtype_tensor.to(torch.float32) else: dtype_tensor = _bfloat16_to_float4_e2m1fn_x2( - torch.empty(shape, dtype=torch.float32, device="cuda") - .uniform_(float(min_val), float(max_val)) - .to(torch.bfloat16) - ) - ref_tensor = ( - float4_e2m1fn_x2_to_float32(dtype_tensor) - .to(torch.float32) - .permute(permute_order) + torch.empty(shape, dtype=torch.float32, device="cuda").uniform_(float(min_val), float(max_val)).to(torch.bfloat16) ) + ref_tensor = float4_e2m1fn_x2_to_float32(dtype_tensor).to(torch.float32).permute(permute_order) dtype_tensor = dtype_tensor.permute(permute_order).view(dtype) return ref_tensor, dtype_tensor -def create_sf_layout_tensor(l, mn, nk, sf_vec_size): - def ceil_div(a, b): - return (a + b - 1) // b +def compute_reference_amax(output_tensor: torch.Tensor) -> float: + """ + Compute reference amax value on CPU. + + Args: + output_tensor: torch.Tensor, GEMM output result (CPU tensor) + + Returns: + float: reference amax value + """ + # Ensure FP32 for computation + if output_tensor.dtype != torch.float32: + output_fp32 = output_tensor.float() + else: + output_fp32 = output_tensor + + # Compute absolute maximum value + reference_amax = torch.amax(torch.abs(output_fp32)) + return reference_amax.item() + + +def create_sf_layout_tensor(l, mn, nk, sf_vec_size): sf_k = ceil_div(nk, sf_vec_size) atom_m = (32, 4) @@ -95,9 +106,7 @@ def ceil_div(a, b): mma_permute_order = (3, 4, 1, 5, 2, 0) # Create f32 cute torch tensor (cpu) - cute_f32_torch_tensor_cpu = torch.zeros(mma_shape, dtype=torch.float32).permute( - mma_permute_order - ) + cute_f32_torch_tensor_cpu = torch.zeros(mma_shape, dtype=torch.float32).permute(mma_permute_order) return cute_f32_torch_tensor_cpu, sf_k @@ -109,13 +118,7 @@ def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): ref_permute_order = (1, 2, 0) # Create f32 ref torch tensor (cpu) - ref_f32_torch_tensor_cpu = ( - torch.empty(ref_shape, dtype=torch.float32) - .uniform_(1, 3) - .permute(ref_permute_order) - .to(torch.int8) - .to(torch.float32) - ) + ref_f32_torch_tensor_cpu = torch.empty(ref_shape, dtype=torch.float32).uniform_(1, 3).permute(ref_permute_order).to(torch.int8).to(torch.float32) # convert ref f32 tensor to cute f32 tensor try: @@ -126,25 +129,17 @@ def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): from_dlpack(cute_f32_torch_tensor_cpu), ) except Exception: - pytest.skip( - "CUTLASS is not installed; skipping tests due to scale factor tensor creation requiring CUTLASS." - ) + pytest.skip("CUTLASS is not installed; skipping tests due to scale factor tensor creation requiring CUTLASS.") # reshape makes memory contiguous ref_f32_torch_tensor_cpu = ( - ref_f32_torch_tensor_cpu.permute(2, 0, 1) - .unsqueeze(-1) - .expand(l, mn, sf_k, sf_vec_size) - .reshape(l, mn, sf_k * sf_vec_size) - .permute(*ref_permute_order) + ref_f32_torch_tensor_cpu.permute(2, 0, 1).unsqueeze(-1).expand(l, mn, sf_k, sf_vec_size).reshape(l, mn, sf_k * sf_vec_size).permute(*ref_permute_order) ) ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] if dtype != torch.int8: cute_torch_tensor = cute_f32_torch_tensor_cpu.to(dtype).cuda() else: - cute_torch_tensor = ( - cute_f32_torch_tensor_cpu.to(torch.float8_e8m0fnu).cuda().view(dtype) - ) + cute_torch_tensor = cute_f32_torch_tensor_cpu.to(torch.float8_e8m0fnu).cuda().view(dtype) return ref_f32_torch_tensor_cpu.cuda(), cute_torch_tensor diff --git a/test/python/fe_api/test_gemm_amax.py b/test/python/fe_api/test_gemm_amax.py index b8cb9b88..bdce9a1c 100644 --- a/test/python/fe_api/test_gemm_amax.py +++ b/test/python/fe_api/test_gemm_amax.py @@ -4,19 +4,142 @@ from test_utils import torch_fork_set_rng from fe_api.test_gemm_amax_utils import ( - with_gemm_amax_params, + with_gemm_amax_params_fp4, + with_gemm_amax_params_fp8, ) + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_amax_params_fp4 +def test_gemm_amax_compile_execute_fp4( + a_major, + b_major, + c_major, + ab_dtype, + sf_dtype, + c_dtype, + acc_dtype, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + request, +): + _test_gemm_amax_compile_execute( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + sf_vec_size=sf_vec_size, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_amax_params_fp8 +def test_gemm_amax_compile_execute_fp8( + a_major, + b_major, + c_major, + ab_dtype, + sf_dtype, + c_dtype, + acc_dtype, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + request, +): + _test_gemm_amax_compile_execute( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + sf_vec_size=sf_vec_size, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_amax_params_fp4 +def test_gemm_amax_wrapper_fp4( + a_major, + b_major, + c_major, + ab_dtype, + sf_dtype, + c_dtype, + acc_dtype, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + request, +): + _test_gemm_amax_wrapper( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + sf_vec_size=sf_vec_size, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_amax_params_fp8 +def test_gemm_amax_wrapper_fp8( + a_major, + b_major, + c_major, + ab_dtype, + sf_dtype, + c_dtype, + acc_dtype, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + request, +): + _test_gemm_amax_wrapper( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + sf_vec_size=sf_vec_size, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + request=request, + ) + + """ GemmAmax API with explicit set_params, compile, and execute paths. Use this method when running one static configuration for each GemmAmax object. """ -@pytest.mark.L0 -@torch_fork_set_rng(seed=0) -@with_gemm_amax_params -def test_gemm_amax_compile_execute( +def _test_gemm_amax_compile_execute( a_major, b_major, c_major, @@ -39,9 +162,7 @@ def test_gemm_amax_compile_execute( gemm_amax_init, ) except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_amax_init( request, a_major, @@ -56,22 +177,18 @@ def test_gemm_amax_compile_execute( cluster_shape_mn, ) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - a_torch, a_ref, b_torch, b_ref, sfa_torch, sfa_ref, sfb_torch, sfb_ref = ( - allocate_input_tensors( - cfg["m"], - cfg["n"], - cfg["k"], - cfg["l"], - cfg["ab_dtype"], - cfg["sf_dtype"], - cfg["sf_vec_size"], - cfg["a_major"], - cfg["b_major"], - ) - ) - c_torch, amax_torch = allocate_output_tensors( - cfg["m"], cfg["n"], cfg["l"], cfg["c_dtype"], cfg["c_major"] + a_torch, a_ref, b_torch, b_ref, sfa_torch, sfa_ref, sfb_torch, sfb_ref = allocate_input_tensors( + cfg["m"], + cfg["n"], + cfg["k"], + cfg["l"], + cfg["ab_dtype"], + cfg["sf_dtype"], + cfg["sf_vec_size"], + cfg["a_major"], + cfg["b_major"], ) + c_torch, amax_torch = allocate_output_tensors(cfg["m"], cfg["n"], cfg["l"], cfg["c_dtype"], cfg["c_major"]) gemm = GemmAmaxSm100( sample_a=a_torch, @@ -100,9 +217,7 @@ def test_gemm_amax_compile_execute( current_stream=stream, ) - check_ref_gemm_amax( - a_ref, b_ref, sfa_ref, sfb_ref, c_torch, amax_torch, skip_ref=cfg["skip_ref"] - ) + check_ref_gemm_amax(a_ref, b_ref, sfa_ref, sfb_ref, c_torch, amax_torch, skip_ref=cfg["skip_ref"]) """ @@ -111,10 +226,7 @@ def test_gemm_amax_compile_execute( """ -@pytest.mark.L0 -@torch_fork_set_rng(seed=0) -@with_gemm_amax_params -def test_gemm_amax_wrapper( +def _test_gemm_amax_wrapper( a_major, b_major, c_major, @@ -137,9 +249,7 @@ def test_gemm_amax_wrapper( gemm_amax_init, ) except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_amax_init( request, a_major, @@ -154,18 +264,16 @@ def test_gemm_amax_wrapper( cluster_shape_mn, ) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - a_torch, a_ref, b_torch, b_ref, sfa_torch, sfa_ref, sfb_torch, sfb_ref = ( - allocate_input_tensors( - cfg["m"], - cfg["n"], - cfg["k"], - cfg["l"], - cfg["ab_dtype"], - cfg["sf_dtype"], - cfg["sf_vec_size"], - cfg["a_major"], - cfg["b_major"], - ) + a_torch, a_ref, b_torch, b_ref, sfa_torch, sfa_ref, sfb_torch, sfb_ref = allocate_input_tensors( + cfg["m"], + cfg["n"], + cfg["k"], + cfg["l"], + cfg["ab_dtype"], + cfg["sf_dtype"], + cfg["sf_vec_size"], + cfg["a_major"], + cfg["b_major"], ) try: @@ -186,6 +294,4 @@ def test_gemm_amax_wrapper( except (ValueError, NotImplementedError) as e: pytest.skip(f"Unsupported testcase: {e}") - check_ref_gemm_amax( - a_ref, b_ref, sfa_ref, sfb_ref, c_torch, amax_torch, skip_ref=cfg["skip_ref"] - ) + check_ref_gemm_amax(a_ref, b_ref, sfa_ref, sfb_ref, c_torch, amax_torch, skip_ref=cfg["skip_ref"]) diff --git a/test/python/fe_api/test_gemm_amax_utils.py b/test/python/fe_api/test_gemm_amax_utils.py index 35ac3e58..30ba91c0 100644 --- a/test/python/fe_api/test_gemm_amax_utils.py +++ b/test/python/fe_api/test_gemm_amax_utils.py @@ -12,41 +12,92 @@ ) from test_fe_api_utils import create_and_permute_tensor, create_scale_factor_tensor - # Parameterization marks for GEMM Amax -GEMM_AMAX_PARAM_MARKS = [ - pytest.mark.parametrize("a_major", ["k", "m"]), - pytest.mark.parametrize("b_major", ["k", "n"]), +GEMM_AMAX_PARAM_MARKS_FP4 = [ + pytest.mark.parametrize("a_major", ["k"]), + pytest.mark.parametrize("b_major", ["k"]), pytest.mark.parametrize("c_major", ["m", "n"]), pytest.mark.parametrize( "ab_dtype", - [torch.float8_e5m2, torch.float8_e4m3fn, torch.uint8, torch.float4_e2m1fn_x2], + [ + torch.float4_e2m1fn_x2, + # torch.uint8, + ], ), pytest.mark.parametrize( - "sf_dtype", [torch.float8_e8m0fnu, torch.int8, torch.float8_e4m3fn] + "sf_dtype", + [ + torch.float8_e8m0fnu, + # torch.int8, + torch.float8_e4m3fn, + ], ), pytest.mark.parametrize( "c_dtype", [ torch.float32, - torch.float16, + # torch.float16, torch.bfloat16, - torch.float8_e5m2, + # torch.float8_e5m2, torch.float8_e4m3fn, torch.float4_e2m1fn_x2, - torch.uint8, + # torch.uint8, ], ), pytest.mark.parametrize("acc_dtype", [torch.float32]), pytest.mark.parametrize("sf_vec_size", [16, 32]), + pytest.mark.parametrize( + "mma_tiler_mn", + [ + (128, 128), + ], + ), + pytest.mark.parametrize("cluster_shape_mn", [(1, 1), (2, 2)]), +] + +GEMM_AMAX_PARAM_MARKS_FP8 = [ + pytest.mark.parametrize("a_major", ["k", "m"]), + pytest.mark.parametrize("b_major", ["k", "n"]), + pytest.mark.parametrize("c_major", ["m", "n"]), + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float8_e5m2, + # torch.float8_e4m3fn, + ], + ), + pytest.mark.parametrize( + "sf_dtype", + [ + torch.float8_e8m0fnu, + # torch.int8, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + torch.float32, + # torch.float16, + torch.bfloat16, + ], + ), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize("sf_vec_size", [32]), pytest.mark.parametrize("mma_tiler_mn", [(128, 128), (128, 256)]), - pytest.mark.parametrize("cluster_shape_mn", [(1, 1), (1, 2), (2, 2)]), + pytest.mark.parametrize("cluster_shape_mn", [(1, 1), (2, 2)]), ] -def with_gemm_amax_params(func): +def with_gemm_amax_params_fp4(func): """Apply all GEMM Amax parameterization marks to a test function.""" - for mark in reversed(GEMM_AMAX_PARAM_MARKS): + for mark in reversed(GEMM_AMAX_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def with_gemm_amax_params_fp8(func): + """Apply all GEMM Amax parameterization marks to a test function.""" + for mark in reversed(GEMM_AMAX_PARAM_MARKS_FP8): func = mark(func) return func @@ -67,15 +118,11 @@ def gemm_amax_init( """Build test config, allowing CLI overrides for problem size/tiling/cluster/skip-ref.""" major, _ = torch.cuda.get_device_capability() if major < 10: - pytest.skip( - f"Environment not supported: requires compute capability >= 10, found {major}" - ) + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") mnkl_str = request.config.getoption("--gemm-amax-mnkl", default=None) mma_tiler_str = request.config.getoption("--gemm-amax-mma-tiler", default=None) - cluster_shape_str = request.config.getoption( - "--gemm-amax-cluster-shape", default=None - ) + cluster_shape_str = request.config.getoption("--gemm-amax-cluster-shape", default=None) skip_ref = request.config.getoption("--gemm-amax-skip-ref", default=False) if mnkl_str is not None: @@ -107,9 +154,7 @@ def gemm_amax_init( } -def allocate_input_tensors( - m, n, k, l, ab_dtype, sf_dtype, sf_vec_size, a_major, b_major -): +def allocate_input_tensors(m, n, k, l, ab_dtype, sf_dtype, sf_vec_size, a_major, b_major): """Allocate and initialize input tensors for GEMM Amax tests.""" a_ref, a_tensor = create_and_permute_tensor(l, m, k, a_major == "m", ab_dtype) b_ref, b_tensor = create_and_permute_tensor(l, n, k, b_major == "n", ab_dtype) @@ -122,9 +167,7 @@ def allocate_input_tensors( def allocate_output_tensors(m, n, l, c_dtype, c_major): """Allocate and initialize output tensors for GEMM Amax tests.""" _, c_tensor = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) - amax_tensor = torch.full( - (1, 1, 1), -float("inf"), device="cuda", dtype=torch.float32 - ) + amax_tensor = torch.full((1, 1, 1), -float("inf"), device="cuda", dtype=torch.float32) return c_tensor, amax_tensor @@ -155,43 +198,23 @@ def check_ref_gemm_amax(a, b, sfa_ref, sfb_ref, c, amax, skip_ref=False): m, n, l = c_ref.shape # Convert ref: f32 -> f8 -> f32 using CUTE's conversion - ref_f8_ = torch.empty(l, m, n, dtype=torch.uint8, device="cuda").permute( - 1, 2, 0 - ) - ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( - leading_dim=1 - ) + ref_f8_ = torch.empty(l, m, n, dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic(leading_dim=1) ref_f8.element_type = _convert_to_cutlass_data_type(c.dtype) ref_device = c_ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() - ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic( - leading_dim=1 - ) + ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) cute.testing.convert(ref_tensor, ref_f8) # f32 -> f8 cute.testing.convert(ref_f8, ref_tensor) # f8 -> f32 c_ref = ref_device.cpu() - torch.testing.assert_close( - c_ref.to(torch.float32), c.cpu().to(torch.float32), atol=0.1, rtol=0.1 - ) + torch.testing.assert_close(c_ref.to(torch.float32), c.cpu().to(torch.float32), atol=0.1, rtol=0.1) elif is_c_fp4: - fp4_c_ref = _bfloat16_to_float4_e2m1fn_x2( - c_ref.permute(2, 0, 1).to(torch.bfloat16) - ) - c_ref = ( - float4_e2m1fn_x2_to_float32(fp4_c_ref).to(torch.float32).permute(1, 2, 0) - ) - - c_f32 = ( - float4_e2m1fn_x2_to_float32( - c.cpu().permute(2, 0, 1).view(torch.float4_e2m1fn_x2) - ) - .to(torch.float32) - .permute(1, 2, 0) - ) - - torch.testing.assert_close( - c_ref.to(torch.float32), c_f32.to(torch.float32), atol=0.1, rtol=0.1 - ) + fp4_c_ref = _bfloat16_to_float4_e2m1fn_x2(c_ref.permute(2, 0, 1).to(torch.bfloat16)) + c_ref = float4_e2m1fn_x2_to_float32(fp4_c_ref).to(torch.float32).permute(1, 2, 0) + + c_f32 = float4_e2m1fn_x2_to_float32(c.cpu().permute(2, 0, 1).view(torch.float4_e2m1fn_x2)).to(torch.float32).permute(1, 2, 0) + + torch.testing.assert_close(c_ref.to(torch.float32), c_f32.to(torch.float32), atol=0.1, rtol=0.1) else: c_ref = c_ref.to(c.dtype) torch.testing.assert_close(c_ref, c.cpu(), atol=0.01, rtol=0.01) diff --git a/test/python/fe_api/test_gemm_swiglu.py b/test/python/fe_api/test_gemm_swiglu.py index bd264ad9..b877f592 100644 --- a/test/python/fe_api/test_gemm_swiglu.py +++ b/test/python/fe_api/test_gemm_swiglu.py @@ -8,11 +8,11 @@ check_ref_gemm_swiglu, with_gemm_swiglu_params, gemm_swiglu_init, - with_gemm_swiglu_quant_params, + with_gemm_swiglu_quant_params_fp4, + with_gemm_swiglu_quant_params_fp8, check_ref_gemm_swiglu_quant, ) - """ GemmSwiglu API with explicit set_params, compile, and execute paths. Use this method when running one static configuration for each GemmSwiglu object. @@ -38,10 +38,7 @@ def test_gemm_swiglu_compile_execute( from cudnn import GemmSwigluSm100 from cuda.bindings import driver as cuda except ImportError as e: - # raise e - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_swiglu_init( request, a_major, @@ -65,9 +62,7 @@ def test_gemm_swiglu_compile_execute( cfg["a_major"], cfg["b_major"], ) - ab12_torch, c_torch, _, _, _ = allocate_output_tensors( - cfg["m"], cfg["n"], cfg["l"], cfg["ab12_dtype"], cfg["c_dtype"], cfg["c_major"] - ) + ab12_torch, c_torch, _, _, _ = allocate_output_tensors(cfg["m"], cfg["n"], cfg["l"], cfg["ab12_dtype"], cfg["c_dtype"], cfg["c_major"]) gemm_swiglu = GemmSwigluSm100( sample_a=a_torch, @@ -82,7 +77,6 @@ def test_gemm_swiglu_compile_execute( try: assert gemm_swiglu.check_support(), "Unsupported testcase" except (ValueError, NotImplementedError) as e: - # raise e pytest.skip(f"Unsupported testcase: {e}") gemm_swiglu.compile(current_stream=stream) gemm_swiglu.execute( @@ -130,9 +124,7 @@ def test_gemm_swiglu_wrapper( from cuda.bindings import driver as cuda except ImportError as e: print(f"ImportError: {e}") - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_swiglu_init( request, a_major, @@ -186,8 +178,145 @@ def test_gemm_swiglu_wrapper( @pytest.mark.L0 @torch_fork_set_rng(seed=0) -@with_gemm_swiglu_quant_params -def test_gemm_swiglu_compile_execute_quantize( +@with_gemm_swiglu_quant_params_fp4 +def test_gemm_swiglu_compile_execute_quant_fp4( + a_major, + b_major, + c_major, + ab_dtype, + ab12_dtype, + c_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + _test_gemm_swiglu_compile_execute_quant( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + ab12_dtype=ab12_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_swiglu_quant_params_fp8 +def test_gemm_swiglu_compile_execute_quant_fp8( + a_major, + b_major, + c_major, + ab_dtype, + ab12_dtype, + c_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + _test_gemm_swiglu_compile_execute_quant( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + ab12_dtype=ab12_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_swiglu_quant_params_fp4 +def test_gemm_swiglu_wrapper_quant_fp4( + a_major, + b_major, + c_major, + ab_dtype, + ab12_dtype, + c_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + _test_gemm_swiglu_wrapper_quant( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + ab12_dtype=ab12_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_gemm_swiglu_quant_params_fp8 +def test_gemm_swiglu_wrapper_quant_fp8( + a_major, + b_major, + c_major, + ab_dtype, + ab12_dtype, + c_dtype, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + request, +): + _test_gemm_swiglu_wrapper_quant( + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + ab12_dtype=ab12_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + request=request, + ) + + +def _test_gemm_swiglu_compile_execute_quant( a_major, b_major, c_major, @@ -206,9 +335,7 @@ def test_gemm_swiglu_compile_execute_quantize( from cudnn import GemmSwigluSm100 from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_swiglu_init( request, a_major, @@ -318,10 +445,7 @@ def test_gemm_swiglu_compile_execute_quantize( ) -@pytest.mark.L0 -@torch_fork_set_rng(seed=0) -@with_gemm_swiglu_quant_params -def test_gemm_swiglu_wrapper_quantize( +def _test_gemm_swiglu_wrapper_quant( a_major, b_major, c_major, @@ -340,9 +464,7 @@ def test_gemm_swiglu_wrapper_quantize( from cudnn import gemm_swiglu_wrapper_sm100 from cuda.bindings import driver as cuda except ImportError as e: - pytest.skip( - "Environment not supported: cudnn optional dependencies not installed" - ) + pytest.skip("Environment not supported: cudnn optional dependencies not installed") cfg = gemm_swiglu_init( request, a_major, diff --git a/test/python/fe_api/test_gemm_swiglu_utils.py b/test/python/fe_api/test_gemm_swiglu_utils.py index 85698b4c..32aec9e4 100644 --- a/test/python/fe_api/test_gemm_swiglu_utils.py +++ b/test/python/fe_api/test_gemm_swiglu_utils.py @@ -7,6 +7,7 @@ import pytest from typing import Optional, Tuple from test_fe_api_utils import ( + compute_reference_amax, create_and_permute_tensor, create_scale_factor_tensor, create_sf_layout_tensor, @@ -24,24 +25,52 @@ pytest.mark.parametrize( "ab_dtype", [ - torch.float16, + # torch.float16, torch.bfloat16, torch.float32, torch.float8_e4m3fn, - torch.float8_e5m2, + # torch.float8_e5m2, ], ), pytest.mark.parametrize( - "ab12_dtype", [torch.float16, torch.bfloat16, torch.float32] + "ab12_dtype", + [ + # torch.float16, + torch.bfloat16, + torch.float32, + ], ), pytest.mark.parametrize( - "acc_dtype", [torch.float32] - ), # Note: float16 accumulator is supported but disabled in testing - pytest.mark.parametrize("c_dtype", [torch.float16, torch.bfloat16]), + "acc_dtype", + [ + torch.float32, + # torch.float16, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + # torch.float16, + torch.bfloat16 + ], + ), pytest.mark.parametrize( - "mma_tiler_mn", [(128, 128), (128, 64), (256, 256), (256, 128)] + "mma_tiler_mn", + [ + (128, 128), + (256, 256), + # (128, 64), + # (256, 128), + ], + ), + pytest.mark.parametrize( + "cluster_shape_mn", + [ + (1, 1), + (2, 2), + # (4, 4), + ], ), - pytest.mark.parametrize("cluster_shape_mn", [(1, 1), (2, 2), (4, 4)]), ] @@ -51,37 +80,33 @@ def with_gemm_swiglu_params(func): return func -GEMM_SWIGLU_QUANT_PARAM_MARKS = [ - pytest.mark.parametrize("a_major", ["k", "m"]), - pytest.mark.parametrize("b_major", ["k", "n"]), - pytest.mark.parametrize("c_major", ["m", "n"]), +GEMM_SWIGLU_QUANT_PARAM_MARKS_FP4 = [ + pytest.mark.parametrize("a_major", ["k"]), + pytest.mark.parametrize("b_major", ["k"]), + pytest.mark.parametrize("c_major", ["n"]), pytest.mark.parametrize( "ab_dtype", [ torch.float4_e2m1fn_x2, # torch.uint8, - torch.float8_e4m3fn, - torch.float8_e5m2, ], ), pytest.mark.parametrize( "ab12_dtype", [ - # torch.float32 - torch.float16, + torch.float32, + # torch.float16, torch.bfloat16, torch.float8_e4m3fn, - torch.float8_e5m2, + # torch.float8_e5m2, ], ), pytest.mark.parametrize( "c_dtype", [ - # torch.float32 - torch.float16, + torch.float32, + # torch.float16, torch.bfloat16, - torch.float8_e4m3fn, - torch.float8_e5m2, ], ), pytest.mark.parametrize("acc_dtype", [torch.float32]), @@ -99,7 +124,7 @@ def with_gemm_swiglu_params(func): [ (1, 1), (2, 2), - (4, 4), + # (4, 4), ], ), pytest.mark.parametrize("sf_vec_size", [16, 32]), @@ -107,9 +132,64 @@ def with_gemm_swiglu_params(func): pytest.mark.parametrize("vector_f32", [True, False]), ] +GEMM_SWIGLU_QUANT_PARAM_MARKS_FP8 = [ + pytest.mark.parametrize("a_major", ["k", "m"]), + pytest.mark.parametrize("b_major", ["k", "n"]), + pytest.mark.parametrize("c_major", ["m", "n"]), + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float8_e4m3fn, + # torch.float8_e5m2, + ], + ), + pytest.mark.parametrize( + "ab12_dtype", + [ + # torch.float16, + torch.bfloat16, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + torch.float32, + # torch.float16, + torch.bfloat16, + ], + ), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize( + "mma_tiler_mn", + [ + (128, 128), + (256, 256), + # (128, 64), + # (256, 128), + ], + ), + pytest.mark.parametrize( + "cluster_shape_mn", + [ + (1, 1), + (2, 2), + # (4, 4), + ], + ), + pytest.mark.parametrize("sf_vec_size", [32]), + pytest.mark.parametrize("sf_dtype", [torch.float8_e8m0fnu]), + pytest.mark.parametrize("vector_f32", [True, False]), +] -def with_gemm_swiglu_quant_params(func): - for mark in reversed(GEMM_SWIGLU_QUANT_PARAM_MARKS): + +def with_gemm_swiglu_quant_params_fp4(func): + for mark in reversed(GEMM_SWIGLU_QUANT_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def with_gemm_swiglu_quant_params_fp8(func): + for mark in reversed(GEMM_SWIGLU_QUANT_PARAM_MARKS_FP8): func = mark(func) return func @@ -133,15 +213,11 @@ def gemm_swiglu_init( """Initialize configuration for GEMM SwiGLU tests.""" major, _ = torch.cuda.get_device_capability() if major < 10: - pytest.skip( - f"Environment not supported: requires compute capability >= 10, found {major}" - ) + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") mnkl_str = request.config.getoption("--gemm-swiglu-mnkl", default=None) mma_tiler_str = request.config.getoption("--gemm-swiglu-mma-tiler", default=None) - cluster_shape_str = request.config.getoption( - "--gemm-swiglu-cluster-shape", default=None - ) + cluster_shape_str = request.config.getoption("--gemm-swiglu-cluster-shape", default=None) alpha_opt = request.config.getoption("--gemm-swiglu-alpha", default=None) skip_ref = request.config.getoption("--gemm-swiglu-skip-ref", default=False) @@ -218,9 +294,7 @@ def run_gemm_swiglu_quant_ref( n = b_ref.shape[0] assert n % group == 0, "N must be divisible by 32 for GLU block grouping" num_blocks = n // group - assert ( - num_blocks % 2 == 0 - ), "Number of 32-col blocks must be even (pairs of input/gate)" + assert num_blocks % 2 == 0, "Number of 32-col blocks must be even (pairs of input/gate)" cols = torch.arange(n, device=ab12_ref.device, dtype=torch.long) block_cols = cols.view(num_blocks, group) @@ -253,24 +327,16 @@ def run_gemm_swiglu_quant_ref( ref_sfc_f32 = ref_sfc_f32.permute(1, 2, 0) # For some reason, using `ref_sfc_32_torch = ref_sfc_f32.to(sfc_dtype).to(torch.float32)` leads to different/incorrect results - ref_sfc_f8_torch = torch.empty( - (l, sfm, sfn), dtype=torch.uint8, device="cuda" - ).permute(1, 2, 0) - ref_sfc_f8 = from_dlpack( - ref_sfc_f8_torch, assumed_align=16 - ).mark_layout_dynamic(leading_dim=1) + ref_sfc_f8_torch = torch.empty((l, sfm, sfn), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_sfc_f8 = from_dlpack(ref_sfc_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) ref_sfc_f8.element_type = _convert_to_cutlass_data_type(sfc_dtype) ref_sfc_f32_device = ref_sfc_f32.cuda() - ref_sfc_f32_tensor = from_dlpack( - ref_sfc_f32_device, assumed_align=16 - ).mark_layout_dynamic(leading_dim=1) + ref_sfc_f32_tensor = from_dlpack(ref_sfc_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) cute.testing.convert(ref_sfc_f32_tensor, ref_sfc_f8) cute.testing.convert(ref_sfc_f8, ref_sfc_f32_tensor) ref_sfc_32 = ref_sfc_f32_device.cpu() - ref_sfc_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor( - l, sfm, n // 2, sf_vec_size - ) + ref_sfc_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor(l, sfm, n // 2, sf_vec_size) cvt_sf_MKL_to_M32x4xrm_K4xrk_L( from_dlpack(ref_sfc_32), from_dlpack(ref_sfc_f32_cute_torch_tensor_cpu), @@ -287,14 +353,6 @@ def run_gemm_swiglu_quant_ref( return ab12_ref_ret, c_ref, sfc_ref, amax_ref -def compute_reference_amax(output_tensor: torch.Tensor) -> float: - if output_tensor.dtype != torch.float32: - output_fp32 = output_tensor.float() - else: - output_fp32 = output_tensor - return torch.amax(torch.abs(output_fp32)).item() - - def run_gemm_swiglu_ref(a_ref, b_ref, alpha): ab12_ref, c_ref = None, None if a_ref.dtype in {torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float8_e5m2}: @@ -306,18 +364,13 @@ def run_gemm_swiglu_ref(a_ref, b_ref, alpha): n = b_ref.shape[0] assert n % group == 0, "N must be divisible by 32 for GLU block grouping" num_blocks = n // group - assert ( - num_blocks % 2 == 0 - ), "Number of 32-col blocks must be even (pairs of input/gate)" + assert num_blocks % 2 == 0, "Number of 32-col blocks must be even (pairs of input/gate)" cols = torch.arange(n, device=ab12_ref.device, dtype=torch.long) block_cols = cols.view(num_blocks, group) input_idx = block_cols[0::2].reshape(-1) gate_idx = block_cols[1::2].reshape(-1) - c_ref = ab12_ref.index_select(1, input_idx) * ( - ab12_ref.index_select(1, gate_idx) - * torch.sigmoid(ab12_ref.index_select(1, gate_idx)) - ) + c_ref = ab12_ref.index_select(1, input_idx) * (ab12_ref.index_select(1, gate_idx) * torch.sigmoid(ab12_ref.index_select(1, gate_idx))) c_ref = c_ref.to(torch.float32) return ab12_ref, c_ref @@ -345,9 +398,7 @@ def check_ref_gemm_swiglu( rtol=0.1, ) else: - torch.testing.assert_close( - ab12.cpu(), ab12_ref.to(ab12.dtype), atol=0.01, rtol=9e-03 - ) + torch.testing.assert_close(ab12.cpu(), ab12_ref.to(ab12.dtype), atol=0.01, rtol=9e-03) is_c_fp8 = c.dtype in {torch.float8_e4m3fn, torch.float8_e5m2} if is_c_fp8: @@ -358,9 +409,7 @@ def check_ref_gemm_swiglu( rtol=0.1, ) else: - torch.testing.assert_close( - c.cpu(), c_ref.to(c.dtype), atol=0.01, rtol=9e-03 - ) + torch.testing.assert_close(c.cpu(), c_ref.to(c.dtype), atol=0.01, rtol=9e-03) else: print("Skipping reference check") @@ -392,9 +441,7 @@ def check_ref_gemm_swiglu_quant( b_ref = b_ref.clone().to(torch.float32).cpu() sfa_ref = sfa_ref.float().cpu() sfb_ref = sfb_ref.float().cpu() - norm_const_ref = ( - norm_const_ref.float().cpu() if norm_const_ref is not None else None - ) + norm_const_ref = norm_const_ref.float().cpu() if norm_const_ref is not None else None sfc_dtype = sfc.dtype if sfc is not None else None ab12_ref, c_ref, sfc_ref, amax_ref = run_gemm_swiglu_quant_ref( a_ref, @@ -418,9 +465,7 @@ def check_ref_gemm_swiglu_quant( rtol=0.01, ) else: - torch.testing.assert_close( - ab12.cpu(), ab12_ref.to(ab12.dtype), atol=0.01, rtol=0.01 - ) + torch.testing.assert_close(ab12.cpu(), ab12_ref.to(ab12.dtype), atol=0.01, rtol=0.01) if c_dtype in {torch.float32, torch.float16, torch.bfloat16}: torch.testing.assert_close(c.cpu(), c_ref.to(c.dtype), atol=0.01, rtol=0.01) @@ -429,9 +474,7 @@ def check_ref_gemm_swiglu_quant( torch.uint8, }: reference_amax = torch.tensor(compute_reference_amax(c_ref.clone())) - torch.testing.assert_close( - amax.cpu().squeeze(), reference_amax, atol=0.01, rtol=0.01 - ) + torch.testing.assert_close(amax.cpu().squeeze(), reference_amax, atol=0.01, rtol=0.01) elif c_dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: torch.testing.assert_close( sfc.cpu().to(torch.float32), @@ -439,9 +482,7 @@ def check_ref_gemm_swiglu_quant( atol=0.01, rtol=0.01, ) - torch.testing.assert_close( - c.cpu().to(torch.float32), c_ref.to(torch.float32), atol=0.01, rtol=0.01 - ) + torch.testing.assert_close(c.cpu().to(torch.float32), c_ref.to(torch.float32), atol=0.01, rtol=0.01) def allocate_input_tensors( @@ -523,12 +564,8 @@ def allocate_output_tensors( sfc_ref, sfc_tensor, amax_tensor = None, None, None if is_block_scaled: if c_dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: - sfc_ref, sfc_tensor = create_scale_factor_tensor( - l, m, n // 2, sf_vec_size, sf_dtype - ) + sfc_ref, sfc_tensor = create_scale_factor_tensor(l, m, n // 2, sf_vec_size, sf_dtype) if c_dtype == torch.bfloat16: - amax_tensor = torch.full( - (1, 1, 1), -float("inf"), device="cuda", dtype=torch.float32 - ) + amax_tensor = torch.full((1, 1, 1), -float("inf"), device="cuda", dtype=torch.float32) return ab12_tensor, c_tensor, sfc_tensor, sfc_ref, amax_tensor diff --git a/test/python/fe_api/test_grouped_gemm_swiglu.py b/test/python/fe_api/test_grouped_gemm_swiglu.py new file mode 100644 index 00000000..3738dd1d --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_swiglu.py @@ -0,0 +1,371 @@ +""" +Tests for Grouped GEMM SwiGLU Forward Kernel (SM100+) + +This module tests the contiguous grouped block-scaled GEMM with SwiGLU activation +for MoE (Mixture of Experts) workloads. + +Reference: continugous_blockscaled_grouped_gemm_swiglu_quant_fusion.py +""" + +import torch +import pytest +from test_utils import torch_fork_set_rng +from fe_api.test_grouped_gemm_swiglu_utils import ( + grouped_gemm_swiglu_init, + with_grouped_gemm_swiglu_params_fp4, + with_grouped_gemm_swiglu_params_fp8, + allocate_grouped_gemm_input_tensors, + allocate_grouped_gemm_output_tensors, + check_ref_grouped_gemm_swiglu, +) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_swiglu_params_fp4 +def test_grouped_gemm_swiglu_compile_execute_fp4( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_swiglu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_swiglu_params_fp8 +def test_grouped_gemm_swiglu_compile_execute_fp8( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_swiglu_compile_execute( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_swiglu_params_fp4 +def test_grouped_gemm_swiglu_wrapper_fp4( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_swiglu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +@with_grouped_gemm_swiglu_params_fp8 +def test_grouped_gemm_swiglu_wrapper_fp8( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + _test_grouped_gemm_swiglu_wrapper( + ab_dtype=ab_dtype, + c_dtype=c_dtype, + d_dtype=d_dtype, + cd_major=cd_major, + acc_dtype=acc_dtype, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + vector_f32=vector_f32, + discrete_col_sfd=discrete_col_sfd, + request=request, + ) + + +""" +GroupedGemmSwiglu API with explicit check_support, compile, and execute paths. +Use this method when running one static configuration for each GroupedGemmSwiglu object. +""" + + +def _test_grouped_gemm_swiglu_compile_execute( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import GroupedGemmSwigluSm100 + from cuda.bindings import driver as cuda + except ImportError as e: + raise e + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_swiglu_init( + request, + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + cta_tile_m=cfg["mma_tiler_mn"][0], + ) + + outputs = allocate_grouped_gemm_output_tensors( + tensor_m=inputs["tensor_m"], + n=cfg["n"], + l=cfg["l"], + ab_dtype=cfg["ab_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + ) + + api = GroupedGemmSwigluSm100( + sample_a=inputs["a_tensor"], + sample_b=inputs["b_tensor"], + sample_c=outputs["c_tensor"], + sample_d=outputs["d_tensor"], + sample_sfa=inputs["sfa_tensor"], + sample_sfb=inputs["sfb_tensor"], + sample_tile_idx_to_expert_idx=inputs["tile_idx_to_expert_idx"], + sample_num_non_exiting_tiles=inputs["num_non_exiting_tiles"], + sample_alpha=inputs["alpha_tensor"], + sample_amax=outputs.get("amax_tensor"), + sample_d_col=outputs["d_col_tensor"], + sample_sfd_row=outputs.get("sfd_row_tensor"), + sample_sfd_col=outputs.get("sfd_col_tensor"), + sample_norm_const=inputs.get("norm_const_tensor"), + sample_prob=inputs.get("prob_tensor"), + sample_m_split_cumsum=inputs.get("num_m_split_cumsum_tensor"), + acc_dtype=cfg["acc_dtype"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + ) + + try: + assert api.check_support(), "Unsupported testcase" + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + api.compile(current_stream=stream) + api.execute( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + c_tensor=outputs["c_tensor"], + d_tensor=outputs["d_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + tile_idx_to_expert_idx=inputs["tile_idx_to_expert_idx"], + num_non_exiting_tiles=inputs["num_non_exiting_tiles"], + alpha_tensor=inputs["alpha_tensor"], + d_col_tensor=outputs["d_col_tensor"], + sfd_row_tensor=outputs.get("sfd_row_tensor"), + sfd_col_tensor=outputs.get("sfd_col_tensor"), + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + amax_tensor=outputs.get("amax_tensor"), + m_split_cumsum=inputs.get("num_m_split_cumsum_tensor"), + current_stream=stream, + ) + + check_ref_grouped_gemm_swiglu( + inputs, + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) + + +""" +GroupedGemmSwiglu API with grouped_gemm_swiglu_wrapper: +Use the wrapper to directly call GroupedGemmSwiglu without explicit setup and compilation. +""" + + +def _test_grouped_gemm_swiglu_wrapper( + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + request, +): + try: + from cudnn import grouped_gemm_swiglu_wrapper_sm100 + from cuda.bindings import driver as cuda + except ImportError as e: + pytest.skip("Environment not supported: cudnn optional dependencies not installed") + + cfg = grouped_gemm_swiglu_init( + request, + ab_dtype, + c_dtype, + d_dtype, + cd_major, + acc_dtype, + mma_tiler_mn, + cluster_shape_mn, + sf_vec_size, + sf_dtype, + vector_f32, + discrete_col_sfd, + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + inputs = allocate_grouped_gemm_input_tensors( + n=cfg["n"], + k=cfg["k"], + l=cfg["l"], + group_m_list=cfg["group_m_list"], + ab_dtype=cfg["ab_dtype"], + sf_dtype=cfg["sf_dtype"], + sf_vec_size=cfg["sf_vec_size"], + m_aligned=cfg["m_aligned"], + cta_tile_m=cfg["mma_tiler_mn"][0], + ) + + try: + for _ in range(2): # Run twice to test caching path + outputs = grouped_gemm_swiglu_wrapper_sm100( + a_tensor=inputs["a_tensor"], + b_tensor=inputs["b_tensor"], + sfa_tensor=inputs["sfa_tensor"], + sfb_tensor=inputs["sfb_tensor"], + tile_idx_to_expert_idx=inputs["tile_idx_to_expert_idx"], + num_non_exiting_tiles=inputs["num_non_exiting_tiles"], + alpha_tensor=inputs["alpha_tensor"], + norm_const_tensor=inputs.get("norm_const_tensor"), + prob_tensor=inputs.get("prob_tensor"), + m_split_cumsum=inputs.get("num_m_split_cumsum_tensor"), + acc_dtype=cfg["acc_dtype"], + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + cd_major=cfg["cd_major"], + mma_tiler_mn=cfg["mma_tiler_mn"], + cluster_shape_mn=cfg["cluster_shape_mn"], + sf_vec_size=cfg["sf_vec_size"], + vector_f32=cfg["vector_f32"], + m_aligned=cfg["m_aligned"], + discrete_col_sfd=cfg["discrete_col_sfd"], + current_stream=stream, + ) + except (ValueError, NotImplementedError) as e: + pytest.skip(f"Unsupported testcase: {e}") + + check_ref_grouped_gemm_swiglu( + inputs, + outputs, + cfg, + skip_ref=cfg["skip_ref"], + ) diff --git a/test/python/fe_api/test_grouped_gemm_swiglu_utils.py b/test/python/fe_api/test_grouped_gemm_swiglu_utils.py new file mode 100644 index 00000000..dac7d8be --- /dev/null +++ b/test/python/fe_api/test_grouped_gemm_swiglu_utils.py @@ -0,0 +1,831 @@ +""" +Utilities and parameterization for Grouped GEMM SwiGLU tests. +Contains test configuration fixtures, tensor creation, and reference implementations. + +Reference: continugous_blockscaled_grouped_gemm_swiglu_quant_fusion.py (lines 3518-4825) +""" + +import torch +import pytest +from typing import Optional, Tuple, List, Dict, Any +from test_fe_api_utils import ( + ceil_div, + compute_reference_amax, + create_and_permute_tensor, + create_scale_factor_tensor, + create_sf_layout_tensor, + cvt_sf_MKL_to_M32x4xrm_K4xrk_L, +) +from test_low_precision_matmul import ( + _bfloat16_to_float4_e2m1fn_x2, + float4_e2m1fn_x2_to_float32, +) + +# ============================================================================= +# Parameterization Marks +# ============================================================================= + +GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP8 = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + # torch.float8_e4m3fn, + # torch.float8_e5m2, + # torch.float16, + torch.bfloat16, + # torch.float32, + ], + ), + pytest.mark.parametrize( + "d_dtype", + [ + torch.float8_e4m3fn, + # torch.float8_e5m2, + # torch.bfloat16, + ], + ), + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize( + "mma_tiler_mn", + [ + (256, 256), + ], + ), + pytest.mark.parametrize( + "cluster_shape_mn", + [ + (2, 1), + (1, 1), + ], + ), + pytest.mark.parametrize("sf_vec_size", [32]), + pytest.mark.parametrize( + "sf_dtype", + [ + torch.float8_e8m0fnu, + ], + ), + pytest.mark.parametrize("vector_f32", [True, False]), + pytest.mark.parametrize("discrete_col_sfd", [True, False]), +] + +GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP4 = [ + pytest.mark.parametrize( + "ab_dtype", + [ + torch.float4_e2m1fn_x2, + # torch.uint8, + ], + ), + pytest.mark.parametrize( + "c_dtype", + [ + # torch.float16, + torch.bfloat16, + ], + ), + pytest.mark.parametrize( + "d_dtype", + [ + torch.bfloat16, + torch.float32, + ], + ), + pytest.mark.parametrize("cd_major", ["n"]), + pytest.mark.parametrize("acc_dtype", [torch.float32]), + pytest.mark.parametrize( + "mma_tiler_mn", + [ + (256, 256), + (128, 128), + ], + ), + pytest.mark.parametrize( + "cluster_shape_mn", + [ + (2, 1), + (1, 1), + ], + ), + pytest.mark.parametrize("sf_vec_size", [16, 32]), + pytest.mark.parametrize( + "sf_dtype", + [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ], + ), + pytest.mark.parametrize("vector_f32", [True, False]), + pytest.mark.parametrize("discrete_col_sfd", [False]), +] + + +def with_grouped_gemm_swiglu_params_fp4(func): + """Decorator to apply grouped GEMM SwiGLU FP4 test parameters.""" + for mark in reversed(GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP4): + func = mark(func) + return func + + +def with_grouped_gemm_swiglu_params_fp8(func): + """Decorator to apply grouped GEMM SwiGLU FP8 test parameters.""" + for mark in reversed(GROUPED_GEMM_SWIGLU_PARAM_MARKS_FP8): + func = mark(func) + return func + + +# ============================================================================= +# Configuration Initialization +# ============================================================================= + + +def grouped_gemm_swiglu_init( + request, + ab_dtype: torch.dtype, + c_dtype: torch.dtype, + d_dtype: torch.dtype, + cd_major: str, + acc_dtype: torch.dtype, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sf_vec_size: int, + sf_dtype: torch.dtype, + vector_f32: bool = False, + discrete_col_sfd: bool = False, +) -> Dict[str, Any]: + """Initialize configuration for Grouped GEMM SwiGLU tests. + + :param request: pytest request object + :param ab_dtype: Data type for A and B tensors + :param c_dtype: Data type for intermediate C tensor (always bfloat16) + :param d_dtype: Data type for output D tensor (fp8 when ab is fp8, bf16 when ab is fp4) + :param cd_major: Major dimension for output C and D tensors + :param acc_dtype: Accumulator data type + :param mma_tiler_mn: MMA tiler shape + :param cluster_shape_mn: Cluster shape + :param sf_vec_size: Scale factor vector size + :param sf_dtype: Scale factor data type + :param vector_f32: Use vectorized f32 operations + :param discrete_col_sfd: Generate discrete col-major scale factor tensor + :return: Configuration dictionary + """ + major, _ = torch.cuda.get_device_capability() + if major < 10: + pytest.skip(f"Environment not supported: requires compute capability >= 10, found {major}") + + # Parse CLI options + nkl_str = request.config.getoption("--grouped-gemm-nkl", default=None) + group_m_str = request.config.getoption("--grouped-gemm-group-m", default=None) + m_aligned_opt = request.config.getoption("--grouped-gemm-m-aligned", default=None) + skip_ref = request.config.getoption("--grouped-gemm-skip-ref", default=False) + + # Default values + if nkl_str is not None: + n, k, l = [int(x.strip()) for x in nkl_str.split(",")] + else: + n, k, l = 512, 512, 4 + + if group_m_str is not None: + group_m_list = [int(x.strip()) for x in group_m_str.split(",")] + else: + # Default: equal M values per group + group_m_list = [256] * l + + m_aligned = int(m_aligned_opt) if m_aligned_opt is not None else mma_tiler_mn[0] + + config = { + "n": n, + "k": k, + "l": l, + "group_m_list": group_m_list, + "m_aligned": m_aligned, + "mma_tiler_mn": mma_tiler_mn, + "cluster_shape_mn": cluster_shape_mn, + "ab_dtype": ab_dtype, + "c_dtype": c_dtype, + "d_dtype": d_dtype, + "cd_major": cd_major, + "acc_dtype": acc_dtype, + "sf_vec_size": sf_vec_size, + "sf_dtype": sf_dtype, + "vector_f32": vector_f32, + "skip_ref": skip_ref, + "discrete_col_sfd": discrete_col_sfd, + } + + return config + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_dtype_rcp_limits(dtype: torch.dtype) -> float: + """Get reciprocal of max value for quantization.""" + if dtype == torch.float8_e5m2: + return 1 / 128.0 + elif dtype == torch.float8_e4m3fn: + return 1 / 448.0 + elif dtype in {torch.float4_e2m1fn_x2, torch.uint8}: + return 1 / 6.0 + return 1.0 + + +def create_mask( + group_m_list: List[int], + cta_tile_m: int, + m_aligned: int = 128, + permuted_m: Optional[int] = None, +) -> Tuple[int, List[int], torch.Tensor, torch.Tensor]: + """Create mask and group mapping for contiguous grouped GEMM. + + :param group_m_list: List of M values for each group (will be aligned to m_aligned) + :param cta_tile_m: CTA tile size in M dimension (from mma_tiler_mn[0]) + :param m_aligned: Alignment requirement for group M dimension + :param permuted_m: Optional padded M dimension for CUDA graph support + + Note: m_aligned should be a multiple of the CTA tile M dimension to prevent + a single tile from spanning multiple groups, which would cause incorrect + B matrix access. + + Note: For cuda_graph support, set permuted_m to the pre-calculated padded size: + permuted_m = m * topK + num_local_experts * (256 - 1) + Example: 4096*8 + (256/32)*255 = 34808 + Only the actual valid rows (aligned_groupm[0]+aligned_groupm[1]+...) contain + valid data. The kernel will exit when tile_idx >= num_non_exiting_tiles. + + :return: Tuple of (valid_m, aligned_group_m_list, tile_idx_to_expert_idx, num_non_exiting_tiles, num_m_split_cumsum) + - tile_idx_to_expert_idx: shape (permuted_m/cta_tile_m,) if permuted_m provided, + else (valid_m/cta_tile_m,) + - num_non_exiting_tiles: scalar value = valid_m/cta_tile_m + - num_m_split_cumsum: cumulative sum of aligned_group_m_list + """ + valid_m = 0 + aligned_group_m_list = [] + tile_idx_to_expert_idx = [] + m_split_cumsum = [] + m_split_cumsum.append(valid_m) + + for i, group_m in enumerate(group_m_list): + aligned_group_m = ((group_m + m_aligned - 1) // m_aligned) * m_aligned + valid_m += aligned_group_m + aligned_group_m_list.append(aligned_group_m) + + # Calculate number of tiles for this group based on CTA tile M size + # Each tile covers cta_tile_m rows in M dimension + num_tiles_in_group = aligned_group_m // cta_tile_m + # Add expert_idx for each tile in this group + tile_idx_to_expert_idx.extend([i] * num_tiles_in_group) + m_split_cumsum.append(valid_m) + + # Compute num_non_exiting_tiles (number of valid tiles in M dimension) + num_non_exiting_tiles = len(tile_idx_to_expert_idx) + + # Apply padding if requested (for cuda_graph support) + if permuted_m is not None: + if permuted_m < valid_m: + raise ValueError(f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). " f"Cannot pad to a smaller size.") + if permuted_m > valid_m: + # Calculate how many padding tiles are needed based on CTA tile M size + num_padding_tiles = (permuted_m - valid_m) // cta_tile_m + # Pad with large negative value (these tiles won't be accessed due to + # num_non_exiting_tiles check) + tile_idx_to_expert_idx.extend([int(-2e9)] * num_padding_tiles) + + # Convert to tensors + tile_idx_to_expert_idx_tensor = torch.tensor(tile_idx_to_expert_idx, device="cuda", dtype=torch.int32) + num_non_exiting_tiles_tensor = torch.tensor([num_non_exiting_tiles], device="cuda", dtype=torch.int32) + num_m_split_cumsum_tensor = torch.tensor(m_split_cumsum, device="cuda", dtype=torch.int32) + + return ( + valid_m, + aligned_group_m_list, + tile_idx_to_expert_idx_tensor, + num_non_exiting_tiles_tensor, + num_m_split_cumsum_tensor, + ) + + +# ============================================================================= +# Tensor Allocation +# ============================================================================= + + +def allocate_grouped_gemm_input_tensors( + n: int, + k: int, + l: int, + group_m_list: List[int], + ab_dtype: torch.dtype, + sf_dtype: torch.dtype, + sf_vec_size: int, + m_aligned: int, + cta_tile_m: int, + permuted_m: Optional[int] = None, + norm_const: float = 1.0, + device: str = "cuda", +) -> Dict[str, Any]: + """Allocate input tensors for grouped GEMM SwiGLU. + + Matches the original create_tensors() implementation. + + :return: Dictionary containing all input tensors and metadata + """ + + ( + valid_m, + aligned_group_m_list, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + num_m_split_cumsum, + ) = create_mask(group_m_list, cta_tile_m, m_aligned, permuted_m) + + tensor_m = permuted_m if permuted_m is not None else valid_m + + # Note: a and b tensors are always K-major + a_ref, a_tensor = create_and_permute_tensor(1, tensor_m, k, False, ab_dtype) + b_ref, b_tensor = create_and_permute_tensor(l, n, k, False, ab_dtype) + + sfa_ref, sfa_tensor = create_scale_factor_tensor(1, tensor_m, k, sf_vec_size, sf_dtype) + sfb_ref, sfb_tensor = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) + + alpha_tensor = torch.randint(-2, 2, (l,), dtype=torch.float32, device=device).float() + + prob_tensor = torch.randint(-2, 2, (tensor_m, 1, 1), dtype=torch.float32, device=device).float() + + result = { + "a_tensor": a_tensor, + "a_ref": a_ref, + "b_tensor": b_tensor, + "b_ref": b_ref, + "sfa_tensor": sfa_tensor, + "sfa_ref": sfa_ref, + "sfb_tensor": sfb_tensor, + "sfb_ref": sfb_ref, + "alpha_tensor": alpha_tensor, + "prob_tensor": prob_tensor, + "tile_idx_to_expert_idx": tile_idx_to_expert_idx, + "num_non_exiting_tiles": num_non_exiting_tiles, + "num_m_split_cumsum_tensor": num_m_split_cumsum, + "aligned_group_m_list": aligned_group_m_list, + "valid_m": valid_m, + "tensor_m": tensor_m, + "norm_const_tensor": None, + } + + # Norm constant tensor + if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sf_dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ]: + result["norm_const_tensor"] = torch.tensor([norm_const], dtype=torch.float32, device=device) + + return result + + +def allocate_grouped_gemm_output_tensors( + tensor_m: int, + n: int, + l: int, + ab_dtype: torch.dtype, + c_dtype: torch.dtype, + d_dtype: torch.dtype, + cd_major: str, + sf_dtype: torch.dtype, + sf_vec_size: int = 16, + device: str = "cuda", +) -> Dict[str, Any]: + """Allocate output tensors for grouped GEMM SwiGLU. + + Matches the original create_tensors() implementation. + + :param c_dtype: Data type for intermediate C tensor (always bfloat16) + :param d_dtype: Data type for output D tensor (fp8 when ab is fp8, bf16 when ab is fp4) + :return: Dictionary containing all output tensors + """ + n_out = n // 2 # After SwiGLU + + _, c_tensor = create_and_permute_tensor(1, tensor_m, n, cd_major == "m", c_dtype) + _, d_tensor = create_and_permute_tensor(1, tensor_m, n_out, cd_major == "m", d_dtype) + _, d_col_tensor = create_and_permute_tensor(1, tensor_m, n_out, cd_major == "m", d_dtype) + + result = { + "c_tensor": c_tensor, + "d_tensor": d_tensor, + "d_col_tensor": d_col_tensor, + "sfd_row_tensor": None, + "sfd_col_tensor": None, + } + + if d_dtype in [torch.bfloat16, torch.float16]: + result["amax_tensor"] = torch.full((l, 1), float("-inf"), dtype=torch.float32, device=device) + + if ab_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and sf_dtype in [ + torch.float8_e8m0fnu, + torch.float8_e4m3fn, + ]: # generate_sfd + sfd_row_ref, sfd_row_tensor = create_scale_factor_tensor(1, tensor_m, n_out, sf_vec_size, sf_dtype) + result["sfd_row_tensor"] = sfd_row_tensor + result["sfd_row_ref"] = sfd_row_ref + + sfd_col_ref, sfd_col_tensor = create_scale_factor_tensor(1, n_out, tensor_m, sf_vec_size, sf_dtype) + result["sfd_col_tensor"] = sfd_col_tensor + result["sfd_col_ref"] = sfd_col_ref + + return result + + +# ============================================================================= +# Reference Implementations +# ============================================================================= + + +def run_grouped_gemm_swiglu_ref( + a_ref: torch.Tensor, + b_ref: torch.Tensor, + sfa_ref: torch.Tensor, + sfb_ref: torch.Tensor, + alpha_tensor: torch.Tensor, + prob_tensor: torch.Tensor, + aligned_group_m_list: List[int], + valid_m: int, + generate_amax: bool = False, + generate_sfd: bool = False, + norm_const_tensor: Optional[torch.Tensor] = None, + c_dtype: torch.dtype = torch.bfloat16, + d_dtype: torch.dtype = torch.float32, + sf_vec_size: int = 16, + sf_dtype: torch.dtype = torch.float8_e8m0fnu, +) -> torch.Tensor: + """Run reference implementation for grouped GEMM SwiGLU. + + Matches the reference checking in continugous_blockscaled_grouped_gemm_swiglu_quant_fusion.py + (lines 4113-4179) + + :param a_ref: A tensor (tensor_m, k, 1) in float32 + :param b_ref: B tensor (n, k, l) in float32 + :param sfa_ref: Scale factor A tensor (tensor_m, k, 1) in float32 + :param sfb_ref: Scale factor B tensor (n, k, l) in float32 + :param alpha_tensor: Per-group alpha scaling (l,) + :param prob_tensor: Per-row probability scaling (tensor_m, 1, 1) + :param aligned_group_m_list: Aligned M values per group + :param valid_m: Total valid M dimension + :param generate_amax: Generate AMAX tensor + :param generate_sfd: Generate SFD tensor + :param norm_const_tensor: Normalization constant tensor (1,) + :param c_dtype: Intermediate C tensor dtype (always bfloat16) + :param d_dtype: Output D tensor dtype + :param sf_vec_size: Scale factor vector size + :param sf_dtype: Scale factor dtype + :return: Reference output tensor (valid_m, n_out, 1) + """ + n, k, l = b_ref.shape + n_out = n // 2 + ref_tensors = {} + + # Step 1: Compute GEMM per group with scale factors + ref = torch.empty((1, valid_m, n), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + res_a = torch.einsum("mk,mk->mk", a_ref[start:end, :, 0], sfa_ref[start:end, :, 0]) + res_b = torch.einsum("nk,nk->nk", b_ref[:, :, i], sfb_ref[:, :, i]) + ref[0, start:end, :] = torch.einsum("mk,nk->mn", res_a, res_b) + start = end + ref = ref.permute((1, 2, 0)) + + # Step 2: Apply alpha per group + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + ref[start:end, :, 0] = ref[start:end, :, 0] * alpha_tensor[i].item() + start = end + + ref_tensors["c_ref"] = ref.clone() + + # Step 3: Apply SwiGLU with interleaved block layout + group = 32 + assert n % group == 0, "N must be divisible by 32 for GLU block grouping" + num_blocks = n // group + assert num_blocks % 2 == 0, "Number of 32-col blocks must be even (pairs of input/gate)" + + cols = torch.arange(n, device=ref.device, dtype=torch.long) + block_cols = cols.view(num_blocks, group) + # up: blocks 0,2,4,6,... (even blocks) + # gate: blocks 1,3,5,7,... (odd blocks) + up_idx = block_cols[0::2].reshape(-1) + gate_idx = block_cols[1::2].reshape(-1) + ref_up = ref.index_select(1, up_idx) + ref_gate = ref.index_select(1, gate_idx) + + # SwiGLU: up * (gate * sigmoid(gate)) + ref_gate = ref_gate * torch.sigmoid(ref_gate) + ref_after_swiglu = ref_up * ref_gate + + # Step 4: Apply prob + ref_after_swiglu = ref_after_swiglu * prob_tensor.expand(-1, n_out, -1) + ref_tensors["d_ref"] = ref_after_swiglu.clone() + + if generate_amax: + amax_ref = torch.empty((l,), dtype=torch.float32, device=a_ref.device) + start = 0 + for i, group_m in enumerate(aligned_group_m_list): + end = start + group_m + amax_ref[i] = compute_reference_amax(ref_after_swiglu[start:end, :, 0].clone()) + start = end + ref_tensors["amax_ref"] = amax_ref + + if generate_sfd: + try: + from cutlass.cute.runtime import from_dlpack + import cutlass.cute as cute + from cudnn.datatypes import _convert_to_cutlass_data_type + except ImportError: + pytest.skip("CUTLASS not available for scale factor conversion") + + norm_const = norm_const_tensor[0].item() + + # 1. Compute reference SFDRow (m, sfn, l) in fp32 + sfn = ceil_div(n_out, sf_vec_size) + # Resahpe ref to (l, m, sfn, sf_vec_size) + ref_for_sf = ref_after_swiglu.permute(2, 0, 1).contiguous() # (l, m, n) + # l is involved in valid_m + ref_for_sf = ref_for_sf.view(1, valid_m, sfn, sf_vec_size) + # Take abs max over sf_vec_size dimension + ref_for_sf, _ = torch.abs(ref_for_sf).max(dim=3) # (l, m, sfn) + # Multiply by norm_const and rcp_limits + ref_sfd_row_f32 = ref_for_sf * norm_const * get_dtype_rcp_limits(d_dtype) + # Permute to (m, sfn, l) + ref_sfd_row_f32 = ref_sfd_row_f32.permute(1, 2, 0) + + # Convert fp32 -> f8 -> fp32 for ref_sfd_row_f32 + ref_sfd_row_f8_torch = torch.empty(*(1, valid_m, sfn), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_sfd_row_f8 = from_dlpack(ref_sfd_row_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_sfd_row_f8.element_type = _convert_to_cutlass_data_type(sf_dtype) + ref_sfd_row_f32_device = ref_sfd_row_f32.cuda() + ref_sfd_row_f32_tensor = from_dlpack(ref_sfd_row_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_sfd_row_f32_tensor, ref_sfd_row_f8) + cute.testing.convert(ref_sfd_row_f8, ref_sfd_row_f32_tensor) + ref_sfd_row_f32 = ref_sfd_row_f32_device.cpu() + + # 2. Convert ref_sfd_row_f32 to scale factor layout and compare with kernel sfd tensor + ref_sfd_row_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor(1, valid_m, n_out, sf_vec_size) + + # convert ref_after_swiglu f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_sfd_row_f32), + from_dlpack(ref_sfd_row_f32_cute_torch_tensor_cpu), + ) + ref_sfd_row_f32 = ref_sfd_row_f32.cuda() + ref_tensors["sfd_row_ref"] = ref_sfd_row_f32_cute_torch_tensor_cpu.clone() + + # 3. Quantized output with scale factor + # Compute reciprocal of ref_sfd_row_f32 and multiply by norm_const + ref_sfd_row_rcp = norm_const * ref_sfd_row_f32.reciprocal() + ref_sfd_row_rcp = torch.clamp(ref_sfd_row_rcp, max=3.40282346638528859812e38) + # Expand the sfn dimension by repeating each value sf_vec_size times + # ref_sfd_row_rcp: (m, sfn, l) -> (m, sfn, sf_vec_size, l) -> (m, n, l) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp.unsqueeze(2).expand(valid_m, sfn, sf_vec_size, 1) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded.reshape(valid_m, sfn * sf_vec_size, 1) + # Trim to exact n dimension if needed + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded[:, :n_out, :] + + # Apply scale to reference output: ref = ref * ref_sfd_row_rcp + ref_after_row_quant = torch.einsum("mnl,mnl->mnl", ref_after_swiglu, ref_sfd_row_rcp_expanded) + ref_tensors["d_ref"] = ref_after_row_quant.clone() + + # Col Quantized SFD tensor + # 1. Compute reference SFDCol (m, sfn, l) in fp32 + ref_after_swiglu = ref_after_swiglu.permute(2, 1, 0).contiguous().permute(1, 2, 0) + n_after_swiglu = ref_after_swiglu.shape[1] + sfn = ceil_div(n_after_swiglu, sf_vec_size) + valid_m = ref_after_swiglu.shape[0] + # Reshape ref to (l, m, sfn, sf_vec_size) + ref_for_sf = ref_after_swiglu.permute(2, 0, 1).contiguous() # (l, m, n) + # l is involved in valid_m + ref_for_sf = ref_for_sf.view(1, valid_m, sfn, sf_vec_size) + # Take abs max over sf_vec_size dimension + ref_for_sf, _ = torch.abs(ref_for_sf).max(dim=3) # (l, m, sfn) + # Multiply by norm_const and rcp_limits + ref_sfd_row_f32 = ref_for_sf * norm_const * get_dtype_rcp_limits(d_dtype) + # Permute to (m, sfn, l) + ref_sfd_row_f32 = ref_sfd_row_f32.permute(1, 2, 0) + + # Convert fp32 -> f8 -> fp32 for ref_sfd_row_f32 + ref_sfd_row_f8_torch = torch.empty(*(1, valid_m, sfn), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_sfd_row_f8 = from_dlpack(ref_sfd_row_f8_torch, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_sfd_row_f8.element_type = _convert_to_cutlass_data_type(sf_dtype) + ref_sfd_row_f32_device = ref_sfd_row_f32.cuda() + ref_sfd_row_f32_tensor = from_dlpack(ref_sfd_row_f32_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_sfd_row_f32_tensor, ref_sfd_row_f8) + cute.testing.convert(ref_sfd_row_f8, ref_sfd_row_f32_tensor) + ref_sfd_row_f32 = ref_sfd_row_f32_device.cpu() + + # 2. Convert ref_sfd_row_f32 to scale factor layout and compare with kernel sfd tensor + ref_sfd_row_f32_cute_torch_tensor_cpu, _ = create_sf_layout_tensor(1, valid_m, n_after_swiglu, sf_vec_size) + + # convert ref_after_swiglu f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_sfd_row_f32), + from_dlpack(ref_sfd_row_f32_cute_torch_tensor_cpu), + ) + ref_sfd_row_f32 = ref_sfd_row_f32.cuda() + ref_tensors["sfd_col_ref"] = ref_sfd_row_f32_cute_torch_tensor_cpu.clone() + + # 3. Quantized output with scale factor + # Compute reciprocal of ref_sfd_row_f32 and multiply by norm_const + ref_sfd_row_rcp = norm_const * ref_sfd_row_f32.reciprocal() + ref_sfd_row_rcp = torch.clamp(ref_sfd_row_rcp, max=3.40282346638528859812e38) + # Expand the sfn dimension by repeating each value sf_vec_size times + # ref_sfd_row_rcp: (m, sfn, l) -> (m, sfn, sf_vec_size, l) -> (m, n, l) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp.unsqueeze(2).expand(valid_m, sfn, sf_vec_size, 1) + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded.reshape(valid_m, sfn * sf_vec_size, 1) + # Trim to exact n dimension if needed + ref_sfd_row_rcp_expanded = ref_sfd_row_rcp_expanded[:, :n_after_swiglu, :] + + # Apply scale to reference output: ref = ref * ref_sfd_row_rcp + ref_after_row_quant = torch.einsum("mnl,mnl->mnl", ref_after_swiglu, ref_sfd_row_rcp_expanded) + + # Convert ref_after_row_quant : f32 -> f8 -> f32 + ref_ = torch.empty(*(1, valid_m, n_after_swiglu), dtype=torch.uint8, device="cuda").permute(1, 2, 0) + ref_ = from_dlpack(ref_, assumed_align=16).mark_layout_dynamic(leading_dim=1) + ref_.element_type = _convert_to_cutlass_data_type(d_dtype) + ref_device = ref_after_row_quant.cuda() + ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_tensor, ref_) + cute.testing.convert(ref_, ref_tensor) + + ref_tensors["d_col_ref"] = ref_device.clone().permute(1, 0, 2) + + return ref_tensors + + +# ============================================================================= +# Reference Checking +# ============================================================================= + + +def check_ref_grouped_gemm_swiglu( + inputs: Dict[str, Any], + outputs: Dict[str, Any], + cfg: Dict[str, Any], + atol: float = 1e-1, + rtol: float = 1e-2, + skip_ref: bool = False, +) -> None: + """Check grouped GEMM SwiGLU result against reference. + + :param inputs: Dictionary of input tensors (from allocate_grouped_gemm_input_tensors) + :param outputs: Dictionary of output tensors (from allocate_grouped_gemm_output_tensors) + :param cfg: Configuration dictionary (from grouped_gemm_swiglu_init) + :param atol: Absolute tolerance + :param rtol: Relative tolerance + :param skip_ref: Skip reference check if True + """ + if skip_ref: + print("Skipping reference check") + return + + # Run reference + ref_tensors = run_grouped_gemm_swiglu_ref( + a_ref=inputs["a_ref"].to(torch.float32), + b_ref=inputs["b_ref"].to(torch.float32), + sfa_ref=inputs["sfa_ref"].to(torch.float32), + sfb_ref=inputs["sfb_ref"].to(torch.float32), + alpha_tensor=inputs["alpha_tensor"], + prob_tensor=inputs["prob_tensor"], + aligned_group_m_list=inputs["aligned_group_m_list"], + valid_m=inputs["valid_m"], + generate_amax=(outputs.get("amax_tensor") is not None), + generate_sfd=(outputs.get("sfd_row_tensor") is not None), + norm_const_tensor=inputs.get("norm_const_tensor"), + c_dtype=cfg["c_dtype"], + d_dtype=cfg["d_dtype"], + sf_vec_size=cfg["sf_vec_size"], + sf_dtype=cfg["sf_dtype"], + ) + + torch.cuda.synchronize() + + c_gpu = outputs["c_tensor"][: inputs["valid_m"]] + c_ref = ref_tensors["c_ref"] + torch.testing.assert_close( + c_gpu.cpu().float(), + c_ref.cpu().to(cfg["c_dtype"]).to(torch.float32), + atol=atol, + rtol=rtol, + ) + + if cfg["d_dtype"] in [torch.float32, torch.float16, torch.bfloat16]: + if ref_tensors.get("amax_ref") is not None: + amax_gpu = outputs["amax_tensor"] + amax_ref = ref_tensors["amax_ref"] + torch.testing.assert_close( + amax_gpu.cpu().squeeze(), + amax_ref.cpu(), + atol=atol, + rtol=rtol, + ) + + d_gpu = outputs["d_tensor"][: inputs["valid_m"]] + d_ref = ref_tensors["d_ref"] + torch.testing.assert_close( + d_gpu.cpu().float(), + d_ref.cpu().to(cfg["d_dtype"]).to(torch.float32), + atol=atol, + rtol=rtol, + ) + elif cfg["d_dtype"] in [torch.float8_e4m3fn, torch.float8_e5m2]: + if ref_tensors.get("sfd_row_ref") is not None: # generate_sfd + # sfd_row_ref + sfd_row_gpu = outputs["sfd_row_tensor"] + sfd_row_ref = ref_tensors["sfd_row_ref"] + torch.testing.assert_close( + sfd_row_gpu.cpu().float(), + sfd_row_ref.cpu().to(torch.float32), + atol=atol, + rtol=rtol, + ) + + # d_ref (row) + d_gpu = outputs["d_tensor"] + d_ref = ref_tensors["d_ref"] + torch.testing.assert_close( + d_gpu.cpu().float(), + d_ref.to(cfg["d_dtype"]).to(torch.float32).cpu(), + atol=atol, + rtol=rtol, + ) + + # sfd_col + if cfg["discrete_col_sfd"]: + # discrete col sfd + group_m_list = inputs["aligned_group_m_list"] + group_n_tile_list = [group // 128 for group in group_m_list] + m_tile = ref_tensors["sfd_col_ref"].shape[2] + + sfd_col_torch_gpu_f8 = outputs["sfd_col_tensor"].cpu().to(torch.float32) + sfd_col_ref_f32 = ref_tensors["sfd_col_ref"].cpu().to(torch.float32) + + res_real_idx = 0 + cumsum_n = 0 + total_n = sum(group_n_tile_list) + for n_tile in group_n_tile_list: + for m_idx in range(m_tile): + for n_idx in range(n_tile): + res_real_m_idx = res_real_idx // total_n + res_real_n_idx = res_real_idx % total_n + + ref_real_n_idx = n_idx + cumsum_n + ref_slice = sfd_col_ref_f32[:, :, m_idx, :, ref_real_n_idx, :] + res_slice = sfd_col_torch_gpu_f8[:, :, res_real_m_idx, :, res_real_n_idx, :] + torch.testing.assert_close( + ref_slice, + res_slice, + atol=atol, + rtol=rtol, + ) + res_real_idx += 1 + cumsum_n += n_tile + else: + # contiguous col sfd + sfd_col_gpu = outputs["sfd_col_tensor"] + sfd_col_ref = ref_tensors["sfd_col_ref"] + torch.testing.assert_close( + sfd_col_gpu.cpu().float(), + sfd_col_ref.cpu().to(torch.float32), + atol=atol, + rtol=rtol, + ) + + # d_col_ref + d_col_gpu = outputs["d_col_tensor"] + d_col_ref = ref_tensors["d_col_ref"] + torch.testing.assert_close( + d_col_gpu.cpu().float(), + d_col_ref.to(cfg["d_dtype"]).to(torch.float32).cpu(), + atol=atol, + rtol=rtol, + ) + else: + # Note: This is outside support surface + d_gpu = outputs["d_tensor"][: inputs["valid_m"]] + d_ref = ref_tensors["d_ref"][: inputs["valid_m"]] + torch.testing.assert_close( + d_gpu.cpu().float(), + d_ref.cpu().to(cfg["d_dtype"]).to(torch.float32), + atol=atol, + rtol=rtol, + ) + + else: + raise NotImplementedError(f"Unsupported dtype: {cfg['d_dtype']}") diff --git a/test/python/pytest.ini b/test/python/pytest.ini index d5da807d..b87edc3c 100644 --- a/test/python/pytest.ini +++ b/test/python/pytest.ini @@ -6,5 +6,4 @@ markers = L3: specifies L3 level (use -m L3) L4: specifies L4 level (use -m L4) addopts = - -m L0 --tb=short - + -m L0 --tb=short --no-header diff --git a/test/python/sdpa/blocked.py b/test/python/sdpa/blocked.py new file mode 100644 index 00000000..2f5526db --- /dev/null +++ b/test/python/sdpa/blocked.py @@ -0,0 +1,112 @@ +# Blocked tests configuration +# Format: "test_name": {"sms": ["SM_90", "SM_100"], "cudnn_versions": ["91100"]} +# - sms: List of GPU architectures to block on (e.g., "SM_90", "SM_100") +# - cudnn_versions: List of cuDNN versions to block on (e.g., "91100") +# If a field is None or missing, the test is blocked on all values for that field. + +# fmt: off + +BLOCKED_TESTS = { + # Currently empty - add blocked tests as needed + # Example entries: + # "test_sdpa_random_bwd[test64]": {"sms": ["SM_90", "SM_100"], "cudnn_versions": ["91100"]}, + # "test_sdpa_random_bwd[test65]": {"sms": ["SM_100"], "cudnn_versions": ["91100", "91000"]}, + # "test_sdpa_random_bwd[test66]": {"sms": ["SM_80"]}, + # "test_sdpa_random_bwd[test67]": {"cudnn_versions": ["90000"]}, + # "test_sdpa_random_bwd[test68]": {}, + + # FP8 forward edge cases producing NaN - blocked until investigated + # Original test_sdpa_fp8.py only tested: h_q=h_k=h_v=4, s_kv=256/1024, d_qk=64/128/192, d_v=64/128 + # + # | Test | s_q | s_kv | h_q | h_k | d_qk | d_v | dtype | otype | Issue | + # |---------|-----|------|-----|-----|------|-----|---------|--------|-------------------------------| + # | test14 | 89 | 569 | 8 | 1 | 128 | 128 | e5m2 | fp16 | e5m2+GQA+non-aligned s_q | + # | test17 | 207 | 207 | 9 | 9 | 120 | 120 | e4m3 | e4m3 | d_qk=120 not multiple of 16 | + # | test18 | 766 | 766 | 13 | 1 | 192 | 128 | e5m2 | fp16 | e5m2+d_qk=192+GQA | + # | test21 | 1 | 936 | 10 | 5 | 64 | 64 | e4m3 | e5m2 | s_q=1 + GQA + mixed fp8 out | + # | test40 | 1 | 552 | 3 | 3 | 64 | 64 | e4m3 | e4m3 | s_q=1 + MHA | + # | test41 | 1 | 225 | 11 | 11 | 64 | 64 | e4m3 | fp16 | s_q=1 + MHA | + # | test42 | 896 | 896 | 13 | 13 | 192 | 128 | e4m3 | fp16 | d_qk=192 + large MHA | + # | test57 | 1 | 949 | 8 | 8 | 64 | 64 | e5m2 | e4m3 | s_q=1 + MHA + mixed fp8 out | + # | test64 | 1 | 489 | 9 | 1 | 64 | 64 | e5m2 | fp16 | s_q=1 + GQA + e5m2 | + # | test73 | 1 | 321 | 9 | 1 | 64 | 64 | e4m3 | fp16 | s_q=1 + GQA | + # | test86 | 1 | 375 | 8 | 2 | 64 | 64 | e5m2 | fp16 | s_q=1 + GQA + e5m2 | + # | test90 | 1 | 213 | 12 | 3 | 64 | 64 | e4m3 | fp16 | s_q=1 + GQA | + # | test96 | 1 | 132 | 13 | 1 | 64 | 64 | e4m3 | fp16 | s_q=1 + GQA | + # | test128 | 1 | 707 | 10 | 1 | 64 | 64 | e4m3 | e5m2 | s_q=1 + GQA + mixed fp8 out | + "test_sdpa_fp8_fwd_L0[test14]": {}, + "test_sdpa_fp8_fwd_L0[test17]": {}, + "test_sdpa_fp8_fwd_L0[test18]": {}, + "test_sdpa_fp8_fwd_L0[test21]": {}, + "test_sdpa_fp8_fwd_L0[test40]": {}, + "test_sdpa_fp8_fwd_L0[test41]": {}, + "test_sdpa_fp8_fwd_L0[test42]": {}, + "test_sdpa_fp8_fwd_L0[test57]": {}, + "test_sdpa_fp8_fwd_L0[test64]": {}, + "test_sdpa_fp8_fwd_L0[test73]": {}, + "test_sdpa_fp8_fwd_L0[test86]": {}, + "test_sdpa_fp8_fwd_L0[test90]": {}, + "test_sdpa_fp8_fwd_L0[test96]": {}, + "test_sdpa_fp8_fwd_L0[test128]": {}, + + # Ragged backward tests failing on Ampere (SM_80) - disallowed mismatches + "test_sdpa_random_bwd_ragged_L0[test2]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test13]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test40]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test41]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test59]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test60]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test66]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test72]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test91]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test96]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test111]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test116]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test126]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test131]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test133]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test136]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test139]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test144]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test145]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test153]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test155]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test162]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test163]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test166]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test188]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test192]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test213]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test218]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test220]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test235]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test237]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test238]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test241]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test243]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test247]": {"sms": ["SM_80"]}, + "test_sdpa_random_bwd_ragged_L0[test256]": {"sms": ["SM_80"]}, +} + + +def show_blocked_tests(blocked_tests, gpu_arch, cudnn_ver): + print(f"\n\nBlocked tests on {gpu_arch} and cudnn_ver={cudnn_ver}:") + if blocked_tests: + for index, test in enumerate(blocked_tests): + print(f"{index+1:<4} : {test}") + else: + print("[empty]") + +def fetch_blocked_tests(gpu_arch, cudnn_ver): + """ + Returns a list of test names that should be blocked for the given GPU architecture + and cuDNN version. + """ + assert type(gpu_arch) == type(cudnn_ver) == str, "expecting strings" + blocked_tests = [] + for test, config in BLOCKED_TESTS.items(): + sms = config.get("sms") + libs = config.get("cudnn_versions") + if (test not in blocked_tests) and (sms is None or gpu_arch in sms) and (libs is None or cudnn_ver in libs): + blocked_tests.append(test) + return blocked_tests diff --git a/test/python/sdpa/fp16.py b/test/python/sdpa/fp16.py new file mode 100644 index 00000000..72cb102c --- /dev/null +++ b/test/python/sdpa/fp16.py @@ -0,0 +1,600 @@ +import cudnn +import pytest +import torch +from enum import IntEnum +from looseversion import LooseVersion + +from .fp16_ref import compute_ref +from .helpers import ( + convert_to_cudnn_type, + exact_equal, + approx_equal, + alloc_tensor, + prefix_sum, + convert_packed_to_uniform, + convert_uniform_to_packed, + create_container_and_page_table, + time_execution, + profile_execution, +) + +# fmt: off + +class TensorUid(IntEnum): + q = 0 + k = 1 + v = 2 + o = 3 + stats = 4 + bias = 5 + dQ = 6 + dK = 7 + dV = 8 + dO = 9 + dBias = 10 + seq_len_q = 11 + seq_len_kv = 12 + q_ragged_offset = 13 + k_ragged_offset = 14 + v_ragged_offset = 15 + o_ragged_offset = 16 + stats_ragged_offset = 17 + seed = 18 + offset = 19 + rng_dump = 20 + block_mask = 21 + container_k = 22 + container_v = 23 + page_table_k = 24 + page_table_v = 25 + workspace = 26 + + +def validate_config(cfg): + if not all((x > 0 and type(x) == int) for x in (cfg.batches, cfg.d_qk, cfg.d_v, cfg.s_q, cfg.s_kv, cfg.h_q, cfg.h_k, cfg.h_v)): + assert False, "tensor dimensions must be integer and positive" + + assert cfg.shape_q == (cfg.batches, cfg.h_q, cfg.s_q, cfg.d_qk), f"wrong shape_q={cfg.shape_q}" + assert cfg.shape_k == (cfg.batches, cfg.h_k, cfg.s_kv, cfg.d_qk), f"wrong shape_k={cfg.shape_k}" + assert cfg.shape_v == (cfg.batches, cfg.h_v, cfg.s_kv, cfg.d_v), f"wrong shape_v={cfg.shape_v}" + assert cfg.shape_o == (cfg.batches, cfg.h_q, cfg.s_q, cfg.d_v), f"wrong shape_o={cfg.shape_o}" + + if cfg.is_train: + assert cfg.is_paged == False and cfg.block_size == None, "paged attention not allowed in backward pass" + + if cfg.is_ragged: + assert cfg.is_padding == True, "is_ragged=True and is_padding=False not allowed" + + assert isinstance(cfg.seq_len_q, (list, tuple)), "input 'seq_len_q' must be list or tuple" + if cfg.is_padding: + assert len(cfg.seq_len_q) == cfg.batches, f"wrong 'seq_len_q' length" + else: + assert len(cfg.seq_len_q) == 0, f"wrong 'seq_len_q' length, expecting 0" + + assert isinstance(cfg.seq_len_kv, (list, tuple)), "input 'seq_len_kv' must be list or tuple" + if cfg.is_padding: + assert len(cfg.seq_len_kv) == cfg.batches, f"wrong 'seq_len_kv' length, expecting {cfg.batches}" + else: + assert len(cfg.seq_len_kv) == 0, f"wrong 'seq_len_kv' length, expecting 0" + + assert all(x >= 0 and type(x) == int for x in cfg.seq_len_q), f"wrong seq_len_q={cfg.seq_len_q}" + assert all(x >= 0 and type(x) == int for x in cfg.seq_len_kv), f"wrong seq_len_kv={cfg.seq_len_kv}" + + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + print("@@@@ Overall result: WAIVED, test_mhas_v2.py supports cudnn 9.10.0 or higher.") + pytest.skip("test_mhas_v2.py requires cudnn 9.10.0 or higher") + + if cudnn_version < "9.13.1" and cfg.implementation == cudnn.attention_implementation.UNIFIED: + print("@@@@ Overall result: WAIVED, unified SDPA implementation requires cudnn 9.13.1 or higher.") + pytest.skip("unified SDPA implementation requires cudnn 9.13.1 or higher") + + if cfg.s_q == cfg.s_kv == 1: + print("@@@@ Overall result: WAIVED, skipping known issue of s_q == s_kv == 1.") + pytest.skip("skipping known issue of s_q == s_kv == 1") + + +def allocate_tensors(cfg, rng_data_gen): + allocs = {} + max_t_q = max(64, ((sum(cfg.seq_len_q) + 63) // 64) * 64) if cfg.is_ragged else None + max_t_kv = max(64, ((sum(cfg.seq_len_kv) + 63) // 64) * 64) if cfg.is_ragged else None + + if cfg.is_ragged: + allocs[TensorUid.q] = alloc_tensor((max_t_q, cfg.h_q, cfg.d_qk), cfg.data_type, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.k] = alloc_tensor((max_t_kv, cfg.h_k, cfg.d_qk), cfg.data_type, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.v] = alloc_tensor((max_t_kv, cfg.h_v, cfg.d_v), cfg.data_type, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.o] = alloc_tensor((max_t_q, cfg.h_q, cfg.d_v), cfg.data_type) + allocs[TensorUid.stats] = alloc_tensor((max_t_q, cfg.h_q, 1), torch.float32) if cfg.is_train else (None, None, None) + if cfg.is_train: + allocs[TensorUid.dQ] = alloc_tensor((max_t_q, cfg.h_q, cfg.d_qk), cfg.data_type) + allocs[TensorUid.dK] = alloc_tensor((max_t_kv, cfg.h_k, cfg.d_qk), cfg.data_type) + allocs[TensorUid.dV] = alloc_tensor((max_t_kv, cfg.h_v, cfg.d_v), cfg.data_type) + allocs[TensorUid.dO] = alloc_tensor((max_t_q, cfg.h_q, cfg.d_v), cfg.data_type, rng=rng_data_gen, mean=0.0, std=0.1) + else: + allocs[TensorUid.q] = alloc_tensor(cfg.shape_q, cfg.data_type, strides=cfg.stride_q, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.k] = alloc_tensor(cfg.shape_k, cfg.data_type, strides=cfg.stride_k, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.v] = alloc_tensor(cfg.shape_v, cfg.data_type, strides=cfg.stride_v, rng=rng_data_gen, mean=-0.5, std=1.0) + allocs[TensorUid.o] = alloc_tensor(cfg.shape_o, cfg.data_type, strides=cfg.stride_o) + allocs[TensorUid.stats] = alloc_tensor((cfg.batches, cfg.h_q, cfg.s_q, 1), torch.float32) if cfg.is_train else (None, None, None) + if cfg.is_train: + allocs[TensorUid.dQ] = alloc_tensor(cfg.shape_q, cfg.data_type, strides=cfg.stride_q) + allocs[TensorUid.dK] = alloc_tensor(cfg.shape_k, cfg.data_type, strides=cfg.stride_k) + allocs[TensorUid.dV] = alloc_tensor(cfg.shape_v, cfg.data_type, strides=cfg.stride_v) + allocs[TensorUid.dO] = alloc_tensor(cfg.shape_o, cfg.data_type, strides=cfg.stride_o, rng=rng_data_gen, mean=0.0, std=0.1) + + seq_len_q_gpu = torch.tensor(cfg.seq_len_q, dtype=torch.int32, device="cuda").view(-1, 1, 1, 1) if len(cfg.seq_len_q) > 0 else None + seq_len_kv_gpu = torch.tensor(cfg.seq_len_kv, dtype=torch.int32, device="cuda").view(-1, 1, 1, 1) if len(cfg.seq_len_kv) > 0 else None + allocs[TensorUid.seq_len_q] = (seq_len_q_gpu, None, None) + allocs[TensorUid.seq_len_kv] = (seq_len_kv_gpu, None, None) + + if cfg.is_ragged: + allocs[TensorUid.q_ragged_offset] = ((prefix_sum(seq_len_q_gpu) * cfg.h_q * cfg.d_qk).to(torch.int64), None, None) + allocs[TensorUid.k_ragged_offset] = ((prefix_sum(seq_len_kv_gpu) * cfg.h_k * cfg.d_qk).to(torch.int64), None, None) + allocs[TensorUid.v_ragged_offset] = ((prefix_sum(seq_len_kv_gpu) * cfg.h_v * cfg.d_v).to(torch.int64), None, None) + allocs[TensorUid.o_ragged_offset] = ((prefix_sum(seq_len_q_gpu) * cfg.h_q * cfg.d_v).to(torch.int64), None, None) + allocs[TensorUid.stats_ragged_offset] = ((prefix_sum(seq_len_q_gpu) * cfg.h_q * 1).to(torch.int64), None, None) + + if cfg.is_bias: + allocs[TensorUid.bias] = alloc_tensor((1, cfg.h_q, cfg.s_q, cfg.s_kv), cfg.data_type, rng=rng_data_gen, mean=0.0, std=1.0) + if cfg.is_train and cfg.is_bias: + allocs[TensorUid.dBias] = alloc_tensor((1, cfg.h_q, cfg.s_q, cfg.s_kv), cfg.data_type) + + if cfg.is_block_mask: + TILE_M, TILE_N = 128, 128 + block_mask_gpu = torch.randint(0, 256, (cfg.batches, cfg.h_q, (cfg.s_q + TILE_M - 1) // TILE_M, ((cfg.s_kv + TILE_N - 1) // TILE_N + 7) // 8), dtype=torch.uint8, device="cuda") + allocs[TensorUid.block_mask] = (block_mask_gpu, None, None) + + if cfg.is_dropout: + allocs[TensorUid.seed] = (torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda"), None, None) + allocs[TensorUid.offset] = (torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda"), None, None) + allocs[TensorUid.rng_dump] = (torch.zeros((cfg.batches, cfg.h_q, cfg.s_q, cfg.s_kv), dtype=torch.float32, device="cuda"), None, None) + + if cfg.is_paged: + container_k, page_table_k = create_container_and_page_table(allocs[TensorUid.k][0], cfg.block_size) + container_v, page_table_v = create_container_and_page_table(allocs[TensorUid.v][0], cfg.block_size) + allocs[TensorUid.container_k] = (container_k, None, None) + allocs[TensorUid.container_v] = (container_v, None, None) + allocs[TensorUid.page_table_k] = (page_table_k, None, None) + allocs[TensorUid.page_table_v] = (page_table_v, None, None) + + tensors = {uid: alloc[0] for uid, alloc in allocs.items()} + return allocs, tensors, max_t_q, max_t_kv + + +def create_forward_graph(cfg, tensors, cudnn_handle): + cudnn_dtype = convert_to_cudnn_type(cfg.data_type) + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + ) + + q = graph.tensor(uid=int(TensorUid.q), dim=cfg.shape_q, stride=cfg.stride_q, data_type=cudnn_dtype) + k = graph.tensor(uid=int(TensorUid.k), dim=cfg.shape_k, stride=cfg.stride_k, data_type=cudnn_dtype) + v = graph.tensor(uid=int(TensorUid.v), dim=cfg.shape_v, stride=cfg.stride_v, data_type=cudnn_dtype) + + page_table_k = page_table_v = paged_attention_max_seq_len_kv = None + if cfg.is_paged: + container_k_gpu = tensors.get(TensorUid.container_k) + container_v_gpu = tensors.get(TensorUid.container_v) + page_table_k_gpu = tensors.get(TensorUid.page_table_k) + page_table_v_gpu = tensors.get(TensorUid.page_table_v) + k = graph.tensor(uid=int(TensorUid.container_k), dim=container_k_gpu.size(), stride=container_k_gpu.stride(), data_type=cudnn_dtype) + v = graph.tensor(uid=int(TensorUid.container_v), dim=container_v_gpu.size(), stride=container_v_gpu.stride(), data_type=cudnn_dtype) + page_table_k = graph.tensor(uid=int(TensorUid.page_table_k), dim=page_table_k_gpu.size(), stride=page_table_k_gpu.stride(), data_type=cudnn.data_type.INT32) + page_table_v = graph.tensor(uid=int(TensorUid.page_table_v), dim=page_table_v_gpu.size(), stride=page_table_v_gpu.stride(), data_type=cudnn.data_type.INT32) + paged_attention_max_seq_len_kv = cfg.s_kv + + bias = graph.tensor(uid=int(TensorUid.bias), dim=(1, cfg.h_q, cfg.s_q, cfg.s_kv), stride=(cfg.h_q * cfg.s_q * cfg.s_kv, cfg.s_q * cfg.s_kv, cfg.s_kv, 1), data_type=cudnn_dtype) if cfg.is_bias else None + + TILE_M, TILE_N = 128, 128 + block_mask_dim = (cfg.batches, cfg.h_q, (cfg.s_q + TILE_M - 1) // TILE_M, ((cfg.s_kv + TILE_N - 1) // TILE_N + 7) // 8) + block_mask = graph.tensor(uid=int(TensorUid.block_mask), dim=block_mask_dim, stride=(block_mask_dim[1]*block_mask_dim[2]*block_mask_dim[3], block_mask_dim[2]*block_mask_dim[3], block_mask_dim[3], 1), data_type=cudnn.data_type.UINT8) if cfg.is_block_mask else None + + seq_len_q = graph.tensor(uid=int(TensorUid.seq_len_q), dim=(cfg.batches, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) if cfg.is_padding else None + seq_len_kv = graph.tensor(uid=int(TensorUid.seq_len_kv), dim=(cfg.batches, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) if cfg.is_padding else None + + seed = offset = dropout_tuple = rng_dump = None + if cfg.is_dropout: + seed = graph.tensor(uid=int(TensorUid.seed), dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + offset = graph.tensor(uid=int(TensorUid.offset), dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + dropout_tuple = (cfg.dropout_prob, seed, offset) + rng_dump = graph.tensor(uid=int(TensorUid.rng_dump), dim=(cfg.batches, cfg.h_q, cfg.s_q, cfg.s_kv), stride=(cfg.h_q * cfg.s_q * cfg.s_kv, cfg.s_q * cfg.s_kv, cfg.s_kv, 1), data_type=cudnn.data_type.FLOAT) + + q_ragged_offset = graph.tensor(uid=int(TensorUid.q_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) if cfg.is_ragged else None + k_ragged_offset = graph.tensor(uid=int(TensorUid.k_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) if cfg.is_ragged else None + v_ragged_offset = graph.tensor(uid=int(TensorUid.v_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) if cfg.is_ragged else None + o_ragged_offset = graph.tensor(uid=int(TensorUid.o_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) if cfg.is_ragged else None + stats_ragged_offset = graph.tensor(uid=int(TensorUid.stats_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) if cfg.is_ragged and cfg.is_train else None + + if cfg.is_ragged: + q.set_ragged_offset(q_ragged_offset) + k.set_ragged_offset(k_ragged_offset) + v.set_ragged_offset(v_ragged_offset) + + attn_scale = 0.125 + + o, stats = graph.sdpa( + name="sdpa_forward", + q=q, k=k, v=v, + generate_stats=cfg.is_train, + attn_scale=attn_scale, + bias=bias, + block_mask=block_mask, + use_alibi_mask=cfg.is_alibi, + use_padding_mask=cfg.is_padding, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + diagonal_band_left_bound=cfg.left_bound, + diagonal_band_right_bound=cfg.right_bound, + diagonal_alignment=cfg.diag_align, + dropout=dropout_tuple, + rng_dump=rng_dump, + paged_attention_k_table=page_table_k, + paged_attention_v_table=page_table_v, + paged_attention_max_seq_len_kv=paged_attention_max_seq_len_kv, + implementation=cfg.implementation, + ) + + o.set_uid(int(TensorUid.o)).set_output(True).set_dim(cfg.shape_o).set_stride(cfg.stride_o) + if cfg.is_ragged: + o.set_ragged_offset(o_ragged_offset) + + if cfg.is_train: + dim_stats = (cfg.batches, cfg.h_q, cfg.s_q, 1) + stride_stats = (cfg.s_q * cfg.h_q, 1, cfg.h_q, 1) if cfg.is_ragged else (cfg.h_q * cfg.s_q, cfg.s_q, 1, 1) + stats.set_uid(int(TensorUid.stats)).set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(dim_stats).set_stride(stride_stats) + if cfg.is_ragged: + stats.set_ragged_offset(stats_ragged_offset) + + try: + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + except cudnn.cudnnGraphNotSupportedError as e: + print(f"@@@@ Overall result: WAIVED, not supported forward graph. {e}") + pytest.skip("not supported forward graph") + except Exception as e: + print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception during forward graph build. {e}") + pytest.fail("unexpected exception during forward graph build", pytrace=False) + + variant_pack = { + int(TensorUid.q): tensors.get(TensorUid.q), + int(TensorUid.container_k) if cfg.is_paged else int(TensorUid.k): tensors.get(TensorUid.container_k) if cfg.is_paged else tensors.get(TensorUid.k), + int(TensorUid.container_v) if cfg.is_paged else int(TensorUid.v): tensors.get(TensorUid.container_v) if cfg.is_paged else tensors.get(TensorUid.v), + int(TensorUid.bias): tensors.get(TensorUid.bias), + int(TensorUid.block_mask): tensors.get(TensorUid.block_mask), + int(TensorUid.seq_len_q): tensors.get(TensorUid.seq_len_q), + int(TensorUid.seq_len_kv): tensors.get(TensorUid.seq_len_kv), + int(TensorUid.q_ragged_offset): tensors.get(TensorUid.q_ragged_offset), + int(TensorUid.k_ragged_offset): tensors.get(TensorUid.k_ragged_offset), + int(TensorUid.v_ragged_offset): tensors.get(TensorUid.v_ragged_offset), + int(TensorUid.o_ragged_offset): tensors.get(TensorUid.o_ragged_offset), + int(TensorUid.stats_ragged_offset): tensors.get(TensorUid.stats_ragged_offset), + int(TensorUid.o): tensors.get(TensorUid.o), + int(TensorUid.stats): tensors.get(TensorUid.stats), + int(TensorUid.page_table_k): tensors.get(TensorUid.page_table_k), + int(TensorUid.page_table_v): tensors.get(TensorUid.page_table_v), + int(TensorUid.seed): tensors.get(TensorUid.seed), + int(TensorUid.offset): tensors.get(TensorUid.offset), + int(TensorUid.rng_dump): tensors.get(TensorUid.rng_dump), + } + variant_pack = {k: v for k, v in variant_pack.items() if v is not None} + + return graph, variant_pack + + +def create_backward_graph(cfg, tensors, cudnn_handle, max_t_q, max_t_kv): + cudnn_dtype = convert_to_cudnn_type(cfg.data_type) + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + sm_version = torch.cuda.get_device_capability()[0] * 10 + torch.cuda.get_device_capability()[1] + + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + sm_version=sm_version + ) + + dim_stats = (cfg.batches, cfg.h_q, cfg.s_q, 1) + stride_stats = (cfg.s_q * cfg.h_q, 1, cfg.h_q, 1) if cfg.is_ragged else (cfg.h_q * cfg.s_q, cfg.s_q, 1, 1) + + q = graph.tensor(uid=int(TensorUid.q), dim=cfg.shape_q, stride=cfg.stride_q, data_type=cudnn_dtype) + k = graph.tensor(uid=int(TensorUid.k), dim=cfg.shape_k, stride=cfg.stride_k, data_type=cudnn_dtype) + v = graph.tensor(uid=int(TensorUid.v), dim=cfg.shape_v, stride=cfg.stride_v, data_type=cudnn_dtype) + o = graph.tensor(uid=int(TensorUid.o), dim=cfg.shape_o, stride=cfg.stride_o, data_type=cudnn_dtype) + dO = graph.tensor(uid=int(TensorUid.dO), dim=cfg.shape_o, stride=cfg.stride_o, data_type=cudnn_dtype) + stats = graph.tensor(uid=int(TensorUid.stats), dim=dim_stats, stride=stride_stats, data_type=cudnn.data_type.FLOAT) + + bias_dim = (1, cfg.h_q, cfg.s_q, cfg.s_kv) + bias_stride = (cfg.h_q * cfg.s_q * cfg.s_kv, cfg.s_q * cfg.s_kv, cfg.s_kv, 1) + bias = graph.tensor(uid=int(TensorUid.bias), dim=bias_dim, stride=bias_stride, data_type=cudnn_dtype) if cfg.is_bias else None + dBias = graph.tensor(uid=int(TensorUid.dBias), dim=bias_dim, stride=bias_stride, data_type=cudnn_dtype) if cfg.is_bias else None + + seq_len_q = graph.tensor(uid=int(TensorUid.seq_len_q), dim=(cfg.batches, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) if cfg.is_padding else None + seq_len_kv = graph.tensor(uid=int(TensorUid.seq_len_kv), dim=(cfg.batches, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) if cfg.is_padding else None + + seed = offset = dropout_tuple = None + if cfg.is_dropout: + seed = graph.tensor(uid=int(TensorUid.seed), dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + offset = graph.tensor(uid=int(TensorUid.offset), dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + dropout_tuple = (cfg.dropout_prob, seed, offset) + + attn_scale = 0.125 + + dQ, dK, dV = graph.sdpa_backward( + name="sdpa_backward", + q=q, k=k, v=v, o=o, dO=dO, stats=stats, + attn_scale=attn_scale, + bias=bias, + dBias=dBias, + use_alibi_mask=cfg.is_alibi, + use_padding_mask=cfg.is_padding, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + max_total_seq_len_q=max_t_q, + max_total_seq_len_kv=max_t_kv, + diagonal_band_left_bound=cfg.left_bound, + diagonal_band_right_bound=cfg.right_bound, + diagonal_alignment=cfg.diag_align, + dropout=dropout_tuple, + use_deterministic_algorithm=cfg.is_determin, + ) + + dQ.set_uid(int(TensorUid.dQ)).set_output(True).set_dim(cfg.shape_q).set_stride(cfg.stride_q) + dK.set_uid(int(TensorUid.dK)).set_output(True).set_dim(cfg.shape_k).set_stride(cfg.stride_k) + dV.set_uid(int(TensorUid.dV)).set_output(True).set_dim(cfg.shape_v).set_stride(cfg.stride_v) + + if cfg.is_ragged: + q_ragged_offset = graph.tensor(uid=int(TensorUid.q_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + k_ragged_offset = graph.tensor(uid=int(TensorUid.k_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + v_ragged_offset = graph.tensor(uid=int(TensorUid.v_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + o_ragged_offset = graph.tensor(uid=int(TensorUid.o_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + stats_ragged_offset = graph.tensor(uid=int(TensorUid.stats_ragged_offset), dim=(cfg.batches + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64) + q.set_ragged_offset(q_ragged_offset) + k.set_ragged_offset(k_ragged_offset) + v.set_ragged_offset(v_ragged_offset) + o.set_ragged_offset(o_ragged_offset) + stats.set_ragged_offset(stats_ragged_offset) + dQ.set_ragged_offset(q_ragged_offset) + dK.set_ragged_offset(k_ragged_offset) + dV.set_ragged_offset(v_ragged_offset) + dO.set_ragged_offset(o_ragged_offset) + + try: + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + except cudnn.cudnnGraphNotSupportedError as e: + print(f"@@@@ Overall result: WAIVED, not supported backward graph. {e}") + pytest.skip("not supported backward graph") + except Exception as e: + print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception during backward graph build. {e}") + pytest.fail("unexpected exception during backward graph build", pytrace=False) + + variant_pack = { + int(TensorUid.q): tensors.get(TensorUid.q), + int(TensorUid.k): tensors.get(TensorUid.k), + int(TensorUid.v): tensors.get(TensorUid.v), + int(TensorUid.o): tensors.get(TensorUid.o), + int(TensorUid.stats): tensors.get(TensorUid.stats), + int(TensorUid.dQ): tensors.get(TensorUid.dQ), + int(TensorUid.dK): tensors.get(TensorUid.dK), + int(TensorUid.dV): tensors.get(TensorUid.dV), + int(TensorUid.dO): tensors.get(TensorUid.dO), + int(TensorUid.bias): tensors.get(TensorUid.bias), + int(TensorUid.dBias): tensors.get(TensorUid.dBias), + int(TensorUid.seq_len_q): tensors.get(TensorUid.seq_len_q), + int(TensorUid.seq_len_kv): tensors.get(TensorUid.seq_len_kv), + int(TensorUid.q_ragged_offset): tensors.get(TensorUid.q_ragged_offset), + int(TensorUid.k_ragged_offset): tensors.get(TensorUid.k_ragged_offset), + int(TensorUid.v_ragged_offset): tensors.get(TensorUid.v_ragged_offset), + int(TensorUid.o_ragged_offset): tensors.get(TensorUid.o_ragged_offset), + int(TensorUid.stats_ragged_offset): tensors.get(TensorUid.stats_ragged_offset), + int(TensorUid.seed): tensors.get(TensorUid.seed), + int(TensorUid.offset): tensors.get(TensorUid.offset), + } + variant_pack = {k: v for k, v in variant_pack.items() if v is not None} + + return graph, variant_pack + + +def check_deterministic(cfg, tensors, allocs, bwd_graph, bwd_pack, cudnn_handle, request): + if not cfg.is_determin: + return + + dQ_gpu = tensors.get(TensorUid.dQ) + dK_gpu = tensors.get(TensorUid.dK) + dV_gpu = tensors.get(TensorUid.dV) + workspace = allocs[TensorUid.workspace] + + dQ_gpu_rerun = dQ_gpu.clone().detach() + dK_gpu_rerun = dK_gpu.clone().detach() + dV_gpu_rerun = dV_gpu.clone().detach() + + torch.fill_(dQ_gpu, float("nan")) + torch.fill_(dK_gpu, float("nan")) + torch.fill_(dV_gpu, float("nan")) + bwd_graph.execute(bwd_pack, workspace[0], cudnn_handle) + torch.cuda.synchronize() + + determin_err_count = 0 + determin_err_count += exact_equal(dQ_gpu, dQ_gpu_rerun, tag="dQ_determin", disp_elems=request.config.getoption("--diffs")) + determin_err_count += exact_equal(dK_gpu, dK_gpu_rerun, tag="dK_determin", disp_elems=request.config.getoption("--diffs")) + determin_err_count += exact_equal(dV_gpu, dV_gpu_rerun, tag="dV_determin", disp_elems=request.config.getoption("--diffs")) + + if determin_err_count != 0: + print("@@@@ Overall result: FAILED, determinism check failed - outputs differ between runs.") + pytest.fail("determinism check failed", pytrace=False) + print("@@@@ Determinism check: PASSED, dQ, dK, dV bitwise match between runs.") + + +def execute_graph(graph, variant_pack, allocs, tensors, cudnn_handle, request, label="Graph"): + workspace = alloc_tensor(graph.get_workspace_size(), torch.uint8) + allocs[TensorUid.workspace] = workspace + tensors[TensorUid.workspace] = workspace[0] + + if request.config.getoption("--perf"): + times_ms = time_execution(graph.execute, variant_pack, workspace[0], cudnn_handle) + print(f"@@@@ {label} graph.execute avg_time_ms={times_ms.mean().item():.3f}") + profile_execution(graph.execute, variant_pack, workspace[0], cudnn_handle) + + graph.execute(variant_pack, workspace[0], cudnn_handle) + torch.cuda.synchronize() + + if workspace[1] is not None and not torch.all(workspace[1]==-1).item(): + print(f"@@@@ Overall result: FAILED, {label} workspace overwritten outside its boundaries.") + print(workspace[1]) + pytest.fail(f"{label} workspace overwritten outside boundaries", pytrace=False) + + +def compute_and_compare_reference(cfg, allocs, tensors, diffs): + cudnn_version = LooseVersion(cudnn.backend_version_string()) + + q_gpu = tensors.get(TensorUid.q) + k_gpu = tensors.get(TensorUid.k) + v_gpu = tensors.get(TensorUid.v) + dO_gpu = tensors.get(TensorUid.dO) + seq_len_q_gpu = tensors.get(TensorUid.seq_len_q) + seq_len_kv_gpu = tensors.get(TensorUid.seq_len_kv) + block_mask_gpu = tensors.get(TensorUid.block_mask) + bias_gpu = tensors.get(TensorUid.bias) + rng_dump_gpu = tensors.get(TensorUid.rng_dump) + + q_ref = q_gpu.detach().float() + k_ref = k_gpu.detach().float() + v_ref = v_gpu.detach().float() + dO_ref = dO_gpu.detach().float() if dO_gpu is not None else None + seq_len_q_ref = seq_len_q_gpu.flatten().detach() if seq_len_q_gpu is not None else None + seq_len_kv_ref = seq_len_kv_gpu.flatten().detach() if seq_len_kv_gpu is not None else None + block_mask_ref = block_mask_gpu.detach() if block_mask_gpu is not None else None + bias_ref = bias_gpu.detach().float() if bias_gpu is not None else None + rng_dump_ref = rng_dump_gpu.detach().float() if rng_dump_gpu is not None else None + + if cfg.is_train: + q_ref.requires_grad_() + k_ref.requires_grad_() + v_ref.requires_grad_() + if cfg.is_train and cfg.is_bias: + bias_ref.requires_grad_() + + if cfg.is_ragged: + q_ref = convert_packed_to_uniform(q_ref, seq_len_q_ref, cfg.s_q) + k_ref = convert_packed_to_uniform(k_ref, seq_len_kv_ref, cfg.s_kv) + v_ref = convert_packed_to_uniform(v_ref, seq_len_kv_ref, cfg.s_kv) + if cfg.is_ragged and cfg.is_train: + dO_ref = convert_packed_to_uniform(dO_ref, seq_len_q_ref, cfg.s_q) + + max_t_q = max(64, ((seq_len_q_ref.sum().item() + 63) // 64) * 64) if cfg.is_ragged else None + max_t_kv = max(64, ((seq_len_kv_ref.sum().item() + 63) // 64) * 64) if cfg.is_ragged else None + + attn_scale = 0.125 + + ret = compute_ref( + q_ref, k_ref, v_ref, + attn_scale=attn_scale, + bias=bias_ref, + block_mask=block_mask_ref, + is_alibi=cfg.is_alibi, + padding=(seq_len_q_ref, seq_len_kv_ref) if cfg.is_padding else None, + left_bound=cfg.left_bound, + right_bound=cfg.right_bound, + diag_align=cfg.diag_align, + dropout_prob=cfg.dropout_prob, + dropout_mask=rng_dump_ref, + generate_stats=cfg.is_train, + ) + + o_ref, stats_ref = ret if cfg.is_train else (ret, None) + + o_gpu = tensors.get(TensorUid.o) + stats_gpu = tensors.get(TensorUid.stats) + + if cfg.is_padding and not cfg.is_ragged: + for i, m in enumerate(seq_len_q_ref): + o_ref[i, :, m:, :] = 0 + o_gpu[i, :, m:, :] = 0 + if cfg.is_train: + if cudnn_version < "9.14.0": + stats_ref[i, :, m:, :] = 0 + stats_gpu[i, :, m:, :] = 0 + else: + stats_ref[i, :, m:, :] = -float("inf") + + if cfg.is_train: + inputs_ref = [q_ref, k_ref, v_ref, bias_ref] if cfg.is_bias else [q_ref, k_ref, v_ref] + grads = torch.autograd.grad(outputs=o_ref, inputs=inputs_ref, grad_outputs=dO_ref) + dQ_ref = grads[0] + dK_ref = grads[1] + dV_ref = grads[2] + dBias_ref = grads[3] if cfg.is_bias else None + + if cfg.is_train and cfg.is_padding: + for i, (m, n) in enumerate(zip(seq_len_q_ref, seq_len_kv_ref)): + dQ_ref[i, :, m:, :] = 0 + dK_ref[i, :, n:, :] = 0 + dV_ref[i, :, n:, :] = 0 + + if cfg.is_ragged: + o_ref = convert_uniform_to_packed(o_ref, seq_len_q_ref, max_t_q) + if cfg.is_train and cfg.is_ragged: + dQ_ref = convert_uniform_to_packed(dQ_ref, seq_len_q_ref, max_t_q) + dK_ref = convert_uniform_to_packed(dK_ref, seq_len_kv_ref, max_t_kv) + dV_ref = convert_uniform_to_packed(dV_ref, seq_len_kv_ref, max_t_kv) + stats_ref = convert_uniform_to_packed(stats_ref, seq_len_q_ref, max_t_q) + + err_count = 0 + err_count += approx_equal(allocs[TensorUid.o], o_ref, atol=2e-2, rtol=2e-2, tag="o", disp_elems=diffs) + if cfg.is_train: + dkv_atol = 2e-2 if cfg.data_type == torch.float16 else 7e-2 + err_count += approx_equal(allocs[TensorUid.stats], stats_ref, atol=2e-2, rtol=2e-2, tag="stats", disp_elems=diffs) + err_count += approx_equal(allocs[TensorUid.dQ], dQ_ref, atol=2e-2, rtol=2e-2, tag="dQ", disp_elems=diffs) + err_count += approx_equal(allocs[TensorUid.dK], dK_ref, atol=dkv_atol, rtol=2e-2, tag="dK", disp_elems=diffs) + err_count += approx_equal(allocs[TensorUid.dV], dV_ref, atol=dkv_atol, rtol=2e-2, tag="dV", disp_elems=diffs) + if cfg.is_train and cfg.is_bias: + err_count += approx_equal(allocs[TensorUid.dBias], dBias_ref, atol=2e-2, rtol=2e-2, tag="dBias", disp_elems=diffs) + + if err_count != 0: + print("@@@@ Overall result: FAILED, disallowed mismatches") + pytest.fail("disallowed mismatches", pytrace=False) + else: + print("@@@@ Overall result: PASSED, everything looks good!") + + +def cleanup_tensors(allocs): + for uid in list(allocs.keys()): + entry = allocs.get(uid) + if entry is not None and entry[0] is not None: + del allocs[uid] + torch.cuda.empty_cache() + + +def exec_sdpa(cfg, request, cudnn_handle): + if request.config.option.dryrun: + pytest.skip("dry run mode") + + validate_config(cfg) + + rng_data_gen = torch.Generator(device="cuda").manual_seed(cfg.rng_data_seed) + allocs, tensors, max_t_q, max_t_kv = allocate_tensors(cfg, rng_data_gen) + + fwd_graph, fwd_pack = create_forward_graph(cfg, tensors, cudnn_handle) + bwd_graph, bwd_pack = create_backward_graph(cfg, tensors, cudnn_handle, max_t_q, max_t_kv) if cfg.is_train else (None, None) + + execute_graph(fwd_graph, fwd_pack, allocs, tensors, cudnn_handle, request, label="Forward") + + if cfg.is_train: + execute_graph(bwd_graph, bwd_pack, allocs, tensors, cudnn_handle, request, label="Backward") + check_deterministic(cfg, tensors, allocs, bwd_graph, bwd_pack, cudnn_handle, request) + + compute_and_compare_reference(cfg, allocs, tensors, request.config.getoption("--diffs")) + cleanup_tensors(allocs) diff --git a/test/python/sdpa/fp16_ref.py b/test/python/sdpa/fp16_ref.py new file mode 100644 index 00000000..a5af5b73 --- /dev/null +++ b/test/python/sdpa/fp16_ref.py @@ -0,0 +1,175 @@ +import torch +import math +import cudnn + +# fmt: off + +def compute_ref( + q, + k, + v, + attn_scale=None, + bias=None, + block_mask=None, + is_alibi=False, + padding=None, + diag_align=cudnn.diagonal_alignment.TOP_LEFT, + left_bound=None, + right_bound=None, + dropout_prob=0.0, + dropout_mask=None, + generate_stats=False, + device="cuda", +): + b, h_q, s_q, d_qk = q.shape + _, h_k, s_kv, _ = k.shape + _, h_v, _, d_v = v.shape + + assert k.shape == (b, h_k, s_kv, d_qk) + assert v.shape == (b, h_v, s_kv, d_v) + + # use float32 datatype and math for reference computation + q = q.to(dtype=torch.float32, device=device) + k = k.to(dtype=torch.float32, device=device) + v = v.to(dtype=torch.float32, device=device) + + # expand tensors for GQA and MQA + if h_q != h_k: + assert h_q % h_k == 0 + k = k.unsqueeze(2) + k = k.expand(-1, -1, h_q // h_k, -1, -1) + k = k.reshape(k.size(0), -1, k.size(3), k.size(4)) + if h_q != h_v: + assert h_q % h_v == 0 + v = v.unsqueeze(2) + v = v.expand(-1, -1, h_q // h_v, -1, -1) + v = v.reshape(v.size(0), -1, v.size(3), v.size(4)) + + # generate masks to compute reference values for padding mask (also called variable sequence length) + if padding is not None: + q_mask = torch.zeros(b, 1, s_q, 1, dtype=torch.bool, device=device) + k_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + v_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + s_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + p_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + seq_len_q, seq_len_kv = padding + for i, (m, n) in enumerate(zip(seq_len_q, seq_len_kv)): + q_mask[i, :, m:, :] = True + k_mask[i, :, n:, :] = True + v_mask[i, :, n:, :] = True + s_mask[i, :, :, n:] = True + p_mask[i, :, m:, :] = True + + q = q.masked_fill(q_mask, 0.0) + k = k.masked_fill(k_mask, 0.0) + v = v.masked_fill(v_mask, 0.0) + + s = torch.einsum("bhqd,bhkd->bhqk", q, k) + if attn_scale is not None: + s = s * attn_scale + + # Attention masks are applied in the following order: + # - Bias mask + # - Alibi mask + # - Padding mask + # - Causal mask + if bias is not None: + s = s + bias + if is_alibi: + index_row = torch.arange(s_q, dtype=torch.float32, device=device).view(-1, 1) + index_col = torch.arange(s_kv, dtype=torch.float32, device=device) + distance = index_col - index_row + + # Get the closest power of 2 to `n_heads`. + # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, + # and then add the remaining slopes. + n = 2 ** math.floor(math.log2(h_q)) + m_0 = 2.0 ** (-8.0 / n) + m = torch.pow(m_0, torch.arange(1, 1 + n)) + + # If `n_heads` is not a power of 2, then we add the remaining slopes. + # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously). + # And pick the slopes upto `n_heads`. + if n < h_q: + m_hat_0 = 2.0 ** (-4.0 / n) + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (h_q - n), 2)) + # Concatenate the slopes with the remaining slopes. + m = torch.cat([m, m_hat]) + + # Reshape the tensor to [1, num_heads, 1, 1] + m = m.view(1, -1, 1, 1).to(device=device) + + alibi_mask = distance.to(dtype=torch.float32) * m + s = s + alibi_mask + + if padding is not None: + s = s.masked_fill(s_mask, float("-inf")) + + if diag_align == diag_align.TOP_LEFT and right_bound is not None: + causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + causal_mask.triu_(diagonal=1 + right_bound) + s = s.masked_fill(causal_mask, float("-inf")) + elif diag_align == diag_align.BOTTOM_RIGHT and right_bound is not None: + causal_mask_bottom_right = None + if padding: + causal_mask_bottom_right = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + seq_len_q, seq_len_kv = padding + for i in range(b): + causal_mask_bottom_right[i, :, :, :].triu_(diagonal=seq_len_kv[i] - seq_len_q[i] + 1 + right_bound) + else: + causal_mask_bottom_right = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1 + right_bound) + s = s.masked_fill(causal_mask_bottom_right, float("-inf")) + + if left_bound is not None: + assert diag_align is not None + if diag_align == diag_align.TOP_LEFT: + swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + swa_mask.tril_(diagonal=-1 * left_bound) + elif diag_align == diag_align.BOTTOM_RIGHT: + # BRCM + SWA for variable sequence lengths + if padding: + swa_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + seq_len_q, seq_len_kv = padding + for i in range(b): + swa_mask[i, :, :, :].tril_(diagonal=seq_len_kv[i] - seq_len_q[i] - left_bound) + # BRCM + SWA for fixed sequence lengths + else: + swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + swa_mask.tril_(diagonal=-1 * left_bound + (s_kv - s_q)) + s = s.masked_fill(swa_mask, float("-inf")) + + if block_mask is not None: + TILE_M = 128 + TILE_N = 128 + + block_mask = block_mask.to(dtype=torch.uint8, device=device) + block_mask = ((block_mask[..., None] & (1 << torch.arange(8, device=block_mask.device))) != 0).reshape(block_mask.shape[0], block_mask.shape[1], block_mask.shape[2], block_mask.shape[3] * 8) + block_mask = block_mask.unsqueeze(3).unsqueeze(5) + block_mask = block_mask.repeat(1, 1, 1, TILE_M, 1, TILE_N) + block_mask = block_mask.reshape(block_mask.shape[0], block_mask.shape[1], block_mask.shape[2] * TILE_M, block_mask.shape[4] * TILE_N) + block_mask = block_mask[:, :, :s_q, :s_kv] + s += torch.where(block_mask, torch.tensor(0.0), torch.tensor(float('-inf'))) + + p = torch.softmax(s, dim=-1) + + all_inf = torch.isneginf(s).all(dim=-1, keepdim=True) + if torch.any(all_inf): + p = torch.where(all_inf, torch.zeros_like(p), p) + + if padding is not None: + p = p.masked_fill(p_mask, 0.0) + + # apply dropout mask over softmax outputs + if dropout_prob != 0.0: + assert dropout_mask != None, "PyTorch reference must have dropout_mask for dropout" + p = (p * dropout_mask) / (1 - dropout_prob) + + o = torch.einsum("bhqk,bhkd->bhqd", p, v) + + # softmax stats is used for backwards computation + if generate_stats: + stats = torch.logsumexp(s, dim=-1, keepdim=True) + return o, stats + + return o diff --git a/test/python/sdpa/fp8.py b/test/python/sdpa/fp8.py new file mode 100644 index 00000000..09dac0e8 --- /dev/null +++ b/test/python/sdpa/fp8.py @@ -0,0 +1,384 @@ +import cudnn +import pytest +import torch +import math +from enum import IntEnum +from looseversion import LooseVersion + +from .fp8_ref import compute_ref +from .helpers import get_fp8_scale_factor, get_fp8_descale_factor, convert_to_cudnn_type + +# fmt: off + +class GraphFwdUid(IntEnum): + q = 0 + k = 1 + v = 2 + q_descale = 5 + k_descale = 6 + v_descale = 7 + s_scale = 9 + s_descale = 8 + o_scale = 10 + o = 3 + stats = 4 + s_amax = 11 + o_amax = 12 + kv_seq_len = 13 + q_seq_len = 14 + k_block_table = 15 + v_block_table = 16 + +class GraphBwdUid(IntEnum): + q = 100 + k = 101 + v = 102 + o = 103 + dO = 104 + stats = 105 + q_descale = 106 + k_descale = 107 + v_descale = 108 + o_descale = 109 + dO_descale = 110 + s_descale = 111 + dP_descale = 112 + s_scale = 113 + dQ_scale = 114 + dK_scale = 115 + dV_scale = 116 + dP_scale = 117 + dQ = 118 + dK = 119 + dV = 120 + dQ_amax = 121 + dK_amax = 122 + dV_amax = 123 + dP_amax = 124 + +def generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, block_size): + graph_fwd = cudnn.pygraph(io_data_type=cudnn_itype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT) + + use_padding_mask = None + kv_seq_len = None + q_seq_len = None + k_block_table = None + v_block_table = None + + if block_size == 0: + q = graph_fwd.tensor(uid=GraphFwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) + k = graph_fwd.tensor(uid=GraphFwdUid.k, dim=(b, h_k, s_kv, d_qk), stride=(s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1), data_type=cudnn_itype) + v = graph_fwd.tensor(uid=GraphFwdUid.v, dim=(b, h_v, s_kv, d_vo), stride=(s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1), data_type=cudnn_itype) + else: + table_size = math.ceil(s_kv / block_size) + num_blocks = table_size * b + + q = graph_fwd.tensor(uid=GraphFwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) + k = graph_fwd.tensor(uid=GraphFwdUid.k, dim=(num_blocks, h_k, block_size, d_qk), stride=(block_size * h_k * d_qk, block_size * d_qk, d_qk, 1), data_type=cudnn_itype) + v = graph_fwd.tensor(uid=GraphFwdUid.v, dim=(num_blocks, h_v, block_size, d_vo), stride=(block_size * h_v * d_vo, block_size * d_vo, d_vo, 1), data_type=cudnn_itype) + + use_padding_mask = True + kv_seq_len = graph_fwd.tensor(uid=GraphFwdUid.kv_seq_len, dim=(b, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) + q_seq_len = graph_fwd.tensor(uid=GraphFwdUid.q_seq_len, dim=(b, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) + k_block_table = graph_fwd.tensor(uid=GraphFwdUid.k_block_table, dim=(b, 1, table_size, 1), stride=(table_size, table_size, 1, 1), data_type=cudnn.data_type.INT32) + v_block_table = graph_fwd.tensor(uid=GraphFwdUid.v_block_table, dim=(b, 1, table_size, 1), stride=(table_size, table_size, 1, 1), data_type=cudnn.data_type.INT32) + + q_descale = graph_fwd.tensor(uid=GraphFwdUid.q_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + k_descale = graph_fwd.tensor(uid=GraphFwdUid.k_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + v_descale = graph_fwd.tensor(uid=GraphFwdUid.v_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + s_scale = graph_fwd.tensor(uid=GraphFwdUid.s_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + s_descale = graph_fwd.tensor(uid=GraphFwdUid.s_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + o_scale = graph_fwd.tensor(uid=GraphFwdUid.o_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + + o, stats, amax_s, amax_o = graph_fwd.sdpa_fp8( + q=q, k=k, v=v, + descale_q=q_descale, descale_k=k_descale, descale_v=v_descale, + scale_s=s_scale, descale_s=s_descale, scale_o=o_scale, + generate_stats=True, attn_scale=attn_scale, use_causal_mask=False, + use_padding_mask=use_padding_mask, seq_len_kv=kv_seq_len, seq_len_q=q_seq_len, + paged_attention_k_table=k_block_table, paged_attention_v_table=v_block_table, + paged_attention_max_seq_len_kv=s_kv, + ) + + o.set_uid(GraphFwdUid.o).set_output(True).set_dim((b, h_q, s_qo, d_vo)).set_stride((s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1)).set_data_type(cudnn_otype) + stats.set_uid(GraphFwdUid.stats).set_output(True).set_dim((b, h_q, s_qo, 1)).set_stride((s_qo * h_q, s_qo, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_s.set_uid(GraphFwdUid.s_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_o.set_uid(GraphFwdUid.o_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + + return graph_fwd + +def generate_graph_bwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, deterministic): + graph_bwd = cudnn.pygraph(io_data_type=cudnn_itype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT) + + q = graph_bwd.tensor(uid=GraphBwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) + k = graph_bwd.tensor(uid=GraphBwdUid.k, dim=(b, h_k, s_kv, d_qk), stride=(s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1), data_type=cudnn_itype) + v = graph_bwd.tensor(uid=GraphBwdUid.v, dim=(b, h_v, s_kv, d_vo), stride=(s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1), data_type=cudnn_itype) + o = graph_bwd.tensor(uid=GraphBwdUid.o, dim=(b, h_q, s_qo, d_vo), stride=(s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1), data_type=cudnn_otype) + dO = graph_bwd.tensor(uid=GraphBwdUid.dO, dim=(b, h_q, s_qo, d_vo), stride=(s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1), data_type=cudnn_itype) + stats = graph_bwd.tensor(uid=GraphBwdUid.stats, dim=(b, h_q, s_qo, 1), stride=(s_qo * h_q, s_qo, 1, 1), data_type=cudnn.data_type.FLOAT) + + q_descale = graph_bwd.tensor(uid=GraphBwdUid.q_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + k_descale = graph_bwd.tensor(uid=GraphBwdUid.k_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + v_descale = graph_bwd.tensor(uid=GraphBwdUid.v_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + o_descale = graph_bwd.tensor(uid=GraphBwdUid.o_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dO_descale = graph_bwd.tensor(uid=GraphBwdUid.dO_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + s_descale = graph_bwd.tensor(uid=GraphBwdUid.s_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dP_descale = graph_bwd.tensor(uid=GraphBwdUid.dP_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + + s_scale = graph_bwd.tensor(uid=GraphBwdUid.s_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dQ_scale = graph_bwd.tensor(uid=GraphBwdUid.dQ_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dK_scale = graph_bwd.tensor(uid=GraphBwdUid.dK_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dV_scale = graph_bwd.tensor(uid=GraphBwdUid.dV_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + dP_scale = graph_bwd.tensor(uid=GraphBwdUid.dP_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) + + dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP = graph_bwd.sdpa_fp8_backward( + q=q, k=k, v=v, o=o, dO=dO, stats=stats, + descale_q=q_descale, descale_k=k_descale, descale_v=v_descale, + descale_o=o_descale, descale_dO=dO_descale, descale_s=s_descale, descale_dP=dP_descale, + scale_s=s_scale, scale_dQ=dQ_scale, scale_dK=dK_scale, scale_dV=dV_scale, scale_dP=dP_scale, + attn_scale=attn_scale, use_padding_mask=False, use_deterministic_algorithm=deterministic, + ) + + dQ.set_uid(GraphBwdUid.dQ).set_output(True).set_dim((b, h_q, s_qo, d_qk)).set_stride((s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1)).set_data_type(cudnn_itype) + dK.set_uid(GraphBwdUid.dK).set_output(True).set_dim((b, h_k, s_kv, d_qk)).set_stride((s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1)).set_data_type(cudnn_itype) + dV.set_uid(GraphBwdUid.dV).set_output(True).set_dim((b, h_v, s_kv, d_vo)).set_stride((s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1)).set_data_type(cudnn_itype) + + amax_dQ.set_uid(GraphBwdUid.dQ_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_dK.set_uid(GraphBwdUid.dK_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_dV.set_uid(GraphBwdUid.dV_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_dP.set_uid(GraphBwdUid.dP_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + + return graph_bwd + +def create_paged_container_and_block_table(tensor, block_size): + B, H, S, D = tensor.shape + blocks_per_batch = math.ceil(S / block_size) + + padding_seq = blocks_per_batch * block_size - S + if padding_seq > 0: + zeros = torch.zeros(B, H, padding_seq, D, device="cuda", dtype=tensor.dtype) + cat_tensor = torch.cat((tensor, zeros), dim=2) + else: + cat_tensor = tensor + + container = torch.cat(cat_tensor.chunk(blocks_per_batch, dim=2), dim=0) + + table_size = math.ceil(S / block_size) + block_table_temp = torch.linspace(0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32).reshape(table_size, 1, B, 1) + block_table_temp = torch.transpose(block_table_temp, 0, 2) + + block_table = (torch.zeros(blocks_per_batch * B, device="cuda", dtype=torch.int32).as_strided((B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1))) + block_table.copy_(block_table_temp) + + return (container, block_table) + +def exec_sdpa_fp8(cfg, request, cudnn_handle): + if request.config.option.dryrun: + pytest.skip("dryrun") + + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.14.0": + pytest.skip("SDPA FP8 requires cuDNN 9.14.0 or higher") + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("SDPA FP8 requires Blackwell or higher") + + torch_itype = cfg.data_type + torch_otype = cfg.output_type if hasattr(cfg, 'output_type') and cfg.output_type else cfg.data_type + cudnn_itype = convert_to_cudnn_type(torch_itype) + cudnn_otype = convert_to_cudnn_type(torch_otype) + + b = cfg.batches + h_q, h_k, h_v = cfg.h_q, cfg.h_k, cfg.h_v + s_qo, s_kv = cfg.s_q, cfg.s_kv + d_qk, d_vo = cfg.d_qk, cfg.d_v + block_size = cfg.block_size if cfg.is_paged else 0 + deterministic = cfg.is_determin if hasattr(cfg, 'is_determin') else False + + attn_scale = 0.125 + + is_paged = block_size > 0 + + try: + if cfg.is_infer: + graph = generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, block_size) + else: + graph = generate_graph_bwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, deterministic) + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + except cudnn.cudnnGraphNotSupportedError as e: + pytest.skip(f"unsupported graph: {e}") + except Exception as e: + pytest.fail(f"Error building graph: {e}") + + rng_data = torch.Generator(device="cuda").manual_seed(cfg.rng_data_seed) + + q_gen = torch.clamp(torch.randn(b, s_qo, h_q, d_qk, dtype=torch.float, device="cuda", generator=rng_data), min=-2.0, max=2.0) + k_gen = torch.clamp(torch.randn(b, s_kv, h_k, d_qk, dtype=torch.float, device="cuda", generator=rng_data), min=-2.0, max=2.0) + v_gen = torch.clamp(torch.randn(b, s_kv, h_v, d_vo, dtype=torch.float, device="cuda", generator=rng_data), min=-2.0, max=2.0) + + q_amax = q_gen.abs().max().item() + k_amax = k_gen.abs().max().item() + v_amax = v_gen.abs().max().item() + s_amax, o_amax = compute_ref(q_gen, k_gen, v_gen, attn_scale, return_type="amax") + + q_gpu = (q_gen * get_fp8_scale_factor(q_amax, torch_itype)).to(torch_itype) + k_gpu = (k_gen * get_fp8_scale_factor(k_amax, torch_itype)).to(torch_itype) + v_gpu = (v_gen * get_fp8_scale_factor(v_amax, torch_itype)).to(torch_itype) + + if cfg.is_infer: + if is_paged: + k_gpu_bhsd = torch.einsum('bshd->bhsd', k_gpu).contiguous() + v_gpu_bhsd = torch.einsum('bshd->bhsd', v_gpu).contiguous() + container_k_gpu, k_block_table_gpu = create_paged_container_and_block_table(k_gpu_bhsd, block_size) + container_v_gpu, v_block_table_gpu = create_paged_container_and_block_table(v_gpu_bhsd, block_size) + + kv_seq_len_gpu = torch.full((b, 1, 1, 1), s_kv, device="cuda", dtype=torch.int32) + q_seq_len_gpu = torch.full((b, 1, 1, 1), s_qo, device="cuda", dtype=torch.int32) + o_gpu = torch.full((b, s_qo, h_q, d_vo), float('nan'), dtype=torch_otype, device="cuda") + stats_gpu = torch.full((b, h_q, s_qo, 1), float('nan'), dtype=torch.float, device="cuda") + + q_descale_gpu = torch.tensor([get_fp8_descale_factor(q_amax, torch_itype)], dtype=torch.float, device="cuda") + k_descale_gpu = torch.tensor([get_fp8_descale_factor(k_amax, torch_itype)], dtype=torch.float, device="cuda") + v_descale_gpu = torch.tensor([get_fp8_descale_factor(v_amax, torch_itype)], dtype=torch.float, device="cuda") + s_scale_gpu = torch.tensor([get_fp8_scale_factor(s_amax, torch_itype)], dtype=torch.float, device="cuda") + s_descale_gpu = torch.tensor([get_fp8_descale_factor(s_amax, torch_itype)], dtype=torch.float, device="cuda") + o_scale_gpu = torch.tensor([get_fp8_scale_factor(o_amax, torch_otype)], dtype=torch.float, device="cuda") + + s_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + o_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + + variant_pack = { + int(GraphFwdUid.q): q_gpu, + int(GraphFwdUid.k): k_gpu, + int(GraphFwdUid.v): v_gpu, + int(GraphFwdUid.q_descale): q_descale_gpu, + int(GraphFwdUid.k_descale): k_descale_gpu, + int(GraphFwdUid.v_descale): v_descale_gpu, + int(GraphFwdUid.s_descale): s_descale_gpu, + int(GraphFwdUid.s_scale): s_scale_gpu, + int(GraphFwdUid.o_scale): o_scale_gpu, + int(GraphFwdUid.o): o_gpu, + int(GraphFwdUid.stats): stats_gpu, + int(GraphFwdUid.s_amax): s_amax_gpu, + int(GraphFwdUid.o_amax): o_amax_gpu, + } + + if is_paged: + variant_pack[int(GraphFwdUid.k)] = container_k_gpu + variant_pack[int(GraphFwdUid.v)] = container_v_gpu + variant_pack[int(GraphFwdUid.kv_seq_len)] = kv_seq_len_gpu + variant_pack[int(GraphFwdUid.q_seq_len)] = q_seq_len_gpu + variant_pack[int(GraphFwdUid.k_block_table)] = k_block_table_gpu + variant_pack[int(GraphFwdUid.v_block_table)] = v_block_table_gpu + + workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") + graph.execute(variant_pack, workspace, handle=cudnn_handle) + torch.cuda.synchronize() + + q_ref = q_gpu.detach().float() * get_fp8_descale_factor(q_amax, torch_itype) + k_ref = k_gpu.detach().float() * get_fp8_descale_factor(k_amax, torch_itype) + v_ref = v_gpu.detach().float() * get_fp8_descale_factor(v_amax, torch_itype) + o_ref = compute_ref(q_ref, k_ref, v_ref, attn_scale=attn_scale) + + o_gpu_comp = o_gpu.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) + + atol, rtol = 0.08, 0.2 + if torch_itype == torch.float8_e5m2: + atol, rtol = 0.16, 0.4 + + torch.testing.assert_close(o_gpu_comp, o_ref, atol=atol, rtol=rtol) + + else: + dO_gen = torch.clamp(torch.randn(b, s_qo, h_q, d_vo, dtype=torch.float, device="cuda", generator=rng_data), min=-2.0, max=2.0) + dO_amax = dO_gen.abs().max().item() + + q_gpu = q_gen.to(torch_itype) + k_gpu = k_gen.to(torch_itype) + v_gpu = v_gen.to(torch_itype) + + graph_fwd = generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, 0) + graph_fwd.validate(); graph_fwd.build_operation_graph() + graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_fwd.check_support(); graph_fwd.build_plans() + + o_gpu = torch.full((b, s_qo, h_q, d_vo), float('nan'), dtype=torch_otype, device="cuda") + stats_gpu = torch.full((b, h_q, s_qo, 1), float('nan'), dtype=torch.float, device="cuda") + dO_gpu = dO_gen.to(torch_itype) + + q_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + k_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + v_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + s_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + s_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + o_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + s_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + o_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + + variant_pack_fwd = { + int(GraphFwdUid.q): q_gpu, int(GraphFwdUid.k): k_gpu, int(GraphFwdUid.v): v_gpu, + int(GraphFwdUid.q_descale): q_descale_gpu, int(GraphFwdUid.k_descale): k_descale_gpu, + int(GraphFwdUid.v_descale): v_descale_gpu, int(GraphFwdUid.s_descale): s_descale_gpu, + int(GraphFwdUid.s_scale): s_scale_gpu, int(GraphFwdUid.o_scale): o_scale_gpu, + int(GraphFwdUid.o): o_gpu, int(GraphFwdUid.stats): stats_gpu, + int(GraphFwdUid.s_amax): s_amax_gpu, int(GraphFwdUid.o_amax): o_amax_gpu, + } + + workspace_fwd = torch.empty(graph_fwd.get_workspace_size(), dtype=torch.uint8, device="cuda") + graph_fwd.execute(variant_pack_fwd, workspace_fwd, handle=cudnn_handle) + torch.cuda.synchronize() + + o_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dO_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dP_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dQ_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dK_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dV_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + dP_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") + + dQ_gpu = torch.full((b, s_qo, h_q, d_qk), float('nan'), dtype=torch_itype, device="cuda") + dK_gpu = torch.full((b, s_kv, h_k, d_qk), float('nan'), dtype=torch_itype, device="cuda") + dV_gpu = torch.full((b, s_kv, h_v, d_vo), float('nan'), dtype=torch_itype, device="cuda") + dQ_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + dK_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + dV_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + dP_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") + + variant_pack_bwd = { + int(GraphBwdUid.q): q_gpu, int(GraphBwdUid.k): k_gpu, int(GraphBwdUid.v): v_gpu, + int(GraphBwdUid.o): o_gpu, int(GraphBwdUid.dO): dO_gpu, int(GraphBwdUid.stats): stats_gpu, + int(GraphBwdUid.q_descale): q_descale_gpu, int(GraphBwdUid.k_descale): k_descale_gpu, + int(GraphBwdUid.v_descale): v_descale_gpu, int(GraphBwdUid.o_descale): o_descale_gpu, + int(GraphBwdUid.dO_descale): dO_descale_gpu, int(GraphBwdUid.s_descale): s_descale_gpu, + int(GraphBwdUid.s_scale): s_scale_gpu, int(GraphBwdUid.dP_descale): dP_descale_gpu, + int(GraphBwdUid.dP_scale): dP_scale_gpu, int(GraphBwdUid.dQ_scale): dQ_scale_gpu, + int(GraphBwdUid.dK_scale): dK_scale_gpu, int(GraphBwdUid.dV_scale): dV_scale_gpu, + int(GraphBwdUid.dQ): dQ_gpu, int(GraphBwdUid.dK): dK_gpu, int(GraphBwdUid.dV): dV_gpu, + int(GraphBwdUid.dQ_amax): dQ_amax_gpu, int(GraphBwdUid.dK_amax): dK_amax_gpu, + int(GraphBwdUid.dV_amax): dV_amax_gpu, int(GraphBwdUid.dP_amax): dP_amax_gpu, + } + + workspace_bwd = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") + graph.execute(variant_pack_bwd, workspace_bwd, handle=cudnn_handle) + torch.cuda.synchronize() + + q_ref = q_gpu.detach().float() + k_ref = k_gpu.detach().float() + v_ref = v_gpu.detach().float() + + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + o_tmp = compute_ref(q_ref, k_ref, v_ref, attn_scale=attn_scale) + dQ_ref, dK_ref, dV_ref = torch.autograd.grad(outputs=o_tmp, inputs=[q_ref, k_ref, v_ref], grad_outputs=dO_gen) + + dQ_out = dQ_gpu.detach().float() + dK_out = dK_gpu.detach().float() + dV_out = dV_gpu.detach().float() + + atol, rtol = 0.16, 0.2 + torch.testing.assert_close(dQ_out, dQ_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(dK_out, dK_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(dV_out, dV_ref, atol=atol, rtol=rtol) diff --git a/test/python/sdpa/fp8_ref.py b/test/python/sdpa/fp8_ref.py new file mode 100644 index 00000000..deecf36b --- /dev/null +++ b/test/python/sdpa/fp8_ref.py @@ -0,0 +1,29 @@ +import torch + +# fmt: off + +def compute_ref(q, k, v, attn_scale=1.0, return_type="o"): + b, s_q, h_q, d_qk = q.shape + _, s_kv, h_k, _ = k.shape + _, _, h_v, d_v = v.shape + + assert k.shape == (b, s_kv, h_k, d_qk) + assert v.shape == (b, s_kv, h_v, d_v) + + if h_q != h_k: + k = k.repeat_interleave(h_q // h_k, dim=2) + if h_q != h_v: + v = v.repeat_interleave(h_q // h_v, dim=2) + + s = torch.einsum("bqhd,bkhd->bhqk", q, k) * attn_scale + p = s.softmax(dim=-1) + o = torch.einsum("bhqk,bkhd->bqhd", p, v) + + if return_type == "o": + return o + if return_type == "o_stats": + return o, torch.zeros() + elif return_type == "amax": + return p.abs().max().item(), o.abs().max().item() + else: + raise ValueError(f"Unsupported return type: {return_type}") diff --git a/test/python/sdpa/helpers.py b/test/python/sdpa/helpers.py new file mode 100644 index 00000000..71998f0c --- /dev/null +++ b/test/python/sdpa/helpers.py @@ -0,0 +1,255 @@ +import cudnn +import torch +import math + +# fmt: off + +def get_fp8_largest_po2(dtype: torch.dtype): + if dtype == torch.float8_e4m3fn: + return 128.0 + elif dtype == torch.float8_e5m2: + return 32768.0 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + +def get_fp8_scale_factor(amax: float, dtype: torch.dtype, fudge_factor: float = 0.25, epsilon = 0.0625): + if dtype == torch.float16 or dtype == torch.bfloat16: + return 1.0 + po2_next = 2 ** math.ceil(math.log2(max(amax, epsilon))) + return get_fp8_largest_po2(dtype) / po2_next * fudge_factor + +def get_fp8_descale_factor(amax: float, dtype: torch.dtype, fudge_factor: float = 0.25, epsilon = 0.0625): + return 1.0 / get_fp8_scale_factor(amax, dtype, fudge_factor, epsilon) + +def compute_total_elems(shape, strides): + """Compute total element count (max offset + 1) from shape and strides.""" + return sum((s - 1) * st for s, st in zip(shape, strides)) + 1 + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + elif torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E4M3 + elif torch_type == torch.float8_e5m2: + return cudnn.data_type.FP8_E5M2 + else: + assert False, "unsupported tensor data type" + +def alloc_tensor(shape, data_type, *, elems=None, strides=None, rng=None, mean=0.0, std=1.0, margins=512): + if strides is None: + # Compute default contiguous strides + if hasattr(shape, '__iter__'): + strides = [] + prod = 1 + for dim in reversed(shape): + strides.insert(0, prod) + prod *= int(dim) + if elems is None: + elems = prod + else: + if elems is None: + elems = int(shape) + strides = (1,) + shape = (shape,) + elif elems is None: + elems = compute_total_elems(shape, strides) + + assert margins >= 0 and type(margins) == int, "wrong input" + + rawbuf = torch.empty(elems+2*margins, dtype=data_type, device="cuda") + if torch.is_floating_point(rawbuf): + rawbuf.fill_(float('nan')) + else: + rawbuf.fill_(-1) + + tensor = torch.as_strided(rawbuf, shape, strides, storage_offset=margins) + sepbuf = (torch.as_strided(rawbuf, (2, margins), (elems+margins, 1), storage_offset=0) if margins > 0 else None) + + if rng is not None: + tensor.normal_(mean=mean, std=std, generator=rng) + + if math.prod(shape) == elems: + rawbuf = None + + return tensor, sepbuf, rawbuf + +def prefix_sum(t): + t = t.flatten() + return torch.cat((torch.zeros(1, dtype=t.dtype, device=t.device), torch.cumsum(t, dim=0))) + +def convert_packed_to_uniform(packed_tensor, seq_len, s_max, fill_value=0): + assert packed_tensor.dim() == 3 + t, h, d = packed_tensor.size() + seq_len = seq_len.flatten() + b = seq_len.size(0) + + uniform_tensor = torch.full((b, s_max, h, d), fill_value, dtype=packed_tensor.dtype, device=packed_tensor.device) + + t_idx = 0 + for bi, s in enumerate(seq_len): + uniform_tensor[bi, 0:s, :, :] = packed_tensor[t_idx : t_idx + s, :, :] + t_idx += s + + uniform_tensor = torch.einsum("bshd->bhsd", uniform_tensor) + return uniform_tensor + +def convert_uniform_to_packed(uniform_tensor, seq_len, max_t): + assert uniform_tensor.dim() == 4 + uniform_tensor = torch.einsum("bhsd->bshd", uniform_tensor) + b, s, h, d = uniform_tensor.size() + seq_len = seq_len.flatten() + assert seq_len.size(0) == b + packed_tensor = torch.full((max_t, h, d), float('nan'), dtype=uniform_tensor.dtype, device=uniform_tensor.device) + + t_idx = 0 + for bi, s_len in enumerate(seq_len): + packed_tensor[t_idx : t_idx + s_len, :, :] = uniform_tensor[bi, 0:s_len, :, :] + t_idx += s_len + + return packed_tensor + +def create_container_and_page_table(tensor, block_size): + B, H, S, D = tensor.shape + blocks_per_batch = math.ceil(S/block_size) + + padding_seq = (blocks_per_batch * block_size) - S + if padding_seq > 0: + zeros = torch.zeros(B,H,padding_seq,D, device='cuda', dtype=tensor.dtype) + cat_tensor = torch.cat((tensor, zeros), axis = 2) + else: + cat_tensor = tensor + + reshaped = torch.cat((cat_tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0) + + table_size = math.ceil(S/block_size) + page_table = torch.linspace(0, B*table_size-1, B*table_size, device='cuda', dtype=torch.int32).reshape(table_size,1,B,1) + page_table = torch.transpose(page_table,0,2) + + return(reshaped, page_table) + +def exact_equal(actual, expected, tag, disp_elems): + both_nan = torch.isnan(actual) & torch.isnan(expected) + mismatches = torch.where((actual != expected) & ~both_nan) + mismatch_cnt = mismatches[0].numel() + num_elements = torch.numel(actual) + if mismatch_cnt != 0: + percentage = 100 * mismatch_cnt / num_elements + if disp_elems > 0: + print(f"Comparing '{tag}' for exact (bitwise) equality") + combined = torch.stack(mismatches, dim=-1).tolist() + count = 0 + for index in combined: + diff = actual[tuple(index)].float() - expected[tuple(index)].float() + print(f"idx{index}: {tag}_run1={actual[tuple(index)]}, {tag}_run2={expected[tuple(index)]}, diff={diff:+.2e}") + count += 1 + if count >= disp_elems: + break + print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' for exact equality (first {count} mismatches displayed)") + else: + print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' for exact equality") + else: + print(f"%%%% Exact (bitwise) equality of '{tag}' verified") + return mismatch_cnt + +def approx_equal(alloc, expected, atol, rtol, tag, disp_elems): + actual, sepbuf, rawbuf = alloc + mismatches = torch.where(torch.isclose(actual.float(), expected, rtol=rtol, atol=atol, equal_nan=True) == False) + mismatch_cnt = mismatches[0].numel() + num_elements = torch.numel(actual) + if mismatch_cnt != 0: + percentage = 100 * mismatch_cnt / num_elements + if disp_elems > 0: + print(f"Comparing '{tag}' using rtol={rtol:.4e}, atol={atol:.4e}") + combined = torch.stack(mismatches, dim=-1).tolist() + count = 0 + for index in combined: + diff = actual[tuple(index)] - expected[tuple(index)] + if math.isfinite(diff): + print(f"idx{index}: {tag}_gpu={actual[tuple(index)]:+.6e}, {tag}_ref={expected[tuple(index)]:+.6e}, diff={diff:+.2e}") + else: + print(f"idx{index}: {tag}_gpu={actual[tuple(index)]:+.6e}, {tag}_ref={expected[tuple(index)]:+.6e}") + count += 1 + if count >= disp_elems: + break + print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' results (first {count} mismatches displayed)") + else: + print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' results") + + num_nans = torch.isnan(actual).sum().item() + num_infs = torch.isinf(actual).sum().item() + num_zeros = num_elements - torch.count_nonzero(actual) + num_finites_nz = num_elements - num_nans - num_infs - num_zeros + + print(f"%%%% {tag}_gpu overview: elements={num_elements:,}, finites_nz={num_finites_nz:,}, zeros={num_zeros:,}, nans={num_nans:,}, infs={num_infs:,}") + + num_nans = torch.isnan(expected).sum().item() + num_infs = torch.isinf(expected).sum().item() + num_zeros = num_elements - torch.count_nonzero(expected) + num_finites_nz = num_elements - num_nans - num_infs - num_zeros + + print(f"%%%% {tag}_ref overview: elements={num_elements:,}, finites_nz={num_finites_nz:,}, zeros={num_zeros:,}, nans={num_nans:,}, infs={num_infs:,}") + else: + print(f"%%%% Numerical divergence of '{tag}' within limits") + + if sepbuf is not None and not torch.all(torch.isnan(sepbuf)).item(): + print(f"%%%% Buffer '{tag}' overwritten outside its boundaries") + print(sepbuf) + mismatch_cnt += 1 + + if rawbuf is not None: + actual.fill_(float('nan')) + if not torch.all(torch.isnan(rawbuf)).item(): + print(f"%%%% Unused gaps of '{tag}' tensor were overwritten") + mismatch_cnt += 1 + + return mismatch_cnt + +def time_execution(fn, *args, num_warmup: int = 3, num_trials: int = 10) -> torch.Tensor: + elapsed_times = torch.zeros(num_trials, dtype=torch.float) + for _ in range(num_warmup): + fn(*args) + torch.cuda.synchronize() + for i in range(num_trials): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + fn(*args) + end_event.record() + torch.cuda.synchronize() + elapsed_times[i] = start_event.elapsed_time(end_event) + return elapsed_times + +def profile_execution(fn, *args, trace_dir=None): + activities = [torch.profiler.ProfilerActivity.CUDA] + if trace_dir: + activities.append(torch.profiler.ProfilerActivity.CPU) + with torch.profiler.profile( + activities=activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + on_trace_ready=(torch.profiler.tensorboard_trace_handler(trace_dir) if trace_dir else None), + ) as prof: + fn(*args) + torch.cuda.synchronize() + print("Sorted by CUDA time:") + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) + print() + if torch.profiler.ProfilerActivity.CPU in activities: + print("Sorted by CPU time:") + print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) + print() + +def print_section_begin(msg, width=80): + print(f" {msg} ".center(width, "=")) + +def print_section_end(width=80): + print("=" * width) diff --git a/test/python/mha_v2_utils.py b/test/python/sdpa/random_config.py similarity index 58% rename from test/python/mha_v2_utils.py rename to test/python/sdpa/random_config.py index 0d070662..9c6205cf 100644 --- a/test/python/mha_v2_utils.py +++ b/test/python/sdpa/random_config.py @@ -6,24 +6,21 @@ from dataclasses import dataclass, field, asdict -# Invalid left/right attention bound (negative values may be used in the future). -INVALID_BOUND = 99999 - - -def get_strides_from_indices( - shape, indices=[0, 1, 2, 3], gaps=[0, 0, 0, 0], rng_geom=None -): - assert ( - len(shape) == len(gaps) == 4 - and sorted(indices) == [0, 1, 2, 3] - and indices[3] == 3 - ), "wrong input" +# fmt: off + +def generate_test_seeds(*, num_tests, rng_seed): + rng = random.Random(rng_seed) + return [(i+1, num_tests, rng.randint(65536, 2147483647)) for i in range(num_tests)] + + +def get_strides_from_indices(shape, indices=[0, 1, 2, 3], gaps=[0, 0, 0, 0], rng_geom=None): + """Compute strides for a given dimension order and optional gaps.""" + assert len(shape) == len(gaps) == 4 and sorted(indices) == [0, 1, 2, 3] and indices[3] == 3, "wrong input" strides = [0, 0, 0, 1] # d should always have stride 1 curr_stride = 1 for i in range(3, 0, -1): j = indices[i] - curr_stride = (shape[j] + gaps[j]) * curr_stride j = indices[i - 1] strides[j] = curr_stride @@ -32,19 +29,47 @@ def get_strides_from_indices( if rng_geom is not None and shape[j] == 1: strides[j] = max(strides[j], rng_geom.choice([0, 3331333, 99990001])) - total_size = shape[j] * curr_stride - return tuple(strides), tuple(gaps), total_size + return tuple(strides) def get_strides_from_layout(shape, layout, gaps=[0, 0, 0, 0], rng_geom=None): + """Compute strides for a given layout string (e.g. 'bshd', 'bhsd').""" assert "".join(sorted(layout)) == "bdhs", f"wrong layout '{layout}'" indices = ["bhsd".index(ch) for ch in layout] return get_strides_from_indices(shape, indices, gaps, rng_geom) +def compute_default_BHSD_strides(shape): + """Compute default BHSD strides (rightmost dim is innermost with stride=1, no gaps).""" + if shape is None: + return None + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return tuple(strides) + + +def compute_packed_strides(shape): + """Compute packed (ragged) BSHD strides for BHSD shape: (s*h*d, d, h*d, 1).""" + if shape is None: + return None + b, h, s, d = shape + return (s * h * d, d, h * d, 1) + + @dataclass -class exec_cfg: +class ExecConfig: + # Registry for enum-like fields: field_name -> module/class to getattr from + _ENUM_FIELDS = { + 'data_type': torch, + 'output_type': torch, + 'diag_align': cudnn.diagonal_alignment, + 'implementation': cudnn.attention_implementation, + } + data_type: torch.dtype = None + output_type: torch.dtype = None + rng_geom_seed: int = None rng_data_seed: int = None is_alibi: bool = None @@ -76,19 +101,15 @@ class exec_cfg: shape_q: tuple[int, int, int, int] = None stride_q: tuple[int, int, int, int] = None - elems_q: int = None shape_k: tuple[int, int, int, int] = None stride_k: tuple[int, int, int, int] = None - elems_k: int = None shape_v: tuple[int, int, int, int] = None stride_v: tuple[int, int, int, int] = None - elems_v: int = None shape_o: tuple[int, int, int, int] = None stride_o: tuple[int, int, int, int] = None - elems_o: int = None seq_len_q: list[int] = field(default_factory=list) seq_len_kv: list[int] = field(default_factory=list) @@ -97,6 +118,85 @@ class exec_cfg: implementation: cudnn.attention_implementation = cudnn.attention_implementation.AUTO + @property + def is_train(self): + return not self.is_infer + + def fill_derived_fields(self): + """ + Fill in derived fields (shapes, strides) from basic dims. + - Shapes are computed from basic dims (batches, h_q/k/v, s_q/kv, d_qk/v) + - Strides default to BHSD layout if not provided + """ + # Compute shapes from basic dims if not provided + if self.shape_q is None and all(x is not None for x in [self.batches, self.h_q, self.s_q, self.d_qk]): + self.shape_q = (self.batches, self.h_q, self.s_q, self.d_qk) + if self.shape_k is None and all(x is not None for x in [self.batches, self.h_k, self.s_kv, self.d_qk]): + self.shape_k = (self.batches, self.h_k, self.s_kv, self.d_qk) + if self.shape_v is None and all(x is not None for x in [self.batches, self.h_v, self.s_kv, self.d_v]): + self.shape_v = (self.batches, self.h_v, self.s_kv, self.d_v) + if self.shape_o is None and all(x is not None for x in [self.batches, self.h_q, self.s_q, self.d_v]): + self.shape_o = (self.batches, self.h_q, self.s_q, self.d_v) + + # Compute strides if not provided (packed for ragged, default BHSD otherwise) + stride_fn = compute_packed_strides if self.is_ragged else compute_default_BHSD_strides + if self.stride_q is None and self.shape_q is not None: + self.stride_q = stride_fn(self.shape_q) + if self.stride_k is None and self.shape_k is not None: + self.stride_k = stride_fn(self.shape_k) + if self.stride_v is None and self.shape_v is not None: + self.stride_v = stride_fn(self.shape_v) + if self.stride_o is None and self.shape_o is not None: + self.stride_o = stride_fn(self.shape_o) + + def serialize(self) -> dict: + """Convert config to a serializable dict for repro commands.""" + cfg_dict = asdict(self) + for field, enum_cls in self._ENUM_FIELDS.items(): + if cfg_dict.get(field) is not None: + val = cfg_dict[field] + if hasattr(val, 'name'): + module_prefix = enum_cls.__module__.split('.')[0] + cfg_dict[field] = f"{module_prefix}.{enum_cls.__name__}.{val.name}" + else: + cfg_dict[field] = str(val) + return cfg_dict + + @classmethod + def deserialize(cls, d: dict) -> "ExecConfig": + """Create ExecConfig from a serialized dict.""" + for field, enum_cls in cls._ENUM_FIELDS.items(): + if d.get(field) is not None: + name = d[field].split('.')[-1] + assert hasattr(enum_cls, name), f"Invalid {field}: {name}" + d[field] = getattr(enum_cls, name) + cfg = cls(**d) + cfg.fill_derived_fields() + return cfg + + def to_repro_cmd(self, test_file: str) -> str: + """Generate a readable multi-line repro command with aligned backslashes.""" + cfg_dict = self.serialize() + indent = " " * 4 + # Build lines without trailing backslash first + lines = [ + "pytest -vv -s -rA --no-header --tb=short", + f"{indent}{test_file}::test_repro", + f"{indent}--repro \"", + f"{indent}{indent}" + "{", + ] + items = list(cfg_dict.items()) + for i, (k, v) in enumerate(items): + comma = "," if i < len(items) - 1 else "" + lines.append(f"{indent}{indent}{indent}'{k}': {repr(v)}{comma}") + lines.append(f"{indent}{indent}" + "}") + lines.append(f'{indent}"') + # Find max length and align backslashes (except last line) + max_len = max(len(line) for line in lines[:-1]) + aligned = [f"{line:<{max_len}} \\" for line in lines[:-1]] + aligned.append(lines[-1]) + return "\n".join(aligned) + class RandomizationContext: def __init__(self, **kwargs): @@ -108,71 +208,56 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): pass - def __call__(self, rng, rng_data_seed): + def __call__(self, rng, rng_data_seed, rng_geom_seed=None): - randoms_ = exec_cfg() + randoms_ = ExecConfig() randoms = {} + randoms_.rng_geom_seed = rng_geom_seed randoms_.rng_data_seed = rng_data_seed randoms["rng_data_seed"] = rng_data_seed self.rng_data = torch.Generator(device="cuda").manual_seed(rng_data_seed) - randoms = { - k: v(rng) for k, v in self.kwargs.items() if not hasattr(randoms_, k) - } - [ - setattr(randoms_, k, v(rng)) - for k, v in self.kwargs.items() - if hasattr(randoms_, k) - ] + randoms = {k: v(rng) for k, v in self.kwargs.items() if not hasattr(randoms_, k)} + [setattr(randoms_, k, v(rng)) for k, v in self.kwargs.items() if hasattr(randoms_, k)] if "is_deterministic" in randoms: randoms_.is_determin = randoms["is_deterministic"] == True + if "is_bias" in randoms: + randoms_.is_bias = randoms["is_bias"] == True + randoms_.s_q, randoms_.s_kv = randoms["s_q_s_kv"] randoms_.d_qk, randoms_.d_v = randoms["d_qk_d_v"] randoms_.h_q, randoms_.h_k, randoms_.h_v = randoms["head_count"] randoms_.is_ragged = randoms["is_q_ragged_or_padded_or_full"] == "ragged" - randoms_.is_padding = ( - randoms["is_q_ragged_or_padded_or_full"] == "padded" - or randoms["is_q_ragged_or_padded_or_full"] == "ragged" - ) + randoms_.is_padding = randoms["is_q_ragged_or_padded_or_full"] == "padded" or randoms["is_q_ragged_or_padded_or_full"] == "ragged" if randoms["is_q_ragged_or_padded_or_full"] != "full": - randoms_.seq_len_q = [ - rng.randint(1, randoms_.s_q) for _ in range(randoms_.batches) + # ~10% chance of 0-length sequence for each batch + randoms_.seq_len_q = [0 if rng.random() < 0.1 else rng.randint(1, randoms_.s_q) for _ in range(randoms_.batches)] + # ~10% chance of 0-length sequence for each batch (independent of seq_len_q) + randoms_.seq_len_kv = [ + # 0 if rng.random() < 0.1 else rng.randint(randoms_.seq_len_q[i], randoms_.s_kv) for i in range(randoms_.batches) + rng.randint(1, randoms_.s_kv) + for i in range(randoms_.batches) ] - if randoms_.seq_len_q is not None: - randoms_.seq_len_kv = [ - rng.randint(randoms_.seq_len_q[i], randoms_.s_kv) - for i in range(randoms_.batches) - ] - else: - randoms_.seq_len_kv = [ - rng.randint(randoms_.s_q, randoms_.s_kv) - for _ in range(randoms_.batches) - ] - - # Decide the left and right bounds for the sliding window mask - randoms_.left_bound = INVALID_BOUND - randoms_.right_bound = INVALID_BOUND + + # Decide the left and right bounds for the sliding window mask (None = no bound) + randoms_.left_bound = None + randoms_.right_bound = None if randoms["with_sliding_mask"] == "no_mask": - randoms_.left_bound = INVALID_BOUND - randoms_.right_bound = INVALID_BOUND + pass # left_bound and right_bound stay None elif randoms["with_sliding_mask"] == "left_window_only": - randoms_.left_bound = rng.randint(0, randoms_.s_kv // 2) + randoms_.left_bound = rng.randint(1, max(1, randoms_.s_kv // 2)) randoms_.right_bound = 0 elif randoms["with_sliding_mask"] == "right_window_only": - randoms_.left_bound = ( - INVALID_BOUND - if randoms_.diag_align == cudnn.diagonal_alignment.BOTTOM_RIGHT - else 1 - ) + randoms_.left_bound = None if randoms_.diag_align == cudnn.diagonal_alignment.BOTTOM_RIGHT else 1 randoms_.right_bound = rng.randint(0, randoms_.s_kv // 2) elif randoms["with_sliding_mask"] == "band_around_diag": randoms_.left_bound = rng.randint(1, randoms_.s_kv // 2) @@ -186,12 +271,8 @@ def __call__(self, rng, rng_data_seed): randoms_.shape_o = (randoms_.batches, randoms_.h_q, randoms_.s_q, randoms_.d_v) if randoms_.is_ragged: # Ideally Q ragged and O ragged - randoms_.stride_q, _, randoms_.elems_q = get_strides_from_layout( - randoms_.shape_q, "bshd" - ) - randoms_.stride_o, _, randoms_.elems_o = get_strides_from_layout( - randoms_.shape_o, "bshd" - ) + randoms_.stride_q = get_strides_from_layout(randoms_.shape_q, "bshd") + randoms_.stride_o = get_strides_from_layout(randoms_.shape_o, "bshd") else: indices = [0, 1, 2] @@ -206,12 +287,8 @@ def __call__(self, rng, rng_data_seed): gaps_q.append(elem_align * rng.randint(0, 2)) gaps_o.append(elem_align * rng.randint(0, 2)) - (randoms_.stride_q, randoms_.gaps_q, randoms_.elems_q) = ( - get_strides_from_indices(randoms_.shape_q, indices, gaps_q, rng) - ) - (randoms_.stride_o, randoms_.gaps_o, randoms_.elems_o) = ( - get_strides_from_indices(randoms_.shape_o, indices, gaps_o, rng) - ) + randoms_.stride_q = get_strides_from_indices(randoms_.shape_q, indices, gaps_q, rng) + randoms_.stride_o = get_strides_from_indices(randoms_.shape_o, indices, gaps_o, rng) # Decide K, V randoms_.shape_k = ( @@ -223,12 +300,8 @@ def __call__(self, rng, rng_data_seed): randoms_.shape_v = (randoms_.batches, randoms_.h_v, randoms_.s_kv, randoms_.d_v) if randoms_.is_ragged: # Ideally K ragged and V ragged - randoms_.stride_k, _, randoms_.elems_k = get_strides_from_layout( - randoms_.shape_k, "bshd" - ) - randoms_.stride_v, _, randoms_.elems_v = get_strides_from_layout( - randoms_.shape_v, "bshd" - ) + randoms_.stride_k = get_strides_from_layout(randoms_.shape_k, "bshd") + randoms_.stride_v = get_strides_from_layout(randoms_.shape_v, "bshd") else: indices = [0, 1, 2] @@ -243,12 +316,8 @@ def __call__(self, rng, rng_data_seed): gaps_k.append(elem_align * rng.randint(0, 2)) gaps_v.append(elem_align * rng.randint(0, 2)) - (randoms_.stride_k, randoms_.gaps_k, randoms_.elems_k) = ( - get_strides_from_indices(randoms_.shape_k, indices, gaps_k, rng) - ) - (randoms_.stride_v, randoms_.gaps_v, randoms_.elems_v) = ( - get_strides_from_indices(randoms_.shape_v, indices, gaps_v, rng) - ) + randoms_.stride_k = get_strides_from_indices(randoms_.shape_k, indices, gaps_k, rng) + randoms_.stride_v = get_strides_from_indices(randoms_.shape_v, indices, gaps_v, rng) return randoms_ @@ -290,41 +359,19 @@ def __call__(self, rng): max_exp = math.floor(math.log2(self.max)) if min_exp > max_exp: raise ValueError(f"No power of two in range [{self.min}, {self.max}]") - exp = ( - rng.randint(min_exp, max_exp) - if dice == 0 - else self.with_high_probability[ - rng.randint(0, len(self.with_high_probability) - 1) - ] - ) + exp = rng.randint(min_exp, max_exp) if dice == 0 else self.with_high_probability[rng.randint(0, len(self.with_high_probability) - 1)] return 1 << exp if dice == 0 else exp elif self.multiple_of: # compute the first and last valid multiples, then pick randomly - first = ( - (self.min + self.multiple_of - 1) // self.multiple_of - ) * self.multiple_of + first = ((self.min + self.multiple_of - 1) // self.multiple_of) * self.multiple_of last = (self.max // self.multiple_of) * self.multiple_of if first > self.max: - raise ValueError( - f"No multiples of {self.multiple_of} in range [{self.min}, {self.max}]" - ) + raise ValueError(f"No multiples of {self.multiple_of} in range [{self.min}, {self.max}]") count = ((last - first) // self.multiple_of) + 1 - idx = ( - rng.randint(0, count - 1) - if dice == 0 - else self.with_high_probability[ - rng.randint(0, len(self.with_high_probability) - 1) - ] - ) + idx = rng.randint(0, count - 1) if dice == 0 else self.with_high_probability[rng.randint(0, len(self.with_high_probability) - 1)] return first + idx * self.multiple_of else: - return ( - rng.randint(self.min, self.max) - if dice == 0 - else self.with_high_probability[ - rng.randint(0, len(self.with_high_probability) - 1) - ] - ) + return rng.randint(self.min, self.max) if dice == 0 else self.with_high_probability[rng.randint(0, len(self.with_high_probability) - 1)] class RandomHeadGenerator: @@ -401,9 +448,7 @@ def __call__(self, rng): if d_qk < d_v: d_qk = d_v else: - d_qk, d_v = self.with_high_probability[ - rng.randint(0, len(self.with_high_probability) - 1) - ] + d_qk, d_v = self.with_high_probability[rng.randint(0, len(self.with_high_probability) - 1)] return d_qk, d_v @@ -433,6 +478,7 @@ def __call__(self, rng): s_q = s_kv else: s_q = self.s_q_gen(rng) + # Always s_q <=s_kv if s_q > s_kv: s_q = s_kv @@ -440,9 +486,7 @@ def __call__(self, rng): class RandomBatchSize(RandomIntValue): - def __init__( - self, min: int, max: int, with_high_probability: Optional[List[int]] = None - ): + def __init__(self, min: int, max: int, with_high_probability: Optional[List[int]] = None): super().__init__(min, max, with_high_probability=with_high_probability) def __call__(self, rng): @@ -450,12 +494,8 @@ def __call__(self, rng): class RandomBlockSize(RandomIntValue): - def __init__( - self, min: int, max: int, with_high_probability: Optional[List[int]] = None - ): - super().__init__( - min, max, with_high_probability=with_high_probability, power_of_two=True - ) + def __init__(self, min: int, max: int, with_high_probability: Optional[List[int]] = None): + super().__init__(min, max, with_high_probability=with_high_probability, power_of_two=True) def __call__(self, rng): return super().__call__(rng) @@ -494,70 +534,13 @@ def test_randomization_context(seed): cudnn.diagonal_alignment.BOTTOM_RIGHT: 1, } ), - is_q_ragged_or_padded_or_full=RandomChoice( - {"ragged": 1, "padded": 1, "full": 1} - ), - is_kv_ragged_or_paged_or_padded_or_full=RandomChoice( - {"ragged": 1, "paged": 1, "padded": 1, "full": 1} - ), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged": 1, "padded": 1, "full": 1}), + is_kv_ragged_or_paged_or_padded_or_full=RandomChoice({"ragged": 1, "paged": 1, "padded": 1, "full": 1}), stats_layout=RandomChoice({"ragged": 1, "full": 1, "disabled": 2}), ) as ctx: return ctx -def time_execution( - fn, - *args, - num_warmup: int = 3, - num_trials: int = 10, -) -> torch.Tensor: - elapsed_times = torch.zeros(num_trials, dtype=torch.float) - - for _ in range(num_warmup): - fn(*args) - torch.cuda.synchronize() - - for i in range(num_trials): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - fn(*args) - end_event.record() - torch.cuda.synchronize() - - elapsed_times[i] = start_event.elapsed_time(end_event) - - return elapsed_times - - -def profile_execution(fn, *args, trace_dir=None): - activities = [torch.profiler.ProfilerActivity.CUDA] - if trace_dir: - activities.append(torch.profiler.ProfilerActivity.CPU) - - with torch.profiler.profile( - activities=activities, - record_shapes=True, - profile_memory=True, - with_stack=True, - on_trace_ready=( - torch.profiler.tensorboard_trace_handler(trace_dir) if trace_dir else None - ), - ) as prof: - fn(*args) - torch.cuda.synchronize() - - print("Sorted by CUDA time:") - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)) - print() - - if torch.profiler.ProfilerActivity.CPU in activities: - print("Sorted by CPU time:") - print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) - print() - - if __name__ == "__main__": num_tests = 10 seed = 768 diff --git a/test/python/test_apply_rope.py b/test/python/test_apply_rope.py index 8f0f2520..95fb12a1 100644 --- a/test/python/test_apply_rope.py +++ b/test/python/test_apply_rope.py @@ -32,9 +32,7 @@ def build_rope_cache( return cos, sin -def apply_rope_ref( - q: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> torch.Tensor: +def apply_rope_ref(q: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: def fn(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) @@ -103,9 +101,7 @@ def test_apply_rope(cudnn_handle): stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - g, uids = create_rope_graph( - cudnn_handle, x1_gpu, x2_gpu, cos1_gpu, cos2_gpu, sin1_gpu, sin2_gpu - ) + g, uids = create_rope_graph(cudnn_handle, x1_gpu, x2_gpu, cos1_gpu, cos2_gpu, sin1_gpu, sin2_gpu) x1_uid, x2_uid, sin1_uid, sin2_uid, cos1_uid, cos2_uid, Y1_uid, Y2_uid = uids workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8) diff --git a/test/python/test_batchnorm.py b/test/python/test_batchnorm.py index 9e169c2f..0169104a 100644 --- a/test/python/test_batchnorm.py +++ b/test/python/test_batchnorm.py @@ -74,17 +74,15 @@ def test_bn_relu_with_mask(cudnn_handle): momentum = graph.tensor_like(momentum_cpu) comparison = graph.tensor_like(x_gpu) - y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = ( - graph.batchnorm( - name="BN", - input=x, - scale=scale, - bias=bias, - in_running_mean=in_running_mean, - in_running_var=in_running_var, - epsilon=epsilon, - momentum=momentum, - ) + y_before_relu, saved_mean, saved_inv_var, out_running_mean, out_running_var = graph.batchnorm( + name="BN", + input=x, + scale=scale, + bias=bias, + in_running_mean=in_running_mean, + in_running_var=in_running_var, + epsilon=epsilon, + momentum=momentum, ) y = graph.relu(name="relu", input=y_before_relu) mask = graph.cmp_gt(name="cmp", input=y, comparison=comparison) @@ -125,9 +123,7 @@ def test_bn_relu_with_mask(cudnn_handle): comparison: comparison_gpu, mask: mask_gpu, } - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute( variant_pack, workspace, @@ -166,9 +162,7 @@ def test_bn_relu_with_mask(cudnn_handle): # fmt: on -@pytest.mark.parametrize( - "dump_dX_dRelu", [True, False], ids=lambda p: f"dump_dX_dRelu{int(p)}" -) +@pytest.mark.parametrize("dump_dX_dRelu", [True, False], ids=lambda p: f"dump_dX_dRelu{int(p)}") @pytest.mark.skipif( LooseVersion(cudnn.backend_version_string()) < "8.9", reason="DBN fusions not supported below cudnn 8.9", @@ -259,9 +253,7 @@ def test_drelu_dadd_dbn(dump_dX_dRelu, cudnn_handle): if dump_dX_dRelu: variant_pack[dX_drelu] = dX_dRelu_gpu - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace, handle=cudnn_handle) torch.cuda.synchronize() @@ -308,9 +300,7 @@ def test_bn_infer_drelu_dbn(cudnn_handle): stride=x_gpu.stride(), data_type=x_gpu.dtype, ) - dY = graph.tensor( - name="dY", dim=dY_gpu.size(), stride=dY_gpu.stride(), data_type=dY_gpu.dtype - ) + dY = graph.tensor(name="dY", dim=dY_gpu.size(), stride=dY_gpu.stride(), data_type=dY_gpu.dtype) scale = graph.tensor( name="scale", dim=scale_gpu.size(), @@ -336,9 +326,7 @@ def test_bn_infer_drelu_dbn(cudnn_handle): data_type=inv_var_gpu.dtype, ) - y = graph.batchnorm_inference( - input=x, mean=mean, inv_variance=inv_variance, scale=scale, bias=bias - ) + y = graph.batchnorm_inference(input=x, mean=mean, inv_variance=inv_variance, scale=scale, bias=bias) dX_dRelu = graph.relu_backward(loss=dY, input=y) @@ -381,9 +369,7 @@ def test_bn_infer_drelu_dbn(cudnn_handle): dBias: dBias_gpu, } - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace, handle=cudnn_handle) torch.cuda.synchronize() diff --git a/test/python/test_block_scale_quantize.py b/test/python/test_block_scale_quantize.py index 4c4c6c53..e0d8f94a 100644 --- a/test/python/test_block_scale_quantize.py +++ b/test/python/test_block_scale_quantize.py @@ -12,7 +12,7 @@ def get_cc(): """Get CUDA compute capability.""" - (major, minor) = torch.cuda.get_device_capability() + major, minor = torch.cuda.get_device_capability() return major * 10 + minor @@ -53,23 +53,13 @@ def calculate_block_scale_dims(m, n, k, block_size): INDESTRUCTIBLE_128x4_BLOCK_M_N = 128 INDESTRUCTIBLE_128x4_BLOCK_K = 4 - block_scale_dim_m = ( - div_up(m, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N - ) - block_scale_dim_n = ( - div_up(n, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N - ) - block_scale_dim_k = ( - div_up(div_up(k, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) - * INDESTRUCTIBLE_128x4_BLOCK_K - ) + block_scale_dim_m = div_up(m, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + block_scale_dim_n = div_up(n, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + block_scale_dim_k = div_up(div_up(k, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) * INDESTRUCTIBLE_128x4_BLOCK_K # For output quantization (lines 461-463) block_scale_dim_out_m = block_scale_dim_m - block_scale_dim_out_n = ( - div_up(div_up(n, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) - * INDESTRUCTIBLE_128x4_BLOCK_K - ) + block_scale_dim_out_n = div_up(div_up(n, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) * INDESTRUCTIBLE_128x4_BLOCK_K return ( block_scale_dim_m, @@ -203,14 +193,10 @@ def test_block_scale_quantize_matmul( ) # Dequantize A (lines 515-517) - dequant_tensor_a = g.block_scale_dequantize( - tensor_a, block_descale_a, block_size=[1, block_size], name="dequantize_a" - ) + dequant_tensor_a = g.block_scale_dequantize(tensor_a, block_descale_a, block_size=[1, block_size], name="dequantize_a") # Dequantize B (lines 519-521) - dequant_tensor_b = g.block_scale_dequantize( - tensor_b, block_descale_b, block_size=[block_size, 1], name="dequantize_b" - ) + dequant_tensor_b = g.block_scale_dequantize(tensor_b, block_descale_b, block_size=[block_size, 1], name="dequantize_b") # Matmul (lines 523-526) tensor_c = g.matmul( @@ -231,9 +217,7 @@ def test_block_scale_quantize_matmul( # Set output properties (lines 533-536) tensor_d.set_output(True).set_data_type(datatype_output) - block_scale.set_output(True).set_data_type(datatype_scale).set_reordering_type( - cudnn.tensor_reordering.F8_128x4 - ) + block_scale.set_output(True).set_data_type(datatype_scale).set_reordering_type(cudnn.tensor_reordering.F8_128x4) # Build and validate graph (lines 540-551) g.validate() @@ -250,32 +234,20 @@ def test_block_scale_quantize_matmul( # Using uint8 as a generic container since we're just testing the graph execution if dtype_a == "FP4_E2M1": # FP4 is packed, so size is smaller - tensor_a_data = torch.randint( - 0, 16, (b, m, k // 2), dtype=torch.uint8, device="cuda" - ) + tensor_a_data = torch.randint(0, 16, (b, m, k // 2), dtype=torch.uint8, device="cuda") elif dtype_a == "FP8_E4M3": - tensor_a_data = torch.randint( - 0, 256, (b, m, k), dtype=torch.uint8, device="cuda" - ) + tensor_a_data = torch.randint(0, 256, (b, m, k), dtype=torch.uint8, device="cuda") elif dtype_a == "FP8_E5M2": - tensor_a_data = torch.randint( - 0, 256, (b, m, k), dtype=torch.uint8, device="cuda" - ) + tensor_a_data = torch.randint(0, 256, (b, m, k), dtype=torch.uint8, device="cuda") else: tensor_a_data = torch.randn((b, m, k), dtype=torch.float16, device="cuda") if dtype_b == "FP4_E2M1": - tensor_b_data = torch.randint( - 0, 16, (b, k, n // 2), dtype=torch.uint8, device="cuda" - ) + tensor_b_data = torch.randint(0, 16, (b, k, n // 2), dtype=torch.uint8, device="cuda") elif dtype_b == "FP8_E4M3": - tensor_b_data = torch.randint( - 0, 256, (b, k, n), dtype=torch.uint8, device="cuda" - ) + tensor_b_data = torch.randint(0, 256, (b, k, n), dtype=torch.uint8, device="cuda") elif dtype_b == "FP8_E5M2": - tensor_b_data = torch.randint( - 0, 256, (b, k, n), dtype=torch.uint8, device="cuda" - ) + tensor_b_data = torch.randint(0, 256, (b, k, n), dtype=torch.uint8, device="cuda") else: tensor_b_data = torch.randn((b, k, n), dtype=torch.float16, device="cuda") @@ -315,9 +287,7 @@ def test_block_scale_quantize_matmul( # Output tensor if dtype_output == "FP4_E2M1": - tensor_d_data = torch.empty( - (b, m, n // 2), dtype=torch.uint8, device="cuda" - ) + tensor_d_data = torch.empty((b, m, n // 2), dtype=torch.uint8, device="cuda") elif dtype_output == "FP8_E4M3": tensor_d_data = torch.empty((b, m, n), dtype=torch.uint8, device="cuda") elif dtype_output == "FP8_E5M2": @@ -326,9 +296,7 @@ def test_block_scale_quantize_matmul( tensor_d_data = torch.empty((b, m, n), dtype=torch.float16, device="cuda") # Get workspace - workspace = torch.empty( - g.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8) # Execute (lines 557-565) variant_pack = { @@ -347,7 +315,5 @@ def test_block_scale_quantize_matmul( assert scale_output_data is not None print( - f"✓ Test passed: b={b}, m={m}, n={n}, k={k}, " - f"dtype_a={dtype_a}, dtype_b={dtype_b}, " - f"dtype_scale={dtype_scale}, dtype_output={dtype_output}" + f"✓ Test passed: b={b}, m={m}, n={n}, k={k}, " f"dtype_a={dtype_a}, dtype_b={dtype_b}, " f"dtype_scale={dtype_scale}, dtype_output={dtype_output}" ) diff --git a/test/python/test_block_scale_quantize_dynamic_shape.py b/test/python/test_block_scale_quantize_dynamic_shape.py new file mode 100644 index 00000000..8eb1eb52 --- /dev/null +++ b/test/python/test_block_scale_quantize_dynamic_shape.py @@ -0,0 +1,233 @@ +""" +Test suite for block_scale_quantize with dynamic shape overrides Python API. +Based on blackwell_nvfp4_mxfp8_block_scale_matmul.cpp +""" + +import cudnn +import pytest +import torch + +from test_utils import torch_fork_set_rng + + +def get_cc(): + """Get CUDA compute capability.""" + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def div_up(a, b): + """Integer division with rounding up.""" + return (a + b - 1) // b + + +def calculate_block_scale_dims(m, n, k, block_size): + """ + Calculate block scale dimensions using indestructible block formula. + Based on C++ lines 319-325, 454-463. + """ + INDESTRUCTIBLE_128x4_BLOCK_M_N = 128 + INDESTRUCTIBLE_128x4_BLOCK_K = 4 + + block_scale_dim_m = div_up(m, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + block_scale_dim_n = div_up(n, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + block_scale_dim_k = div_up(div_up(k, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) * INDESTRUCTIBLE_128x4_BLOCK_K + + return block_scale_dim_m, block_scale_dim_n, block_scale_dim_k + + +class TestBlockScaleQuantizeMatmulDynamicShape: + """ + Test block_scale_quantize API with full matmul workflow. + Based on C++ TEST_CASE "Blackwell Block Scale Matmul dynamic shape overrides" (lines 749-910). + """ + + @pytest.mark.skipif( + cudnn.backend_version() < 91800, + reason="block_scale_quantize requires cuDNN >= 9.18.0", + ) + @pytest.mark.skipif( + get_cc() < 100, + reason="block_scale_quantize requires CUDA compute capability larger than 100", + ) + @pytest.mark.parametrize( + "b,m,n,k", + [ + (1, 1024, 1024, 1024), + ], + ) + @pytest.mark.L0 + @torch_fork_set_rng(seed=999) + def test_block_scale_quantize_matmul_dynamic_shape(self, cudnn_handle, b, m, n, k): + """ + Test block_scale_quantize in a full matmul workflow: + 1. Create quantized inputs A, B with block scales + 2. Dequantize A and B + 3. Perform matmul + 4. Quantize output using block_scale_quantize + 5. Validate execution succeeds + + This mirrors the C++ test at lines 749-910. + """ + # Skip FP4 tests if PyTorch doesn't support it + if not hasattr(torch, "float4_e2m1fn_x2"): + pytest.skip("PyTorch does not support float4_e2m1fn_x2") + + A_UID = 1 + SF_A_UID = 2 + B_UID = 3 + SF_B_UID = 4 + C_UID = 5 + + datatype_a = cudnn.data_type.FP4_E2M1 + datatype_b = cudnn.data_type.FP4_E2M1 + datatype_scale = cudnn.data_type.FP8_E4M3 + datatype_output = cudnn.data_type.BFLOAT16 + block_size = 16 + + matmul_dynamic_shapes = [ + {"b": 2, "m": 1024, "n": 1024, "k": 1024}, + {"b": 2, "m": 2048, "n": 2048, "k": 2048}, + ] + + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + is_dynamic_shape_enabled=True, + ) + + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = calculate_block_scale_dims(m, n, k, block_size) + + A = graph.tensor( + name="A", + uid=A_UID, + dim=[b, m, k], + stride=[m * k, k, 1], + data_type=datatype_a, + ) + + SF_A = graph.tensor( + name="SF_A", + uid=SF_A_UID, + dim=[b, block_scale_dim_m, block_scale_dim_k], + stride=[block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1], + data_type=datatype_scale, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequan_tensor_a = graph.block_scale_dequantize(A, SF_A, block_size=[1, block_size], name="dequantize_a") + + B = graph.tensor( + name="B", + uid=B_UID, + dim=[b, k, n], + stride=[n * k, 1, k], + data_type=datatype_b, + ) + + SF_B = graph.tensor( + name="SF_B", + uid=SF_B_UID, + dim=[b, block_scale_dim_k, block_scale_dim_n], + stride=[block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k], + data_type=datatype_scale, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequan_tensor_b = graph.block_scale_dequantize(B, SF_B, block_size=[block_size, 1], name="dequantize_b") + + C = graph.matmul( + dequan_tensor_a, + dequan_tensor_b, + compute_data_type=cudnn.data_type.FLOAT, + name="matmul", + ) + C.set_uid(C_UID).set_output(True).set_data_type(datatype_output) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + for dynamic_shape in matmul_dynamic_shapes: + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = calculate_block_scale_dims( + dynamic_shape["m"], + dynamic_shape["n"], + dynamic_shape["k"], + block_size, + ) + + override_uids = [A_UID, SF_A_UID, B_UID, SF_B_UID, C_UID] + + override_shapes = [ + [dynamic_shape["b"], dynamic_shape["m"], dynamic_shape["k"]], + [dynamic_shape["b"], block_scale_dim_m, block_scale_dim_k], + [dynamic_shape["b"], dynamic_shape["k"], dynamic_shape["n"]], + [dynamic_shape["b"], block_scale_dim_k, block_scale_dim_n], + [dynamic_shape["b"], dynamic_shape["m"], dynamic_shape["n"]], + ] + + override_strides = [ + [dynamic_shape["m"] * dynamic_shape["k"], dynamic_shape["k"], 1], + [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1], + [dynamic_shape["n"] * dynamic_shape["k"], 1, dynamic_shape["k"]], + [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k], + [dynamic_shape["m"] * dynamic_shape["n"], dynamic_shape["n"], 1], + ] + + A_gpu = torch.randint( + 0, + 256, + (dynamic_shape["b"], dynamic_shape["m"], dynamic_shape["k"] // 2), + dtype=torch.uint8, + device="cuda", + ) + SF_A_gpu = torch.ones( + (b, block_scale_dim_m, block_scale_dim_k), + dtype=torch.float8_e4m3fn, + device="cuda", + ) + B_gpu = torch.randint( + 0, + 256, + (dynamic_shape["b"], dynamic_shape["k"] // 2, dynamic_shape["n"]), + dtype=torch.uint8, + device="cuda", + ) + SF_B_gpu = torch.ones( + (b, block_scale_dim_k, block_scale_dim_n), + dtype=torch.float8_e4m3fn, + device="cuda", + ) + C_gpu = torch.empty( + (dynamic_shape["b"], dynamic_shape["m"], dynamic_shape["n"]), + dtype=torch.bfloat16, + device="cuda", + ) + + variant_pack = { + A_UID: A_gpu, + SF_A_UID: SF_A_gpu, + B_UID: B_gpu, + SF_B_UID: SF_B_gpu, + C_UID: C_gpu, + } + + workspace_size = graph.get_workspace_size() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + + graph.execute( + variant_pack, + workspace, + handle=cudnn_handle, + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + torch.cuda.synchronize() + + print(f"✓ Test passed: b={b}, m={m}, n={n}, k={k}") diff --git a/test/python/test_conv_bias.py b/test/python/test_conv_bias.py index 5ed61375..67391a13 100644 --- a/test/python/test_conv_bias.py +++ b/test/python/test_conv_bias.py @@ -20,9 +20,7 @@ def forward( ): if b is not None: b = b.reshape(-1) # Conv2d needs a 1D tensor - conv_output = torch.nn.functional.conv2d( - x, w, bias=b, padding=padding, stride=stride, dilation=dilation - ) + conv_output = torch.nn.functional.conv2d(x, w, bias=b, padding=padding, stride=stride, dilation=dilation) return torch.clamp(conv_output, min=lower_clip, max=upper_clip) @@ -58,18 +56,14 @@ def create_conv_bias_relu_graph( ) bias_output = g.bias(name="bias", input=conv_output, bias=B) - Y = g.relu( - name="relu", input=bias_output, lower_clip=lower_clip, upper_clip=upper_clip - ) + Y = g.relu(name="relu", input=bias_output, lower_clip=lower_clip, upper_clip=upper_clip) Y.set_output(True) return g, [X, W, B, Y] @cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) -def create_conv_relu_graph( - handle, X_gpu, W_gpu, padding, stride, dilation, lower_clip=0.5, upper_clip=0.55 -): +def create_conv_relu_graph(handle, X_gpu, W_gpu, padding, stride, dilation, lower_clip=0.5, upper_clip=0.55): with cudnn.graph( handle, io_data_type=cudnn.data_type.HALF, @@ -79,13 +73,9 @@ def create_conv_relu_graph( X = g.tensor_like(X_gpu) W = g.tensor_like(W_gpu) - conv_output = g.conv_fprop( - image=X, weight=W, padding=padding, stride=stride, dilation=dilation - ) + conv_output = g.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) - Y = g.relu( - name="relu", input=conv_output, lower_clip=lower_clip, upper_clip=upper_clip - ) + Y = g.relu(name="relu", input=conv_output, lower_clip=lower_clip, upper_clip=upper_clip) Y.set_output(True) return g, [X, W, Y] @@ -99,15 +89,9 @@ def create_conv_relu_graph( @torch_fork_set_rng(seed=0) def test_conv_bias_relu(cudnn_handle): # Reference code - X_gpu = torch.randn(4, 16, 56, 56, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - W_gpu = torch.randn(16, 16, 3, 3, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - B_gpu = torch.randn(1, 16, 1, 1, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) + X_gpu = torch.randn(4, 16, 56, 56, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(16, 16, 3, 3, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + B_gpu = torch.randn(1, 16, 1, 1, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) padding = [1, 1] stride = [3, 3] dilation = [1, 1] @@ -128,12 +112,8 @@ def test_conv_bias_relu(cudnn_handle): stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - single_mode_graph = cudnn.jit(heur_modes=cudnn.heur_mode.A)( - create_conv_bias_relu_graph.__wrapped__ - ) - g, uids = single_mode_graph( - cudnn_handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation - ) + single_mode_graph = cudnn.jit(heur_modes=cudnn.heur_mode.A)(create_conv_bias_relu_graph.__wrapped__) + g, uids = single_mode_graph(cudnn_handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation) X_uid, W_uid, B_uid, Y_uid = uids @@ -154,12 +134,8 @@ def test_conv_bias_relu(cudnn_handle): @torch_fork_set_rng(seed=0) def test_conv_relu(cudnn_handle): # Reference code - X_gpu = torch.randn(20, 40, 30, 40, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - W_gpu = torch.randn(54, 40, 3, 4, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) + X_gpu = torch.randn(20, 40, 30, 40, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(54, 40, 3, 4, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) padding = [0, 1] stride = [2, 3] dilation = [1, 1] @@ -179,17 +155,13 @@ def test_conv_relu(cudnn_handle): stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - g, uids = create_conv_relu_graph( - cudnn_handle, X_gpu, W_gpu, padding, stride, dilation - ) + g, uids = create_conv_relu_graph(cudnn_handle, X_gpu, W_gpu, padding, stride, dilation) X_uid, W_uid, Y_uid = uids Y_actual = torch.zeros_like(Y_expected) workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8) - g.execute( - {X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle - ) + g.execute({X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) @@ -199,12 +171,8 @@ def test_conv_relu(cudnn_handle): @torch_fork_set_rng(seed=0) def test_conv_relu_execution_plan_creation(cudnn_handle): # Reference code - X_gpu = torch.randn( - 20, 40, 30, 40, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - W_gpu = torch.randn( - 54, 40, 3, 4, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) + X_gpu = torch.randn(20, 40, 30, 40, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(54, 40, 3, 4, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) padding = [0, 1] stride = [2, 3] dilation = [1, 1] @@ -230,16 +198,10 @@ def test_conv_relu_execution_plan_creation(cudnn_handle): handle=cudnn_handle, ) - X = graph.tensor( - name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype - ) - W = graph.tensor( - name="W", dim=W_gpu.size(), stride=W_gpu.stride(), data_type=W_gpu.dtype - ) + X = graph.tensor(name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype) + W = graph.tensor(name="W", dim=W_gpu.size(), stride=W_gpu.stride(), data_type=W_gpu.dtype) - conv_output = graph.conv_fprop( - image=X, weight=W, padding=padding, stride=stride, dilation=dilation - ) + conv_output = graph.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) Y = graph.relu(name="relu", input=conv_output, lower_clip=0.5, upper_clip=0.55) Y.set_output(True) @@ -256,22 +218,16 @@ def test_conv_relu_execution_plan_creation(cudnn_handle): for knob in knobs: if knob.type == cudnn.knob_type.KERNEL_CFG: - for kernel_cfg in range( - knob.min_value, knob.max_value + 1, knob.stride - ): + for kernel_cfg in range(knob.min_value, knob.max_value + 1, knob.stride): try: - graph.create_execution_plan( - engine, {cudnn.knob_type.KERNEL_CFG: kernel_cfg} - ) + graph.create_execution_plan(engine, {cudnn.knob_type.KERNEL_CFG: kernel_cfg}) except RuntimeError: continue graph.check_support() graph.build_plans() - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) Y_actual = torch.zeros_like(Y_expected) graph.execute({X: X_gpu, W: W_gpu, Y: Y_actual}, workspace, handle=cudnn_handle) @@ -281,9 +237,7 @@ def test_conv_relu_execution_plan_creation(cudnn_handle): @cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) -def create_conv3d_bias_leaky_relu_graph( - handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation, negative_slope -): +def create_conv3d_bias_leaky_relu_graph(handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation, negative_slope): with cudnn.graph( handle, io_data_type=cudnn.data_type.HALF, @@ -294,9 +248,7 @@ def create_conv3d_bias_leaky_relu_graph( W = g.tensor_like(W_gpu) B = g.tensor_like(B_gpu) - conv_output = g.conv_fprop( - image=X, weight=W, padding=padding, stride=stride, dilation=dilation - ) + conv_output = g.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) bias_output = g.bias(name="bias", input=conv_output, bias=B) Y = g.leaky_relu(name="relu", input=bias_output, negative_slope=negative_slope) @@ -320,15 +272,9 @@ def test_conv3d_bias_leaky_relu(cudnn_handle): negative_slope = 0.01 # Reference code - X_gpu = torch.randn(N, D, H, W, C, device="cuda", dtype=torch.float16).permute( - 0, 4, 1, 2, 3 - ) - W_gpu = torch.randn(K, R, S, T, C, device="cuda", dtype=torch.float16).permute( - 0, 4, 1, 2, 3 - ) - B_gpu = torch.randn(1, 1, 1, 1, K, device="cuda", dtype=torch.float16).permute( - 0, 4, 1, 2, 3 - ) + X_gpu = torch.randn(N, D, H, W, C, device="cuda", dtype=torch.float16).permute(0, 4, 1, 2, 3) + W_gpu = torch.randn(K, R, S, T, C, device="cuda", dtype=torch.float16).permute(0, 4, 1, 2, 3) + B_gpu = torch.randn(1, 1, 1, 1, K, device="cuda", dtype=torch.float16).permute(0, 4, 1, 2, 3) # Get reference result conv_out_expected = ( @@ -343,16 +289,12 @@ def test_conv3d_bias_leaky_relu(cudnn_handle): .to("cuda") .to(torch.float16) ) - Y_expected = torch.nn.functional.leaky_relu( - conv_out_expected, negative_slope=negative_slope - ) + Y_expected = torch.nn.functional.leaky_relu(conv_out_expected, negative_slope=negative_slope) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - g, uids = create_conv3d_bias_leaky_relu_graph( - cudnn_handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation, negative_slope - ) + g, uids = create_conv3d_bias_leaky_relu_graph(cudnn_handle, X_gpu, W_gpu, B_gpu, padding, stride, dilation, negative_slope) X_uid, W_uid, B_uid, Y_uid = uids Y_actual = torch.zeros_like(Y_expected) @@ -392,12 +334,8 @@ def test_leaky_relu_backward(cudnn_handle): negative_slope = 0.01 # Reference code - loss_gpu = torch.randn(N, C, H, W, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - input_gpu = torch.randn(N, C, H, W, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) + loss_gpu = torch.randn(N, C, H, W, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + input_gpu = torch.randn(N, C, H, W, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): return torch.ones_like(grad).masked_fill_(mask <= 0.0, negative_slope) * grad @@ -407,9 +345,7 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - g, uids = create_leaky_relu_backward_graph( - cudnn_handle, loss_gpu, input_gpu, negative_slope - ) + g, uids = create_leaky_relu_backward_graph(cudnn_handle, loss_gpu, input_gpu, negative_slope) loss_uid, input_uid, Y_uid = uids Y_actual = torch.zeros_like(Y_expected) @@ -436,9 +372,7 @@ def create_conv_int8_graph(handle, X_gpu, W_gpu, padding, stride, dilation): X = g.tensor_like(X_gpu) W = g.tensor_like(W_gpu) - conv_output = g.conv_fprop( - image=X, weight=W, padding=padding, stride=stride, dilation=dilation - ) + conv_output = g.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) Y = g.identity(name="identity", input=conv_output) Y.set_output(True).set_data_type(cudnn.data_type.INT32) @@ -461,43 +395,25 @@ def test_conv_int8(cudnn_handle): compare_output = True # Reference code - X_gpu = torch.randint(-127, 128, (N, C, H, W), device="cuda", dtype=torch.int8).to( - memory_format=torch.channels_last - ) - W_gpu = torch.randint(-127, 128, (K, C, R, S), device="cuda", dtype=torch.int8).to( - memory_format=torch.channels_last - ) + X_gpu = torch.randint(-127, 128, (N, C, H, W), device="cuda", dtype=torch.int8).to(memory_format=torch.channels_last) + W_gpu = torch.randint(-127, 128, (K, C, R, S), device="cuda", dtype=torch.int8).to(memory_format=torch.channels_last) try: - Y_expected = ( - torch.nn.functional.conv2d( - X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation - ) - .to("cuda") - .to(torch.int32) - ) + Y_expected = torch.nn.functional.conv2d(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation).to("cuda").to(torch.int32) except: - print( - "Torch does not support int8 convolution. Disabling comparison of output tensor" - ) + print("Torch does not support int8 convolution. Disabling comparison of output tensor") compare_output = False stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) - g, uids = create_conv_int8_graph( - cudnn_handle, X_gpu, W_gpu, padding, stride, dilation - ) + g, uids = create_conv_int8_graph(cudnn_handle, X_gpu, W_gpu, padding, stride, dilation) X_uid, W_uid, Y_uid = uids - Y_actual = torch.randint(0, 127, X_gpu.size(), device="cuda", dtype=torch.int32).to( - memory_format=torch.channels_last - ) + Y_actual = torch.randint(0, 127, X_gpu.size(), device="cuda", dtype=torch.int32).to(memory_format=torch.channels_last) workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8) - g.execute( - {X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle - ) + g.execute({X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() diff --git a/test/python/test_conv_fprop.py b/test/python/test_conv_fprop.py new file mode 100644 index 00000000..6ac4c06b --- /dev/null +++ b/test/python/test_conv_fprop.py @@ -0,0 +1,112 @@ +""" +Test for conv fprop using tvm-ffi based execute API. + +This test validates that the TVM-FFI migration for PyGraph::execute +works correctly by running a simple convolution forward pass. +""" + +import cudnn +import pytest +import torch +from test_utils import torch_fork_set_rng + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +def test_conv_fprop_tvm_ffi(cudnn_handle): + """Test conv fprop using the tvm-ffi based execute API.""" + # Setup tensors + X_gpu = torch.randn(4, 16, 32, 32, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(32, 16, 3, 3, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + padding = [1, 1] + stride = [1, 1] + dilation = [1, 1] + + # Reference result using PyTorch + Y_expected = torch.nn.functional.conv2d(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) + + # Set stream + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + + # Build cudnn graph + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.HALF, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + ) + + X = graph.tensor_like(X_gpu) + W = graph.tensor_like(W_gpu) + + conv_output = graph.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) + conv_output.set_output(True) + + graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + + Y_actual = torch.zeros_like(Y_expected) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + # Execute using the tvm-ffi based execute API + graph.execute( + {X: X_gpu, W: W_gpu, conv_output: Y_actual}, + workspace, + handle=cudnn_handle, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=0) +def test_conv_fprop_execute_plan_at_index_tvm_ffi(cudnn_handle): + """Test conv fprop using execute_plan_at_index with tvm-ffi.""" + # Setup tensors + X_gpu = torch.randn(2, 8, 16, 16, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(16, 8, 3, 3, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + padding = [1, 1] + stride = [1, 1] + dilation = [1, 1] + + # Reference result using PyTorch + Y_expected = torch.nn.functional.conv2d(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) + + # Set stream + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + + # Build cudnn graph + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.HALF, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + ) + + X = graph.tensor_like(X_gpu) + W = graph.tensor_like(W_gpu) + + conv_output = graph.conv_fprop(image=X, weight=W, padding=padding, stride=stride, dilation=dilation) + conv_output.set_output(True) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + Y_actual = torch.zeros_like(Y_expected) + workspace = torch.empty(graph.get_workspace_size_plan_at_index(0), device="cuda", dtype=torch.uint8) + + # Execute using execute_plan_at_index with tvm-ffi + graph.execute_plan_at_index( + {X: X_gpu, W: W_gpu, conv_output: Y_actual}, + workspace, + index=0, + handle=cudnn_handle, + ) + + torch.cuda.synchronize() + torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) diff --git a/test/python/test_conv_fuzzer.py b/test/python/test_conv_fuzzer.py new file mode 100644 index 00000000..e2f423cc --- /dev/null +++ b/test/python/test_conv_fuzzer.py @@ -0,0 +1,1158 @@ +""" +Convolution Fuzzer - Randomized stress testing for cuDNN convolution operations. + +This fuzzer tests convolution operations with randomized: +- Shapes (batch, channels, spatial dimensions) +- Spatial dimensions (2D or 3D) +- Data types (fp16, bf16, fp32, int8) +- Convolution parameters (padding, stride, dilation) +- Operation types (fprop, dgrad, wgrad) +- Epilogues (none, bias, relu, bias_relu) + +Layout: NHWC/NDHWC (channels last) for memory layout +Logical dimension order: N, C, spatial_dims... + +Run with: + pytest -vv -s -rA test_conv_fuzzer.py + +Options: + --num-tests N Number of random tests to run (default: 100) + --seed N Random seed for reproducibility (default: random) + --diffs N Number of mismatches to display (default: 10) +""" + +import cudnn +import pytest +import random +import torch +import math +import sys +import signal +from datetime import datetime +from dataclasses import dataclass +from typing import Optional, Tuple +from enum import IntEnum + +# fmt: off + +# Handle Ctrl-C gracefully +def signal_handler(sig, frame): + print("\n\nInterrupted by user (Ctrl-C), exiting...") + if torch.cuda.is_available(): + torch.cuda.synchronize() + sys.exit(1) + +signal.signal(signal.SIGINT, signal_handler) + +if __name__ == "__main__": + print("This is pytest script. Run with: pytest -vv -s -rA test_conv_fuzzer.py") + sys.exit(0) + + +# ============================================================================ +# Configuration and Constants +# ============================================================================ + +class ConvType(IntEnum): + FPROP = 0 # Forward convolution + DGRAD = 1 # Input gradient (backward data) + WGRAD = 2 # Weight gradient (backward filter) + +class EpilogueType(IntEnum): + NONE = 0 + BIAS = 1 + RELU = 2 + BIAS_RELU = 3 + +SUPPORTED_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, +] + +# int8 convolutions have stricter requirements, test separately +SUPPORTED_DTYPES_WITH_INT8 = SUPPORTED_DTYPES + [torch.int8] + + +# ============================================================================ +# Utility Functions +# ============================================================================ + +def convert_to_cudnn_type(torch_type): + """Convert PyTorch dtype to cuDNN data type.""" + mapping = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + torch.float32: cudnn.data_type.FLOAT, + torch.bool: cudnn.data_type.BOOLEAN, + torch.uint8: cudnn.data_type.UINT8, + torch.int8: cudnn.data_type.INT8, + torch.int32: cudnn.data_type.INT32, + torch.int64: cudnn.data_type.INT64, + } + if torch_type not in mapping: + raise ValueError(f"Unsupported tensor data type: {torch_type}") + return mapping[torch_type] + + +def get_gpu_arch(): + """Get GPU SM architecture version.""" + major, minor = torch.cuda.get_device_capability() + return f"SM_{major * 10 + minor}" + + +def get_sm_count(): + """Get number of SMs on the GPU.""" + props = torch.cuda.get_device_properties(0) + return props.multi_processor_count + + +def get_gpu_name(): + """Get GPU name.""" + return torch.cuda.get_device_name() + + +def conv_type_name(conv_type: ConvType) -> str: + """Get human-readable conv type name.""" + names = { + ConvType.FPROP: "fprop", + ConvType.DGRAD: "dgrad", + ConvType.WGRAD: "wgrad", + } + return names.get(conv_type, "unknown") + + +def epilogue_name(epilogue: EpilogueType) -> str: + """Get human-readable epilogue name.""" + names = { + EpilogueType.NONE: "none", + EpilogueType.BIAS: "bias", + EpilogueType.RELU: "relu", + EpilogueType.BIAS_RELU: "bias_relu", + } + return names.get(epilogue, "unknown") + + +def compute_channels_last_strides(shape: Tuple[int, ...]) -> Tuple[int, ...]: + """ + Compute channels-last strides for NHWC (2D) or NDHWC (3D) layout. + + Logical dim order: (N, C, spatial_dims...) + Memory order: N, spatial_dims..., C + + For 2D (NCHW logical -> NHWC memory): + shape = (N, C, H, W) + memory_order = (N, H, W, C) -> strides computed from last to first + strides[N] = H*W*C, strides[C] = 1, strides[H] = W*C, strides[W] = C + + For 3D (NCDHW logical -> NDHWC memory): + shape = (N, C, D, H, W) + memory_order = (N, D, H, W, C) -> strides computed from last to first + """ + ndim = len(shape) + if ndim < 3: + raise ValueError(f"Shape must have at least 3 dimensions, got {ndim}") + + # shape = (N, C, spatial_dims...) + N = shape[0] + C = shape[1] + spatial = shape[2:] # (H, W) or (D, H, W) + + # Memory layout: (N, spatial_dims..., C) + # Compute strides from innermost to outermost + strides = [0] * ndim + + # C is innermost in memory (stride = 1) + strides[1] = 1 + + # Spatial dims next (reversed order in memory) + stride = C + for i in range(ndim - 1, 1, -1): # W, H, [D] order + strides[i] = stride + stride *= shape[i] + + # N is outermost + strides[0] = stride + + return tuple(strides) + + +def compute_num_elements(shape: Tuple[int, ...], strides: Tuple[int, ...]) -> int: + """Compute number of elements needed for storage given shape and strides.""" + if not shape: + return 1 + max_offset = sum((d - 1) * s for d, s in zip(shape, strides)) + return max_offset + 1 + + +def compute_output_spatial(input_spatial: int, filter_spatial: int, + padding: int, stride: int, dilation: int) -> int: + """Compute output spatial dimension for convolution.""" + effective_filter = (filter_spatial - 1) * dilation + 1 + return (input_spatial + 2 * padding - effective_filter) // stride + 1 + + +def fill_with_garbage(tensor: torch.Tensor, nan_probability: float = 0.1) -> None: + """ + Fill tensor with garbage values (mix of random values and NaNs). + This helps catch bugs where cuDNN doesn't write all output locations. + """ + # Choose range based on dtype to avoid overflow + if tensor.dtype in (torch.float16, torch.bfloat16): + lo, hi = -1e4, 1e4 # FP16 max is ~65504 + else: + lo, hi = -1e6, 1e6 + + # Fill with random garbage + tensor.uniform_(lo, hi) + + # Sprinkle in some NaNs (only for float types) + if nan_probability > 0 and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + nan_mask = torch.rand(tensor.shape, device=tensor.device) < nan_probability + tensor[nan_mask] = float('nan') + + +# ============================================================================ +# Test Configuration +# ============================================================================ + +@dataclass +class ConvConfig: + """Configuration for a single convolution test.""" + # Spatial dimensions (2 for 2D, 3 for 3D) + spatial_dims: int + + # Basic dimensions + batch: int # N + in_channels: int # C_in + out_channels: int # C_out (K) + + # Spatial sizes: (H, W) for 2D or (D, H, W) for 3D + input_spatial: Tuple[int, ...] # Input spatial dimensions + filter_spatial: Tuple[int, ...] # Filter spatial dimensions + + # Convolution parameters (per spatial dimension) + padding: Tuple[int, ...] + stride: Tuple[int, ...] + dilation: Tuple[int, ...] + + # Operation type + conv_type: ConvType + + # Data types + x_dtype: torch.dtype # Input dtype + w_dtype: torch.dtype # Weight dtype + y_dtype: torch.dtype # Output dtype + + # Epilogue (only for fprop) + epilogue: EpilogueType + + # Random seed for data generation + rng_seed: int + + # Computed shapes and strides (set during tensor creation) + # Logical order: (N, C, spatial...) + x_shape: Tuple[int, ...] = None + w_shape: Tuple[int, ...] = None + y_shape: Tuple[int, ...] = None + x_strides: Tuple[int, ...] = None + w_strides: Tuple[int, ...] = None + y_strides: Tuple[int, ...] = None + x_elems: int = 0 + w_elems: int = 0 + y_elems: int = 0 + + # Bias tensor info (for epilogue) + bias_shape: Tuple[int, ...] = None + bias_strides: Tuple[int, ...] = None + bias_elems: int = 0 + + @property + def output_spatial(self) -> Tuple[int, ...]: + """Compute output spatial dimensions.""" + return tuple( + compute_output_spatial(inp, flt, pad, strd, dil) + for inp, flt, pad, strd, dil in zip( + self.input_spatial, self.filter_spatial, + self.padding, self.stride, self.dilation + ) + ) + + def to_repro_dict(self) -> dict: + """Convert config to reproducible dictionary.""" + return { + 'spatial_dims': self.spatial_dims, + 'batch': self.batch, + 'in_channels': self.in_channels, + 'out_channels': self.out_channels, + 'input_spatial': self.input_spatial, + 'filter_spatial': self.filter_spatial, + 'padding': self.padding, + 'stride': self.stride, + 'dilation': self.dilation, + 'conv_type': int(self.conv_type), + 'x_dtype': str(self.x_dtype), + 'w_dtype': str(self.w_dtype), + 'y_dtype': str(self.y_dtype), + 'epilogue': int(self.epilogue), + 'rng_seed': self.rng_seed, + } + + +class ConfigGenerator: + """Generator for random convolution configurations.""" + + def __init__(self, seed: int, allow_unaligned: bool = False): + self.rng = random.Random(seed) + self.sm_version = torch.cuda.get_device_capability()[0] * 10 + torch.cuda.get_device_capability()[1] + self.allow_unaligned = allow_unaligned + + def random_spatial_dims(self) -> int: + """Generate random spatial dimension count (2 or 3).""" + return self.rng.choice([2, 2, 2, 3]) # Prefer 2D + + def random_batch(self) -> int: + """Generate random batch size.""" + return self.rng.choice([1, 1, 2, 2, 4, 4, 8]) + + def random_channels(self, min_val: int = 1, max_val: int = 256) -> int: + """Generate random channel count (reduced max for memory).""" + val = self.rng.randint(int(math.sqrt(min_val)), int(math.sqrt(max_val))) + val = val * val + if self.allow_unaligned: + return max(1, val) + else: + # Round up to multiple of 8 for tensor core alignment + return max(8, ((val + 7) // 8) * 8) + + def random_spatial_size(self, min_val: int = 1, max_val: int = 128) -> int: + """Generate random spatial dimension size (reduced max for memory).""" + val = self.rng.randint(int(math.sqrt(min_val)), int(math.sqrt(max_val))) + val = val * val + if self.allow_unaligned: + return max(1, val) + else: + return max(8, ((val + 7) // 8) * 8) + + def random_filter_size(self) -> int: + """Generate random filter spatial size.""" + return self.rng.choice([1, 1, 3, 3, 3, 5, 7]) + + def random_padding(self, filter_size: int) -> int: + """Generate random padding.""" + # Padding typically 0 to (filter_size - 1) // 2 + max_pad = (filter_size - 1) // 2 + return self.rng.randint(0, max(0, max_pad)) + + def random_stride(self) -> int: + """Generate random stride.""" + return self.rng.choice([1, 1, 1, 2, 2, 3]) + + def random_dilation(self) -> int: + """Generate random dilation.""" + return self.rng.choice([1, 1, 1, 1, 2]) + + def random_dtype(self) -> torch.dtype: + """Generate random data type.""" + return self.rng.choice(SUPPORTED_DTYPES) + + def random_conv_type(self) -> ConvType: + """Generate random convolution type.""" + # Weight towards fprop but include dgrad/wgrad + weights = [0.5, 0.25, 0.25] # fprop, dgrad, wgrad + return self.rng.choices(list(ConvType), weights=weights)[0] + + def random_epilogue(self) -> EpilogueType: + """Generate random epilogue type.""" + weights = [0.6, 0.15, 0.15, 0.1] + return self.rng.choices(list(EpilogueType), weights=weights)[0] + + def generate(self) -> ConvConfig: + """Generate a random convolution configuration.""" + spatial_dims = self.random_spatial_dims() + + batch = self.random_batch() + in_channels = self.random_channels() + out_channels = self.random_channels() + + # Generate spatial dimensions + input_spatial = tuple(self.random_spatial_size() for _ in range(spatial_dims)) + filter_spatial = tuple(self.random_filter_size() for _ in range(spatial_dims)) + + # Ensure output spatial dims are positive + padding = [] + stride = [] + dilation = [] + for i in range(spatial_dims): + flt = filter_spatial[i] + dil = self.random_dilation() + strd = self.random_stride() + pad = self.random_padding(flt) + + # Check output size is positive + effective_filter = (flt - 1) * dil + 1 + out_size = (input_spatial[i] + 2 * pad - effective_filter) // strd + 1 + + # If output would be non-positive, adjust padding + while out_size < 1: + pad += 1 + out_size = (input_spatial[i] + 2 * pad - effective_filter) // strd + 1 + + padding.append(pad) + stride.append(strd) + dilation.append(dil) + + padding = tuple(padding) + stride = tuple(stride) + dilation = tuple(dilation) + + # Convolution type + conv_type = self.random_conv_type() + + # Data types - ensure compatible combinations + # Keep same dtype for all tensors for stability (like test_conv_bias.py) + x_dtype = self.random_dtype() + w_dtype = x_dtype # Same dtype for input and weights + y_dtype = x_dtype # Same dtype for output (mixed precision needs special handling) + + # Epilogue only for fprop + if conv_type == ConvType.FPROP: + epilogue = self.random_epilogue() + else: + epilogue = EpilogueType.NONE + + config = ConvConfig( + spatial_dims=spatial_dims, + batch=batch, + in_channels=in_channels, + out_channels=out_channels, + input_spatial=input_spatial, + filter_spatial=filter_spatial, + padding=padding, + stride=stride, + dilation=dilation, + conv_type=conv_type, + x_dtype=x_dtype, + w_dtype=w_dtype, + y_dtype=y_dtype, + epilogue=epilogue, + rng_seed=self.rng.randint(0, 2**31 - 1), + ) + + return config + + +# ============================================================================ +# Test Execution +# ============================================================================ + +def create_tensors(config: ConvConfig, rng: random.Random): + """ + Create tensors based on configuration. + + Tensor naming convention (shapes are always the same regardless of conv_type): + X: (N, C_in, spatial...) - input image shape + W: (C_out, C_in, filter...) - weight/filter shape + Y: (N, C_out, output_spatial...) - output shape + + Meaning varies by conv_type: + FPROP: X=input, W=weights, Y=output (compute Y from X,W) + DGRAD: X=dX(output), W=weights, Y=dY(input) (compute dX from dY,W) + WGRAD: X=input, W=dW(output), Y=dY(input) (compute dW from X,dY) + """ + torch_rng = torch.Generator(device='cuda') + torch_rng.manual_seed(config.rng_seed) + + # Compute shapes (same for all conv types) + x_shape = (config.batch, config.in_channels) + config.input_spatial + w_shape = (config.out_channels, config.in_channels) + config.filter_spatial + y_shape = (config.batch, config.out_channels) + config.output_spatial + + # Use PyTorch's native channels_last memory format for proper cuDNN compatibility + if config.spatial_dims == 2: + memory_format = torch.channels_last + else: # 3D + memory_format = torch.channels_last_3d + + # Create tensors - which ones are input (random) vs output (garbage) depends on conv_type + # Output tensors are filled with garbage (random + NaNs) to catch bugs where cuDNN + # doesn't write all output locations + if config.conv_type == ConvType.FPROP: + # FPROP: X,W are inputs, Y is output + X = torch.empty(x_shape, device='cuda', dtype=config.x_dtype).to(memory_format=memory_format) + X.normal_(mean=0.5, std=0.1, generator=torch_rng) + W = torch.empty(w_shape, device='cuda', dtype=config.w_dtype).to(memory_format=memory_format) + W.normal_(mean=0.5, std=0.1, generator=torch_rng) + Y = torch.empty(y_shape, device='cuda', dtype=config.y_dtype).to(memory_format=memory_format) + fill_with_garbage(Y) # Output - fill with garbage + + elif config.conv_type == ConvType.DGRAD: + # DGRAD: Y(dY),W are inputs, X(dX) is output + Y = torch.empty(y_shape, device='cuda', dtype=config.y_dtype).to(memory_format=memory_format) + Y.normal_(mean=0.5, std=0.1, generator=torch_rng) # dY - gradient from upstream + W = torch.empty(w_shape, device='cuda', dtype=config.w_dtype).to(memory_format=memory_format) + W.normal_(mean=0.5, std=0.1, generator=torch_rng) # weights + X = torch.empty(x_shape, device='cuda', dtype=config.x_dtype).to(memory_format=memory_format) + fill_with_garbage(X) # dX output - fill with garbage + + else: # WGRAD + # WGRAD: X,Y(dY) are inputs, W(dW) is output + X = torch.empty(x_shape, device='cuda', dtype=config.x_dtype).to(memory_format=memory_format) + X.normal_(mean=0.5, std=0.1, generator=torch_rng) # input image + Y = torch.empty(y_shape, device='cuda', dtype=config.y_dtype).to(memory_format=memory_format) + Y.normal_(mean=0.5, std=0.1, generator=torch_rng) # dY - gradient from upstream + W = torch.empty(w_shape, device='cuda', dtype=config.w_dtype).to(memory_format=memory_format) + fill_with_garbage(W) # dW output - fill with garbage + + # Update config with actual shapes and strides + config.x_shape = tuple(X.size()) + config.w_shape = tuple(W.size()) + config.y_shape = tuple(Y.size()) + config.x_strides = tuple(X.stride()) + config.w_strides = tuple(W.stride()) + config.y_strides = tuple(Y.stride()) + config.x_elems = X.numel() + config.w_elems = W.numel() + config.y_elems = Y.numel() + + # Bias tensor if needed (only for FPROP, shape: 1, K, 1, 1, ... for broadcasting) + bias = None + if config.conv_type == ConvType.FPROP and config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU]: + bias_shape = (1, config.out_channels) + (1,) * config.spatial_dims + bias = torch.empty(bias_shape, device='cuda', dtype=config.y_dtype).contiguous() + bias.normal_(mean=0.0, std=0.1, generator=torch_rng) + + config.bias_shape = tuple(bias.size()) + config.bias_strides = tuple(bias.stride()) + config.bias_elems = bias.numel() + + return X, W, Y, bias + + +def compute_reference(config: ConvConfig, X: torch.Tensor, W: torch.Tensor, + Y: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + """ + Compute reference result using PyTorch. + + Convention: + FPROP: compute Y = conv(X, W) + bias + relu + DGRAD: compute dX from dY(=Y) and W + WGRAD: compute dW from X and dY(=Y) + + Returns the tensor that should match the cuDNN output: + FPROP -> Y_ref (to compare with Y) + DGRAD -> dX_ref (to compare with X) + WGRAD -> dW_ref (to compare with W) + """ + compute_dtype = torch.float32 + + if config.conv_type == ConvType.FPROP: + # FPROP: Y = conv(X, W) + X_f = X.to(compute_dtype).contiguous() + W_f = W.to(compute_dtype).contiguous() + + try: + if config.spatial_dims == 2: + ref = torch.nn.functional.conv2d( + X_f, W_f, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + else: + ref = torch.nn.functional.conv3d( + X_f, W_f, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + + # Apply epilogue + if bias is not None and config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU]: + ref = ref + bias.to(compute_dtype) + if config.epilogue in [EpilogueType.RELU, EpilogueType.BIAS_RELU]: + ref = torch.relu(ref) + + return ref.to(config.y_dtype) + finally: + del X_f, W_f + + elif config.conv_type == ConvType.DGRAD: + # DGRAD: dX = conv_dgrad(dY, W) + # Y contains dY (gradient from upstream), W contains weights + # We compute dX (gradient w.r.t. input) + dY_f = Y.to(compute_dtype).contiguous() + W_f = W.to(compute_dtype).contiguous() + + # Use autograd to compute the reference + # Create a dummy input and run forward, then backward to get dX + dummy_X = torch.zeros(config.x_shape, device='cuda', dtype=compute_dtype, requires_grad=True) + + try: + if config.spatial_dims == 2: + dummy_Y = torch.nn.functional.conv2d( + dummy_X, W_f, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + else: + dummy_Y = torch.nn.functional.conv3d( + dummy_X, W_f, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + + # Backward pass to get dX + dummy_Y.backward(dY_f) + dX_ref = dummy_X.grad.clone() + + return dX_ref.to(config.x_dtype) + finally: + del dY_f, W_f, dummy_X, dummy_Y + + else: # WGRAD + # WGRAD: dW = conv_wgrad(X, dY) + # X contains input, Y contains dY (gradient from upstream) + # We compute dW (gradient w.r.t. weights) + X_f = X.to(compute_dtype).contiguous() + dY_f = Y.to(compute_dtype).contiguous() + + # Use autograd to compute the reference + # Create a dummy weight and run forward, then backward to get dW + dummy_W = torch.zeros(config.w_shape, device='cuda', dtype=compute_dtype, requires_grad=True) + + try: + if config.spatial_dims == 2: + dummy_Y = torch.nn.functional.conv2d( + X_f, dummy_W, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + else: + dummy_Y = torch.nn.functional.conv3d( + X_f, dummy_W, + padding=config.padding, + stride=config.stride, + dilation=config.dilation + ) + + # Backward pass to get dW + dummy_Y.backward(dY_f) + dW_ref = dummy_W.grad.clone() + + return dW_ref.to(config.w_dtype) + finally: + del X_f, dY_f, dummy_W, dummy_Y + + +def run_cudnn_conv(config: ConvConfig, X: torch.Tensor, W: torch.Tensor, Y: torch.Tensor, + bias: Optional[torch.Tensor], cudnn_handle) -> Tuple[bool, str]: + """ + Run convolution using cuDNN and return success status and message. + + Convention: + FPROP: inputs=X,W, output=Y (compute Y) + DGRAD: inputs=Y(dY),W, output=X (compute dX into X) + WGRAD: inputs=X,Y(dY), output=W (compute dW into W) + """ + try: + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + + # Determine compute and IO data types + if config.x_dtype == torch.float32: + io_dtype = cudnn.data_type.FLOAT + elif config.x_dtype == torch.bfloat16: + io_dtype = cudnn.data_type.BFLOAT16 + else: + io_dtype = cudnn.data_type.HALF + + # Create graph + graph = cudnn.pygraph( + handle=cudnn_handle, + io_data_type=io_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + # Build convolution operation based on conv_type + if config.conv_type == ConvType.FPROP: + # FPROP: Y = conv(X, W) + X_tensor = graph.tensor( + name="X", dim=list(X.size()), stride=list(X.stride()), + data_type=convert_to_cudnn_type(config.x_dtype) + ) + W_tensor = graph.tensor( + name="W", dim=list(W.size()), stride=list(W.stride()), + data_type=convert_to_cudnn_type(config.w_dtype) + ) + + conv_output = graph.conv_fprop( + image=X_tensor, + weight=W_tensor, + padding=list(config.padding), + stride=list(config.stride), + dilation=list(config.dilation), + ) + + # Apply epilogue + if config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU]: + B_tensor = graph.tensor( + name="B", dim=list(bias.size()), stride=list(bias.stride()), + data_type=convert_to_cudnn_type(config.y_dtype) + ) + conv_output = graph.bias(name="bias", input=conv_output, bias=B_tensor) + + if config.epilogue in [EpilogueType.RELU, EpilogueType.BIAS_RELU]: + conv_output = graph.relu(name="relu", input=conv_output) + + conv_output.set_output(True) + + # Execution dict: X,W are inputs, Y is output + exec_dict = {X_tensor: X, W_tensor: W, conv_output: Y} + if config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU]: + exec_dict[B_tensor] = bias + + elif config.conv_type == ConvType.DGRAD: + # DGRAD: dX = conv_dgrad(dY, W) + # Y contains dY (input), W contains weights, X is where we store dX (output) + dY_tensor = graph.tensor( + name="dY", dim=list(Y.size()), stride=list(Y.stride()), + data_type=convert_to_cudnn_type(config.y_dtype) + ) + W_tensor = graph.tensor( + name="W", dim=list(W.size()), stride=list(W.stride()), + data_type=convert_to_cudnn_type(config.w_dtype) + ) + + conv_output = graph.conv_dgrad( + loss=dY_tensor, + filter=W_tensor, + padding=list(config.padding), + stride=list(config.stride), + dilation=list(config.dilation), + ) + # Must set output dimensions explicitly for dgrad (cuDNN can't infer them) + conv_output.set_output(True).set_dim(list(X.size())).set_stride(list(X.stride())) + + # Execution dict: Y(dY),W are inputs, X(dX) is output + exec_dict = {dY_tensor: Y, W_tensor: W, conv_output: X} + + else: # WGRAD + # WGRAD: dW = conv_wgrad(X, dY) + # X contains input, Y contains dY, W is where we store dW (output) + X_tensor = graph.tensor( + name="X", dim=list(X.size()), stride=list(X.stride()), + data_type=convert_to_cudnn_type(config.x_dtype) + ) + dY_tensor = graph.tensor( + name="dY", dim=list(Y.size()), stride=list(Y.stride()), + data_type=convert_to_cudnn_type(config.y_dtype) + ) + + conv_output = graph.conv_wgrad( + image=X_tensor, + loss=dY_tensor, + padding=list(config.padding), + stride=list(config.stride), + dilation=list(config.dilation), + ) + conv_output.set_output(True).set_dim(list(W.size())).set_stride(list(W.stride())) + + # Execution dict: X,Y(dY) are inputs, W(dW) is output + exec_dict = {X_tensor: X, dY_tensor: Y, conv_output: W} + + # Validate and build + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + # Allocate workspace and fill with garbage to catch uninitialized memory bugs + workspace_size = graph.get_workspace_size() + workspace = torch.empty(workspace_size, device='cuda', dtype=torch.uint8) + if workspace_size > 0: + # Fill with random garbage + some NaN patterns to test proper workspace init + workspace.random_(0, 256) + # Sprinkle in NaN bit patterns (0x7FC00000 for float32 NaN) + nan_mask = torch.rand(workspace_size, device='cuda') < 0.1 + workspace[nan_mask] = 0xFF + + graph.execute(exec_dict, workspace, handle=cudnn_handle) + torch.cuda.synchronize() + + return True, "Success" + + except cudnn.cudnnGraphNotSupportedError as e: + return False, f"Graph not supported: {e}" + except Exception as e: + return False, f"Error: {e}" + + +def compare_results(actual: torch.Tensor, ref: torch.Tensor, _dtype: torch.dtype, + num_diffs: int = 10) -> Tuple[bool, str]: + """Compare cuDNN result with reference.""" + # Base tolerances - TF32/FP16/BF16 all have similar effective precision + # (TF32 and FP16 have 10-bit mantissa, BF16 has 7-bit but we use same tolerance) + # cuDNN uses TF32 for FP32 tensor core ops + # _dtype kept for potential future per-dtype tolerance tuning + rtol, atol = 1e-2, 1e-2 + + if ref.shape != actual.shape: + return False, f"Shape mismatch: actual={actual.shape}, ref={ref.shape}" + + # Compare + actual_f = actual.to(torch.float32).contiguous() + ref_f = ref.to(torch.float32).contiguous() + + diff = torch.abs(actual_f - ref_f) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + # Relative difference + denom = torch.maximum(torch.abs(ref_f), torch.tensor(1e-6, device='cuda')) + rel_diff = diff / denom + max_rel_diff = rel_diff.max().item() + + # Find mismatches - element fails if it exceeds BOTH tolerances + # mismatch_mask = (diff > atol) & (rel_diff > rtol) + mismatch_mask = (diff > torch.abs(atol + rtol * ref_f)) + mismatch_indices = torch.nonzero(mismatch_mask) + num_mismatches = mismatch_indices.shape[0] + + # Pass if no elements fail both tolerance checks + passed = num_mismatches == 0 + + if passed: + return True, f"max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}, max_rel_diff={max_rel_diff:.2e}" + else: + msg = f"MISMATCH: {num_mismatches} elements differ (max_diff={max_diff:.2e}, max_rel_diff={max_rel_diff:.2e})\n" + for i in range(min(num_diffs, num_mismatches)): + idx = tuple(mismatch_indices[i].tolist()) + act_val = actual_f[idx].item() + ref_val = ref_f[idx].item() + d = diff[idx].item() + msg += f" [{idx}]: actual={act_val:.6f}, expected={ref_val:.6f}, diff={d:.2e} tol={atol + rtol * ref_f[idx].item():.2e}\n" + + return False, msg + + +def estimate_memory_mb(config: ConvConfig) -> float: + """Estimate GPU memory usage in MB for tensors (X, W, Y, Y_ref, bias).""" + dtype_bytes = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + torch.int8: 1, + } + elem_size = dtype_bytes.get(config.x_dtype, 4) + + # X, W, Y tensors + Y_ref (float32 for comparison) + x_bytes = config.x_elems * elem_size + w_bytes = config.w_elems * elem_size + y_bytes = config.y_elems * elem_size + y_ref_bytes = config.y_elems * 4 # float32 + + total = x_bytes + w_bytes + y_bytes + y_ref_bytes + if config.bias_elems: + total += config.bias_elems * elem_size + + return total / (1024 * 1024) + + +def format_test_header(config: ConvConfig, test_num: int, total_tests: int, test_name: str) -> str: + """Format test header similar to matmul fuzzer.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + gpu_info = f"{get_gpu_arch()} ({get_sm_count()} SM-s, {get_gpu_name()})" + + spatial_str = "2D" if config.spatial_dims == 2 else "3D" + mem_mb = estimate_memory_mb(config) + + lines = [ + "", # Newline to separate from pytest's test name line + "=" * 90, + f"#### Test #{test_num} of {total_tests} at {timestamp} ", + "", + f"test_name = {test_name}", + f"platform_info = {gpu_info}, cudnn_ver={cudnn.backend_version()}", + f"rng_data_seed = {config.rng_seed}", + f"conv_type = {conv_type_name(config.conv_type)} ({spatial_str})", + f"basic_dims = [N={config.batch}, C_in={config.in_channels}, C_out={config.out_channels}]", + f"input_spatial = {config.input_spatial}", + f"filter_spatial = {config.filter_spatial}", + f"output_spatial = {config.output_spatial}", + f"padding = {config.padding}", + f"stride = {config.stride}", + f"dilation = {config.dilation}", + f"x(N,C,spatial) = dim={config.x_shape}, strides={config.x_strides}, elems={config.x_elems}, type={config.x_dtype}", + f"w(K,C,spatial) = dim={config.w_shape}, strides={config.w_strides}, elems={config.w_elems}, type={config.w_dtype}", + f"y(N,K,spatial) = dim={config.y_shape}, strides={config.y_strides}, elems={config.y_elems}, type={config.y_dtype}", + ] + + if config.bias_shape: + lines.append(f"bias(1,K,1...) = dim={config.bias_shape}, strides={config.bias_strides}, elems={config.bias_elems}, type={config.y_dtype}") + + lines.extend([ + f"epilogue = {epilogue_name(config.epilogue)}", + f"est_memory = {mem_mb:.1f} MB", + f"repro_cmd = pytest -vv -s -rA {__file__}::test_repro --repro \"{config.to_repro_dict()}\"", + " ", + ]) + + return "\n".join(lines) + + +# ============================================================================ +# Pytest Fixtures and Configuration +# ============================================================================ +# Note: pytest_addoption is defined in conftest.py +# Options used: --seed, --num-tests, --diffs, --repro + +@pytest.fixture +def num_diffs(request): + return request.config.getoption("--diffs") + + +# ============================================================================ +# Test Parameter Generation +# ============================================================================ + +def tlist_with_configs(*, num_tests: int, rng_seed: int, allow_unaligned: bool = False): + """Generate list of test parameters with pre-generated configs for descriptive test names.""" + rng = random.Random(rng_seed) + params = [] + for i in range(num_tests): + config_seed = rng.randint(65536, 2**31 - 1) + generator = ConfigGenerator(config_seed, allow_unaligned=allow_unaligned) + config = generator.generate() + params.append((i + 1, num_tests, config_seed, config)) + return params + + +def make_test_id(param, prefix: str = "t"): + """Create descriptive test ID from pre-generated config.""" + test_num, total_tests, config_seed, config = param + dtype_short = { + torch.float16: 'f16', + torch.bfloat16: 'bf16', + torch.float32: 'f32', + } + dt = dtype_short.get(config.x_dtype, 'unk') + spatial = '2d' if config.spatial_dims == 2 else '3d' + conv = conv_type_name(config.conv_type)[:2] # fp, dg, wg + epi = epilogue_name(config.epilogue)[:4] + # Example: t1_N2_C64x128_32x32_f16_2d_fp_none + spatial_str = 'x'.join(str(s) for s in config.input_spatial) + return f"{prefix}{test_num}_N{config.batch}_C{config.in_channels}x{config.out_channels}_{spatial_str}_{dt}_{spatial}_{conv}_{epi}" + + +# Pre-generated test parameter lists +DEFAULT_NUM_TESTS = 1024 +DEFAULT_SEED_L0 = 42 +DEFAULT_SEED_L1 = 12345 + +TEST_PARAMS_L0 = tlist_with_configs(num_tests=DEFAULT_NUM_TESTS, rng_seed=DEFAULT_SEED_L0, allow_unaligned=False) +TEST_PARAMS_L1 = tlist_with_configs(num_tests=DEFAULT_NUM_TESTS, rng_seed=DEFAULT_SEED_L1, allow_unaligned=True) + +SKIP_TEST_NUMS_L0 = {} + +# ============================================================================ +# Test Functions +# ============================================================================ + +@pytest.mark.L0 +@pytest.mark.parametrize("test_num,total_tests,config_seed,config", TEST_PARAMS_L0, + ids=[make_test_id(p) for p in TEST_PARAMS_L0]) +def test_conv_random_L0_0(test_num: int, total_tests: int, config_seed: int, config: ConvConfig, cudnn_handle, num_diffs, request): + """Random convolution tests (fprop/dgrad/wgrad) with aligned dimensions (L0).""" + # Skip known failing tests + if test_num in SKIP_TEST_NUMS_L0: + pytest.skip(f"Known failing test (dgrad f32 precision issue)") + + # Create tensors + rng = random.Random(config_seed) + X, W, Y, bias = create_tensors(config, rng) + ref = None + + try: + # Print test header + test_name = f"test_conv_random_L0_0[{make_test_id((test_num, total_tests, config_seed, config))}]" + print(format_test_header(config, test_num, total_tests, test_name)) + + # Run cuDNN + success, msg = run_cudnn_conv(config, X, W, Y, bias, cudnn_handle) + + if not success: + print(f"%%%% cuDNN execution failed: {msg}") + pytest.skip(f"cuDNN not supported: {msg}") + return + + # Compute reference and compare + ref = compute_reference(config, X, W, Y, bias) + + # Determine which tensor to compare based on conv_type + if config.conv_type == ConvType.FPROP: + actual, dtype, name = Y, config.y_dtype, "Y" + elif config.conv_type == ConvType.DGRAD: + actual, dtype, name = X, config.x_dtype, "dX" + else: # WGRAD + actual, dtype, name = W, config.w_dtype, "dW" + + passed, compare_msg = compare_results(actual, ref, dtype, num_diffs) + + if passed: + print(f"%%%% Numerical divergence of '{name}' within limits ({compare_msg})") + print("@@@@ Overall result: PASSED, everything looks good!") + else: + print(f"%%%% {compare_msg}") + print("@@@@ Overall result: FAILED, numerical mismatch!") + pytest.fail(f"Numerical mismatch: {compare_msg}") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del X, W, Y + if bias is not None: + del bias + if ref is not None: + del ref + torch.cuda.empty_cache() + + +@pytest.mark.L1 +@pytest.mark.parametrize("test_num,total_tests,config_seed,config", TEST_PARAMS_L1, + ids=[make_test_id(p, prefix="u") for p in TEST_PARAMS_L1]) +def test_conv_random_L0_1(test_num: int, total_tests: int, config_seed: int, config: ConvConfig, cudnn_handle, num_diffs, request): + """Random convolution tests (fprop/dgrad/wgrad) with unaligned dimensions (L1).""" + # Create tensors + rng = random.Random(config_seed) + X, W, Y, bias = create_tensors(config, rng) + ref = None + + try: + # Print test header + test_name = f"test_conv_random_L0_1[{make_test_id((test_num, total_tests, config_seed, config), prefix='u')}]" + print(format_test_header(config, test_num, total_tests, test_name)) + + # Run cuDNN + success, msg = run_cudnn_conv(config, X, W, Y, bias, cudnn_handle) + + if not success: + print(f"%%%% cuDNN execution failed: {msg}") + pytest.skip(f"cuDNN not supported: {msg}") + return + + # Compute reference and compare + ref = compute_reference(config, X, W, Y, bias) + + # Determine which tensor to compare based on conv_type + if config.conv_type == ConvType.FPROP: + actual, dtype, name = Y, config.y_dtype, "Y" + elif config.conv_type == ConvType.DGRAD: + actual, dtype, name = X, config.x_dtype, "dX" + else: # WGRAD + actual, dtype, name = W, config.w_dtype, "dW" + + passed, compare_msg = compare_results(actual, ref, dtype, num_diffs) + + if passed: + print(f"%%%% Numerical divergence of '{name}' within limits ({compare_msg})") + print("@@@@ Overall result: PASSED, everything looks good!") + else: + print(f"%%%% {compare_msg}") + print("@@@@ Overall result: FAILED, numerical mismatch!") + pytest.fail(f"Numerical mismatch: {compare_msg}") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del X, W, Y + if bias is not None: + del bias + if ref is not None: + del ref + torch.cuda.empty_cache() + + +@pytest.mark.L0 +def test_repro(cudnn_handle, num_diffs, request): + """Reproduce a specific test case from repro dict.""" + repro_str = request.config.getoption("--repro") + if repro_str is None: + pytest.skip("No --repro argument provided") + return + + import ast + repro = ast.literal_eval(repro_str) + + # Reconstruct config from repro dict + dtype_map = { + 'torch.float16': torch.float16, + 'torch.bfloat16': torch.bfloat16, + 'torch.float32': torch.float32, + 'torch.int8': torch.int8, + } + + config = ConvConfig( + spatial_dims=repro['spatial_dims'], + batch=repro['batch'], + in_channels=repro['in_channels'], + out_channels=repro['out_channels'], + input_spatial=tuple(repro['input_spatial']), + filter_spatial=tuple(repro['filter_spatial']), + padding=tuple(repro['padding']), + stride=tuple(repro['stride']), + dilation=tuple(repro['dilation']), + conv_type=ConvType(repro['conv_type']), + x_dtype=dtype_map[repro['x_dtype']], + w_dtype=dtype_map[repro['w_dtype']], + y_dtype=dtype_map[repro['y_dtype']], + epilogue=EpilogueType(repro['epilogue']), + rng_seed=repro['rng_seed'], + ) + + # Create tensors + rng = random.Random(config.rng_seed) + X, W, Y, bias = create_tensors(config, rng) + ref = None + + try: + # Print test header + print(format_test_header(config, 1, 1, "test_repro")) + + # Run cuDNN + success, msg = run_cudnn_conv(config, X, W, Y, bias, cudnn_handle) + + if not success: + print(f"%%%% cuDNN execution failed: {msg}") + pytest.fail(f"cuDNN failed: {msg}") + return + + # Compute reference and compare + ref = compute_reference(config, X, W, Y, bias) + + # Determine which tensor to compare based on conv_type + if config.conv_type == ConvType.FPROP: + actual, dtype, name = Y, config.y_dtype, "Y" + elif config.conv_type == ConvType.DGRAD: + actual, dtype, name = X, config.x_dtype, "dX" + else: # WGRAD + actual, dtype, name = W, config.w_dtype, "dW" + + passed, compare_msg = compare_results(actual, ref, dtype, num_diffs) + + if passed: + print(f"%%%% Numerical divergence of '{name}' within limits ({compare_msg})") + print("@@@@ Overall result: PASSED, everything looks good!") + else: + print(f"%%%% {compare_msg}") + print("@@@@ Overall result: FAILED, numerical mismatch!") + pytest.fail(f"Numerical mismatch: {compare_msg}") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del X, W, Y + if bias is not None: + del bias + if ref is not None: + del ref + torch.cuda.empty_cache() diff --git a/test/python/test_conv_genstats.py b/test/python/test_conv_genstats.py index 8ec60072..dd3d0a44 100644 --- a/test/python/test_conv_genstats.py +++ b/test/python/test_conv_genstats.py @@ -7,17 +7,11 @@ class Conv_Genstats(torch.nn.Module): - def forward( - self, scale, bias, x, w, padding=[1, 1], stride=[1, 1], dilation=[1, 1] - ): + def forward(self, scale, bias, x, w, padding=[1, 1], stride=[1, 1], dilation=[1, 1]): x_conv = torch.relu(x * scale + bias) - conv_output = torch.nn.functional.conv2d( - x_conv, w, padding=padding, stride=stride, dilation=dilation - ) + conv_output = torch.nn.functional.conv2d(x_conv, w, padding=padding, stride=stride, dilation=dilation) sum = torch.sum(conv_output, dim=(0, 2, 3), dtype=torch.float32) - sq_sum = torch.sum( - torch.square(conv_output), dim=(0, 2, 3), dtype=torch.float32 - ) + sq_sum = torch.sum(torch.square(conv_output), dim=(0, 2, 3), dtype=torch.float32) return conv_output, sum, sq_sum @@ -40,27 +34,11 @@ def forward( def test_conv_genstats(cudnn_handle): # Reference - X_gpu = torch.randn( - n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - W_gpu = torch.randn( - k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - scale = ( - torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - * 0.01 - ) - bias = ( - torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - * 0.01 - ) - Y_expected, sum_expected, sq_sum_expected = model( - scale, bias, X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation - ) + X_gpu = torch.randn(n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + W_gpu = torch.randn(k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + scale = torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) * 0.01 + bias = torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) * 0.01 + Y_expected, sum_expected, sq_sum_expected = model(scale, bias, X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -73,26 +51,16 @@ def test_conv_genstats(cudnn_handle): handle=cudnn_handle, ) - X = graph.tensor( - name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype - ) - W = graph.tensor( - name="W", dim=W_gpu.size(), stride=W_gpu.stride(), data_type=W_gpu.dtype - ) + X = graph.tensor(name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype) + W = graph.tensor(name="W", dim=W_gpu.size(), stride=W_gpu.stride(), data_type=W_gpu.dtype) - S = graph.tensor( - name="S", dim=scale.size(), stride=scale.stride(), data_type=scale.dtype - ) - B = graph.tensor( - name="B", dim=bias.size(), stride=bias.stride(), data_type=bias.dtype - ) + S = graph.tensor(name="S", dim=scale.size(), stride=scale.stride(), data_type=scale.dtype) + B = graph.tensor(name="B", dim=bias.size(), stride=bias.stride(), data_type=bias.dtype) S_OUT = graph.scale(name="scale", input=X, scale=S) B_OUT = graph.bias(name="bias", input=S_OUT, bias=B) CONV_IN = graph.relu(name="relu", input=B_OUT) - Y = graph.conv_fprop( - image=CONV_IN, weight=W, padding=padding, stride=stride, dilation=dilation - ) + Y = graph.conv_fprop(image=CONV_IN, weight=W, padding=padding, stride=stride, dilation=dilation) Y.set_output(True) SUM, SQ_SUM = graph.genstats(name="genstats", input=Y) @@ -116,9 +84,7 @@ def test_conv_genstats(cudnn_handle): Y_actual = torch.zeros_like(Y_expected) # Below tests capability to run with just device pointers - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute( { X: X_gpu.data_ptr(), diff --git a/test/python/test_conv_reduction.py b/test/python/test_conv_reduction.py index 30c555bd..71872f4b 100644 --- a/test/python/test_conv_reduction.py +++ b/test/python/test_conv_reduction.py @@ -23,9 +23,7 @@ def conv_reduce_cache_key(handle, X_gpu, W_gpu): @cudnn.graph_cache(key_fn=conv_reduce_cache_key) def create_conv_reduce_graph(handle, X_gpu, W_gpu): with cudnn.graph(handle) as (g, _): - print( - f"Creating graph with X_gpu shape: {X_gpu.shape} and W_gpu shape: {W_gpu.shape}" - ) + print(f"Creating graph with X_gpu shape: {X_gpu.shape} and W_gpu shape: {W_gpu.shape}") X = g.tensor_like(X_gpu) W = g.tensor_like(W_gpu) Y_conv = g.conv_fprop(X, W, padding=[1, 1], stride=[1, 1], dilation=[1, 1]) @@ -45,17 +43,11 @@ def test_reduction(cudnn_handle): padding = stride = dilation = [1, 1] # Reference - X_gpu = torch.randn(N, C, H, W, dtype=torch.float16, device="cuda").to( - memory_format=torch.channels_last - ) - W_gpu = torch.randn(K, C, R, S, dtype=torch.float16, device="cuda").to( - memory_format=torch.channels_last - ) + X_gpu = torch.randn(N, C, H, W, dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) + W_gpu = torch.randn(K, C, R, S, dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) # Perform convolution using FP32 computation while input and filter remain in FP16 - with torch.cuda.amp.autocast(dtype=torch.float32): - conv_output = torch.nn.functional.conv2d( - X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation - ) + with torch.amp.autocast("cuda", dtype=torch.float32): + conv_output = torch.nn.functional.conv2d(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) Y_expected = conv_output.sum(dim=1) stream = torch.cuda.current_stream().cuda_stream @@ -67,12 +59,8 @@ def test_reduction(cudnn_handle): X_uid, W_uid, Y_uid = uids - X_gpu_2 = torch.randn(N, C, H, W, dtype=torch.float16, device="cuda").to( - memory_format=torch.channels_last - ) - W_gpu_2 = torch.randn(K, C, R, S, dtype=torch.float16, device="cuda").to( - memory_format=torch.channels_last - ) + X_gpu_2 = torch.randn(N, C, H, W, dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) + W_gpu_2 = torch.randn(K, C, R, S, dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) g2, uids2 = create_conv_reduce_graph(cudnn_handle, X_gpu_2, W_gpu_2) @@ -88,9 +76,7 @@ def test_reduction(cudnn_handle): workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8) - g.execute( - {X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle - ) + g.execute({X_uid: X_gpu, W_uid: W_gpu, Y_uid: Y_actual}, workspace, handle=cudnn_handle) # g.execute( # {X_uid: X_gpu_2, W_uid: W_gpu_2, Y_uid: Y_actual_2}, workspace, handle=cudnn_handle diff --git a/test/python/test_deviceless_aot_compilation.py b/test/python/test_deviceless_aot_compilation.py index e65ba7cb..8e23d97b 100644 --- a/test/python/test_deviceless_aot_compilation.py +++ b/test/python/test_deviceless_aot_compilation.py @@ -54,9 +54,7 @@ def test_device_properties(): dim=[K, C, R, S], stride=[C * R * S, 1, C * S, C], ) - Y_tensor = graph.conv_fprop( - X_tensor, W_tensor, padding=padding, stride=stride, dilation=dilation - ) + Y_tensor = graph.conv_fprop(X_tensor, W_tensor, padding=padding, stride=stride, dilation=dilation) Y_tensor.set_output(True) graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) @@ -64,16 +62,10 @@ def test_device_properties(): # Step 3 # Compute reference - X_gpu = torch.randn(N, C, H, W, dtype=torch.float32, device="cuda").to( - memory_format=torch.channels_last - ) - W_gpu = torch.randn(K, C, R, S, dtype=torch.float32, device="cuda").to( - memory_format=torch.channels_last - ) + X_gpu = torch.randn(N, C, H, W, dtype=torch.float32, device="cuda").to(memory_format=torch.channels_last) + W_gpu = torch.randn(K, C, R, S, dtype=torch.float32, device="cuda").to(memory_format=torch.channels_last) with torch.amp.autocast(device_type="cuda", dtype=torch.float32): - Y_ref = torch.nn.functional.conv2d( - X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation - ) + Y_ref = torch.nn.functional.conv2d(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) # Create handle only when needed (for graph execution) cudnn_handle = cudnn.create_handle() @@ -86,9 +78,7 @@ def test_device_properties(): Y_actual = torch.zeros_like(Y_ref) - workspace = torch.empty( - graph_deserialized.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph_deserialized.get_workspace_size(), device="cuda", dtype=torch.uint8) graph_deserialized.execute( {X_tensor: X_gpu, W_tensor: W_gpu, Y_tensor: Y_actual}, diff --git a/test/python/test_flexible_sdpa.py b/test/python/test_flexible_sdpa.py index a022d487..8165641d 100644 --- a/test/python/test_flexible_sdpa.py +++ b/test/python/test_flexible_sdpa.py @@ -28,20 +28,11 @@ def create_container_and_page_table(tensor, block_size): # Create the page table table_size = math.ceil(S / block_size) - page_table_temp = torch.linspace( - 0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32 - ).reshape(table_size, 1, B, 1) + page_table_temp = torch.linspace(0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32).reshape(table_size, 1, B, 1) page_table_temp = torch.transpose(page_table_temp, 0, 2) # Make batch size outer dimension (cuDNN backend requirement) - page_table = ( - torch.randn(blocks_per_batch * B) - .int() - .cuda() - .as_strided( - (B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1) - ) - ) + page_table = torch.randn(blocks_per_batch * B).int().cuda().as_strided((B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1)) page_table.copy_(page_table_temp) return (container, page_table) @@ -88,9 +79,7 @@ def padding_mask(sdpa_graph, q_kt_tensor, seq_len_q, seq_len_kv, neg_inf): ) padding_mask.set_data_type(cudnn.data_type.BOOLEAN) - out = sdpa_graph.binary_select( - input0=q_kt_tensor, input1=neg_inf, mask=padding_mask, name="binary_select" - ) + out = sdpa_graph.binary_select(input0=q_kt_tensor, input1=neg_inf, mask=padding_mask, name="binary_select") return out @@ -106,9 +95,7 @@ def softcap(sdpa_graph, q_kt_tensor, softcap_tensor): return out -def decode_mask( - sdpa_graph, q_kt_tensor, seq_len_kv, seq_len_q, neg_inf, softcap_tensor -): +def decode_mask(sdpa_graph, q_kt_tensor, seq_len_kv, seq_len_q, neg_inf, softcap_tensor): softcap_out = softcap(sdpa_graph, q_kt_tensor, softcap_tensor) @@ -125,9 +112,7 @@ def causal_mask(sdpa_graph, q_kt_tensor, neg_inf): col_index = sdpa_graph.gen_index(input=q_kt_tensor, axis=3) col_index.set_data_type(cudnn.data_type.INT32) - mask = sdpa_graph.cmp_ge( - input=row_index, comparison=col_index, compute_data_type=cudnn.data_type.BOOLEAN - ) + mask = sdpa_graph.cmp_ge(input=row_index, comparison=col_index, compute_data_type=cudnn.data_type.BOOLEAN) mask.set_data_type(cudnn.data_type.BOOLEAN) out = sdpa_graph.binary_select(input0=q_kt_tensor, input1=neg_inf, mask=mask) @@ -136,18 +121,14 @@ def causal_mask(sdpa_graph, q_kt_tensor, neg_inf): def constant_bound_mask(score_mod_graph, index, bound): - is_less_than_bound = score_mod_graph.cmp_lt( - input=index, comparison=bound, compute_data_type=cudnn.data_type.BOOLEAN - ) + is_less_than_bound = score_mod_graph.cmp_lt(input=index, comparison=bound, compute_data_type=cudnn.data_type.BOOLEAN) is_less_than_bound.set_data_type(cudnn.data_type.INT32) return is_less_than_bound def diag_bound_mask(score_mod_graph, row_index, col_index, diag_bound_0, diag_bound_1): - row_minus_col = score_mod_graph.sub( - a=row_index, b=col_index, compute_data_type=cudnn.data_type.INT32 - ) + row_minus_col = score_mod_graph.sub(a=row_index, b=col_index, compute_data_type=cudnn.data_type.INT32) row_minus_col.set_data_type(cudnn.data_type.INT32) is_larger_or_equal_to_diag_bound_0 = score_mod_graph.cmp_ge( input=row_minus_col, @@ -191,9 +172,7 @@ def arrow_mask( is_less_than_row_bound = constant_bound_mask(score_mod_graph, row_index, row_bound) is_less_than_col_bound = constant_bound_mask(score_mod_graph, col_index, col_bound) - is_within_diag_bound = diag_bound_mask( - score_mod_graph, row_index, col_index, diag_bound_0, diag_bound_1 - ) + is_within_diag_bound = diag_bound_mask(score_mod_graph, row_index, col_index, diag_bound_0, diag_bound_1) mask = score_mod_graph.logical_or( is_less_than_row_bound, @@ -202,9 +181,7 @@ def arrow_mask( ) mask.set_data_type(cudnn.data_type.INT32) - mask = score_mod_graph.logical_or( - mask, is_within_diag_bound, compute_data_type=cudnn.data_type.BOOLEAN - ) + mask = score_mod_graph.logical_or(mask, is_within_diag_bound, compute_data_type=cudnn.data_type.BOOLEAN) mask.set_data_type(cudnn.data_type.INT32) out = score_mod_graph.binary_select(input0=q_kt_tensor, input1=neg_inf, mask=mask) @@ -255,24 +232,16 @@ def test_sdpa_with_flexible_graph(cudnn_handle): ) q = graph.tensor_like(q_gpu) - container_k_gpu, page_table_k_gpu = create_container_and_page_table( - k_gpu, block_size_k - ) - container_v_gpu, page_table_v_gpu = create_container_and_page_table( - v_gpu, block_size_v - ) + container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, block_size_k) + container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, block_size_v) container_k = graph.tensor_like(container_k_gpu) container_v = graph.tensor_like(container_v_gpu) page_table_k = graph.tensor_like(page_table_k_gpu) page_table_v = graph.tensor_like(page_table_v_gpu) - seq_len_q_gpu = torch.randint( - 1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" - ) - seq_len_kv_gpu = torch.randint( - 1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" - ) + seq_len_q_gpu = torch.randint(1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") + seq_len_kv_gpu = torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") seq_len_q = graph.tensor_like(seq_len_q_gpu) seq_len_kv = graph.tensor_like(seq_len_kv_gpu) @@ -344,9 +313,7 @@ def test_sdpa_with_flexible_graph(cudnn_handle): seq_len_kv: seq_len_kv_gpu, } - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace) torch.cuda.synchronize() @@ -393,9 +360,7 @@ def document_mask(sdpa_graph, q_kt_tensor, document_tensor, document_tensor_t, n ) document_mask.set_data_type(cudnn.data_type.INT32) - out = sdpa_graph.binary_select( - input0=q_kt_tensor, input1=neg_inf, mask=document_mask, name="binary_select" - ) + out = sdpa_graph.binary_select(input0=q_kt_tensor, input1=neg_inf, mask=document_mask, name="binary_select") return out @@ -528,9 +493,7 @@ def test_sdpa_with_arrow_mask(cudnn_handle): diag_bound_1: diag_bound_1_cpu, } - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace) torch.cuda.synchronize() @@ -562,15 +525,9 @@ def test_sdpa_with_document_mask(cudnn_handle): cudnn_version = LooseVersion(cudnn.backend_version_string()) if cudnn_version < "9.9.0": - pytest.skip( - "SDPA fprop with document style mask requires cudnn 9.9.0 or higher" - ) - - document_tensor_gpu = ( - torch.randint(0, s_q, (1, 1, s_q, 1), device="cuda", dtype=torch.int32) - .sort(dim=2) - .values - ) + pytest.skip("SDPA fprop with document style mask requires cudnn 9.9.0 or higher") + + document_tensor_gpu = torch.randint(0, s_q, (1, 1, s_q, 1), device="cuda", dtype=torch.int32).sort(dim=2).values document_tensor_gpu_t = document_tensor_gpu.reshape(1, 1, 1, s_q) graph = cudnn.pygraph( @@ -637,8 +594,6 @@ def test_sdpa_with_document_mask(cudnn_handle): neg_inf_tensor: neg_inf_tensor_cpu, } - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace) torch.cuda.synchronize() diff --git a/test/python/test_instancenorm.py b/test/python/test_instancenorm.py index 143710f7..05096880 100644 --- a/test/python/test_instancenorm.py +++ b/test/python/test_instancenorm.py @@ -42,15 +42,9 @@ def test_in(param_extract, cudnn_handle): epsilon_value = 1e-5 - x_gpu = torch.randn( - (N, C, H, W), requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) - scale_gpu = torch.randn( - (1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) - bias_gpu = torch.randn( - (1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) + x_gpu = torch.randn((N, C, H, W), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + scale_gpu = torch.randn((1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + bias_gpu = torch.randn((1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) epsilon_cpu = torch.full( (1, 1, 1, 1), epsilon_value, @@ -59,13 +53,9 @@ def test_in(param_extract, cudnn_handle): dtype=torch.float32, ) - Y_expected = torch.nn.functional.instance_norm( - x_gpu, weight=scale_gpu.view(C), bias=bias_gpu.view(C) - ) + Y_expected = torch.nn.functional.instance_norm(x_gpu, weight=scale_gpu.view(C), bias=bias_gpu.view(C)) mean_expected = x_gpu.to(torch.float32).mean(dim=(2, 3), keepdim=True) - inv_var_expected = torch.rsqrt( - torch.var(x_gpu.to(torch.float32), dim=(2, 3), keepdim=True) + epsilon_value - ) + inv_var_expected = torch.rsqrt(torch.var(x_gpu.to(torch.float32), dim=(2, 3), keepdim=True) + epsilon_value) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -109,9 +99,7 @@ def test_in(param_extract, cudnn_handle): mean_actual = torch.empty_like(mean_expected) inv_var_actual = torch.empty_like(inv_var_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute( { @@ -188,9 +176,7 @@ def test_in(param_extract, cudnn_handle): DScale_actual = torch.empty_like(scale_gpu) Dbias_actual = torch.empty_like(bias_gpu) - workspace = torch.empty( - bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) bwd_graph.execute( { diff --git a/test/python/test_kernel_cache.py b/test/python/test_kernel_cache.py index 0eb5dc8f..bbf910c9 100644 --- a/test/python/test_kernel_cache.py +++ b/test/python/test_kernel_cache.py @@ -108,9 +108,7 @@ def test_kernel_cache(cudnn_handle): dtype=torch.bfloat16, ) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) print("Executing", shape) graph.execute({0: A, 1: B, 2: C}, workspace, handle=cudnn_handle) @@ -168,9 +166,7 @@ def create_my_graph(kernel_cache): start_time = time.time() graph.build([cudnn.heur_mode.FALLBACK]) build_time_ms = (time.time() - start_time) * 1000 - assert ( - build_time_ms <= EXECUTION_TIME_LIMIT_MS - ), f"Graph build time {build_time_ms:.2f}ms exceeded limit of {EXECUTION_TIME_LIMIT_MS}ms" + assert build_time_ms <= EXECUTION_TIME_LIMIT_MS, f"Graph build time {build_time_ms:.2f}ms exceeded limit of {EXECUTION_TIME_LIMIT_MS}ms" @pytest.mark.skipif( @@ -242,14 +238,10 @@ def create_tensors(m, n, k): return A_gpu, B_gpu, C_expected, C_actual A_gpu, B_gpu, C_expected, C_actual = create_tensors(8, 64, 128) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute({0: A_gpu, 1: B_gpu, 2: C_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() - torch.testing.assert_close( - C_actual, C_expected, **global_assert_opts_defaults["default"] - ) + torch.testing.assert_close(C_actual, C_expected, **global_assert_opts_defaults["default"]) # try making a new one with the same kernel cache del graph @@ -259,15 +251,9 @@ def create_tensors(m, n, k): start_time = time.time() graph.build([cudnn.heur_mode.FALLBACK]) build_time_ms = (time.time() - start_time) * 1000 - assert ( - build_time_ms <= EXECUTION_TIME_LIMIT_MS - ), f"Graph build time {build_time_ms:.2f}ms exceeded limit of {EXECUTION_TIME_LIMIT_MS}ms" + assert build_time_ms <= EXECUTION_TIME_LIMIT_MS, f"Graph build time {build_time_ms:.2f}ms exceeded limit of {EXECUTION_TIME_LIMIT_MS}ms" A_gpu, B_gpu, C_expected, C_actual = create_tensors(8, 64, 128) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute({0: A_gpu, 1: B_gpu, 2: C_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() - torch.testing.assert_close( - C_actual, C_expected, **global_assert_opts_defaults["default"] - ) + torch.testing.assert_close(C_actual, C_expected, **global_assert_opts_defaults["default"]) diff --git a/test/python/test_layernorm.py b/test/python/test_layernorm.py index 78639eee..f5ee920c 100644 --- a/test/python/test_layernorm.py +++ b/test/python/test_layernorm.py @@ -9,9 +9,7 @@ embedding_dim_options = [768, 1024, 1280, 1600] input_type_options = [torch.bfloat16, torch.float16] -all_options = [ - elem for elem in itertools.product(*[embedding_dim_options, input_type_options]) -] +all_options = [elem for elem in itertools.product(*[embedding_dim_options, input_type_options])] @pytest.fixture(params=all_options) @@ -39,27 +37,9 @@ def test_layernorm(param_extract, cudnn_handle): epsilon_value = 1e-3 - x_gpu = ( - 3 - * torch.randn( - N, C, H, W, requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) - - 0.5 - ) - scale_gpu = ( - 5 - * torch.randn( - 1, C, H, W, requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) - - 1 - ) - bias_gpu = ( - 7 - * torch.randn( - 1, C, H, W, requires_grad=True, device="cuda", dtype=input_type - ).to(memory_format=torch.channels_last) - - 2 - ) + x_gpu = 3 * torch.randn(N, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - 0.5 + scale_gpu = 5 * torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - 1 + bias_gpu = 7 * torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - 2 epsilon_cpu = torch.full( (1, 1, 1, 1), epsilon_value, @@ -76,9 +56,7 @@ def test_layernorm(param_extract, cudnn_handle): eps=epsilon_value, ) mean_expected = x_gpu.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True) - inv_var_expected = torch.rsqrt( - torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value - ) + inv_var_expected = torch.rsqrt(torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -89,9 +67,7 @@ def test_layernorm(param_extract, cudnn_handle): handle=cudnn_handle, ) - X = graph.tensor( - name="X", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype - ) + X = graph.tensor(name="X", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype) scale = graph.tensor( name="scale", dim=scale_gpu.size(), @@ -141,9 +117,7 @@ def test_layernorm(param_extract, cudnn_handle): mean_actual = torch.empty_like(mean_expected) inv_var_actual = torch.empty_like(inv_var_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute( { @@ -179,9 +153,7 @@ def test_layernorm(param_extract, cudnn_handle): compute_data_type=cudnn.data_type.FLOAT, ) - DY = bwd_graph.tensor( - name="DY", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype - ) + DY = bwd_graph.tensor(name="DY", dim=x_gpu.size(), stride=x_gpu.stride(), data_type=x_gpu.dtype) X_bwd = bwd_graph.tensor_like(X, name="X") scale_bwd = bwd_graph.tensor_like(scale, name="scale") mean_bwd = bwd_graph.tensor_like(mean, name="mean") @@ -216,9 +188,7 @@ def test_layernorm(param_extract, cudnn_handle): DScale_actual = torch.empty_like(scale_gpu) Dbias_actual = torch.empty_like(bias_gpu) - workspace = torch.empty( - bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) bwd_graph.execute( { diff --git a/test/python/test_low_precision_matmul.py b/test/python/test_low_precision_matmul.py index 5dc9a272..0410009a 100644 --- a/test/python/test_low_precision_matmul.py +++ b/test/python/test_low_precision_matmul.py @@ -103,9 +103,7 @@ def _f32_to_floatx_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Te denorm_mask_int = denorm_exp << MBITS_F32 # reinterpret int32 as float32 - denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view( - torch.float32 - ) + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) # save the sign # Note that we have torch.uint32, but some ops like cpu bit shifts @@ -244,17 +242,13 @@ def _floatx_unpacked_to_f32(x: torch.Tensor, ebits: int, mbits: int) -> torch.Te # left shift mantissa until it overflows (create an implicit 1) # subtract exponent by the same amount left_shift = mbits - i - mantissa_f32 = (mantissa_cmp - (1 << i)) << ( - left_shift + MBITS_F32 - mbits - ) + mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits) exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 # we can update this in-place since the values won't overlap # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int' # thus we use + instead of | here - mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = ( - exp_biased_f32 + mantissa_f32 - ) + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32 result = torch.where(denormal_mask, mantissa_lp_int32, result) @@ -266,7 +260,7 @@ def _floatx_unpacked_to_f32(x: torch.Tensor, ebits: int, mbits: int) -> torch.Te def get_cc(): - (major, minor) = torch.cuda.get_device_capability() + major, minor = torch.cuda.get_device_capability() return major * 10 + minor @@ -279,9 +273,7 @@ def matmul_dequantize_cache_key(cudnn_handle, A, B, A_scale, B_scale, BLOCK_SIZE @cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.B]) @cudnn.graph_cache(key_fn=matmul_dequantize_cache_key) -def create_matmul_dequantize_graph( - cudnn_handle, A, B, A_descale, B_descale, BLOCK_SIZE -): +def create_matmul_dequantize_graph(cudnn_handle, A, B, A_descale, B_descale, BLOCK_SIZE): with cudnn.graph(cudnn_handle) as (g, _): @@ -322,12 +314,8 @@ def create_matmul_dequantize_graph( reordering_type=cudnn.tensor_reordering.F8_128x4, ) - after_descale_a = g.block_scale_dequantize( - A_cudnn_tensor, A_descale_tensor, block_size=[1, BLOCK_SIZE] - ) - after_descale_b = g.block_scale_dequantize( - B_cudnn_tensor, B_descale_tensor, block_size=[BLOCK_SIZE, 1] - ) + after_descale_a = g.block_scale_dequantize(A_cudnn_tensor, A_descale_tensor, block_size=[1, BLOCK_SIZE]) + after_descale_b = g.block_scale_dequantize(B_cudnn_tensor, B_descale_tensor, block_size=[BLOCK_SIZE, 1]) C = g.matmul( after_descale_a, @@ -366,9 +354,7 @@ def unpack_uint4(packed_data): packed_uint8 = packed_data.view(torch.uint8).contiguous().view(-1) # Create unpacked array - unpacked = torch.zeros( - packed_uint8.shape[0] * 2, dtype=torch.uint8, device=packed_data.device - ) + unpacked = torch.zeros(packed_uint8.shape[0] * 2, dtype=torch.uint8, device=packed_data.device) # Extract lower and upper 4 bits unpacked[::2] = packed_uint8 & 0x0F # Lower 4 bits @@ -427,16 +413,10 @@ def test_low_precision_fp4_matmul(cudnn_handle): A = _bfloat16_to_float4_e2m1fn_x2(A_ref) B = _bfloat16_to_float4_e2m1fn_x2(B_ref) - A_descale = torch.full( - (batch_size, M, K), 1.0, dtype=torch.float8_e4m3fn, device="cuda" - ) - B_descale = torch.full( - (batch_size, K, N), 1.0, device="cuda", dtype=torch.float8_e4m3fn - ) + A_descale = torch.full((batch_size, M, K), 1.0, dtype=torch.float8_e4m3fn, device="cuda") + B_descale = torch.full((batch_size, K, N), 1.0, device="cuda", dtype=torch.float8_e4m3fn) - g, uids = create_matmul_dequantize_graph( - cudnn_handle, A, B, A_descale, B_descale, BLOCK_SIZE - ) + g, uids = create_matmul_dequantize_graph(cudnn_handle, A, B, A_descale, B_descale, BLOCK_SIZE) A_uid, B_uid, A_descale_uid, B_descale_uid, C_uid = uids diff --git a/test/python/test_matmul_bias_relu.py b/test/python/test_matmul_bias_relu.py index d5c77b74..2bc00076 100644 --- a/test/python/test_matmul_bias_relu.py +++ b/test/python/test_matmul_bias_relu.py @@ -25,7 +25,7 @@ def convert_to_cudnn_type(torch_type): def get_cc(): - (major, minor) = torch.cuda.get_device_capability() + major, minor = torch.cuda.get_device_capability() return major * 10 + minor @@ -33,9 +33,7 @@ def get_cc(): LooseVersion(cudnn.backend_version_string()) < "8.9.6", reason="requires cudnn 8.9.6 or higher", ) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch" -) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch") @pytest.mark.L0 @torch_fork_set_rng(seed=0) def test_int8_bf16_matmul(cudnn_handle): @@ -44,17 +42,8 @@ def test_int8_bf16_matmul(cudnn_handle): B, M, N, K = 16, 32, 64, 128 # Initialize input tensors - A_gpu = ( - torch.randint( - 3, (B, M, K), requires_grad=False, device="cuda", dtype=torch.int8 - ) - - 2 - ) - B_gpu = ( - 3 - * torch.randn(B, K, N, requires_grad=False, device="cuda", dtype=torch.bfloat16) - - 1.25 - ) + A_gpu = torch.randint(3, (B, M, K), requires_grad=False, device="cuda", dtype=torch.int8) - 2 + B_gpu = 3 * torch.randn(B, K, N, requires_grad=False, device="cuda", dtype=torch.bfloat16) - 1.25 stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -71,9 +60,7 @@ def test_int8_bf16_matmul(cudnn_handle): A_casted = graph.identity(input=A, compute_data_type=cudnn.data_type.FLOAT) A_casted.set_data_type(cudnn.data_type.BFLOAT16) - C = graph.matmul( - name="matmul", A=A_casted, B=B, compute_data_type=cudnn.data_type.FLOAT - ) + C = graph.matmul(name="matmul", A=A_casted, B=B, compute_data_type=cudnn.data_type.FLOAT) C.set_output(True).set_data_type(cudnn.data_type.BFLOAT16) graph.validate() @@ -93,9 +80,7 @@ def test_int8_bf16_matmul(cudnn_handle): # Run cudnn graph C_actual = torch.zeros_like(C_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute({A: A_gpu, B: B_gpu, C: C_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() @@ -112,9 +97,7 @@ def test_int8_bf16_matmul(cudnn_handle): LooseVersion(cudnn.backend_version_string()) < "8.9.6", reason="requires cudnn 8.9.6 or higher", ) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch" -) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch") @pytest.mark.parametrize("A_data_type", A_data_type_options) @pytest.mark.parametrize("B_data_type", B_data_type_options) @pytest.mark.parametrize("MMA_data_type", MMA_data_type_options) @@ -127,36 +110,14 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type, cudnn_h # Initialize input tensors if A_data_type != torch.int8: - A_gpu = ( - 2 - * torch.randn( - B, M, K, requires_grad=False, device="cuda", dtype=A_data_type - ) - - 0.5 - ) + A_gpu = 2 * torch.randn(B, M, K, requires_grad=False, device="cuda", dtype=A_data_type) - 0.5 else: - A_gpu = ( - torch.randint( - 4, (B, M, K), requires_grad=False, device="cuda", dtype=A_data_type - ) - - 1 - ) + A_gpu = torch.randint(4, (B, M, K), requires_grad=False, device="cuda", dtype=A_data_type) - 1 if B_data_type != torch.int8: - B_gpu_strided = ( - 3 - * torch.randn( - B, K, N, requires_grad=False, device="cuda", dtype=B_data_type - ) - - 1.25 - ) + B_gpu_strided = 3 * torch.randn(B, K, N, requires_grad=False, device="cuda", dtype=B_data_type) - 1.25 else: - B_gpu_strided = ( - torch.randint( - 3, (B, K, N), requires_grad=False, device="cuda", dtype=B_data_type - ).contiguous() - - 2 - ) + B_gpu_strided = torch.randint(3, (B, K, N), requires_grad=False, device="cuda", dtype=B_data_type).contiguous() - 2 B_gpu = torch.as_strided(B_gpu_strided, (B, K, N), (N * K, 1, N)) @@ -176,10 +137,7 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type, cudnn_h A_casted.set_data_type(convert_to_cudnn_type(MMA_data_type)) # Casting input tensor B is only supported from cudnn v9 - if ( - B_data_type != MMA_data_type - and LooseVersion(cudnn.backend_version_string()) < "9" - ): + if B_data_type != MMA_data_type and LooseVersion(cudnn.backend_version_string()) < "9": pytest.skip("mixed precision on B only supported from cudnn v9.") if LooseVersion(cudnn.backend_version_string()) < "9": @@ -192,9 +150,7 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type, cudnn_h # CAUTION: Hardcodes to fp32 as tests today dont cover inputs that are casted to ints. # In case your usecase does cast inputs to int8, use int32 as compute type here. - C = graph.matmul( - name="matmul", A=A_casted, B=B_casted, compute_data_type=cudnn.data_type.FLOAT - ) + C = graph.matmul(name="matmul", A=A_casted, B=B_casted, compute_data_type=cudnn.data_type.FLOAT) C.set_output(True).set_data_type(convert_to_cudnn_type(MMA_data_type)) graph.validate() @@ -214,9 +170,7 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type, cudnn_h # Run cudnn graph C_actual = torch.zeros_like(C_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute({A: A_gpu, B: B_gpu, C: C_actual}, workspace, handle=cudnn_handle) torch.cuda.synchronize() @@ -227,9 +181,7 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type, cudnn_h problem_size_options = [(1, 128, 768), (16, 512, 1600), (1, 128, 1024)] input_type_options = [torch.bfloat16, torch.float16] -all_options = [ - elem for elem in itertools.product(*[problem_size_options, input_type_options]) -] +all_options = [elem for elem in itertools.product(*[problem_size_options, input_type_options])] @pytest.fixture(params=all_options) @@ -252,15 +204,9 @@ def test_matmul_bias(param_extract, cudnn_handle): pytest.skip("matmul broadcast on ampere with 8.9.6 is not supported.") X_gpu = torch.randn(b, s, e, requires_grad=False, device="cuda", dtype=input_type) - W_gpu = torch.randn( - 1, e, e * 4, requires_grad=False, device="cuda", dtype=input_type - ) - B_gpu = torch.randn( - 1, 1, e * 4, requires_grad=False, device="cuda", dtype=input_type - ) - Y_expected = torch.nn.functional.linear( - X_gpu, W_gpu.squeeze().T, bias=B_gpu.squeeze() - ) + W_gpu = torch.randn(1, e, e * 4, requires_grad=False, device="cuda", dtype=input_type) + B_gpu = torch.randn(1, 1, e * 4, requires_grad=False, device="cuda", dtype=input_type) + Y_expected = torch.nn.functional.linear(X_gpu, W_gpu.squeeze().T, bias=B_gpu.squeeze()) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -309,15 +255,11 @@ def test_matmul_bias(param_extract, cudnn_handle): notes = graph.get_behavior_notes() assert cudnn.behavior_note.RUNTIME_COMPILATION in notes - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) Y_actual = torch.zeros_like(Y_expected) - graph.execute( - {X: X_gpu, W: W_gpu, B: B_gpu, Y: Y_actual}, workspace, handle=cudnn_handle - ) + graph.execute({X: X_gpu, W: W_gpu, B: B_gpu, Y: Y_actual}, workspace, handle=cudnn_handle) atol = 0.0625 if get_cc() == 89 else 1e-3 rtol = 1e-2 if input_type == torch.bfloat16 else 1e-3 diff --git a/test/python/test_matmul_fuzzer.py b/test/python/test_matmul_fuzzer.py new file mode 100644 index 00000000..a4133c34 --- /dev/null +++ b/test/python/test_matmul_fuzzer.py @@ -0,0 +1,1046 @@ +""" +Matmul Fuzzer - Randomized stress testing for cuDNN matmul operations. + +This fuzzer tests matmul operations with randomized: +- Shapes (batch, M, N, K dimensions) +- Layouts (row-major, column-major, transposed, strided) +- Data types (fp16, bf16, fp32, int8) +- Epilogues (none, bias, relu, gelu) + +Run with: + pytest -vv -s -rA test_matmul_fuzzer.py + +Options: + --num-tests N Number of random tests to run (default: 100) + --seed N Random seed for reproducibility (default: random) + --diffs N Number of mismatches to display (default: 10) +""" + +import cudnn +import pytest +import random +import torch +import math +import os +import sys +import signal +from looseversion import LooseVersion +from datetime import datetime +from dataclasses import dataclass, asdict +from typing import Optional, Tuple, List +from enum import IntEnum + +# fmt: off + +# Handle Ctrl-C gracefully +def signal_handler(sig, frame): + print("\n\nInterrupted by user (Ctrl-C), exiting...") + # Force CUDA to sync and cleanup + if torch.cuda.is_available(): + torch.cuda.synchronize() + sys.exit(1) + +signal.signal(signal.SIGINT, signal_handler) + +if __name__ == "__main__": + print("This is pytest script. Run with: pytest -vv -s -rA test_matmul_fuzzer.py") + sys.exit(0) + + +# ============================================================================ +# Configuration and Constants +# ============================================================================ + +class LayoutType(IntEnum): + ROW_MAJOR_PACKED = 0 # Standard row-major packed (strides: [..., N, 1]) + COL_MAJOR_PACKED = 1 # Column-major packed (strides: [..., 1, M]) + STRIDED = 2 # Custom strided layout with gaps + +class EpilogueType(IntEnum): + NONE = 0 + BIAS = 1 + RELU = 2 + BIAS_RELU = 3 + GELU = 4 + BIAS_GELU = 5 + +SUPPORTED_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.int8, +] + +# Compute precisions +COMPUTE_DTYPES = [ + cudnn.data_type.FLOAT, + cudnn.data_type.HALF, + cudnn.data_type.BFLOAT16, +] + + +# ============================================================================ +# Utility Functions +# ============================================================================ + +def convert_to_cudnn_type(torch_type): + """Convert PyTorch dtype to cuDNN data type.""" + mapping = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + torch.float32: cudnn.data_type.FLOAT, + torch.bool: cudnn.data_type.BOOLEAN, + torch.uint8: cudnn.data_type.UINT8, + torch.int8: cudnn.data_type.INT8, + torch.int32: cudnn.data_type.INT32, + torch.int64: cudnn.data_type.INT64, + } + if torch_type not in mapping: + raise ValueError(f"Unsupported tensor data type: {torch_type}") + return mapping[torch_type] + + +def get_gpu_arch(): + """Get GPU SM architecture version.""" + major, minor = torch.cuda.get_device_capability() + return f"SM_{major * 10 + minor}" + + +def get_sm_count(): + """Get number of SMs on the GPU.""" + props = torch.cuda.get_device_properties(0) + return props.multi_processor_count + + +def get_gpu_name(): + """Get GPU name.""" + return torch.cuda.get_device_name() + + +def layout_name(layout: LayoutType) -> str: + """Get human-readable layout name.""" + names = { + LayoutType.ROW_MAJOR_PACKED: "row_major_packed", + LayoutType.COL_MAJOR_PACKED: "col_major_packed", + LayoutType.STRIDED: "strided", + } + return names.get(layout, "unknown") + + +def epilogue_name(epilogue: EpilogueType) -> str: + """Get human-readable epilogue name.""" + names = { + EpilogueType.NONE: "none", + EpilogueType.BIAS: "bias", + EpilogueType.RELU: "relu", + EpilogueType.BIAS_RELU: "bias_relu", + EpilogueType.GELU: "gelu", + EpilogueType.BIAS_GELU: "bias_gelu", + } + return names.get(epilogue, "unknown") + + +def compute_strides(shape: Tuple[int, ...], layout: LayoutType, rng: random.Random) -> Tuple[int, ...]: + """Compute strides for a given shape and layout.""" + ndim = len(shape) + + if layout == LayoutType.ROW_MAJOR_PACKED: + # Standard row-major: last dim has stride 1 + strides = [] + stride = 1 + for dim in reversed(shape): + strides.insert(0, stride) + stride *= dim + return tuple(strides) + + elif layout == LayoutType.COL_MAJOR_PACKED: + # Column-major for the last two dimensions + if ndim < 2: + return (1,) + strides = [1] * ndim + # Last two dims are transposed + strides[-1] = shape[-2] # N stride = M + strides[-2] = 1 # M stride = 1 + # Batch dimensions + stride = shape[-1] * shape[-2] + for i in range(ndim - 3, -1, -1): + strides[i] = stride + stride *= shape[i] + return tuple(strides) + + elif layout == LayoutType.STRIDED: + # Random strided layout with potential gaps + strides = [] + stride = 1 + for dim in reversed(shape): + # Add random padding (1-4x the minimum stride) + padding_factor = rng.choice([1, 1, 1, 2, 2, 4]) + strides.insert(0, stride) + stride *= dim * padding_factor + return tuple(strides) + + return tuple([1] * ndim) + + +def compute_num_elements(shape: Tuple[int, ...], strides: Tuple[int, ...]) -> int: + """Compute number of elements needed for storage given shape and strides.""" + if not shape: + return 1 + max_offset = sum((d - 1) * s for d, s in zip(shape, strides)) + return max_offset + 1 + + +def fill_with_garbage(tensor: torch.Tensor, nan_probability: float = 0.1) -> None: + """ + Fill tensor with garbage values (mix of random values and NaNs). + This helps catch bugs where cuDNN doesn't write all output locations. + """ + # Choose range based on dtype to avoid overflow + if tensor.dtype in (torch.float16, torch.bfloat16): + lo, hi = -1e4, 1e4 # FP16 max is ~65504 + else: + lo, hi = -1e6, 1e6 + + # Fill with random garbage + tensor.uniform_(lo, hi) + + # Sprinkle in some NaNs (only for float types) + if nan_probability > 0 and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + nan_mask = torch.rand(tensor.shape, device=tensor.device) < nan_probability + tensor[nan_mask] = float('nan') + + +# ============================================================================ +# Test Configuration +# ============================================================================ + +@dataclass +class MatmulConfig: + """Configuration for a single matmul test.""" + # Basic dimensions + batch: int + M: int + N: int + K: int + + # Data types + a_dtype: torch.dtype + b_dtype: torch.dtype + c_dtype: torch.dtype + compute_dtype: cudnn.data_type + + # Layouts + a_layout: LayoutType + b_layout: LayoutType + c_layout: LayoutType + + # Transpose flags + a_transposed: bool + b_transposed: bool + + # Epilogue + epilogue: EpilogueType + + # Random seed for data generation + rng_seed: int + + # Computed strides (set during tensor creation) + a_shape: Tuple[int, ...] = None + b_shape: Tuple[int, ...] = None + c_shape: Tuple[int, ...] = None + a_strides: Tuple[int, ...] = None + b_strides: Tuple[int, ...] = None + c_strides: Tuple[int, ...] = None + a_elems: int = 0 + b_elems: int = 0 + c_elems: int = 0 + + # Bias tensor info (set when epilogue uses bias) + bias_shape: Tuple[int, ...] = None + bias_strides: Tuple[int, ...] = None + bias_elems: int = 0 + + def to_repro_dict(self) -> dict: + """Convert config to reproducible dictionary.""" + return { + 'batch': self.batch, + 'M': self.M, + 'N': self.N, + 'K': self.K, + 'a_dtype': str(self.a_dtype), + 'b_dtype': str(self.b_dtype), + 'c_dtype': str(self.c_dtype), + 'epilogue': int(self.epilogue), + 'rng_seed': self.rng_seed, + } + + +class ConfigGenerator: + """Generator for random matmul configurations.""" + + def __init__(self, seed: int, allow_unaligned: bool = False): + self.rng = random.Random(seed) + self.allow_unaligned = allow_unaligned + + def random_batch(self) -> int: + """Generate random batch size (no alignment requirement).""" + return self.rng.choice([1, 1, 2, 3, 4, 5, 7, 8, 16, 32]) + + def random_dim(self, min_val: int = 1, max_val: int = 4096) -> int: + """Generate random dimension size (M, N, or K).""" + val = self.rng.randint(int(math.sqrt(min_val)), int(math.sqrt(max_val))) + if self.allow_unaligned: + # Allow non-aligned sizes for stress testing + return val * val + else: + # Default: round up to next multiple of 8 for tensor core alignment + return ((val * val + 7) // 8) * 8 + + def random_dtype(self) -> torch.dtype: + """Generate random data type.""" + return self.rng.choice(SUPPORTED_DTYPES) + + def random_layout(self) -> LayoutType: + """Generate random layout type.""" + # Prefer row-major but test others too + weights = [0.6, 0.2, 0.2] + return self.rng.choices(list(LayoutType), weights=weights)[0] + + def random_epilogue(self) -> EpilogueType: + """Generate random epilogue type.""" + # Most tests without epilogue + weights = [0.5, 0.15, 0.1, 0.1, 0.075, 0.075] + return self.rng.choices(list(EpilogueType), weights=weights)[0] + + def random_compute_dtype(self) -> cudnn.data_type: + """Generate random compute data type.""" + return self.rng.choice([cudnn.data_type.FLOAT]) # Float is most compatible + + def generate(self) -> MatmulConfig: + """Generate a random matmul configuration.""" + batch = self.random_batch() + M = self.random_dim() + N = self.random_dim() + K = self.random_dim() + + # Data types - ensure compatible combinations + if self.rng.random() < 0.8: + # 80% of tests use same dtype for A and B (more stable) + a_dtype = self.random_dtype() + b_dtype = a_dtype + else: + # 20% test mixed precision + a_dtype = self.random_dtype() + b_dtype = self.random_dtype() + + # Output dtype should be float type (not int8) + if a_dtype == torch.float32 or b_dtype == torch.float32: + c_dtype = torch.float32 + elif a_dtype == torch.int8 or b_dtype == torch.int8: + # int8 inputs typically output to float32 + c_dtype = torch.float32 + else: + c_dtype = self.rng.choice([a_dtype, torch.float32]) + + # Layout selection - prefer row-major for stability + a_layout = self.rng.choices( + [LayoutType.ROW_MAJOR_PACKED, LayoutType.COL_MAJOR_PACKED, LayoutType.STRIDED], + weights=[0.7, 0.15, 0.15] + )[0] + b_layout = self.rng.choices( + [LayoutType.ROW_MAJOR_PACKED, LayoutType.COL_MAJOR_PACKED, LayoutType.STRIDED], + weights=[0.7, 0.15, 0.15] + )[0] + + # Transpose flags - disabled for now, using layouts instead for variety + # cuDNN matmul expects specific input layouts, transpose requires extra handling + a_transposed = False + b_transposed = False + + # Epilogue selection + epilogue = self.random_epilogue() + + config = MatmulConfig( + batch=batch, + M=M, + N=N, + K=K, + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=c_dtype, + compute_dtype=self.random_compute_dtype(), + a_layout=a_layout, + b_layout=b_layout, + c_layout=LayoutType.ROW_MAJOR_PACKED, # Output usually row-major + a_transposed=a_transposed, + b_transposed=b_transposed, + epilogue=epilogue, + rng_seed=self.rng.randint(0, 2**31 - 1), + ) + + return config + + +# ============================================================================ +# Test Execution +# ============================================================================ + +def create_tensors(config: MatmulConfig, rng: random.Random): + """Create input and output tensors based on configuration.""" + torch_rng = torch.Generator(device='cuda') + torch_rng.manual_seed(config.rng_seed) + + # Compute shapes for matmul: C = A @ B + # A: (batch, M, K) + # B: (batch, K, N) + # C: (batch, M, N) + # + # Transpose flags affect storage layout, not logical dimensions: + # - If a_transposed: A stored as (batch, K, M), transposed for matmul to (batch, M, K) + # - If b_transposed: B stored as (batch, N, K), transposed for matmul to (batch, K, N) + # + # For simplicity in this fuzzer, we don't use transpose - just vary layouts instead + + a_shape = (config.batch, config.M, config.K) + b_shape = (config.batch, config.K, config.N) + c_shape = (config.batch, config.M, config.N) + + # Compute strides + a_strides = compute_strides(a_shape, config.a_layout, rng) + b_strides = compute_strides(b_shape, config.b_layout, rng) + c_strides = compute_strides(c_shape, config.c_layout, rng) + + # Compute number of elements + a_elems = compute_num_elements(a_shape, a_strides) + b_elems = compute_num_elements(b_shape, b_strides) + c_elems = compute_num_elements(c_shape, c_strides) + + # Update config with computed values + config.a_shape = a_shape + config.b_shape = b_shape + config.c_shape = c_shape + config.a_strides = a_strides + config.b_strides = b_strides + config.c_strides = c_strides + config.a_elems = a_elems + config.b_elems = b_elems + config.c_elems = c_elems + + # Create tensors + if config.a_dtype == torch.int8: + a_storage = torch.randint(-2, 3, (a_elems,), device='cuda', dtype=torch.int8) + else: + a_storage = torch.empty(a_elems, device='cuda', dtype=config.a_dtype) + a_storage.normal_(mean=0.5, std=0.5, generator=torch_rng) + + if config.b_dtype == torch.int8: + b_storage = torch.randint(-2, 3, (b_elems,), device='cuda', dtype=torch.int8) + else: + b_storage = torch.empty(b_elems, device='cuda', dtype=config.b_dtype) + b_storage.normal_(mean=0.5, std=0.5,generator=torch_rng) + + # Create strided views + A = torch.as_strided(a_storage, a_shape, a_strides) + B = torch.as_strided(b_storage, b_shape, b_strides) + + # Output tensor - fill with garbage to catch bugs where cuDNN doesn't write all outputs + c_storage = torch.empty(c_elems, device='cuda', dtype=config.c_dtype) + fill_with_garbage(c_storage) + C = torch.as_strided(c_storage, c_shape, c_strides) + + # Bias tensor if needed + bias = None + if config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU, EpilogueType.BIAS_GELU]: + # Randomize bias shape: each dim can be 1 (broadcast) or match C + bias_rng = random.Random(config.rng_seed + 1) # Different seed for bias shape + bias_b = bias_rng.choice([1, config.batch]) + bias_m = bias_rng.choice([1, config.M]) + bias_n = bias_rng.choice([1, config.N]) + bias_shape = (bias_b, bias_m, bias_n) + bias_strides = compute_strides(bias_shape, LayoutType.ROW_MAJOR_PACKED, bias_rng) + bias_elems = compute_num_elements(bias_shape, bias_strides) + + config.bias_shape = bias_shape + config.bias_strides = bias_strides + config.bias_elems = bias_elems + + bias = torch.empty(bias_elems, device='cuda', dtype=config.c_dtype) + bias.normal_(mean=0.5, std=0.5, generator=torch_rng) + bias = torch.as_strided(bias, bias_shape, bias_strides) + + return A, B, C, bias + + +def compute_reference(config: MatmulConfig, A: torch.Tensor, B: torch.Tensor, bias: Optional[torch.Tensor]): + """Compute reference result using PyTorch.""" + # Convert to float for computation + compute_dtype = torch.float32 + + A_compute = A.to(compute_dtype) + B_compute = B.to(compute_dtype) + + try: + # Matmul: C = A @ B + # A: (batch, M, K), B: (batch, K, N), C: (batch, M, N) + C_ref = torch.matmul(A_compute, B_compute) + + # Epilogue + if bias is not None and config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU, EpilogueType.BIAS_GELU]: + C_ref = C_ref + bias.to(compute_dtype) + + if config.epilogue in [EpilogueType.RELU, EpilogueType.BIAS_RELU]: + C_ref = torch.relu(C_ref) + elif config.epilogue in [EpilogueType.GELU, EpilogueType.BIAS_GELU]: + C_ref = torch.nn.functional.gelu(C_ref) + + return C_ref.to(config.c_dtype) + finally: + del A_compute, B_compute + + +def run_cudnn_matmul(config: MatmulConfig, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + bias: Optional[torch.Tensor], cudnn_handle) -> Tuple[bool, str]: + """Run matmul using cuDNN and return success status and message.""" + try: + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + + # Create graph + graph = cudnn.pygraph( + handle=cudnn_handle, + compute_data_type=config.compute_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + ) + + # Create input tensors + A_tensor = graph.tensor_like(A) + B_tensor = graph.tensor_like(B) + + # Handle data type casting if needed + mma_dtype = convert_to_cudnn_type(config.c_dtype) + + if config.a_dtype != config.c_dtype: + A_casted = graph.identity(input=A_tensor, compute_data_type=cudnn.data_type.FLOAT) + A_casted.set_data_type(mma_dtype) + else: + A_casted = A_tensor + + if config.b_dtype != config.c_dtype: + B_casted = graph.identity(input=B_tensor, compute_data_type=cudnn.data_type.FLOAT) + B_casted.set_data_type(mma_dtype) + else: + B_casted = B_tensor + + # Matmul + result = graph.matmul(name="matmul", A=A_casted, B=B_casted) + + # Epilogue + if bias is not None and config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU, EpilogueType.BIAS_GELU]: + bias_tensor = graph.tensor_like(bias) + result = graph.bias(name="bias", input=result, bias=bias_tensor) + else: + bias_tensor = None + + if config.epilogue in [EpilogueType.RELU, EpilogueType.BIAS_RELU]: + result = graph.relu(name="relu", input=result) + elif config.epilogue in [EpilogueType.GELU, EpilogueType.BIAS_GELU]: + result = graph.gelu(name="gelu", input=result) + + result.set_output(True).set_data_type(mma_dtype) + + # Build and execute + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + + # Allocate workspace and fill with garbage to catch uninitialized memory bugs + workspace_size = graph.get_workspace_size() + workspace = torch.empty(workspace_size, device='cuda', dtype=torch.uint8) + if workspace_size > 0: + # Fill with random garbage + some NaN patterns to test proper workspace init + workspace.random_(0, 256) + nan_mask = torch.rand(workspace_size, device='cuda') < 0.1 + workspace[nan_mask] = 0xFF + + # Build variant pack + variant_pack = {A_tensor: A, B_tensor: B, result: C} + if bias_tensor is not None: + variant_pack[bias_tensor] = bias + + # Execute + graph.execute(variant_pack, workspace, handle=cudnn_handle) + torch.cuda.synchronize() + + return True, "success" + + except cudnn.cudnnGraphNotSupportedError as e: + return False, f"graph not supported: {e}" + except Exception as e: + return False, f"error: {e}" + + +def compare_results(C_actual: torch.Tensor, C_expected: torch.Tensor, config: MatmulConfig, + max_diffs: int = 10) -> Tuple[bool, int, str]: + """Compare actual and expected results.""" + # Determine tolerances based on dtype + # Note: cuDNN uses TF32 for FP32 tensor core ops, which has same precision as FP16 (10-bit mantissa) + if config.c_dtype == torch.float32: + rtol, atol = 1e-2, 2e-2 # TF32 precision, not full FP32 + elif config.c_dtype == torch.float16: + rtol, atol = 1e-2, 2e-2 + else: # bfloat16 + rtol, atol = 1e-2, 2e-2 + + # Scale tolerances based on problem size (larger problems accumulate more error) + # Use max(1.0, ...) to only increase tolerances for large K, never decrease + scale_factor = max(1.0, math.sqrt(config.K / 128.0)) + rtol *= scale_factor + atol *= scale_factor + + try: + torch.testing.assert_close(C_actual, C_expected, rtol=rtol, atol=atol) + return True, 0, f"Numerical divergence within limits (rtol={rtol:.2e}, atol={atol:.2e})" + except AssertionError: + # Count mismatches + close_mask = torch.isclose(C_actual.float(), C_expected.float(), rtol=rtol, atol=atol) + mismatch_count = (~close_mask).sum().item() + total_elements = C_actual.numel() + percentage = 100.0 * mismatch_count / total_elements + + msg = f"Found {mismatch_count:,} mismatches ({percentage:.2f}%) out of {total_elements:,} elements" + + # Show some mismatches + if max_diffs > 0: + mismatches = torch.where(~close_mask) + for i in range(min(max_diffs, mismatch_count)): + idx = tuple(m[i].item() for m in mismatches) + actual = C_actual[idx].item() + expected = C_expected[idx].item() + diff = actual - expected + msg += f"\n idx{idx}: actual={actual:+.6e}, expected={expected:+.6e}, diff={diff:+.2e}" + + return False, mismatch_count, msg + + +# ============================================================================ +# Test Output Formatting +# ============================================================================ + +def format_test_header(test_num: int, total_tests: int, config: MatmulConfig) -> str: + """Format test header similar to sample log.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + gpu_arch = get_gpu_arch() + gpu_name = get_gpu_name() + sm_count = get_sm_count() + cudnn_ver = cudnn.backend_version() + + lines = [ + "", + "=" * 90, + f"#### Test #{test_num} of {total_tests} at {timestamp} ", + "", + f"test_name = test_matmul_fuzzer[test{test_num}]", + f"platform_info = {gpu_arch} ({sm_count} SM-s, {gpu_name}), cudnn_ver={cudnn_ver}", + f"rng_data_seed = {config.rng_seed}", + f"basic_dims = [b={config.batch}, M={config.M}, N={config.N}, K={config.K}]", + f"matrix_a(b,m,k) = dim={config.a_shape}, strides={config.a_strides}, elems={config.a_elems}, type={config.a_dtype}", + f"matrix_b(b,k,n) = dim={config.b_shape}, strides={config.b_strides}, elems={config.b_elems}, type={config.b_dtype}", + f"matrix_c(b,m,n) = dim={config.c_shape}, strides={config.c_strides}, elems={config.c_elems}, type={config.c_dtype}", + ] + + # Add bias info if epilogue uses bias + if config.epilogue in [EpilogueType.BIAS, EpilogueType.BIAS_RELU, EpilogueType.BIAS_GELU] and config.bias_shape: + lines.append(f"bias(b,m,n) = dim={config.bias_shape}, strides={config.bias_strides}, elems={config.bias_elems}, type={config.c_dtype}") + + lines += [ + f"epilogue = {epilogue_name(config.epilogue)}", + f"repro_cmd = pytest -vv -s -rA {__file__}::test_repro --repro \"{config.to_repro_dict()}\"", + " ", + ] + return "\n".join(lines) + + +def format_test_result(passed: bool, message: str) -> str: + """Format test result.""" + lines = [ + f"%%%% {message}", + ] + if passed: + lines.append("@@@@ Overall result: PASSED, everything looks good!") + else: + lines.append("@@@@ Overall result: FAILED") + return "\n".join(lines) + + +# ============================================================================ +# PyTest Infrastructure +# ============================================================================ + +def pytest_addoption(parser): + """Add custom command line options.""" + try: + parser.addoption("--num-tests", action="store", type=int, default=100, + help="Number of random tests to run") + parser.addoption("--fuzz-seed", action="store", type=int, default=None, + help="Random seed for test generation") + parser.addoption("--unaligned", action="store_true", default=False, + help="Allow M/N/K dimensions that are not multiples of 8") + except Exception: + pass # Options may already be added + + +def tlist(*, num_tests: int, rng_seed: int): + """Generate list of test parameters (legacy, without pre-generated configs).""" + rng = random.Random(rng_seed) + return [(i + 1, num_tests, rng.randint(65536, 2**31 - 1)) for i in range(num_tests)] + + +def tlist_with_configs(*, num_tests: int, rng_seed: int, allow_unaligned: bool = False): + """Generate list of test parameters with pre-generated configs for descriptive test names.""" + rng = random.Random(rng_seed) + params = [] + for i in range(num_tests): + config_seed = rng.randint(65536, 2**31 - 1) + generator = ConfigGenerator(config_seed, allow_unaligned=allow_unaligned) + config = generator.generate() + params.append((i + 1, num_tests, config_seed, config)) + return params + + +def make_test_id(param, prefix: str = "t"): + """Create descriptive test ID from pre-generated config.""" + test_num, total_tests, config_seed, config = param + dtype_short = { + torch.float16: 'f16', + torch.bfloat16: 'bf16', + torch.float32: 'f32', + torch.int8: 'i8', + } + dt = dtype_short.get(config.a_dtype, 'unk') + epi = epilogue_name(config.epilogue)[:4] # Truncate epilogue name + return f"{prefix}{test_num}_b{config.batch}_M{config.M}xN{config.N}xK{config.K}_{dt}_{epi}" + + +# Generate test list +def get_test_params(request): + """Get test parameters from pytest config.""" + try: + num_tests = request.config.getoption("--num-tests", default=100) + seed = request.config.getoption("--fuzz-seed", default=None) + except Exception: + num_tests = 100 + seed = None + + if seed is None: + seed = random.randint(0, 2**31 - 1) + + return num_tests, seed + + +# Fixed test list for default runs +DEFAULT_NUM_TESTS = 2048 +DEFAULT_SEED = 42 +TEST_PARAMS = tlist_with_configs(num_tests=DEFAULT_NUM_TESTS, rng_seed=DEFAULT_SEED) + + +@pytest.mark.L0 +@pytest.mark.parametrize("test_num,total_tests,config_seed,config", TEST_PARAMS, + ids=[make_test_id(p) for p in TEST_PARAMS]) +def test_matmul_fuzz(test_num: int, total_tests: int, config_seed: int, config: MatmulConfig, cudnn_handle, request): + """Fuzz test for matmul operations (M/N/K aligned to multiples of 8).""" + + # Skip if cuDNN handle not available + if cudnn_handle is None: + pytest.skip("cuDNN handle not available") + + # Get display options + try: + max_diffs = request.config.getoption("--diffs", default=10) + except Exception: + max_diffs = 10 + + # Create tensors + rng = random.Random(config_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + # Print test header + print(format_test_header(test_num, total_tests, config)) + + # Compute reference + C_expected = compute_reference(config, A, B, bias) + + # Run cuDNN + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + if not success: + print(f"%%%% cuDNN execution failed: {msg}") + # Skip tests with unsupported configurations rather than failing + skip_keywords = ["not supported", "finalize failed", "mismatch", "invalid", "unsupported"] + if any(kw in msg.lower() for kw in skip_keywords): + print("@@@@ Overall result: SKIPPED (unsupported configuration)") + pytest.skip(f"Unsupported configuration: {msg}") + else: + print("@@@@ Overall result: FAILED") + pytest.fail(f"cuDNN execution failed: {msg}") + + # Compare results + passed, mismatch_count, compare_msg = compare_results(C, C_expected, config, max_diffs) + + print(format_test_result(passed, compare_msg)) + + if not passed: + pytest.fail(f"Numerical mismatch: {mismatch_count} elements differ") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() + + +# Separate test list for unaligned stress testing +UNALIGNED_TEST_PARAMS = tlist_with_configs(num_tests=1024, rng_seed=12345, allow_unaligned=True) + + +@pytest.mark.L1 # L1 for stress testing with unaligned dimensions +@pytest.mark.parametrize("test_num,total_tests,config_seed,config", UNALIGNED_TEST_PARAMS, + ids=[make_test_id(p, prefix="u") for p in UNALIGNED_TEST_PARAMS]) +def test_matmul_fuzz_unaligned(test_num: int, total_tests: int, config_seed: int, config: MatmulConfig, cudnn_handle, request): + """Fuzz test for matmul with unaligned M/N/K dimensions (stress test).""" + + # Skip if cuDNN handle not available + if cudnn_handle is None: + pytest.skip("cuDNN handle not available") + + # Get display options + try: + max_diffs = request.config.getoption("--diffs", default=10) + except Exception: + max_diffs = 10 + + # Create tensors + rng = random.Random(config_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + # Print test header + print(format_test_header(test_num, total_tests, config)) + + # Compute reference + C_expected = compute_reference(config, A, B, bias) + + # Run cuDNN + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + if not success: + print(f"%%%% cuDNN execution failed: {msg}") + # Skip tests with unsupported configurations rather than failing + skip_keywords = ["not supported", "finalize failed", "mismatch", "invalid", "unsupported"] + if any(kw in msg.lower() for kw in skip_keywords): + print("@@@@ Overall result: SKIPPED (unsupported configuration)") + pytest.skip(f"Unsupported configuration: {msg}") + else: + print("@@@@ Overall result: FAILED") + pytest.fail(f"cuDNN execution failed: {msg}") + + # Compare results + passed, mismatch_count, compare_msg = compare_results(C, C_expected, config, max_diffs) + + print(format_test_result(passed, compare_msg)) + + if not passed: + pytest.fail(f"Numerical mismatch: {mismatch_count} elements differ") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() + + +@pytest.mark.L0 +def test_repro(cudnn_handle, request): + """Reproduction test for debugging specific configurations.""" + repro_str = request.config.getoption("--repro", default=None) + if repro_str is None: + pytest.skip("No --repro option provided. Use: pytest test_matmul_fuzzer.py::test_repro --repro ''") + + # Parse repro config + import ast + repro_dict = ast.literal_eval(repro_str) + + # Regenerate config using the same seed (ensures identical random choices) + generator = ConfigGenerator(repro_dict['rng_seed']) + config = generator.generate() + + # Override with explicit values from repro dict + config.batch = repro_dict['batch'] + config.M = repro_dict['M'] + config.N = repro_dict['N'] + config.K = repro_dict['K'] + config.a_dtype = eval(repro_dict['a_dtype']) + config.b_dtype = eval(repro_dict['b_dtype']) + config.c_dtype = eval(repro_dict['c_dtype']) + config.epilogue = EpilogueType(repro_dict['epilogue']) + config.rng_seed = repro_dict['rng_seed'] + + # Run test + rng = random.Random(config.rng_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + print(format_test_header(1, 1, config)) + + C_expected = compute_reference(config, A, B, bias) + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + if not success: + pytest.fail(f"cuDNN execution failed: {msg}") + + passed, mismatch_count, compare_msg = compare_results(C, C_expected, config, max_diffs=20) + print(format_test_result(passed, compare_msg)) + + if not passed: + pytest.fail(f"Numerical mismatch: {mismatch_count} elements differ") + finally: + # Explicit cleanup to prevent GPU memory accumulation + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() + + +# ============================================================================ +# Quick Sanity Tests +# ============================================================================ + +@pytest.mark.L0 +def test_matmul_basic_fp16(cudnn_handle): + """Basic FP16 matmul sanity test.""" + if cudnn_handle is None: + pytest.skip("cuDNN handle not available") + + config = MatmulConfig( + batch=2, M=64, N=128, K=256, + a_dtype=torch.float16, b_dtype=torch.float16, c_dtype=torch.float16, + compute_dtype=cudnn.data_type.FLOAT, + a_layout=LayoutType.ROW_MAJOR_PACKED, b_layout=LayoutType.ROW_MAJOR_PACKED, c_layout=LayoutType.ROW_MAJOR_PACKED, + a_transposed=False, b_transposed=False, + epilogue=EpilogueType.NONE, + rng_seed=12345, + ) + + rng = random.Random(config.rng_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + print(format_test_header(1, 1, config)) + + C_expected = compute_reference(config, A, B, bias) + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + assert success, f"cuDNN failed: {msg}" + + passed, _, compare_msg = compare_results(C, C_expected, config) + print(format_test_result(passed, compare_msg)) + assert passed + finally: + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() + + +@pytest.mark.L0 +def test_matmul_basic_bf16(cudnn_handle): + """Basic BF16 matmul sanity test.""" + if cudnn_handle is None: + pytest.skip("cuDNN handle not available") + + config = MatmulConfig( + batch=4, M=128, N=256, K=512, + a_dtype=torch.bfloat16, b_dtype=torch.bfloat16, c_dtype=torch.bfloat16, + compute_dtype=cudnn.data_type.FLOAT, + a_layout=LayoutType.ROW_MAJOR_PACKED, b_layout=LayoutType.ROW_MAJOR_PACKED, c_layout=LayoutType.ROW_MAJOR_PACKED, + a_transposed=False, b_transposed=False, + epilogue=EpilogueType.NONE, + rng_seed=54321, + ) + + rng = random.Random(config.rng_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + print(format_test_header(1, 1, config)) + + C_expected = compute_reference(config, A, B, bias) + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + assert success, f"cuDNN failed: {msg}" + + passed, _, compare_msg = compare_results(C, C_expected, config) + print(format_test_result(passed, compare_msg)) + assert passed + finally: + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() + + +@pytest.mark.L0 +def test_matmul_with_bias(cudnn_handle): + """Matmul with bias epilogue test.""" + if cudnn_handle is None: + pytest.skip("cuDNN handle not available") + + config = MatmulConfig( + batch=1, M=256, N=512, K=128, + a_dtype=torch.float16, b_dtype=torch.float16, c_dtype=torch.float16, + compute_dtype=cudnn.data_type.FLOAT, + a_layout=LayoutType.ROW_MAJOR_PACKED, b_layout=LayoutType.ROW_MAJOR_PACKED, c_layout=LayoutType.ROW_MAJOR_PACKED, + a_transposed=False, b_transposed=False, + epilogue=EpilogueType.BIAS, + rng_seed=98765, + ) + + rng = random.Random(config.rng_seed) + A, B, C, bias = create_tensors(config, rng) + C_expected = None + + try: + print(format_test_header(1, 1, config)) + + C_expected = compute_reference(config, A, B, bias) + success, msg = run_cudnn_matmul(config, A, B, C, bias, cudnn_handle) + + assert success, f"cuDNN failed: {msg}" + + passed, _, compare_msg = compare_results(C, C_expected, config) + print(format_test_result(passed, compare_msg)) + assert passed + finally: + del A, B, C + if bias is not None: + del bias + if C_expected is not None: + del C_expected + torch.cuda.empty_cache() diff --git a/test/python/test_mhas.py b/test/python/test_mhas.py index ac82fcd1..0e13f69a 100644 --- a/test/python/test_mhas.py +++ b/test/python/test_mhas.py @@ -164,24 +164,15 @@ def compute_ref( causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) causal_mask.triu_(diagonal=1 + right_bound) s = s.masked_fill(causal_mask, float("-inf")) - elif ( - diagonal_alignment == diagonal_alignment.BOTTOM_RIGHT - and right_bound is not None - ): + elif diagonal_alignment == diagonal_alignment.BOTTOM_RIGHT and right_bound is not None: causal_mask_bottom_right = None if padding: - causal_mask_bottom_right = torch.ones( - b, 1, s_q, s_kv, dtype=torch.bool, device=device - ) + causal_mask_bottom_right = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) seq_len_q, seq_len_kv = padding for i in range(b): - causal_mask_bottom_right[i, :, :, :].triu_( - diagonal=seq_len_kv[i] - seq_len_q[i] + 1 + right_bound - ) + causal_mask_bottom_right[i, :, :, :].triu_(diagonal=seq_len_kv[i] - seq_len_q[i] + 1 + right_bound) else: - causal_mask_bottom_right = torch.ones( - s_q, s_kv, dtype=torch.bool, device=device - ) + causal_mask_bottom_right = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1 + right_bound) s = s.masked_fill(causal_mask_bottom_right, float("-inf")) if left_bound is not None: @@ -196,9 +187,7 @@ def compute_ref( swa_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) seq_len_q, seq_len_kv = padding for i in range(b): - swa_mask[i, :, :, :].tril_( - diagonal=seq_len_kv[i] - seq_len_q[i] - left_bound - ) + swa_mask[i, :, :, :].tril_(diagonal=seq_len_kv[i] - seq_len_q[i] - left_bound) # BRCM + SWA for fixed sequence lengths else: swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) @@ -216,9 +205,7 @@ def compute_ref( # apply dropout mask over softmax outputs if dropout_prob != 0.0: - assert ( - dropout_mask != None - ), "PyTorch reference must have dropout_mask for dropout" + assert dropout_mask != None, "PyTorch reference must have dropout_mask for dropout" p = (p * dropout_mask) / (1 - dropout_prob) o = torch.einsum("bhqk,bhkd->bhqd", p, v) @@ -427,18 +414,10 @@ def compute_exclusive_prefix_sum(tensor): else: # sbh3d raise ValueError() - q_ragged_offset = q_ragged_offset.to( - dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 - ) - k_ragged_offset = k_ragged_offset.to( - dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 - ) - v_ragged_offset = v_ragged_offset.to( - dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 - ) - o_ragged_offset = o_ragged_offset.to( - dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 - ) + q_ragged_offset = q_ragged_offset.to(dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32) + k_ragged_offset = k_ragged_offset.to(dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32) + v_ragged_offset = v_ragged_offset.to(dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32) + o_ragged_offset = o_ragged_offset.to(dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32) return q_ragged_offset, k_ragged_offset, v_ragged_offset, o_ragged_offset @@ -446,13 +425,9 @@ def compute_exclusive_prefix_sum(tensor): # @brief Convert a padded page table into a packed page table # @return packed_page_table: packed page table # @return ragged_offset: offset into the packed page table -def convert_uniform_to_ragged_page_tables( - uniform_tensor, seq_len, block_size, cudnn_version -): +def convert_uniform_to_ragged_page_tables(uniform_tensor, seq_len, block_size, cudnn_version): [B, H, S, D] = uniform_tensor.size() - ragged_offset = torch.zeros( - B + 1, 1, 1, 1, dtype=torch.int32, device=uniform_tensor.device - ) # Initialize with first offset as 0 + ragged_offset = torch.zeros(B + 1, 1, 1, 1, dtype=torch.int32, device=uniform_tensor.device) # Initialize with first offset as 0 for i in range(1, B + 1): prev_seq_len = seq_len[i - 1] num_pages_prev_batch = (prev_seq_len + block_size - 1) // block_size @@ -462,17 +437,13 @@ def convert_uniform_to_ragged_page_tables( ragged_offset.to(dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32) # ragged_offset.to(dtype=torch.int32) - packed_page_table = torch.zeros(B * S, H, D).to( - dtype=uniform_tensor.dtype, device=uniform_tensor.device - ) + packed_page_table = torch.zeros(B * S, H, D).to(dtype=uniform_tensor.dtype, device=uniform_tensor.device) uniform_tensor_thd = torch.einsum("bhsd->bshd", uniform_tensor).reshape(B * S, H, D) t_0 = 0 for b, t_1 in enumerate(ragged_offset.flatten()[1:]): - packed_page_table[t_0:t_1, :, :] = uniform_tensor_thd[ - b * S : b * S + (t_1 - t_0), :, : - ] + packed_page_table[t_0:t_1, :, :] = uniform_tensor_thd[b * S : b * S + (t_1 - t_0), :, :] t_0 = t_1 packed_page_table = packed_page_table.reshape(B, S, H, D) @@ -496,9 +467,7 @@ def convert_ragged_to_uniform(ragged_tensor, seq_len): seq_len = seq_len.flatten() # convert bhsd to bshd and flatten - uniform_tensor = torch.zeros(b, s, h, d).to( - dtype=ragged_tensor.dtype, device=ragged_tensor.device - ) + uniform_tensor = torch.zeros(b, s, h, d).to(dtype=ragged_tensor.dtype, device=ragged_tensor.device) ragged_tensor_thd = torch.einsum("bhsd->bshd", ragged_tensor).reshape(b * s, h, d) # copy @@ -512,26 +481,18 @@ def convert_ragged_to_uniform(ragged_tensor, seq_len): return uniform_tensor -def generate_actual_seq_lens( - b, s_q, s_kv, layout, head_group, is_padding, force_sq_less_or_equal_than_skv -): +def generate_actual_seq_lens(b, s_q, s_kv, layout, head_group, is_padding, force_sq_less_or_equal_than_skv): seq_len_q_gpu = None seq_len_kv_gpu = None if is_padding: - seq_len_q_gpu = torch.randint( - 1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" - ) + seq_len_q_gpu = torch.randint(1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") if not (layout == "bs3hd" and head_group == "multi_head"): - seq_len_kv_gpu = torch.randint( - 1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" - ) + seq_len_kv_gpu = torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") # Avoid seq_len_q > seq_len_kv (known limitation): if force_sq_less_or_equal_than_skv: - seq_len_q_gpu = torch.max( - torch.tensor(1), seq_len_q_gpu % seq_len_kv_gpu - ) + seq_len_q_gpu = torch.max(torch.tensor(1), seq_len_q_gpu % seq_len_kv_gpu) else: seq_len_kv_gpu = seq_len_q_gpu diff --git a/test/python/test_mhas_v2.py b/test/python/test_mhas_v2.py index c453f003..9e3f6b97 100644 --- a/test/python/test_mhas_v2.py +++ b/test/python/test_mhas_v2.py @@ -8,17 +8,12 @@ import pytest import random import torch -import math -import os import sys -from looseversion import LooseVersion from datetime import datetime -from enum import IntEnum -from dataclasses import dataclass, asdict -from mha_v2_utils import ( - exec_cfg, - INVALID_BOUND, +from sdpa.random_config import ( + ExecConfig, + generate_test_seeds, RandomizationContext, RandomBatchSize, RandomBlockSize, @@ -27,9 +22,11 @@ RandomHeadGenerator, RandomChoice, SlidingWindowMaskGenerator, - time_execution, - profile_execution, ) +from sdpa.fp16 import exec_sdpa +from sdpa.fp8 import exec_sdpa_fp8 +from sdpa.blocked import fetch_blocked_tests +from sdpa.helpers import print_section_begin, print_section_end # fmt: off @@ -37,193 +34,7 @@ print("This is pytest script.") sys.exit(0) -def tlist(*, num_tests, rng_seed): - assert num_tests >= 1 and type(num_tests) == int, "wrong input" - rng = random.Random(rng_seed) - return [(i+1, num_tests, rng.randint(65536, 2147483647)) for i in range(num_tests)] - -def get_layout_name(string, indices): - assert len(string) == 4 and sorted(indices) == [0, 1, 2, 3], "wrong input" - chars = [string[i] for i in indices] - return ''.join(chars) - -def int_cli_option(org_val, request, cli_opt): - val = request.config.getoption(cli_opt) - return val if type(val) == int else org_val - -def implementation_cli_option(org_val, request, cli_opt): - str_val = request.config.getoption(cli_opt) - val = getattr(cudnn.attention_implementation, str_val, None) if str_val else None - return val if isinstance(val, cudnn.attention_implementation) else org_val - -def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - else: - assert False, "unsupported tensor data type" - -def approx_equal(actual, expected, sepbuf, rawbuf, rtol, atol, tag, disp_elems): - mismatches = torch.where(torch.isclose(actual.float(), expected, rtol=rtol, atol=atol) == False) - mismatch_cnt = mismatches[0].numel() - num_elements = torch.numel(actual) - if mismatch_cnt != 0: - percentage = 100 * mismatch_cnt / num_elements - if disp_elems > 0: - print(f"Comparing '{tag}' using rtol={rtol:.4e}, atol={atol:.4e}") - combined = torch.stack(mismatches, dim=-1).tolist() - count = 0 - for index in combined: - diff = actual[tuple(index)] - expected[tuple(index)] - if math.isfinite(diff): - print(f"idx{index}: {tag}_gpu={actual[tuple(index)]:+.6e}, {tag}_ref={expected[tuple(index)]:+.6e}, diff={diff:+.2e}") - else: - print(f"idx{index}: {tag}_gpu={actual[tuple(index)]:+.6e}, {tag}_ref={expected[tuple(index)]:+.6e}") - count += 1 - if count >= disp_elems: - break - print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' results (first {count} mismatches displayed)") - else: - print(f"%%%% Total {mismatch_cnt:,} mismatches ({percentage:.1f}%) when validating '{tag}' results") - - num_nans = torch.isnan(actual).sum().item() - num_infs = torch.isinf(actual).sum().item() - num_zeros = num_elements - torch.count_nonzero(actual) - num_finites_nz = num_elements - num_nans - num_infs - num_zeros - - print(f"%%%% {tag}_gpu overview: elements={num_elements:,}, finites_nz={num_finites_nz:,}, zeros={num_zeros:,}, nans={num_nans:,}, infs={num_infs:,}") - - num_nans = torch.isnan(expected).sum().item() - num_infs = torch.isinf(expected).sum().item() - num_zeros = num_elements - torch.count_nonzero(expected) - num_finites_nz = num_elements - num_nans - num_infs - num_zeros - - print(f"%%%% {tag}_ref overview: elements={num_elements:,}, finites_nz={num_finites_nz:,}, zeros={num_zeros:,}, nans={num_nans:,}, infs={num_infs:,}") - else: - print(f"%%%% Numerical divergence of '{tag}' within limits") - - # Check if areas before and after the tensor were overwritten (treated as one numerical mismatch). - if sepbuf is not None and not torch.all(torch.isnan(sepbuf)).item(): - print(f"%%%% Buffer '{tag}' overwritten outside its boundaries") - print(sepbuf) - mismatch_cnt += 1 - - # Check if unused elements of the tensor were overwritten (treated as one numerical mismatch). - # Note that this check destroys computed data (overwrites them with NaN-s). - if rawbuf is not None: - actual.fill_(float('nan')) - if not torch.all(torch.isnan(rawbuf)).item(): - print(f"%%%% Unused gaps of '{tag}' tensor were overwritten") - mismatch_cnt += 1 - - return mismatch_cnt - -def alloc_tensor(shape, data_type, *, elems=None, strides=None, rng=None, mean=0.0, std=1.0, margins=512): - # Arguments elems/strides must be both specified or both None. - if elems is None and strides is None: - if hasattr(shape, '__iter__'): - strides = [] - prod = 1 - for dim in reversed(shape): - strides.insert(0, prod) - prod *= int(dim) - elems = prod - else: - elems = int(shape) - strides = (1,) - shape = (shape,) - else: - assert elems is not None and strides is not None, "wrong input" - - assert margins >= 0 and type(margins) == int, "wrong input" - - rawbuf = torch.empty(elems+2*margins, dtype=data_type, device="cuda") - if torch.is_floating_point(rawbuf): - rawbuf.fill_(float('nan')) - else: - rawbuf.fill_(-1) - - tensor = torch.as_strided(rawbuf, shape, strides, storage_offset=margins) - sepbuf = (torch.as_strided(rawbuf, (2, margins), (elems+margins, 1), storage_offset=0) if margins > 0 else None) - - # Use this initialization for floating point types only. - if rng is not None: - tensor.normal_(mean=mean, std=std, generator=rng) - - # Not returning the raw buffer, if the data tensor has no gaps between valid elements. - # If there are unused gaps, then we want to check that those gaps were not overwritten. - if math.prod(shape) == elems: - rawbuf = None - - return tensor, sepbuf, rawbuf - -def fetch_blocked_tests(file_path, gpu_arch, cudnn_ver): - assert type(gpu_arch) == type(cudnn_ver) == str, "expecting strings" - blocked_tests = [] - try: - line_number = None - with open(file_path, 'r') as file: - for line_number, line_buf in enumerate(file, 1): - line_buf = line_buf.split('#', 1)[0] # remove comments - line_buf = "".join(line_buf.split()) # remove whitespaces - if line_buf: - test,sms,libs = (line_buf+"::").split(':')[:3] - if not test: - raise ValueError("missing test name") - sms = sms.split(',') if sms else None - libs = libs.split(',') if libs else None - if (test not in blocked_tests) and (sms == None or gpu_arch in sms) and (libs == None or cudnn_ver in libs): - blocked_tests.append(test) - except Exception as e: - blocked_tests = [] - if line_number != None: - print(f"\n\nWARNING: {e} in {file_path}:{line_number}") - else: - print(f"\n\nWARNING: {e}") - return blocked_tests - -def show_blocked_tests(blocked_tests, gpu_arch, cudnn_ver): - print(f"\n\nBlocked tests on {gpu_arch} and cudnn_ver={cudnn_ver}:") - if blocked_tests: - for index, test in enumerate(blocked_tests): - assert type(test) == str, "test name must be string" - print(f"{index+1:<4} : {test}") - else: - print("[empty]") - -def is_test_blocked(test, blocked_tests): - assert type(test) == str, "test name must be string" - if not blocked_tests: - return False - return True if test in blocked_tests else False - -def truncated_list(beg, end, arr): - if len(arr) >= beg + 3 + end: - hi = max(arr) - lo = min(arr) - s = [*arr[:beg], '...', *arr[beg:][-end:]] - s = '['+', '.join(map(str, s))+'], min='+str(lo)+', max='+str(hi) - else: - s = '['+', '.join(map(str, arr))+']' - return s - -class knobNAR(IntEnum): - NEVER = 0 - ALWAYS = 1 - RANDOM = 2 - -class knobNA(IntEnum): - NEVER = 0 - ALWAYS = 1 - -class testConfig: +class SDPATestConfig: __slots__ = ['gpu_arch', 'gpu_info', 'cudnn_ver', 'blocked_tests', 'implementation', 'cfg'] def __init__(self, *, gpu_arch, gpu_info, cudnn_ver, blocked_tests, implementation): @@ -241,937 +52,20 @@ def __init__(self, *, gpu_arch, gpu_info, cudnn_ver, blocked_tests, implementati self.implementation = implementation - self.cfg = exec_cfg() - - - def showConfig(self, test_no, request, reg_run=True): - if request.config.option.dryrun == 0 or request.config.option.dryrun == 1: - if request.config.option.dryrun == 0: - print("\n" + "=" * 90) - else: - print("\n" + "=" * 40 + "Dry-RUN" + "=" * 40) - print(f"#### Test #{test_no[0]} of {test_no[1]} at", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "\n") - print(f"test_name = {request.node.name}") - # print(f"geom_seed = {self.geom_seed}") - # print(f"data_seed = {self.data_seed}") - print(f"platform_info = {self.gpu_arch} ({self.gpu_info}), cudnn_ver={self.cudnn_ver}") - print(f"rng_data_seed = {self.cfg.rng_data_seed}") - # print(f"head_group = {self.cfg.head_group}") - # print(f"layout = {self.in_layout}->{self.out_layout}") - print(f"basic_dims = [b={self.cfg.batches}, h_q={self.cfg.h_q}, h_k={self.cfg.h_k}, h_v={self.cfg.h_v}, d_qk={self.cfg.d_qk}, d_v={self.cfg.d_v}, s_q={self.cfg.s_q}, s_kv={self.cfg.s_kv}]") - print(f"shape_q(b,h,s,d) = {self.cfg.shape_q}, strides={self.cfg.stride_q}, elems={self.cfg.elems_q}") - print(f"shape_k(b,h,s,d) = {self.cfg.shape_k}, strides={self.cfg.stride_k}, elems={self.cfg.elems_k}") - print(f"shape_v(b,h,s,d) = {self.cfg.shape_v}, strides={self.cfg.stride_v}, elems={self.cfg.elems_v}") - print(f"shape_o(b,h,s,d) = {self.cfg.shape_o}, strides={self.cfg.stride_o}, elems={self.cfg.elems_o}") - - print(f"is_infer = {self.cfg.is_infer}") - print(f"is_padding = {self.cfg.is_padding} ({'ragged' if self.cfg.is_ragged else 'no ragged'})") - print(f"is_alibi = {self.cfg.is_alibi}") - print(f"is_paged = {self.cfg.is_paged} (block_size={self.cfg.block_size})") - print(f"is_bias = {self.cfg.is_bias}") - print(f"is_block_mask = {self.cfg.is_block_mask}") - print(f"is_dropout = {self.cfg.is_dropout}") - if self.cfg.is_infer == False: - print(f"is_determin = {self.cfg.is_determin}") - print(f"diag_align = {self.cfg.diag_align}") - print(f"left_bound = {self.cfg.left_bound}", '(NO BOUND)' if self.cfg.left_bound == INVALID_BOUND else '') - print(f"right_bound = {self.cfg.right_bound}", '(NO BOUND)' if self.cfg.right_bound == INVALID_BOUND else '') - # print(f"seq_len_q = {truncated_list(20, 3, self.seq_len_q)}") - # print(f"seq_len_kv = {truncated_list(20, 3, self.seq_len_kv)}") - print(f"data_type = {self.cfg.data_type}") - print(f"implementation = {self.cfg.implementation.name}") - if reg_run: - # Convert enums to integers and handle torch dtypes for proper serialization - cfg_dict = asdict(self.cfg) - # Convert enum values to integers - if cfg_dict.get('diag_align') is not None: - cfg_dict['diag_align'] = cfg_dict['diag_align'].value - if cfg_dict.get('implementation') is not None: - cfg_dict['implementation'] = cfg_dict['implementation'].name - # Convert torch dtype to string - if cfg_dict.get('data_type') is not None: - cfg_dict['data_type'] = str(cfg_dict['data_type']) - print(f"repro_cmd = pytest -vv -s -rA {request.module.__file__}::test_repro --repro \"{repr(cfg_dict)}\"") - elif request.config.option.dryrun == 2: - print(f"\npytest -vv -s -rA {request.module.__file__}::{request.node.name} --geom_seed {self.geom_seed} --data_seed {self.data_seed}") - elif request.config.option.dryrun == 3: - print(f"repro_cmd = pytest -vv -s -rA {request.module.__file__}::{request.node.name} --geom_seed {self.geom_seed} --data_seed {self.data_seed}") - - else: - assert False, "wrong --dryrun command line option" - - # Make sure to flush everything out. - print(" ", flush=True) - - - def avoid_invalid_configs(self, avoid_invalid_configs): - if avoid_invalid_configs == avoid_invalid_configs.ALWAYS: - # LIMIT: always is_determin=True in inference. - if self.is_infer: - self.is_determin = True - - # LIMIT: Paged attention only in inference. - if not self.is_infer: - self.is_paged = False - - # LIMIT: Paged caches can only be used in combination with padding mask (variable sequence length). - if self.is_paged and not self.is_padding: - self.is_paged = False - - # LIMIT: Paged caches cannot be used with ragged offsets (packed variable sequence lengths). - if self.is_paged and self.is_ragged: - self.is_paged = False - - # LIMIT: left and right bounds are only supported with is_dropout=False, is_bias=False. - if self.left_bound != INVALID_BOUND and self.right_bound != INVALID_BOUND: - self.is_dropout = False - self.is_bias = False - - # LIMIT: when alibi mask is used, diagonal_band_right_bound needs to be exactly 0 (not INVALID_BOUND). - if self.is_alibi and self.right_bound != 0: - self.is_alibi = False - - # LIMIT: bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_dropout=False. - if self.diag_align == self.diag_align.BOTTOM_RIGHT and (self.left_bound != INVALID_BOUND or self.right_bound != INVALID_BOUND): - self.is_bias = False - self.is_alibi = False - self.is_dropout = False - - # LIMIT: Left or right bounds are only supported with is_dropout=False, is_bias=False. - if self.left_bound != INVALID_BOUND or self.right_bound != INVALID_BOUND: - self.is_dropout = False - self.is_bias = False - - # LIMIT: Left bound (a.k.a sliding window) does not support s_q > s_kv - if self.left_bound != INVALID_BOUND and self.s_q.val > self.s_kv.val: - self.left_bound = INVALID_BOUND - - # LIMIT: Bottom right causal mask does not support s_q > s_kv. - if self.s_q.val > self.s_kv.val and self.diag_align == self.diag_align.BOTTOM_RIGHT and self.right_bound != INVALID_BOUND: - self.right_bound = INVALID_BOUND - - if not self.is_infer: - self.is_block_mask = False - -def compute_ref( - q, - k, - v, - attn_scale=None, - bias=None, - block_mask=None, - is_alibi=False, - padding=None, - diag_align=cudnn.diagonal_alignment.TOP_LEFT, - left_bound=INVALID_BOUND, - right_bound=INVALID_BOUND, - dropout_prob=0.0, - dropout_mask=None, - generate_stats=False, - device="cuda", -): - b, h_q, s_q, d_qk = q.shape - _, h_k, s_kv, _ = k.shape - _, h_v, _, d_v = v.shape - - assert k.shape == (b, h_k, s_kv, d_qk) - assert v.shape == (b, h_v, s_kv, d_v) - - # use float32 datatype and math for reference computation - q = q.to(dtype=torch.float32, device=device) - k = k.to(dtype=torch.float32, device=device) - v = v.to(dtype=torch.float32, device=device) - - # expand tensors for GQA and MQA - if h_q != h_k: - assert h_q % h_k == 0 - k = k.unsqueeze(2) - k = k.expand(-1, -1, h_q // h_k, -1, -1) - k = k.reshape(k.size(0), -1, k.size(3), k.size(4)) - if h_q != h_v: - assert h_q % h_v == 0 - v = v.unsqueeze(2) - v = v.expand(-1, -1, h_q // h_v, -1, -1) - v = v.reshape(v.size(0), -1, v.size(3), v.size(4)) - - if left_bound != INVALID_BOUND: - swa_mask_zero = torch.ones(1, 1, s_q, 1, dtype=torch.bool, device=device) - swa_mask_zero[:, :, s_kv + left_bound - 1 :, :] = False - q = q * swa_mask_zero - - # generate masks to compute reference values for padding mask (also called variable sequence length) - if padding is not None: - q_mask = torch.zeros(b, 1, s_q, 1, dtype=torch.bool, device=device) - k_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) - v_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) - s_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) - p_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) - seq_len_q, seq_len_kv = padding - for i, (m, n) in enumerate(zip(seq_len_q, seq_len_kv)): - q_mask[i, :, m:, :] = True - k_mask[i, :, n:, :] = True - v_mask[i, :, n:, :] = True - s_mask[i, :, :, n:] = True - p_mask[i, :, m:, :] = True - - q = q.masked_fill(q_mask, 0.0) - k = k.masked_fill(k_mask, 0.0) - v = v.masked_fill(v_mask, 0.0) - - s = torch.einsum("bhqd,bhkd->bhqk", q, k) - if attn_scale is not None: - s = s * attn_scale - - # Attention masks are applied in the following order: - # - Bias mask - # - Alibi mask - # - Padding mask - # - Causal mask - if bias is not None: - s = s + bias - if is_alibi: - index_row = torch.arange(s_q, dtype=torch.float32, device=device).view(-1, 1) - index_col = torch.arange(s_kv, dtype=torch.float32, device=device) - distance = index_col - index_row - - # Get the closest power of 2 to `n_heads`. - # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, - # and then add the remaining slopes. - n = 2 ** math.floor(math.log2(h_q)) - m_0 = 2.0 ** (-8.0 / n) - m = torch.pow(m_0, torch.arange(1, 1 + n)) - - # If `n_heads` is not a power of 2, then we add the remaining slopes. - # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously). - # And pick the slopes upto `n_heads`. - if n < h_q: - m_hat_0 = 2.0 ** (-4.0 / n) - m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (h_q - n), 2)) - # Concatenate the slopes with the remaining slopes. - m = torch.cat([m, m_hat]) - - # Reshape the tensor to [1, num_heads, 1, 1] - m = m.view(1, -1, 1, 1).to(device=device) - - alibi_mask = distance.to(dtype=torch.float32) * m - s = s + alibi_mask - - if padding is not None: - s = s.masked_fill(s_mask, float("-inf")) - - if diag_align == diag_align.TOP_LEFT and right_bound != INVALID_BOUND: - causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) - causal_mask.triu_(diagonal=1 + right_bound) - s = s.masked_fill(causal_mask, float("-inf")) - elif diag_align == diag_align.BOTTOM_RIGHT and right_bound != INVALID_BOUND: - causal_mask_bottom_right = None - if padding: - causal_mask_bottom_right = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) - seq_len_q, seq_len_kv = padding - for i in range(b): - causal_mask_bottom_right[i, :, :, :].triu_(diagonal=seq_len_kv[i] - seq_len_q[i] + 1 + right_bound) - else: - causal_mask_bottom_right = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) - causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1 + right_bound) - s = s.masked_fill(causal_mask_bottom_right, float("-inf")) - - if left_bound != INVALID_BOUND: - assert diag_align is not None - if diag_align == diag_align.TOP_LEFT: - swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) - swa_mask.tril_(diagonal=-1 * left_bound) - elif diag_align == diag_align.BOTTOM_RIGHT: - # BRCM + SWA for variable sequence lengths - if padding: - swa_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) - seq_len_q, seq_len_kv = padding - for i in range(b): - swa_mask[i, :, :, :].tril_(diagonal=seq_len_kv[i] - seq_len_q[i] - left_bound) - # BRCM + SWA for fixed sequence lengths - else: - swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) - swa_mask.tril_(diagonal=-1 * left_bound + (s_kv - s_q)) - swa_mask &= swa_mask_zero.view(s_q, 1) - s = s.masked_fill(swa_mask, float("-inf")) - - if block_mask is not None: - TILE_M = 128 - TILE_N = 128 - - block_mask = block_mask.to(dtype=torch.uint8, device=device) - block_mask = ((block_mask[..., None] & (1 << torch.arange(8, device=block_mask.device))) != 0).reshape(block_mask.shape[0], block_mask.shape[1], block_mask.shape[2], block_mask.shape[3] * 8) - block_mask = block_mask.unsqueeze(3).unsqueeze(5) - block_mask = block_mask.repeat(1, 1, 1, TILE_M, 1, TILE_N) - block_mask = block_mask.reshape(block_mask.shape[0], block_mask.shape[1], block_mask.shape[2] * TILE_M, block_mask.shape[4] * TILE_N) - block_mask = block_mask[:, :, :s_q, :s_kv] - s += torch.where(block_mask, torch.tensor(0.0), torch.tensor(float('-inf'))) - - p = torch.softmax(s, dim=-1) - - if block_mask is not None: - all_inf = torch.isneginf(s).all(dim=-1, keepdim=True) - if torch.any(all_inf): - p = torch.where(all_inf, torch.zeros_like(p), p) - - if left_bound != INVALID_BOUND: - p = p * swa_mask_zero - if padding is not None: - p = p.masked_fill(p_mask, 0.0) - - # apply dropout mask over softmax outputs - if dropout_prob != 0.0: - assert dropout_mask != None, "PyTorch reference must have dropout_mask for dropout" - p = (p * dropout_mask) / (1 - dropout_prob) - - o = torch.einsum("bhqk,bhkd->bhqd", p, v) - - # softmax stats is used for backwards computation - if generate_stats: - # amax (NOT absolute max) is used here to evenly distribute gradient - row_max = torch.amax(s, -1, True) - row_exp = torch.exp(s - row_max) - row_sum = torch.sum(row_exp, -1, True) - stats = row_max + torch.log(row_sum) - return o, stats - - return o - -# Compute the exclusive prefix sum for ragged sequence dimension -# input tensor has shape (B, 1, 1, 1) -# output tensor has shape (B+1, 1, 1, 1) -# example input seq_len: [2, 4, 1, 6] (along the B dimension) -# example output ragged_offset: [0, 2, 6, 7, 13] (along the B dimension) -def compute_exclusive_prefix_sum(tensor): - assert list(tensor.size())[1:]==[1,1,1] - # We need to provide a tuple of two tensors to torch.cat(). - return torch.cat((torch.zeros(1, 1, 1, 1, dtype=tensor.dtype, device=tensor.device), torch.cumsum(tensor, dim=0))) - -def generate_ragged_offset(h_q, h_k, h_v, d_qk, d_v, seq_len_q, seq_len_kv): - # Only for thd_thd_thd - q_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d_qk - k_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * h_k * d_qk - v_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * h_v * d_v - o_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d_v - - # Convert to int64 for cuDNN 9.6.0 - q_ragged_offset = q_ragged_offset.to(dtype=torch.int64) - k_ragged_offset = k_ragged_offset.to(dtype=torch.int64) - v_ragged_offset = v_ragged_offset.to(dtype=torch.int64) - o_ragged_offset = o_ragged_offset.to(dtype=torch.int64) - - return q_ragged_offset, k_ragged_offset, v_ragged_offset, o_ragged_offset - -def convert_ragged_to_uniform(ragged_tensor, seq_len): - # limitations: - # 1. tensor is bhsd dim order and bshd stride order (may be interleaved) - # 2. ragged tensor is packed and in-order, therefore - # ragged offset is monatomically increasing - assert ragged_tensor.dim() == 4 - b, h, s, d = ragged_tensor.size() - b_stride, h_stride, s_stride, d_stride = ragged_tensor.stride() - assert b_stride >= s_stride >= h_stride >= d_stride - assert seq_len.dim() == 4 and (b, 1, 1, 1) == seq_len.size() - - # ragged offset is given in 4D, convert to 1D locally - seq_len = seq_len.flatten() - - # convert bhsd to bshd and flatten - uniform_tensor = torch.zeros(b, s, h, d).to( - dtype=ragged_tensor.dtype, device=ragged_tensor.device - ) - ragged_tensor_thd = torch.einsum("bhsd->bshd", ragged_tensor).reshape(b * s, h, d) - - # copy - t = 0 - for b, s in enumerate(seq_len): - uniform_tensor[b, 0:s, :, :] = ragged_tensor_thd[t : t + s, :, :] - t += s - - # convert back to bshd to bhsd - uniform_tensor = torch.einsum("bshd->bhsd", uniform_tensor) - return uniform_tensor - -def create_container_and_page_table(tensor, block_size): - B, H, S, D = tensor.shape - # num_blocks = math.ceil(S/block_size) * B - blocks_per_batch = math.ceil(S/block_size) - - padding_seq = (blocks_per_batch * block_size) - S - if padding_seq > 0: - zeros = torch.zeros(B,H,padding_seq,D, device='cuda', dtype=tensor.dtype) - cat_tensor = torch.cat((tensor, zeros), axis = 2) - else: - cat_tensor = tensor - - reshaped = torch.cat((cat_tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0) - - table_size = math.ceil(S/block_size) - page_table = torch.linspace(0, B*table_size-1, B*table_size, device='cuda', dtype=torch.int32).reshape(table_size,1,B,1) - page_table = torch.transpose(page_table,0,2) - - return(reshaped, page_table) - -def exec_sdpa(cfg, request, cudnn_handle): - # Do not run any test when --dryrun option is provided. - - if request.config.option.dryrun: - pytest.skip("dry run mode") - - # # Check if the test is temporarily blocked. - # if is_test_blocked(request.node.name, cfg.blocked_tests): - # print(f"\nWARNING: test '{request.node.name}' is blocked on {cfg.gpu_arch} and cuDNN {cfg.cudnn_ver}") - # print("@@@@ Overall result: SKIPPED, test blocked.") - # pytest.skip("test blocked") - - # ============================ - # Basic parameter check. - # ============================ - - if not all((x > 0 and type(x) == int) for x in (cfg.batches, cfg.d_qk, cfg.d_v, cfg.s_q, cfg.s_kv, cfg.h_q, cfg.h_k, cfg.h_v)): - assert False, "tensor dimensions must be integer and positive" - - assert cfg.shape_q == (cfg.batches, cfg.h_q, cfg.s_q, cfg.d_qk), f"wrong shape_q={cfg.shape_q}" - assert cfg.shape_k == (cfg.batches, cfg.h_k, cfg.s_kv, cfg.d_qk), f"wrong shape_k={cfg.shape_k}" - assert cfg.shape_v == (cfg.batches, cfg.h_v, cfg.s_kv, cfg.d_v), f"wrong shape_v={cfg.shape_v}" - assert cfg.shape_o == (cfg.batches, cfg.h_q, cfg.s_q, cfg.d_v), f"wrong shape_o={cfg.shape_o}" - - if not cfg.is_infer: - assert cfg.is_paged == False and cfg.block_size == None, "paged attention not allowed in backward pass" - - if cfg.is_ragged: - assert cfg.is_padding == True, "is_ragged=True and is_padding=False not allowed" - - assert isinstance(cfg.seq_len_q, (list, tuple)), "input 'seq_len_q' must be list or tuple" - if cfg.is_padding: - assert len(cfg.seq_len_q) == cfg.batches, f"wrong 'seq_len_q' length" - else: - assert len(cfg.seq_len_q) == 0, f"wrong 'seq_len_q' length, expecting 0" - - assert isinstance(cfg.seq_len_kv, (list, tuple)), "input 'seq_len_kv' must be list or tuple" - if cfg.is_padding: - assert len(cfg.seq_len_kv) == cfg.batches, f"wrong 'seq_len_kv' length, expecting {cfg.batches}" - else: - assert len(cfg.seq_len_kv) == 0, f"wrong 'seq_len_kv' length, expecting 0" - - assert all(x >= 0 and type(x) == int for x in cfg.seq_len_q), f"wrong seq_len_q={cfg.seq_len_q}" - assert all(x >= 0 and type(x) == int for x in cfg.seq_len_kv), f"wrong seq_len_kv={cfg.seq_len_kv}" - - cudnn_version = LooseVersion(cudnn.backend_version_string()) - if cudnn_version < "9.10.0": - print("@@@@ Overall result: WAIVED, test_mhas_v2.py supports cudnn 9.10.0 or higher.") - pytest.skip("test_mhas_v2.py requires cudnn 9.10.0 or higher") - - if cudnn_version < "9.13.1" and cfg.implementation == cudnn.attention_implementation.UNIFIED: - print("@@@@ Overall result: WAIVED, unified SDPA implementation requires cudnn 9.13.1 or higher.") - pytest.skip("unified SDPA implementation requires cudnn 9.13.1 or higher") - - if cfg.s_q == cfg.s_kv == 1: - print("@@@@ Overall result: WAIVED, skipping known issue of s_q == s_kv == 1.") - pytest.skip("skipping known issue of s_q == s_kv == 1") - - qkv_num_elems = cfg.elems_q + cfg.elems_k + cfg.elems_v - - rng_data_gen = torch.Generator(device="cuda").manual_seed(cfg.rng_data_seed) - - (q_gpu, _, _) = alloc_tensor(cfg.shape_q, cfg.data_type, elems=cfg.elems_q, strides=cfg.stride_q, rng=rng_data_gen, mean=-0.5, std=1.0) - (k_gpu, _, _) = alloc_tensor(cfg.shape_k, cfg.data_type, elems=cfg.elems_k, strides=cfg.stride_k, rng=rng_data_gen, mean=-0.5, std=1.0) - (v_gpu, _, _) = alloc_tensor(cfg.shape_v, cfg.data_type, elems=cfg.elems_v, strides=cfg.stride_v, rng=rng_data_gen, mean=-0.5, std=1.0) - (bias_gpu, _, _) = (alloc_tensor((1, cfg.h_q, cfg.s_q, cfg.s_kv), cfg.data_type, rng=rng_data_gen, mean=0.0, std=1.0) if cfg.is_bias else (None, None, None)) - - TILE_M = 128 - TILE_N = 128 - block_mask_gpu = torch.randint(0, 256, (cfg.batches, cfg.h_q, (cfg.s_q + TILE_M - 1) // TILE_M, ((cfg.s_kv + TILE_N - 1) // TILE_N + 7) // 8), dtype=torch.uint8, device="cuda") - - if not cfg.is_infer: - (dQ_gpu, dQ_sep, dQ_raw) = alloc_tensor(cfg.shape_q, cfg.data_type, elems=cfg.elems_q, strides=cfg.stride_q) - (dK_gpu, dK_sep, dK_raw) = alloc_tensor(cfg.shape_k, cfg.data_type, elems=cfg.elems_k, strides=cfg.stride_k) - (dV_gpu, dV_sep, dV_raw) = alloc_tensor(cfg.shape_v, cfg.data_type, elems=cfg.elems_v, strides=cfg.stride_v) - (dBias_gpu, dBias_sep, dBias_raw) = (alloc_tensor((1, cfg.h_q, cfg.s_q, cfg.s_kv), cfg.data_type) if cfg.is_bias else (None, None, None)) - (dO_gpu, dO_sep, dO_raw) = alloc_tensor(cfg.shape_o, cfg.data_type, elems=cfg.elems_o, strides=cfg.stride_o, rng=rng_data_gen, mean=0.0, std=0.1) - - # Sequence lengths for gpu, must be a four dimensional tensor. - seq_len_q_gpu = seq_len_kv_gpu = None - if len(cfg.seq_len_q) > 0: - seq_len_q_gpu = torch.tensor(cfg.seq_len_q, dtype=torch.int32, device="cuda") - seq_len_q_gpu = seq_len_q_gpu[:, None, None, None] # batches x 1 x 1 x 1 - if len(cfg.seq_len_kv) > 0: - seq_len_kv_gpu = torch.tensor(cfg.seq_len_kv, dtype=torch.int32, device="cuda") - seq_len_kv_gpu = seq_len_kv_gpu[:, None, None, None] # batches x 1 x 1 x 1 - - # maxT = next_multiple_of_64(sum(seq_len)) - max_t_q = ((torch.sum(seq_len_q_gpu).item() + 63) // 64) * 64 if cfg.is_ragged else None - max_t_kv = ((torch.sum(seq_len_kv_gpu).item() + 63) // 64) * 64 if cfg.is_ragged else None - - if cfg.is_dropout: - seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") - offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - - rng_dump_gpu = torch.zeros((cfg.batches, cfg.h_q, cfg.s_q, cfg.s_kv), dtype=torch.float32, device="cuda") if cfg.is_dropout else None - - if cfg.is_ragged: - q_ragged_offset_gpu, k_ragged_offset_gpu, v_ragged_offset_gpu, o_ragged_offset_gpu = generate_ragged_offset(cfg.h_q, cfg.h_k, cfg.h_v, cfg.d_qk, cfg.d_v, seq_len_q_gpu, seq_len_kv_gpu) - - (o_gpu, o_sep, o_raw) = alloc_tensor(cfg.shape_o, cfg.data_type, elems=cfg.elems_o, strides=cfg.stride_o) - (stats_gpu, stats_sep, stats_raw) = (alloc_tensor((cfg.batches, cfg.h_q, cfg.s_q, 1), torch.float32) if not cfg.is_infer else (None, None, None)) - - container_k_gpu = None - container_v_gpu = None - page_table_k_gpu = None - page_table_v_gpu = None - - if cfg.is_paged: - container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, cfg.block_size) - container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, cfg.block_size) - - stream = torch.cuda.current_stream().cuda_stream - cudnn.set_stream(handle=cudnn_handle, stream=stream) - - # Forward cuDNN graph - graph = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(cfg.data_type), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - handle=cudnn_handle, - ) - - q = graph.tensor_like(q_gpu) - k = graph.tensor_like(k_gpu) if not cfg.is_paged else graph.tensor_like(container_k_gpu) - v = graph.tensor_like(v_gpu) if not cfg.is_paged else graph.tensor_like(container_v_gpu) - - page_table_k = graph.tensor_like(page_table_k_gpu) if cfg.is_paged else None - page_table_v = graph.tensor_like(page_table_v_gpu) if cfg.is_paged else None - - bias = graph.tensor_like(bias_gpu) if cfg.is_bias else None - block_mask = graph.tensor_like(block_mask_gpu) if cfg.is_block_mask else None - - seq_len_q = graph.tensor_like(seq_len_q_gpu) if cfg.is_padding else None - seq_len_kv = graph.tensor_like(seq_len_kv_gpu) if cfg.is_padding else None - - if cfg.is_dropout: - seed = graph.tensor_like(seed_gpu) - offset = graph.tensor_like(offset_gpu) - dropout_tuple = (cfg.dropout_prob, seed, offset) - - rng_dump = graph.tensor_like(rng_dump_gpu) if cfg.is_dropout else None - - q_ragged_offset = graph.tensor_like(q_ragged_offset_gpu) if cfg.is_ragged else None - k_ragged_offset = graph.tensor_like(k_ragged_offset_gpu) if cfg.is_ragged else None - v_ragged_offset = graph.tensor_like(v_ragged_offset_gpu) if cfg.is_ragged else None - o_ragged_offset = graph.tensor_like(o_ragged_offset_gpu) if cfg.is_ragged else None - - if cfg.is_ragged: - q.set_ragged_offset(q_ragged_offset) - k.set_ragged_offset(k_ragged_offset) - v.set_ragged_offset(v_ragged_offset) - - attn_scale = 0.125 - - o, stats = graph.sdpa( - name="sdpa_forward", - q=q, - k=k, - v=v, - generate_stats=not cfg.is_infer, - attn_scale=attn_scale, - bias=bias, - block_mask=block_mask, - use_alibi_mask=cfg.is_alibi, - use_padding_mask=cfg.is_padding, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, - diagonal_band_left_bound=cfg.left_bound if cfg.left_bound != INVALID_BOUND else None, - diagonal_band_right_bound=cfg.right_bound if cfg.right_bound != INVALID_BOUND else None, - diagonal_alignment=cfg.diag_align, - dropout=dropout_tuple if cfg.is_dropout else None, - rng_dump=rng_dump, - paged_attention_k_table=page_table_k, - paged_attention_v_table=page_table_v, - paged_attention_max_seq_len_kv=cfg.s_kv if cfg.is_paged else None, - implementation=cfg.implementation, - ) - - o.set_output(True).set_dim(cfg.shape_o).set_stride(cfg.stride_o) - if cfg.is_ragged: - o.set_ragged_offset(o_ragged_offset) - - if cfg.is_infer == False: - stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - try: - graph.validate() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"@@@@ Overall result: WAIVED, not supported forward graph. {e}") - pytest.skip("not supported forward graph") - except Exception as e: - print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception during forward graph validate. {e}") - pytest.fail("unexpected exception during forward graph validate", pytrace=False) - - try: - graph.build_operation_graph() - graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph.check_support() - graph.build_plans() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"@@@@ Overall result: WAIVED, not supported forward graph after validate. {e}") - pytest.skip("not supported forward graph after validate") - except Exception as e: - print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception after forward validate. {e}") - pytest.fail("unexpected exception after forward validate", pytrace=False) - - variant_pack = { - q: q_gpu, - k: k_gpu if not cfg.is_paged else container_k_gpu, - v: v_gpu if not cfg.is_paged else container_v_gpu, - bias: bias_gpu, - block_mask: block_mask_gpu if cfg.is_block_mask else None, - seq_len_q: seq_len_q_gpu, - seq_len_kv: seq_len_kv_gpu, - q_ragged_offset: q_ragged_offset_gpu if cfg.is_ragged else None, - k_ragged_offset: k_ragged_offset_gpu if cfg.is_ragged else None, - v_ragged_offset: v_ragged_offset_gpu if cfg.is_ragged else None, - o_ragged_offset: o_ragged_offset_gpu if cfg.is_ragged else None, - o: o_gpu, - stats: stats_gpu, - rng_dump: rng_dump_gpu, - page_table_k: page_table_k_gpu, - page_table_v: page_table_v_gpu - } - - if cfg.is_dropout: - variant_pack[seed] = seed_gpu - variant_pack[offset] = offset_gpu - - # Allocate workspace for the forward call. - (workspace, ws_sep, _) = alloc_tensor(graph.get_workspace_size(), torch.uint8) - - # Display available memory. - # torch.cuda.empty_cache() - # free_mem, total_mem = torch.cuda.mem_get_info() - # print(f"Free GPU memory (before forward): {free_mem / (1024**3):.4f} GB of {total_mem / (1024**3):.4f} GB") - - if request.config.getoption("--perf"): - forward_times_ms = time_execution(graph.execute, variant_pack, workspace, cudnn_handle) - print(f"@@@@ Forward graph.execute avg_time_ms={forward_times_ms.mean().item():.3f}") - profile_execution(graph.execute, variant_pack, workspace, cudnn_handle) - - # Execute forward cuDNN graph - graph.execute(variant_pack, workspace, cudnn_handle) - torch.cuda.synchronize() - - if ws_sep is not None and not torch.all(ws_sep==-1).item(): - print("@@@@ Overall result: FAILED, forward workspace overwritten outside its boundaries.") - print(ws_sep) - pytest.fail("forward workspace overwritten outside boundaries", pytrace=False) - - if not cfg.is_infer: - if cudnn_version < "8.9.6" and cfg.is_padding: - # zero out padded region of the output and stats - for i, m in enumerate(seq_len_q_gpu): - o_gpu[i, :, m:, :] = 0 - stats_gpu[i, :, m:, :] = 0 - - stream = torch.cuda.current_stream().cuda_stream #2 - cudnn.set_stream(handle=cudnn_handle, stream=stream) - sm_version = torch.cuda.get_device_capability()[0] * 10 + torch.cuda.get_device_capability()[1] - - # Backward cuDNN graph - graph = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(cfg.data_type), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - handle=cudnn_handle, - sm_version = sm_version - ) - - q = graph.tensor_like(q_gpu) - k = graph.tensor_like(k_gpu) - v = graph.tensor_like(v_gpu) - o = graph.tensor_like(o_gpu) - dO = graph.tensor_like(dO_gpu) - stats = graph.tensor_like(stats_gpu) - - bias = graph.tensor_like(bias_gpu) if cfg.is_bias else None - dBias = (graph.tensor_like(dBias_gpu).set_stride((cfg.h_q * cfg.s_q * cfg.s_kv, cfg.s_q * cfg.s_kv, cfg.s_kv, 1)) if cfg.is_bias else None) - - seq_len_q = graph.tensor_like(seq_len_q_gpu) if cfg.is_padding else None - seq_len_kv = graph.tensor_like(seq_len_kv_gpu) if cfg.is_padding else None - - if cfg.is_dropout: - seed = graph.tensor_like(seed_gpu) - offset = graph.tensor_like(offset_gpu) - dropout_tuple = (cfg.dropout_prob, seed, offset) - - q_ragged_offset = graph.tensor_like(q_ragged_offset_gpu) if cfg.is_ragged else None - k_ragged_offset = graph.tensor_like(k_ragged_offset_gpu) if cfg.is_ragged else None - v_ragged_offset = graph.tensor_like(v_ragged_offset_gpu) if cfg.is_ragged else None - o_ragged_offset = graph.tensor_like(o_ragged_offset_gpu) if cfg.is_ragged else None - - if cfg.is_ragged: - q.set_ragged_offset(q_ragged_offset) - k.set_ragged_offset(k_ragged_offset) - v.set_ragged_offset(v_ragged_offset) - o.set_ragged_offset(o_ragged_offset) - dO.set_ragged_offset(o_ragged_offset) - - dQ, dK, dV = graph.sdpa_backward( - name="sdpa_backward", - q=q, - k=k, - v=v, - o=o, - dO=dO, - stats=stats, - attn_scale=attn_scale, - bias=bias, - dBias=dBias, - use_alibi_mask=cfg.is_alibi, - use_padding_mask=cfg.is_padding, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, - max_total_seq_len_q=max_t_q, - max_total_seq_len_kv=max_t_kv, - diagonal_band_left_bound=cfg.left_bound if cfg.left_bound != INVALID_BOUND else None, - diagonal_band_right_bound=cfg.right_bound if cfg.right_bound != INVALID_BOUND else None, - diagonal_alignment=cfg.diag_align, - dropout=dropout_tuple if cfg.is_dropout else None, - use_deterministic_algorithm=cfg.is_determin, - ) - - dQ.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) - dK.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) - dV.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) - if cfg.is_ragged: - dQ.set_ragged_offset(q_ragged_offset) - dK.set_ragged_offset(k_ragged_offset) - dV.set_ragged_offset(v_ragged_offset) - - try: - graph.validate() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"@@@@ Overall result: WAIVED, not supported backward graph. {e}") - pytest.skip("not supported backward graph") - except Exception as e: - print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception during backward graph validate. {e}") - pytest.fail("unexpected exception during backward graph validate", pytrace=False) - - try: - graph.build_operation_graph() - graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph.check_support() - graph.build_plans() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"@@@@ Overall result: WAIVED, not supported backward graph after validate. {e}") - pytest.skip("not supported backward graph after validate") - except Exception as e: - print(f"@@@@ Overall result: FAILED, unexpected '{e.__class__.__name__}' exception after backward validate. {e}") - pytest.fail("unexpected exception after backward validate", pytrace=False) - - variant_pack = { - q: q_gpu, - k: k_gpu, - v: v_gpu, - o: o_gpu, - dO: dO_gpu, - stats: stats_gpu, - dQ: dQ_gpu, - dK: dK_gpu, - dV: dV_gpu, - bias: bias_gpu, - dBias: dBias_gpu, - seq_len_q: seq_len_q_gpu, - seq_len_kv: seq_len_kv_gpu, - q_ragged_offset: q_ragged_offset_gpu if cfg.is_ragged else None, - k_ragged_offset: k_ragged_offset_gpu if cfg.is_ragged else None, - v_ragged_offset: v_ragged_offset_gpu if cfg.is_ragged else None, - o_ragged_offset: o_ragged_offset_gpu if cfg.is_ragged else None, - } - - if cfg.is_dropout: - variant_pack[seed] = seed_gpu - variant_pack[offset] = offset_gpu - - # Allocate workspace for the backward call. - (workspace, ws_sep, _) = alloc_tensor(graph.get_workspace_size(), torch.uint8) - - # Display available memory. - # torch.cuda.empty_cache() - # free_mem, total_mem = torch.cuda.mem_get_info() - # print(f"Free GPU memory (before backward): {free_mem / (1024**3):.4f} GB of {total_mem / (1024**3):.4f} GB") - - if request.config.getoption("--perf"): - backward_times_ms = time_execution(graph.execute, variant_pack, workspace, cudnn_handle) - print(f"@@@@ Backward graph.execute avg_time_ms={backward_times_ms.mean().item():.3f}") - profile_execution(graph.execute, variant_pack, workspace, cudnn_handle) - - # Execute backward cuDNN graph - graph.execute(variant_pack, workspace, cudnn_handle) - torch.cuda.synchronize() - - if ws_sep is not None and not torch.all(ws_sep==-1).item(): - print("@@@@ Overall result: FAILED, backward workspace overwritten outside its boundaries.") - print(ws_sep) - pytest.fail("backward workspace overwritten outside boundaries", pytrace=False) - - bias_ref = None - rng_dump_ref = None - - if not cfg.is_infer: - # Using torch autograd reference in the backward pass. - q_ref = q_gpu.detach().float().requires_grad_() - k_ref = k_gpu.detach().float().requires_grad_() - v_ref = v_gpu.detach().float().requires_grad_() - dO_ref = dO_gpu.detach().float() - if cfg.is_ragged: - dO_ref = convert_ragged_to_uniform(dO_ref, seq_len_q_gpu.detach()) - if cfg.is_bias: - bias_ref = bias_gpu.detach().float().requires_grad_() - else: - # No autograd in the forward pass. - q_ref = q_gpu.detach().float() - k_ref = k_gpu.detach().float() - v_ref = v_gpu.detach().float() - dO_ref = None - if cfg.is_bias: - bias_ref = bias_gpu.detach().float() - - if cfg.is_ragged: - q_ref = convert_ragged_to_uniform(q_ref, seq_len_q_gpu.detach()) - k_ref = convert_ragged_to_uniform(k_ref, seq_len_kv_gpu.detach()) - v_ref = convert_ragged_to_uniform(v_ref, seq_len_kv_gpu.detach()) - - if cfg.is_padding: - seq_len_q_ref = seq_len_q_gpu.detach().flatten() - seq_len_kv_ref = seq_len_kv_gpu.detach().flatten() - - if cfg.is_dropout: - rng_dump_ref = rng_dump_gpu.detach().float() - - # Compute forward reference output. - ret = compute_ref( - q_ref, - k_ref, - v_ref, - attn_scale=attn_scale, - bias=bias_ref, - block_mask=block_mask_gpu if cfg.is_block_mask else None, - is_alibi=cfg.is_alibi, - padding=(seq_len_q_ref, seq_len_kv_ref) if cfg.is_padding else None, - left_bound=cfg.left_bound, - right_bound=cfg.right_bound, - diag_align=cfg.diag_align, - dropout_prob=cfg.dropout_prob, - dropout_mask=rng_dump_ref, - generate_stats=(cfg.is_infer == False), - ) - - if not cfg.is_infer: - o_ref, stats_ref = ret - else: - o_ref = ret - - if cfg.is_ragged: - o_gpu = convert_ragged_to_uniform(o_gpu, seq_len_q_gpu.detach()) - - err_count = 0 - - if cfg.is_padding: - # zero out padded region of the output for comparison - for i, m in enumerate(seq_len_q_ref): - o_ref[i, :, m:, :] = 0 - o_gpu[i, :, m:, :] = 0 - if cfg.is_infer == False: - if cudnn_version < "9.14.0": - stats_ref[i, :, m:, :] = 0 - stats_gpu[i, :, m:, :] = 0 - else: - stats_ref[i, :, m:, :] = -float("inf") - - diffs = int_cli_option(10, request, "--diffs") - - err_count += approx_equal(o_gpu, o_ref, o_sep, o_raw, atol=2e-2, rtol=2e-2, tag="o", disp_elems=diffs) - - if not cfg.is_infer: - err_count += approx_equal(stats_gpu, stats_ref, stats_sep, stats_raw, atol=2e-2, rtol=2e-2, tag="stats", disp_elems=diffs) - - inputs_ref = [q_ref, k_ref, v_ref] - if cfg.is_bias: - inputs_ref.append(bias_ref) - - [dQ_ref, dK_ref, dV_ref, *opt_refs] = list( - torch.autograd.grad(outputs=o_ref, inputs=inputs_ref, grad_outputs=dO_ref) - ) - - if cfg.is_bias: - dBias_ref = opt_refs.pop(0) - - if cfg.is_ragged: - dQ_gpu = convert_ragged_to_uniform(dQ_gpu, seq_len_q_gpu.detach()) - dK_gpu = convert_ragged_to_uniform(dK_gpu, seq_len_kv_gpu.detach()) - dV_gpu = convert_ragged_to_uniform(dV_gpu, seq_len_kv_gpu.detach()) - - if cfg.is_padding: - # zero out padded region of the output for comparison - for i, (m, n) in enumerate(zip(seq_len_q_ref, seq_len_kv_ref)): - dQ_ref[i, :, m:, :] = 0 - dQ_gpu[i, :, m:, :] = 0 - dK_ref[i, :, n:, :] = 0 - dK_gpu[i, :, n:, :] = 0 - dV_ref[i, :, n:, :] = 0 - dV_gpu[i, :, n:, :] = 0 - - torch.cuda.synchronize() - - err_count += approx_equal(dQ_gpu, dQ_ref, dQ_sep, dQ_raw, atol=2e-2, rtol=2e-2, tag="dQ", disp_elems=diffs) - err_count += approx_equal(dK_gpu, dK_ref, dK_sep, dK_raw, atol=2e-2 if cfg.data_type != torch.bfloat16 else 7e-2, rtol=2e-2, tag="dK", disp_elems=diffs) - err_count += approx_equal(dV_gpu, dV_ref, dV_sep, dV_raw, atol=2e-2 if cfg.data_type != torch.bfloat16 else 7e-2, rtol=2e-2, tag="dV", disp_elems=diffs) - if cfg.is_bias: - err_count += approx_equal(dBias_gpu, dBias_ref, dBias_sep, dBias_raw, atol=2e-2, rtol=2e-2, tag="dBias", disp_elems=diffs) - - if err_count != 0: - print("@@@@ Overall result: FAILED, disallowed mismatches") - pytest.fail("disallowed mismatches", pytrace=False) - else: - print("@@@@ Overall result: PASSED, everything looks good!") - - del workspace - del graph - del variant_pack - - if cfg.is_paged: - del container_k_gpu, container_v_gpu, page_table_k_gpu, page_table_v_gpu - if cfg.is_ragged: - del q_ragged_offset_gpu, k_ragged_offset_gpu, v_ragged_offset_gpu, o_ragged_offset_gpu - if cfg.is_dropout: - del seed_gpu, offset_gpu - del rng_dump_gpu - del rng_dump_ref - if cfg.is_padding: - del seq_len_q_gpu, seq_len_kv_gpu - del seq_len_q_ref, seq_len_kv_ref - - del q_gpu, k_gpu, v_gpu, o_gpu - if cfg.is_bias: - del bias_gpu - if not cfg.is_infer: - del dQ_gpu, dK_gpu, dV_gpu, dO_gpu, stats_gpu - if cfg.is_bias: - del dBias_gpu - - del q_ref, k_ref, v_ref, dO_ref, o_ref, stats_ref - if cfg.is_bias: - del dBias_ref, bias_ref - del dQ_ref, dK_ref, dV_ref - else: - del q_ref, k_ref, v_ref, o_ref - if cfg.is_bias: - del bias_ref - - del o_sep, o_raw - if not cfg.is_infer: - del dQ_sep, dQ_raw, dK_sep, dK_raw, dV_sep, dV_raw - del stats_sep, stats_raw - - torch.cuda.empty_cache() + self.cfg = ExecConfig() + + + def showConfig(self, test_no, request): + is_dryrun = request.config.option.dryrun + print() + print_section_begin("DRY-RUN" if is_dryrun else "") + print(f"#### Test #{test_no[0]} of {test_no[1]} at", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "\n") + print(f"test_name = {request.node.name}") + print(f"platform_info = {self.gpu_arch} ({self.gpu_info}), cudnn_ver={self.cudnn_ver}") + print() + print(self.cfg.to_repro_cmd(request.module.__file__)) + print(flush=True) + @pytest.fixture(scope="package") def env_info(request): @@ -1185,11 +79,8 @@ def env_info(request): gpu_arch = f"SM_{gpu_type[0]}{gpu_type[1]}" gpu_info = f"{sm_count} SM-s, {gpu_name}" cudnn_ver = str(torch.backends.cudnn.version()) - blocked_file = str(request.path) - blocked_file = blocked_file[:-3] + ".block" - blocked_tests = fetch_blocked_tests(blocked_file, gpu_arch, cudnn_ver) - show_blocked_tests(blocked_tests, gpu_arch, cudnn_ver) + blocked_tests = fetch_blocked_tests(gpu_arch, cudnn_ver) return {"gpu_arch": gpu_arch, "gpu_info": gpu_info, "cudnn_ver": cudnn_ver, "blocked_tests": blocked_tests} @@ -1202,13 +93,11 @@ def env_info(request): # # ================================== # # L0 fprop tests # # ================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_fwd_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1227,20 +116,18 @@ def test_sdpa_random_fwd_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_fwd_unified_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1259,10 +146,10 @@ def test_sdpa_random_fwd_unified_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) - test.cfg.implementation = implementation_cli_option(cudnn.attention_implementation.UNIFIED, request, "--implementation") + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.cfg.implementation = getattr(cudnn.attention_implementation, request.config.getoption("--implementation") or "", cudnn.attention_implementation.UNIFIED) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) @@ -1271,13 +158,11 @@ def test_sdpa_random_fwd_unified_L0(env_info, test_no, request, cudnn_handle): # # L0 bprop tests # # ================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=844), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=256, rng_seed=844), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_bwd_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1286,21 +171,21 @@ def test_sdpa_random_bwd_L0(env_info, test_no, request, cudnn_handle): # Create the randomization context within the test with RandomizationContext( - batches=RandomBatchSize(min=1, max=8, with_high_probability=[1,4]), + batches=RandomBatchSize(min=8, max=16), s_q_s_kv = RandomSequenceLength(s_q_min=1, s_q_max=1024, s_kv_min=1, s_kv_max=1024, s_q_distribution={"s_q=1":0, "s_q=s_kv":5, "s_q=random":10}), - d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=128, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":1, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=192, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":5, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), head_count=RandomHeadGenerator(min=1, max=8, head_group_options=(1, 4, 1)), data_type=RandomChoice({torch.float16 : 1, torch.bfloat16 : 2}), with_sliding_mask=SlidingWindowMaskGenerator(causal=10, left_window_only=5, right_window_only=5, band_around_diag=10, no_mask=10), diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT : 1, cudnn.diagonal_alignment.BOTTOM_RIGHT : 1}), - is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 4, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), - is_deterministic=RandomChoice({True : 1, False : 1}), + is_deterministic=RandomChoice({True : 3, False : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_infer = False - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) @@ -1309,13 +194,11 @@ def test_sdpa_random_bwd_L0(env_info, test_no, request, cudnn_handle): # # L0 fprop tests with s_q=1 # # ================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=111), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=111), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_sq1_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1334,20 +217,18 @@ def test_sdpa_random_sq1_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 0, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=32, rng_seed=111), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=111), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_sq1_unified_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1366,10 +247,10 @@ def test_sdpa_random_sq1_unified_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 0, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) - test.cfg.implementation = implementation_cli_option(cudnn.attention_implementation.UNIFIED, request, "--implementation") + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.cfg.implementation = getattr(cudnn.attention_implementation, request.config.getoption("--implementation") or "", cudnn.attention_implementation.UNIFIED) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) @@ -1378,13 +259,11 @@ def test_sdpa_random_sq1_unified_L0(env_info, test_no, request, cudnn_handle): # # L0 lean attention, s_kv=513..2048 # # ===================================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=222), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=222), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_lean_attn_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1403,20 +282,18 @@ def test_sdpa_random_lean_attn_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=222), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=222), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_lean_attn_unified_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1435,10 +312,10 @@ def test_sdpa_random_lean_attn_unified_L0(env_info, test_no, request, cudnn_hand is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) - test.cfg.implementation = implementation_cli_option(cudnn.attention_implementation.UNIFIED, request, "--implementation") + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.cfg.implementation = getattr(cudnn.attention_implementation, request.config.getoption("--implementation") or "", cudnn.attention_implementation.UNIFIED) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) @@ -1446,13 +323,11 @@ def test_sdpa_random_lean_attn_unified_L0(env_info, test_no, request, cudnn_hand # # L0 ragged tests # # ================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_fwd_ragged_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1471,20 +346,18 @@ def test_sdpa_random_fwd_ragged_L0(env_info, test_no, request, cudnn_handle): is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 1, "padded" : 0, "full" : 0}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_fwd_ragged_unified_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1503,21 +376,19 @@ def test_sdpa_random_fwd_ragged_unified_L0(env_info, test_no, request, cudnn_han is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 1, "padded" : 0, "full" : 0}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) - test.cfg.implementation = implementation_cli_option(cudnn.attention_implementation.UNIFIED, request, "--implementation") + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.cfg.implementation = getattr(cudnn.attention_implementation, request.config.getoption("--implementation") or "", cudnn.attention_implementation.UNIFIED) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=256, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_bwd_ragged_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1526,21 +397,24 @@ def test_sdpa_random_bwd_ragged_L0(env_info, test_no, request, cudnn_handle): # Create the randomization context within the test with RandomizationContext( - batches=RandomBatchSize(min=1, max=8, with_high_probability=[1,4]), + batches=RandomBatchSize(min=8, max=16), s_q_s_kv = RandomSequenceLength(s_q_min=1, s_q_max=1024, s_kv_min=1, s_kv_max=1024, s_q_distribution={"s_q=1":0, "s_q=s_kv":5, "s_q=random":10}), - d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=128, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":1, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=192, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":5, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), head_count=RandomHeadGenerator(min=1, max=8, head_group_options=(1, 4, 1)), data_type=RandomChoice({torch.float16 : 1, torch.bfloat16 : 2}), with_sliding_mask=SlidingWindowMaskGenerator(causal=10, left_window_only=5, right_window_only=5, band_around_diag=10, no_mask=10), diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT : 1, cudnn.diagonal_alignment.BOTTOM_RIGHT : 1}), is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 1, "padded" : 0, "full" : 0}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), + is_deterministic=RandomChoice({True : 3, False : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_infer = False - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) + if request.node.name in test.blocked_tests: + pytest.skip(f"blocked test: {request.node.name}") exec_sdpa(test.cfg, request, cudnn_handle) @@ -1548,13 +422,11 @@ def test_sdpa_random_bwd_ragged_L0(env_info, test_no, request, cudnn_handle): # # L0 paged tests # # ================================== -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_fwd_paged_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1574,22 +446,20 @@ def test_sdpa_fwd_paged_L0(env_info, test_no, request, cudnn_handle): stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), block_size=RandomBlockSize(min=1, max=1024, with_high_probability=[1,32,128]), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_paged = True test.cfg.implementation=cudnn.attention_implementation.COMPOSITE # FIXNOW - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_fwd_paged_unified_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1609,21 +479,23 @@ def test_sdpa_fwd_paged_unified_L0(env_info, test_no, request, cudnn_handle): stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), block_size=RandomBlockSize(min=1, max=1024, with_high_probability=[1,32,128]), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_paged = True - test.cfg.implementation = implementation_cli_option(cudnn.attention_implementation.UNIFIED, request, "--implementation") + test.cfg.implementation = getattr(cudnn.attention_implementation, request.config.getoption("--implementation") or "", cudnn.attention_implementation.UNIFIED) - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) -@pytest.mark.parametrize("test_no", tlist(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") +# # ================================== +# # L0 fprop block mask tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") @pytest.mark.L0 def test_sdpa_random_fwd_unified_block_mask_L0(env_info, test_no, request, cudnn_handle): - test = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - - print(f"test: {test} hash {abs(hash(test_no))}") + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) geom_seed = abs(hash(test_no)) data_seed = test_no[2] @@ -1642,14 +514,193 @@ def test_sdpa_random_fwd_unified_block_mask_L0(env_info, test_no, request, cudnn is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 0, "full" : 1}), stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), ) as randomization_ctx: - test.cfg = randomization_ctx(rng, data_seed) + test.cfg = randomization_ctx(rng, data_seed, geom_seed) test.cfg.is_block_mask = True test.cfg.implementation = cudnn.attention_implementation.UNIFIED - test.showConfig(test_no, request, reg_run=True) + test.showConfig(test_no, request) + + exec_sdpa(test.cfg, request, cudnn_handle) + +# # ================================== +# # L0 fprop bias tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.L0 +def test_sdpa_random_fwd_bias_L0(env_info, test_no, request, cudnn_handle): + + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + + geom_seed = abs(hash(test_no)) + data_seed = test_no[2] + + rng = random.Random(geom_seed) + + # Create the randomization context within the test + with RandomizationContext( + batches=RandomBatchSize(min=1, max=8, with_high_probability=[1,4]), + s_q_s_kv = RandomSequenceLength(s_q_min=1, s_q_max=1024, s_kv_min=1, s_kv_max=1024, s_q_distribution={"s_q=1":0, "s_q=s_kv":5, "s_q=random":10}), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=128, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":1, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), + head_count=RandomHeadGenerator(min=1, max=8, head_group_options=(1, 4, 1)), + data_type=RandomChoice({torch.float16 : 1, torch.bfloat16 : 2}), + with_sliding_mask=SlidingWindowMaskGenerator(no_mask=1), + diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT : 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 1, "full" : 1}), + stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), + is_bias=RandomChoice({True : 1}), + ) as randomization_ctx: + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + + test.showConfig(test_no, request) exec_sdpa(test.cfg, request, cudnn_handle) +# # ================================== +# # L0 bprop bias tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=888), ids=lambda p: f"test{p[0]}") +@pytest.mark.L0 +def test_sdpa_random_bwd_bias_L0(env_info, test_no, request, cudnn_handle): + + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + + geom_seed = abs(hash(test_no)) + data_seed = test_no[2] + + rng = random.Random(geom_seed) + + # Create the randomization context within the test + with RandomizationContext( + batches=RandomBatchSize(min=8, max=16), + s_q_s_kv = RandomSequenceLength(s_q_min=1, s_q_max=1024, s_kv_min=1, s_kv_max=1024, s_q_distribution={"s_q=1":0, "s_q=s_kv":5, "s_q=random":10}), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=1, d_qk_max=192, d_v_min=1, d_v_max=128, head_dim_distribution={"d_qk=d_v":5, "d_qk=random":1}, with_high_probability=[(64,64), (128,128), (192,128)]), + head_count=RandomHeadGenerator(min=1, max=8, head_group_options=(1, 4, 1)), + data_type=RandomChoice({torch.float16 : 1, torch.bfloat16 : 2}), + with_sliding_mask=SlidingWindowMaskGenerator(no_mask=10), + diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT : 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged" : 0, "padded" : 4, "full" : 1}), + stats_layout=RandomChoice({"ragged" : 0, "full" : 0, "disabled" : 1}), + is_bias=RandomChoice({True : 1}), + ) as randomization_ctx: + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + + test.cfg.is_infer = False + test.showConfig(test_no, request) + + exec_sdpa(test.cfg, request, cudnn_handle) + + +# # ================================== +# # L0 FP8 fprop tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=128, rng_seed=999), ids=lambda p: f"test{p[0]}") +@pytest.mark.L0 +def test_sdpa_fp8_fwd_L0(env_info, test_no, request, cudnn_handle): + + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + + geom_seed = abs(hash(test_no)) + data_seed = test_no[2] + + rng = random.Random(geom_seed) + + with RandomizationContext( + batches=RandomBatchSize(min=1, max=4, with_high_probability=[1, 2]), + s_q_s_kv=RandomSequenceLength(s_q_min=1, s_q_max=256, s_kv_min=64, s_kv_max=1024, s_q_distribution={"s_q=1": 3, "s_q=s_kv": 5, "s_q=random": 2}), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=64, d_qk_max=192, d_v_min=64, d_v_max=128, head_dim_distribution={"d_qk=d_v": 2, "d_qk=random": 1}, with_high_probability=[(64, 64), (128, 128), (192, 128)]), + head_count=RandomHeadGenerator(min=1, max=16, head_group_options=(1, 5, 2)), + data_type=RandomChoice({torch.float8_e4m3fn: 2, torch.float8_e5m2: 1}), + output_type=RandomChoice({torch.float8_e4m3fn: 1, torch.float8_e5m2: 1, torch.float16: 2}), + with_sliding_mask=SlidingWindowMaskGenerator(no_mask=10), + diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT: 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged": 0, "padded": 0, "full": 1}), + stats_layout=RandomChoice({"disabled": 1}), + ) as randomization_ctx: + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.showConfig(test_no, request) + + if request.node.name in test.blocked_tests: + pytest.skip(f"blocked test: {request.node.name}") + exec_sdpa_fp8(test.cfg, request, cudnn_handle) + + +# # ================================== +# # L0 FP8 bprop tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=64, rng_seed=998), ids=lambda p: f"test{p[0]}") +@pytest.mark.L0 +def test_sdpa_fp8_bwd_L0(env_info, test_no, request, cudnn_handle): + + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + + geom_seed = abs(hash(test_no)) + data_seed = test_no[2] + + rng = random.Random(geom_seed) + + with RandomizationContext( + batches=RandomBatchSize(min=1, max=4, with_high_probability=[1, 2]), + s_q_s_kv=RandomSequenceLength(s_q_min=64, s_q_max=256, s_kv_min=64, s_kv_max=256, s_q_distribution={"s_q=1": 0, "s_q=s_kv": 5, "s_q=random": 5}), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=64, d_qk_max=128, d_v_min=64, d_v_max=128, head_dim_distribution={"d_qk=d_v": 1, "d_qk=random": 0}, with_high_probability=[(64, 64), (128, 128)]), + head_count=RandomHeadGenerator(min=1, max=8, head_group_options=(1, 4, 1)), + data_type=RandomChoice({torch.float8_e4m3fn: 1}), + output_type=RandomChoice({torch.float8_e4m3fn: 1, torch.float16: 1}), + with_sliding_mask=SlidingWindowMaskGenerator(no_mask=10), + diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT: 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged": 0, "padded": 0, "full": 1}), + stats_layout=RandomChoice({"disabled": 1}), + is_deterministic=RandomChoice({True: 1, False: 1}), + ) as randomization_ctx: + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + + test.cfg.is_infer = False + test.showConfig(test_no, request) + + if request.node.name in test.blocked_tests: + pytest.skip(f"blocked test: {request.node.name}") + exec_sdpa_fp8(test.cfg, request, cudnn_handle) + + +# # ================================== +# # L0 FP8 paged attention tests +# # ================================== + +@pytest.mark.parametrize("test_no", generate_test_seeds(num_tests=32, rng_seed=997), ids=lambda p: f"test{p[0]}") +@pytest.mark.L0 +def test_sdpa_fp8_fwd_paged_L0(env_info, test_no, request, cudnn_handle): + + test = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + + geom_seed = abs(hash(test_no)) + data_seed = test_no[2] + + rng = random.Random(geom_seed) + + with RandomizationContext( + batches=RandomBatchSize(min=1, max=4, with_high_probability=[1, 2]), + s_q_s_kv=RandomSequenceLength(s_q_min=64, s_q_max=256, s_kv_min=64, s_kv_max=512, s_q_distribution={"s_q=1": 0, "s_q=s_kv": 5, "s_q=random": 5}), + d_qk_d_v=RandomHiddenDimSize(d_qk_min=64, d_qk_max=128, d_v_min=64, d_v_max=128, head_dim_distribution={"d_qk=d_v": 1, "d_qk=random": 0}, with_high_probability=[(64, 64), (128, 128)]), + head_count=RandomHeadGenerator(min=1, max=4, head_group_options=(1, 2, 0)), + data_type=RandomChoice({torch.float8_e4m3fn: 2, torch.float8_e5m2: 1}), + output_type=RandomChoice({torch.float8_e4m3fn: 1, torch.float8_e5m2: 1, torch.float16: 1}), + with_sliding_mask=SlidingWindowMaskGenerator(no_mask=10), + diag_align=RandomChoice({cudnn.diagonal_alignment.TOP_LEFT: 1}), + is_q_ragged_or_padded_or_full=RandomChoice({"ragged": 0, "padded": 1, "full": 0}), + stats_layout=RandomChoice({"disabled": 1}), + block_size=RandomBlockSize(min=16, max=128, with_high_probability=[16, 32, 64]), + ) as randomization_ctx: + test.cfg = randomization_ctx(rng, data_seed, geom_seed) + test.cfg.is_paged = True + test.showConfig(test_no, request) + + if request.node.name in test.blocked_tests: + pytest.skip(f"blocked test: {request.node.name}") + exec_sdpa_fp8(test.cfg, request, cudnn_handle) + # # =================== # # Single repro test @@ -1662,30 +713,9 @@ def test_sdpa_random_fwd_unified_block_mask_L0(env_info, test_no, request, cudnn @pytest.mark.L3 @pytest.mark.L4 def test_repro(env_info, request, cudnn_handle): - repro_str = request.config.getoption("--repro") - cfg = testConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) - print(f"repro_str: {repro_str}") - - # Parse the dictionary string and reconstruct the exec_cfg object import ast - repro_dict = ast.literal_eval(repro_str) - - # Convert integer enum values back to enum objects - if 'diag_align' in repro_dict and repro_dict['diag_align'] is not None: - repro_dict['diag_align'] = cudnn.diagonal_alignment(repro_dict['diag_align']) - if 'implementation' in repro_dict and repro_dict['implementation'] is not None: - repro_dict['implementation'] = getattr(cudnn.attention_implementation, repro_dict['implementation']) - # Convert string dtype back to torch dtype - if 'data_type' in repro_dict and repro_dict['data_type'] is not None: - if 'torch.float16' in repro_dict['data_type']: - repro_dict['data_type'] = torch.float16 - elif 'torch.bfloat16' in repro_dict['data_type']: - repro_dict['data_type'] = torch.bfloat16 - elif 'torch.float32' in repro_dict['data_type']: - repro_dict['data_type'] = torch.float32 - - cfg.cfg = exec_cfg(**repro_dict) - print(f"cfg.cfg: {cfg.cfg}") - - cfg.showConfig((1,1), request, False) + repro_str = request.config.getoption("--repro") + cfg = SDPATestConfig(**env_info, implementation=cudnn.attention_implementation.AUTO) + cfg.cfg = ExecConfig.deserialize(ast.literal_eval(repro_str)) + cfg.showConfig((1,1), request) exec_sdpa(cfg.cfg, request, cudnn_handle) diff --git a/test/python/test_rmsnorm.py b/test/python/test_rmsnorm.py index b87195d8..43d1997f 100644 --- a/test/python/test_rmsnorm.py +++ b/test/python/test_rmsnorm.py @@ -21,9 +21,7 @@ def __init__(self, dim: int = -1, eps: float = 1e-5) -> None: self.eps = eps self.dim = dim - def forward( - self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor: # NOTE: the original RMSNorm paper implementation is not equivalent norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) inv_var = torch.rsqrt(norm_x.float() + self.eps) @@ -38,12 +36,7 @@ def forward( input_type_options = [torch.float16, torch.bfloat16] bias_options = [True, False] -all_options = [ - elem - for elem in itertools.product( - *[embedding_dim_options, input_type_options, bias_options] - ) -] +all_options = [elem for elem in itertools.product(*[embedding_dim_options, input_type_options, bias_options])] @pytest.fixture(params=all_options) @@ -66,17 +59,9 @@ def test_rmsnorm(param_extract, cudnn_handle): epsilon_value = 1e-3 - x_gpu = ( - 2 * torch.randn(N, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - - 1.25 - ) - scale_gpu = ( - 3 * torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - - 2.75 - ) - bias_gpu = torch.randn( - 1, C, H, W, requires_grad=True, device="cuda", dtype=input_type - ) + x_gpu = 2 * torch.randn(N, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - 1.25 + scale_gpu = 3 * torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - 2.75 + bias_gpu = torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type) epsilon_cpu = torch.full( (1, 1, 1, 1), epsilon_value, @@ -88,9 +73,7 @@ def test_rmsnorm(param_extract, cudnn_handle): print("Running reference") model = RMSNorm(eps=epsilon_value, dim=(1, 2, 3)).float() - Y_expected, inv_var_expected = model( - x_gpu, scale_gpu, bias_gpu if has_bias else None - ) + Y_expected, inv_var_expected = model(x_gpu, scale_gpu, bias_gpu if has_bias else None) print("Building cudnn graph") @@ -135,9 +118,7 @@ def test_rmsnorm(param_extract, cudnn_handle): Y_actual = torch.empty_like(x_gpu) inv_var_actual = torch.empty_like(inv_var_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) print("Executing cudnn graph") graph.execute( @@ -208,9 +189,7 @@ def test_rmsnorm(param_extract, cudnn_handle): DScale_actual = torch.empty_like(scale_gpu) Dbias_actual = torch.empty_like(bias_gpu) - workspace = torch.empty( - bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) print("Executing cudnn bwd_graph") bwd_graph.execute( diff --git a/test/python/test_sdpa_chunked_prefill.py b/test/python/test_sdpa_chunked_prefill.py new file mode 100644 index 00000000..492906e8 --- /dev/null +++ b/test/python/test_sdpa_chunked_prefill.py @@ -0,0 +1,619 @@ +""" +Test for SDPA with chunked prefill using THD (Token-Head-Dimension) layout. + +Chunked prefill processes long sequences by splitting them into smaller chunks +to reduce memory usage. For a sequence of 4096 tokens with chunk_size=1024: +- Chunk 0: Q[0:1024] attends to K[0:1024], V[0:1024] +- Chunk 1: Q[1024:2048] attends to K[0:2048], V[0:2048] +- Chunk 2: Q[2048:3072] attends to K[0:3072], V[0:3072] +- Chunk 3: Q[3072:4096] attends to K[0:4096], V[0:4096] + +THD layout is a ragged/packed format where: +- Q: [chunk_tokens, num_heads, head_dim] - packed Q tensor for current chunk +- K/V: BHSD format [batch, heads, accumulated_seq_len, head_dim] +- O: [chunk_tokens, num_heads, head_dim] - packed output tensor + +The recommended way to run tests: +> pytest -vv -s -rA test_sdpa_chunked_prefill.py +""" + +import cudnn +import pytest +import torch +import math +from looseversion import LooseVersion +from dataclasses import dataclass +from typing import List, Optional, Tuple +from test_utils import torch_fork_set_rng +from enum import Enum, auto + + +class UIDs(Enum): + Q_UID = auto() + K_UID = auto() + V_UID = auto() + O_UID = auto() + RAGGED_Q_UID = auto() + RAGGED_O_UID = auto() + ACTUAL_SEQ_LENS_Q_UID = auto() + ACTUAL_SEQ_LENS_KV_UID = auto() + + +@dataclass +class ChunkedPrefillConfig: + batch_size: int + num_heads_q: int + num_heads_k: int + num_heads_v: int + head_dim_qk: int + head_dim_v: int + total_seq_len: int + chunk_size: int + dtype: torch.dtype = torch.bfloat16 + is_causal: bool = False + attn_scale: Optional[float] = None + + def __post_init__(self): + if self.attn_scale is None: + self.attn_scale = 1.0 / math.sqrt(self.head_dim_qk) + assert self.total_seq_len % self.chunk_size == 0 + + @property + def num_chunks(self) -> int: + return self.total_seq_len // self.chunk_size + + +def convert_to_cudnn_type(torch_type): + type_map = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + torch.float32: cudnn.data_type.FLOAT, + torch.int32: cudnn.data_type.INT32, + torch.int64: cudnn.data_type.INT64, + } + return type_map[torch_type] + + +def compute_ragged_offsets(seq_lens, num_heads, head_dim): + batch_size = seq_lens.shape[0] + elements_per_batch = seq_lens * num_heads * head_dim + ragged_offset = torch.zeros(batch_size + 1, dtype=torch.int64, device=seq_lens.device) + ragged_offset[1:] = torch.cumsum(elements_per_batch, dim=0) + return ragged_offset.view(-1, 1, 1, 1) + + +def create_thd_tensor(seq_lens, num_heads, head_dim, dtype, rng, mean=0.0, std=1.0): + total_tokens = int(seq_lens.sum().item()) + tensor = torch.empty(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + tensor.normal_(mean=mean, std=std, generator=rng) + ragged_offset = compute_ragged_offsets(seq_lens, num_heads, head_dim) + return tensor, ragged_offset + + +def create_bhsd_tensor(batch_size, num_heads, max_seq_len, head_dim, dtype, rng, mean=0.0, std=1.0): + total_elements = batch_size * max_seq_len * num_heads * head_dim + storage = torch.empty(total_elements, dtype=dtype, device="cuda") + storage.normal_(mean=mean, std=std, generator=rng) + strides = (max_seq_len * num_heads * head_dim, head_dim, num_heads * head_dim, 1) + return torch.as_strided(storage, (batch_size, num_heads, max_seq_len, head_dim), strides) + + +def thd_to_bhsd(thd_tensor, seq_lens, max_seq_len): + batch_size = seq_lens.shape[0] + _, num_heads, head_dim = thd_tensor.shape + storage = torch.zeros(batch_size, max_seq_len, num_heads, head_dim, dtype=thd_tensor.dtype, device=thd_tensor.device) + offset = 0 + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + storage[i, :seq_len, :, :] = thd_tensor[offset : offset + seq_len] + offset += seq_len + return storage.permute(0, 2, 1, 3) + + +def bhsd_to_thd(bhsd_tensor, seq_lens): + batch_size = seq_lens.shape[0] + _, num_heads, _, head_dim = bhsd_tensor.shape + total_tokens = int(seq_lens.sum().item()) + thd_tensor = torch.empty(total_tokens, num_heads, head_dim, dtype=bhsd_tensor.dtype, device=bhsd_tensor.device) + bshd_tensor = bhsd_tensor.permute(0, 2, 1, 3) + offset = 0 + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + thd_tensor[offset : offset + seq_len] = bshd_tensor[i, :seq_len, :, :] + offset += seq_len + return thd_tensor + + +def compute_sdpa_reference_with_offset(q_bhsd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, attn_scale, is_causal=False, causal_offset=0): + batch_size, num_heads_q, max_seq_q, head_dim_qk = q_bhsd.shape + _, num_heads_k, max_seq_kv, _ = k_bhsd.shape + _, num_heads_v, _, head_dim_v = v_bhsd.shape + + q = q_bhsd.to(dtype=torch.float32) + k = k_bhsd.to(dtype=torch.float32) + v = v_bhsd.to(dtype=torch.float32) + + if num_heads_q != num_heads_k: + k = k.unsqueeze(2).expand(-1, -1, num_heads_q // num_heads_k, -1, -1).reshape(batch_size, num_heads_q, max_seq_kv, head_dim_qk) + if num_heads_q != num_heads_v: + v = v.unsqueeze(2).expand(-1, -1, num_heads_q // num_heads_v, -1, -1).reshape(batch_size, num_heads_q, max_seq_kv, head_dim_v) + + scores = torch.einsum("bhqd,bhkd->bhqk", q, k) * attn_scale + + device = q.device + q_mask = torch.zeros(batch_size, 1, max_seq_q, 1, dtype=torch.bool, device=device) + kv_mask = torch.zeros(batch_size, 1, 1, max_seq_kv, dtype=torch.bool, device=device) + for i in range(batch_size): + q_mask[i, :, seq_len_q[i] :, :] = True + kv_mask[i, :, :, seq_len_kv[i] :] = True + + scores = scores.masked_fill(kv_mask, float("-inf")) + + if is_causal: + # For chunked prefill, Q position q (in chunk) corresponds to global position (causal_offset + q) + # Q can attend to K[0:causal_offset+q+1], so mask K[k] when k > causal_offset + q + # This means k - q > causal_offset, or k - q >= causal_offset + 1 + # triu_(diagonal=d) sets True where (col - row) >= d + causal_mask = torch.ones(max_seq_q, max_seq_kv, dtype=torch.bool, device=device) + causal_mask.triu_(diagonal=1 + causal_offset) + scores = scores.masked_fill(causal_mask, float("-inf")) + + attn_weights = torch.softmax(scores, dim=-1) + attn_weights = attn_weights.masked_fill(q_mask, 0.0) + output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, v) + + output_mask = torch.zeros(batch_size, 1, max_seq_q, 1, dtype=torch.bool, device=device) + for i in range(batch_size): + output_mask[i, :, seq_len_q[i] :, :] = True + return output.masked_fill(output_mask, 0.0) + + +def compute_chunked_prefill_reference(q_bhsd, k_bhsd, v_bhsd, config, attn_scale): + batch_size, num_heads_q, total_seq, _ = q_bhsd.shape + head_dim_v = v_bhsd.shape[3] + chunk_size = config.chunk_size + num_chunks = config.num_chunks + device = q_bhsd.device + + output = torch.zeros(batch_size, num_heads_q, total_seq, head_dim_v, dtype=torch.float32, device=device) + + for chunk_idx in range(num_chunks): + q_start = chunk_idx * chunk_size + q_end = q_start + chunk_size + kv_end = q_end + + q_chunk = q_bhsd[:, :, q_start:q_end, :] + k_chunk = k_bhsd[:, :, :kv_end, :] + v_chunk = v_bhsd[:, :, :kv_end, :] + + seq_len_q = torch.full((batch_size,), chunk_size, dtype=torch.int32, device=device) + seq_len_kv = torch.full((batch_size,), kv_end, dtype=torch.int32, device=device) + + o_chunk = compute_sdpa_reference_with_offset( + q_chunk, k_chunk, v_chunk, seq_len_q, seq_len_kv, attn_scale, is_causal=config.is_causal, causal_offset=q_start + ) + output[:, :, q_start:q_end, :] = o_chunk + + return output + + +graph_cache = {} + + +def build_cudnn_sdpa_chunk_graph(cudnn_handle, batch_size, h_q, h_k, h_v, d_qk, d_v, chunk_size, kv_seq_len, dtype, attn_scale, is_causal, causal_offset=0): + cudnn_dtype = convert_to_cudnn_type(dtype) + cache_key = (batch_size, h_q, h_k, h_v, d_qk, d_v, chunk_size, kv_seq_len, is_causal, causal_offset) + + if cache_key in graph_cache: + return graph_cache[cache_key] + + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + is_dynamic_shape_enabled=True, + ) + + q = graph.tensor(dim=(batch_size, h_q, chunk_size, d_qk), stride=(h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_dtype, name="Q", uid=UIDs.Q_UID.value) + q_ragged = graph.tensor( + dim=(batch_size + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64, name="Q_ragged_offset", uid=UIDs.RAGGED_Q_UID.value + ) + q.set_ragged_offset(q_ragged) + + k = graph.tensor( + dim=(batch_size, h_k, kv_seq_len, d_qk), stride=(h_k * kv_seq_len * d_qk, d_qk, h_k * d_qk, 1), data_type=cudnn_dtype, name="K", uid=UIDs.K_UID.value + ) + v = graph.tensor( + dim=(batch_size, h_v, kv_seq_len, d_v), stride=(h_v * kv_seq_len * d_v, d_v, h_v * d_v, 1), data_type=cudnn_dtype, name="V", uid=UIDs.V_UID.value + ) + + seq_len_q_tensor = graph.tensor( + dim=(batch_size, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32, name="seq_len_q", uid=UIDs.ACTUAL_SEQ_LENS_Q_UID.value + ) + seq_len_kv_tensor = graph.tensor( + dim=(batch_size, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32, name="seq_len_kv", uid=UIDs.ACTUAL_SEQ_LENS_KV_UID.value + ) + + # For chunked prefill with causal masking, use diagonal_band_right_bound to shift the causal diagonal + # right_bound = causal_offset means Q[i] can attend to K[0:causal_offset+i+1] + # This correctly handles chunks where Q positions represent later positions in the full sequence + o, stats = graph.sdpa( + name="sdpa_chunk", + q=q, + k=k, + v=v, + attn_scale=attn_scale, + use_padding_mask=True, + seq_len_q=seq_len_q_tensor, + seq_len_kv=seq_len_kv_tensor, + diagonal_band_right_bound=causal_offset if is_causal else None, + generate_stats=False, + ) + + o.set_output(True).set_dim((batch_size, h_q, chunk_size, d_v)).set_stride((h_q * d_v, d_v, h_q * d_v, 1)).set_data_type(cudnn_dtype) + o.set_uid(UIDs.O_UID.value) + + o_ragged = graph.tensor( + dim=(batch_size + 1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT64, name="O_ragged_offset", uid=UIDs.RAGGED_O_UID.value + ) + o.set_ragged_offset(o_ragged) + + try: + graph.validate() + except cudnn.cudnnGraphNotSupportedError as e: + pytest.skip(f"Graph not supported: {e}") + + try: + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + except cudnn.cudnnGraphNotSupportedError as e: + pytest.skip(f"Graph not supported after validation: {e}") + + graph_cache[cache_key] = graph + return graph + + +def execute_cudnn_sdpa_chunk( + cudnn_handle, graph, q_chunk, k_chunk, v_chunk, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset, dtype, h_q, d_qk, d_v, batch_size, chunk_size +): + total_q_tokens = q_chunk.shape[0] + o_chunk = torch.empty(total_q_tokens, h_q, d_v, dtype=dtype, device="cuda") + + seq_len_q_4d = seq_len_q.view(-1, 1, 1, 1) + seq_len_kv_4d = seq_len_kv.view(-1, 1, 1, 1) + + variant_pack = { + UIDs.Q_UID.value: q_chunk, + UIDs.RAGGED_Q_UID.value: q_ragged_offset, + UIDs.K_UID.value: k_chunk, + UIDs.V_UID.value: v_chunk, + UIDs.ACTUAL_SEQ_LENS_Q_UID.value: seq_len_q_4d, + UIDs.ACTUAL_SEQ_LENS_KV_UID.value: seq_len_kv_4d, + UIDs.O_UID.value: o_chunk, + UIDs.RAGGED_O_UID.value: o_ragged_offset, + } + + workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") + cudnn.set_stream(handle=cudnn_handle, stream=torch.cuda.current_stream().cuda_stream) + + q_chunk_shape = (batch_size, h_q, chunk_size, d_qk) + o_chunk_shape = (batch_size, h_q, chunk_size, d_v) + + override_uids = [ + UIDs.Q_UID.value, + UIDs.RAGGED_Q_UID.value, + UIDs.K_UID.value, + UIDs.V_UID.value, + UIDs.ACTUAL_SEQ_LENS_Q_UID.value, + UIDs.ACTUAL_SEQ_LENS_KV_UID.value, + UIDs.O_UID.value, + UIDs.RAGGED_O_UID.value, + ] + override_shapes = [ + q_chunk_shape, + q_ragged_offset.shape, + k_chunk.shape, + v_chunk.shape, + seq_len_q_4d.shape, + seq_len_kv_4d.shape, + o_chunk_shape, + o_ragged_offset.shape, + ] + override_strides = [ + q_chunk.stride(), + q_ragged_offset.stride(), + k_chunk.stride(), + v_chunk.stride(), + seq_len_q_4d.stride(), + seq_len_kv_4d.stride(), + o_chunk.stride(), + o_ragged_offset.stride(), + ] + + graph.execute(variant_pack, workspace, handle=cudnn_handle, override_uids=override_uids, override_shapes=override_shapes, override_strides=override_strides) + torch.cuda.synchronize() + return o_chunk + + +def create_bhsd_view(tensor, batch_size, num_heads, seq_len, head_dim): + if tensor.shape == (batch_size, num_heads, seq_len, head_dim): + bshd = tensor.permute(0, 2, 1, 3).contiguous() + strides = (seq_len * num_heads * head_dim, head_dim, num_heads * head_dim, 1) + return torch.as_strided(bshd.view(-1), (batch_size, num_heads, seq_len, head_dim), strides) + storage = tensor.contiguous().view(-1) + strides = (seq_len * num_heads * head_dim, head_dim, num_heads * head_dim, 1) + return torch.as_strided(storage, (batch_size, num_heads, seq_len, head_dim), strides) + + +def extract_thd_chunk(thd_tensor, batch_size, total_seq_len, chunk_idx, chunk_size): + """ + Extract a chunk from THD tensor for all batches. + + THD layout packs all tokens of batch 0 first, then batch 1, etc. + For chunk_idx, we need positions [chunk_idx*chunk_size : (chunk_idx+1)*chunk_size] + from each batch. + + Args: + thd_tensor: [total_tokens, num_heads, head_dim] - packed tensor + batch_size: number of batches + total_seq_len: sequence length per batch + chunk_idx: which chunk to extract (0-indexed) + chunk_size: size of each chunk + + Returns: + chunk: [batch_size * chunk_size, num_heads, head_dim] + """ + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + + chunks = [] + for b in range(batch_size): + batch_offset = b * total_seq_len + chunks.append(thd_tensor[batch_offset + chunk_start : batch_offset + chunk_end, :, :]) + + return torch.cat(chunks, dim=0) + + +def store_thd_chunk(o_full_thd, o_chunk, batch_size, total_seq_len, chunk_idx, chunk_size): + """ + Store a chunk back into the full THD output tensor. + + Args: + o_full_thd: [total_tokens, num_heads, head_dim] - output tensor to fill + o_chunk: [batch_size * chunk_size, num_heads, head_dim] - chunk output + batch_size: number of batches + total_seq_len: sequence length per batch + chunk_idx: which chunk (0-indexed) + chunk_size: size of each chunk + """ + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + + for b in range(batch_size): + batch_offset = b * total_seq_len + chunk_offset = b * chunk_size + o_full_thd[batch_offset + chunk_start : batch_offset + chunk_end, :, :] = o_chunk[chunk_offset : chunk_offset + chunk_size, :, :] + + +def execute_chunked_prefill_cudnn(cudnn_handle, config, q_full_thd, k_full_bhsd, v_full_bhsd): + batch_size, chunk_size, num_chunks = config.batch_size, config.chunk_size, config.num_chunks + total_seq_len = config.total_seq_len + h_q, h_k, h_v, d_qk, d_v, dtype = config.num_heads_q, config.num_heads_k, config.num_heads_v, config.head_dim_qk, config.head_dim_v, config.dtype + + o_full_thd = torch.empty(q_full_thd.shape[0], h_q, d_v, dtype=dtype, device="cuda") + + for chunk_idx in range(num_chunks): + print(f" Processing chunk {chunk_idx + 1}/{num_chunks}...") + kv_end = (chunk_idx + 1) * chunk_size + + # Extract Q chunk from THD tensor (properly handling batch layout) + q_chunk = extract_thd_chunk(q_full_thd, batch_size, total_seq_len, chunk_idx, chunk_size) + + # K/V up to current chunk end + k_chunk_bhsd = create_bhsd_view(k_full_bhsd[:, :, :kv_end, :].contiguous(), batch_size, h_k, kv_end, d_qk) + v_chunk_bhsd = create_bhsd_view(v_full_bhsd[:, :, :kv_end, :].contiguous(), batch_size, h_v, kv_end, d_v) + + seq_len_q = torch.full((batch_size,), chunk_size, dtype=torch.int32, device="cuda") + seq_len_kv = torch.full((batch_size,), kv_end, dtype=torch.int32, device="cuda") + q_ragged_offset = compute_ragged_offsets(seq_len_q, h_q, d_qk) + o_ragged_offset = compute_ragged_offsets(seq_len_q, h_q, d_v) + + causal_offset = chunk_idx * chunk_size if config.is_causal else 0 + graph = build_cudnn_sdpa_chunk_graph( + cudnn_handle, batch_size, h_q, h_k, h_v, d_qk, d_v, chunk_size, kv_end, dtype, config.attn_scale, config.is_causal, causal_offset + ) + o_chunk = execute_cudnn_sdpa_chunk( + cudnn_handle, + graph, + q_chunk, + k_chunk_bhsd, + v_chunk_bhsd, + seq_len_q, + seq_len_kv, + q_ragged_offset, + o_ragged_offset, + dtype, + h_q, + d_qk, + d_v, + batch_size, + chunk_size, + ) + + # Store output chunk back into full THD tensor + store_thd_chunk(o_full_thd, o_chunk, batch_size, total_seq_len, chunk_idx, chunk_size) + + return o_full_thd + + +def compare_outputs(output_gpu, output_ref, atol=0.02, rtol=0.02, tag="output"): + actual, expected = output_gpu.float(), output_ref.float() + mismatches = torch.where(~torch.isclose(actual, expected, rtol=rtol, atol=atol)) + mismatch_cnt = mismatches[0].numel() + if mismatch_cnt > 0: + print(f"\n{tag}: {mismatch_cnt:,} mismatches ({100 * mismatch_cnt / actual.numel():.2f}%)") + for idx in range(min(10, mismatch_cnt)): + pos = tuple(m[idx].item() for m in mismatches) + print(f" idx{pos}: gpu={actual[pos]:+.6e}, ref={expected[pos]:+.6e}, diff={actual[pos] - expected[pos]:+.2e}") + else: + print(f"{tag}: All values match within tolerance") + return mismatch_cnt + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=42) +def test_chunked_prefill_basic(cudnn_handle): + if LooseVersion(cudnn.backend_version_string()) < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0+") + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("Requires SM80+") + + print("\n" + "=" * 80 + "\nTest: Chunked Prefill (non-causal)\n" + "=" * 80) + config = ChunkedPrefillConfig( + batch_size=2, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + total_seq_len=4096, + chunk_size=1024, + dtype=torch.bfloat16, + is_causal=False, + ) + rng = torch.Generator(device="cuda").manual_seed(42) + seq_lens = torch.full((config.batch_size,), config.total_seq_len, dtype=torch.int32, device="cuda") + + q_thd, _ = create_thd_tensor(seq_lens, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.total_seq_len, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.total_seq_len, config.head_dim_v, config.dtype, rng) + + o_thd_gpu = execute_chunked_prefill_cudnn(cudnn_handle, config, q_thd, k_bhsd, v_bhsd) + q_bhsd_ref = thd_to_bhsd(q_thd, seq_lens, config.total_seq_len) + o_bhsd_ref = compute_chunked_prefill_reference(q_bhsd_ref, k_bhsd, v_bhsd, config, config.attn_scale) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_lens) + + if compare_outputs(o_thd_gpu, o_thd_ref) > 0: + pytest.fail("Test failed") + print("\nTEST PASSED") + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=123) +def test_chunked_prefill_causal(cudnn_handle): + if LooseVersion(cudnn.backend_version_string()) < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0+") + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("Requires SM80+") + + print("\n" + "=" * 80 + "\nTest: Chunked Prefill (causal)\n" + "=" * 80) + config = ChunkedPrefillConfig( + batch_size=2, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + total_seq_len=4096, + chunk_size=1024, + dtype=torch.bfloat16, + is_causal=True, + ) + rng = torch.Generator(device="cuda").manual_seed(123) + seq_lens = torch.full((config.batch_size,), config.total_seq_len, dtype=torch.int32, device="cuda") + + q_thd, _ = create_thd_tensor(seq_lens, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.total_seq_len, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.total_seq_len, config.head_dim_v, config.dtype, rng) + + o_thd_gpu = execute_chunked_prefill_cudnn(cudnn_handle, config, q_thd, k_bhsd, v_bhsd) + q_bhsd_ref = thd_to_bhsd(q_thd, seq_lens, config.total_seq_len) + o_bhsd_ref = compute_chunked_prefill_reference(q_bhsd_ref, k_bhsd, v_bhsd, config, config.attn_scale) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_lens) + + if compare_outputs(o_thd_gpu, o_thd_ref) > 0: + pytest.fail("Test failed") + print("\nTEST PASSED") + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=456) +def test_chunked_prefill_gqa(cudnn_handle): + if LooseVersion(cudnn.backend_version_string()) < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0+") + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("Requires SM80+") + + print("\n" + "=" * 80 + "\nTest: Chunked Prefill (GQA)\n" + "=" * 80) + config = ChunkedPrefillConfig( + batch_size=2, + num_heads_q=8, + num_heads_k=2, + num_heads_v=2, + head_dim_qk=128, + head_dim_v=128, + total_seq_len=4096, + chunk_size=1024, + dtype=torch.bfloat16, + is_causal=False, + ) + rng = torch.Generator(device="cuda").manual_seed(456) + seq_lens = torch.full((config.batch_size,), config.total_seq_len, dtype=torch.int32, device="cuda") + + q_thd, _ = create_thd_tensor(seq_lens, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.total_seq_len, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.total_seq_len, config.head_dim_v, config.dtype, rng) + + o_thd_gpu = execute_chunked_prefill_cudnn(cudnn_handle, config, q_thd, k_bhsd, v_bhsd) + q_bhsd_ref = thd_to_bhsd(q_thd, seq_lens, config.total_seq_len) + o_bhsd_ref = compute_chunked_prefill_reference(q_bhsd_ref, k_bhsd, v_bhsd, config, config.attn_scale) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_lens) + + if compare_outputs(o_thd_gpu, o_thd_ref) > 0: + pytest.fail("Test failed") + print("\nTEST PASSED") + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=789) +def test_chunked_prefill_gqa_causal(cudnn_handle): + if LooseVersion(cudnn.backend_version_string()) < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0+") + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("Requires SM80+") + + print("\n" + "=" * 80 + "\nTest: Chunked Prefill (GQA + causal)\n" + "=" * 80) + config = ChunkedPrefillConfig( + batch_size=2, + num_heads_q=8, + num_heads_k=2, + num_heads_v=2, + head_dim_qk=128, + head_dim_v=128, + total_seq_len=4096, + chunk_size=1024, + dtype=torch.bfloat16, + is_causal=True, + ) + rng = torch.Generator(device="cuda").manual_seed(789) + seq_lens = torch.full((config.batch_size,), config.total_seq_len, dtype=torch.int32, device="cuda") + + q_thd, _ = create_thd_tensor(seq_lens, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.total_seq_len, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.total_seq_len, config.head_dim_v, config.dtype, rng) + + o_thd_gpu = execute_chunked_prefill_cudnn(cudnn_handle, config, q_thd, k_bhsd, v_bhsd) + q_bhsd_ref = thd_to_bhsd(q_thd, seq_lens, config.total_seq_len) + o_bhsd_ref = compute_chunked_prefill_reference(q_bhsd_ref, k_bhsd, v_bhsd, config, config.attn_scale) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_lens) + + if compare_outputs(o_thd_gpu, o_thd_ref) > 0: + pytest.fail("Test failed") + print("\nTEST PASSED") + + +if __name__ == "__main__": + print("Run with: pytest -vv -s -rA test_sdpa_chunked_prefill.py") diff --git a/test/python/test_sdpa_fp8.py b/test/python/test_sdpa_fp8.py deleted file mode 100644 index da3fe439..00000000 --- a/test/python/test_sdpa_fp8.py +++ /dev/null @@ -1,811 +0,0 @@ -# fmt: off - -import torch -import cudnn -import pytest -import argparse -from enum import IntEnum -from looseversion import LooseVersion -import math - -from test_utils import torch_fork_set_rng - -torch.nans = lambda *size, **kwargs: torch.full(size, float('nan'), **kwargs) - -# sq1_*, sq4_*, sq32_*, sq64_*: BUG mismatches -TEST_CONFIGS_FWD = { - "d128_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "d64_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "d128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - - # cudnnTest replica: - # ./cudnnTest -RgraphRunner -jsonTestName=LLM_paged_attention_fp8 -kv=dim_b:2 -kv=dim_qh:4 -kv=dim_qs:256 -kv=dim_kvs:256 -kv=dim_d:128 -kv=dim_kvh:4 -kv=Tin:CUDNN_DATA_FP8_E4M3 -kv=Tout:CUDNN_DATA_FP8_E4M3 -kv=atol:0.08 -kv=rtol:0.2 -minDevVer800 -backendEngine-1 -b -gpuRef -kv=block_size:16 -kv=table_size:16 -kv=max_block_num:31 -kv=dim_num_blocks:32 - "d128_f8e4m3_paged": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2, "kv_block_size": 16}, - - "d64_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "d128_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.4}, - "d64_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.4}, - - "gqa_f16": {"b": 2, "h_q": 15, "h_k": 5, "h_v": 3, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "gqa_f8e4m3": {"b": 2, "h_q": 15, "h_k": 5, "h_v": 3, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "gqa_f8e5m2": {"b": 2, "h_q": 15, "h_k": 5, "h_v": 3, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.4}, - - "sq1_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq1_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq1_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "sq1_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "sq1_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq1_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 1, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq4_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq4_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq4_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq4_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq4_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq4_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 4, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq8_skv128_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq8_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq8_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq8_skv128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq8_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq8_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq8_skv128_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq8_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq8_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 8, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq16_skv128_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq16_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq16_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq16_skv128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq16_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq16_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq16_skv128_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq16_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq16_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 16, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq32_skv128_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq32_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq32_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq32_skv128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq32_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq32_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq32_skv128_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq32_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq32_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 32, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq64_skv128_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq64_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq64_skv1024_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq64_skv128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq64_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq64_skv1024_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.16, "rtol": 0.2}, - "sq64_skv128_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 128, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq64_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - "sq64_skv1024_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 64, "s_kv": 1024, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.2}, - - "sq65_skv256_f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 65, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.04, "rtol": 0.1}, - "sq65_skv256_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 65, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "sq65_skv256_f8e5m2": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 65, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e5m2", "otype": "fp8_e5m2", "atol": 0.16, "rtol": 0.4}, -} - -TEST_CONFIGS_BWD = { - "d64_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "d64_f8e4m3_gqa": {"b": 2, "h_q": 4, "h_k": 2, "h_v": 2, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "d64_f8e4m3_o-f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.08, "rtol": 0.2}, - "d64_f8e4m3_o-f16_gqa": {"b": 2, "h_q": 4, "h_k": 2, "h_v": 2, "s_qo": 256, "s_kv": 256, "d_qk": 64, "d_vo": 64, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.08, "rtol": 0.2}, - "d128_f8e4m3": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "d128_f8e4m3_gqa": {"b": 2, "h_q": 4, "h_k": 2, "h_v": 2, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp8_e4m3", "atol": 0.08, "rtol": 0.2}, - "d128_f8e4m3_o-f16": {"b": 2, "h_q": 4, "h_k": 4, "h_v": 4, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.08, "rtol": 0.2}, - "d128_f8e4m3_o-f16_gqa": {"b": 2, "h_q": 4, "h_k": 2, "h_v": 2, "s_qo": 256, "s_kv": 256, "d_qk": 128, "d_vo": 128, "itype": "fp8_e4m3", "otype": "fp16", "atol": 0.08, "rtol": 0.2}, -} - -BLOCKED_CONFIGS_FWD = [ - "sq1_skv1024_f8e5m2", # fails on prefill as well -] - -BLOCKED_CONFIGS_BWD = [] - -def get_torch_and_cudnn_type(type_str): - if type_str == "fp8_e4m3": - return torch.float8_e4m3fn, cudnn.data_type.FP8_E4M3 - elif type_str == "fp8_e5m2": - return torch.float8_e5m2, cudnn.data_type.FP8_E5M2 - elif type_str == "fp16": - return torch.float16, cudnn.data_type.HALF - elif type_str == "bf16": - return torch.bfloat16, cudnn.data_type.BFLOAT16 - else: - return None, None - -def section_begin(msg, width=80): - print(f" {msg} ".center(width, "=")) - -def section_end(width=80): - print("=" * width) - -def get_fp8_largest_po2(dtype: torch.dtype): - if dtype == torch.float8_e4m3fn: - return 128.0 # max representable value: 0x1.e00000p+7 - elif dtype == torch.float8_e5m2: - return 32768.0 # max representable value: 0x1.c00000p+15 - else: - raise ValueError(f"Unsupported dtype: {dtype}") - -def get_fp8_scale_factor(amax: float, dtype: torch.dtype, fudge_factor: float = 0.25, epsilon = 0.0625): - if dtype == torch.float16 or dtype == torch.bfloat16: - return 1.0 - po2_next = 2 ** math.ceil(math.log2(max(amax, epsilon))) - return get_fp8_largest_po2(dtype) / po2_next * fudge_factor - -def get_fp8_descale_factor(amax: float, dtype: torch.dtype, fudge_factor: float = 0.25, epsilon = 0.0625): - return 1.0 / get_fp8_scale_factor(amax, dtype, fudge_factor, epsilon) - -class GraphFwdUid(IntEnum): - q = 0 - k = 1 - v = 2 - - q_descale = 5 - k_descale = 6 - v_descale = 7 - s_scale = 9 - s_descale = 8 - o_scale = 10 - - o = 3 - stats = 4 - - s_amax = 11 - o_amax = 12 - - kv_seq_len = 13 - q_seq_len = 14 - - k_block_table = 15 - v_block_table = 16 - -class GraphBwdUid(IntEnum): - q = 100 - k = 101 - v = 102 - o = 103 - dO = 104 - stats = 105 - - q_descale = 106 - k_descale = 107 - v_descale = 108 - o_descale = 109 - dO_descale = 110 - s_descale = 111 - dP_descale = 112 - s_scale = 113 - dQ_scale = 114 - dK_scale = 115 - dV_scale = 116 - dP_scale = 117 - - dQ = 118 - dK = 119 - dV = 120 - - dQ_amax = 121 - dK_amax = 122 - dV_amax = 123 - dP_amax = 124 - -def generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, block_size): - graph_fwd = cudnn.pygraph(io_data_type=cudnn_itype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT) - - # Variable sequence lenths are required for paged attention - use_padding_mask = None - kv_seq_len = None - q_seq_len = None - k_block_table = None - v_block_table = None - - if block_size == 0: - q = graph_fwd.tensor(uid=GraphFwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) - k = graph_fwd.tensor(uid=GraphFwdUid.k, dim=(b, h_k, s_kv, d_qk), stride=(s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1), data_type=cudnn_itype) - v = graph_fwd.tensor(uid=GraphFwdUid.v, dim=(b, h_v, s_kv, d_vo), stride=(s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1), data_type=cudnn_itype) - else: - table_size = math.ceil(s_kv / block_size) - num_blocks = table_size * b - - q = graph_fwd.tensor(uid=GraphFwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) - k = graph_fwd.tensor(uid=GraphFwdUid.k, dim=(num_blocks, h_k, block_size, d_qk), stride=(block_size * h_k * d_qk, block_size * d_qk, d_qk, 1), data_type=cudnn_itype) - v = graph_fwd.tensor(uid=GraphFwdUid.v, dim=(num_blocks, h_v, block_size, d_vo), stride=(block_size * h_v * d_vo, block_size * d_vo, d_vo, 1), data_type=cudnn_itype) - - use_padding_mask = True - kv_seq_len = graph_fwd.tensor(uid=GraphFwdUid.kv_seq_len, dim=(b, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) - q_seq_len = graph_fwd.tensor(uid=GraphFwdUid.q_seq_len, dim=(b, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.INT32) - k_block_table = graph_fwd.tensor(uid=GraphFwdUid.k_block_table, dim=(b, 1, table_size, 1), stride=(table_size, table_size, 1, 1), data_type=cudnn.data_type.INT32) - v_block_table = graph_fwd.tensor(uid=GraphFwdUid.v_block_table, dim=(b, 1, table_size, 1), stride=(table_size, table_size, 1, 1), data_type=cudnn.data_type.INT32) - - q_descale = graph_fwd.tensor(uid=GraphFwdUid.q_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - k_descale = graph_fwd.tensor(uid=GraphFwdUid.k_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - v_descale = graph_fwd.tensor(uid=GraphFwdUid.v_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - s_scale = graph_fwd.tensor(uid=GraphFwdUid.s_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - s_descale = graph_fwd.tensor(uid=GraphFwdUid.s_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - o_scale = graph_fwd.tensor(uid=GraphFwdUid.o_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - - o, stats, amax_s, amax_o = graph_fwd.sdpa_fp8( - q=q, - k=k, - v=v, - descale_q=q_descale, - descale_k=k_descale, - descale_v=v_descale, - scale_s=s_scale, - descale_s=s_descale, - scale_o=o_scale, - generate_stats=True, - attn_scale=attn_scale, - use_causal_mask=False, - use_padding_mask=use_padding_mask, - seq_len_kv=kv_seq_len, - seq_len_q=q_seq_len, - paged_attention_k_table=k_block_table, # Block Table K: Tensor containing offsets to the container with K blocks - paged_attention_v_table=v_block_table, # Block Table V: Tensor containing offsets to the container with V blocks - paged_attention_max_seq_len_kv=s_kv, # The maximum sequence length for K caches (this is optional, but recommended) - ) - - o.set_uid(GraphFwdUid.o).set_output(True).set_dim((b, h_q, s_qo, d_vo)).set_stride((s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1)).set_data_type(cudnn_otype) - stats.set_uid(GraphFwdUid.stats).set_output(True).set_dim((b, h_q, s_qo, 1)).set_stride((s_qo * h_q, s_qo, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - amax_s.set_uid(GraphFwdUid.s_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - amax_o.set_uid(GraphFwdUid.o_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - - return graph_fwd - -def generate_graph_bwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale): - graph_bwd = cudnn.pygraph(io_data_type=cudnn_itype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT) - - q = graph_bwd.tensor(uid=GraphBwdUid.q, dim=(b, h_q, s_qo, d_qk), stride=(s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1), data_type=cudnn_itype) - k = graph_bwd.tensor(uid=GraphBwdUid.k, dim=(b, h_k, s_kv, d_qk), stride=(s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1), data_type=cudnn_itype) - v = graph_bwd.tensor(uid=GraphBwdUid.v, dim=(b, h_v, s_kv, d_vo), stride=(s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1), data_type=cudnn_itype) - o = graph_bwd.tensor(uid=GraphBwdUid.o, dim=(b, h_q, s_qo, d_vo), stride=(s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1), data_type=cudnn_otype) - dO = graph_bwd.tensor(uid=GraphBwdUid.dO, dim=(b, h_q, s_qo, d_vo), stride=(s_qo * h_q * d_vo, d_vo, h_q * d_vo, 1), data_type=cudnn_itype) - stats = graph_bwd.tensor(uid=GraphBwdUid.stats, dim=(b, h_q, s_qo, 1), stride=(s_qo * h_q, s_qo, 1, 1), data_type=cudnn.data_type.FLOAT) - - q_descale = graph_bwd.tensor(uid=GraphBwdUid.q_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - k_descale = graph_bwd.tensor(uid=GraphBwdUid.k_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - v_descale = graph_bwd.tensor(uid=GraphBwdUid.v_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - o_descale = graph_bwd.tensor(uid=GraphBwdUid.o_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dO_descale = graph_bwd.tensor(uid=GraphBwdUid.dO_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - s_descale = graph_bwd.tensor(uid=GraphBwdUid.s_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dP_descale = graph_bwd.tensor(uid=GraphBwdUid.dP_descale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - - s_scale = graph_bwd.tensor(uid=GraphBwdUid.s_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dQ_scale = graph_bwd.tensor(uid=GraphBwdUid.dQ_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dK_scale = graph_bwd.tensor(uid=GraphBwdUid.dK_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dV_scale = graph_bwd.tensor(uid=GraphBwdUid.dV_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - dP_scale = graph_bwd.tensor(uid=GraphBwdUid.dP_scale, dim=(1, 1, 1, 1), stride=(1, 1, 1, 1), data_type=cudnn.data_type.FLOAT) - - dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP = graph_bwd.sdpa_fp8_backward( - q=q, - k=k, - v=v, - o=o, - dO=dO, - stats=stats, - descale_q=q_descale, - descale_k=k_descale, - descale_v=v_descale, - descale_o=o_descale, - descale_dO=dO_descale, - descale_s=s_descale, - descale_dP=dP_descale, - scale_s=s_scale, - scale_dQ=dQ_scale, - scale_dK=dK_scale, - scale_dV=dV_scale, - scale_dP=dP_scale, - attn_scale=attn_scale, - use_padding_mask=False, - ) - - dQ.set_uid(GraphBwdUid.dQ).set_output(True).set_dim((b, h_q, s_qo, d_qk)).set_stride((s_qo * h_q * d_qk, d_qk, h_q * d_qk, 1)).set_data_type(cudnn_itype) - dK.set_uid(GraphBwdUid.dK).set_output(True).set_dim((b, h_k, s_kv, d_qk)).set_stride((s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1)).set_data_type(cudnn_itype) - dV.set_uid(GraphBwdUid.dV).set_output(True).set_dim((b, h_v, s_kv, d_vo)).set_stride((s_kv * h_v * d_vo, d_vo, h_v * d_vo, 1)).set_data_type(cudnn_itype) - - amax_dQ.set_uid(GraphBwdUid.dQ_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - amax_dK.set_uid(GraphBwdUid.dK_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - amax_dV.set_uid(GraphBwdUid.dV_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - amax_dP.set_uid(GraphBwdUid.dP_amax).set_output(True).set_dim((1, 1, 1, 1)).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) - - return graph_bwd - -def create_paged_container_and_block_table(tensor, block_size): - B, H, S, D = tensor.shape - blocks_per_batch = math.ceil(S / block_size) - - padding_seq = blocks_per_batch * block_size - S - if padding_seq > 0: - zeros = torch.zeros(B, H, padding_seq, D, device="cuda", dtype=tensor.dtype) - cat_tensor = torch.cat((tensor, zeros), dim=2) - else: - cat_tensor = tensor - - container = torch.cat(cat_tensor.chunk(blocks_per_batch, dim=2), dim=0) - - table_size = math.ceil(S / block_size) - block_table_temp = torch.linspace(0, B * table_size - 1, B * table_size, device="cuda", dtype=torch.int32).reshape(table_size, 1, B, 1) - block_table_temp = torch.transpose(block_table_temp, 0, 2) - - block_table = (torch.zeros(blocks_per_batch * B, device="cuda", dtype=torch.int32).as_strided((B, 1, blocks_per_batch, 1), (blocks_per_batch, blocks_per_batch, 1, 1))) - block_table.copy_(block_table_temp) - - return (container, block_table) - -def compute_ref(q, k, v, attn_scale=1.0, return_type="o"): - b, s_q, h_q, d_qk = q.shape - _, s_kv, h_k, _ = k.shape - _, _, h_v, d_v = v.shape - - assert k.shape == (b, s_kv, h_k, d_qk) - assert v.shape == (b, s_kv, h_v, d_v) - - if h_q != h_k: - k = k.repeat_interleave(h_q // h_k, dim=2) - if h_q != h_v: - v = v.repeat_interleave(h_q // h_v, dim=2) - - s = torch.einsum("bqhd,bkhd->bhqk", q, k) * attn_scale - p = s.softmax(dim=-1) - o = torch.einsum("bhqk,bkhd->bqhd", p, v) - - if return_type == "o": - return o - if return_type == "o_stats": - # TODO implement - return o, torch.zeros() - elif return_type == "amax": - return p.abs().max().item(), o.abs().max().item() - else: - raise ValueError(f"Unsupported return type: {return_type}") - -@pytest.mark.parametrize("name", TEST_CONFIGS_FWD.keys()) -@pytest.mark.L0 -@torch_fork_set_rng(seed=0) -def test_sdpa_fwd_fp8(name): - print() - section_begin(f"Running {name}") - config = TEST_CONFIGS_FWD[name] - - if name in BLOCKED_CONFIGS_FWD: - pytest.skip("TEST WAIVED: blocked config") - - cudnn_version = LooseVersion(cudnn.backend_version_string()) - if cudnn_version < "9.14.0": - pytest.skip("TEST WAIVED: SDPA FP8 fprop testing is limited to cuDNN 9.14.0 or higher") - if torch.cuda.get_device_capability()[0] < 10: - pytest.skip("TEST WAIVED: SDPA FP8 fprop testing is limited to Blackwell or higher") - - torch_itype, cudnn_itype = get_torch_and_cudnn_type(config["itype"]) - torch_otype, cudnn_otype = get_torch_and_cudnn_type(config["otype"]) - assert torch_itype is not None and cudnn_itype is not None - assert torch_otype is not None and cudnn_otype is not None - - b = config["b"] - h_q = config["h_q"] - h_k = config["h_k"] - h_v = config["h_v"] - s_qo = config["s_qo"] - s_kv = config["s_kv"] - d_qk = config["d_qk"] - d_vo = config["d_vo"] - - attn_scale = 0.125 - block_size = config.get("kv_block_size", 0) - - is_paged_attention = block_size > 0 - - section_begin("Building Graph") - try: - graph_fwd = generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, block_size) - graph_fwd.validate() - graph_fwd.build_operation_graph() - graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_fwd.check_support() - graph_fwd.build_plans() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"TEST WAIVED: unsupported graph. {e}") - pytest.skip("TEST WAIVED: unsupported graph.") - except Exception as e: - print(f"Error building graph: {e}") - pytest.fail(f"Error building graph: {e}") - section_end() - - section_begin("Allocate and Generate") - q_gen = torch.clamp(torch.randn(b, s_qo, h_q, d_qk, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - k_gen = torch.clamp(torch.randn(b, s_kv, h_k, d_qk, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - v_gen = torch.clamp(torch.randn(b, s_kv, h_v, d_vo, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - - q_amax = q_gen.abs().max().item() - k_amax = k_gen.abs().max().item() - v_amax = v_gen.abs().max().item() - s_amax, o_amax = compute_ref(q_gen, k_gen, v_gen, attn_scale, return_type="amax") - - q_gpu = (q_gen * get_fp8_scale_factor(q_amax, torch_itype)).to(torch_itype) - k_gpu = (k_gen * get_fp8_scale_factor(k_amax, torch_itype)).to(torch_itype) - v_gpu = (v_gen * get_fp8_scale_factor(v_amax, torch_itype)).to(torch_itype) - - if is_paged_attention: - k_gpu_bhsd = torch.einsum('bshd->bhsd', k_gpu).contiguous() - v_gpu_bhsd = torch.einsum('bshd->bhsd', v_gpu).contiguous() - container_k_gpu, k_block_table_gpu = create_paged_container_and_block_table(k_gpu_bhsd, block_size) - container_v_gpu, v_block_table_gpu = create_paged_container_and_block_table(v_gpu_bhsd, block_size) - - kv_seq_len_gpu = torch.full((b, 1, 1, 1), s_kv, device="cuda", dtype=torch.int32) - q_seq_len_gpu = torch.full((b, 1, 1, 1), s_qo, device="cuda", dtype=torch.int32) - o_gpu = torch.nans(b, s_qo, h_q, d_vo, dtype=torch_otype, device="cuda") - stats_gpu = torch.nans(b, h_q, s_qo, 1, dtype=torch.float, device="cuda") - - q_descale_gpu = torch.tensor([get_fp8_descale_factor(q_amax, torch_itype)], dtype=torch.float, device="cuda") - k_descale_gpu = torch.tensor([get_fp8_descale_factor(k_amax, torch_itype)], dtype=torch.float, device="cuda") - v_descale_gpu = torch.tensor([get_fp8_descale_factor(v_amax, torch_itype)], dtype=torch.float, device="cuda") - s_scale_gpu = torch.tensor([get_fp8_scale_factor(s_amax, torch_itype)], dtype=torch.float, device="cuda") - s_descale_gpu = torch.tensor([get_fp8_descale_factor(s_amax, torch_itype)], dtype=torch.float, device="cuda") - o_scale_gpu = torch.tensor([get_fp8_scale_factor(o_amax, torch_otype)], dtype=torch.float, device="cuda") - - s_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - o_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - section_end() - - section_begin("Execute") - # execute forward and backward graph - variant_pack = { - int(GraphFwdUid.q): q_gpu, - int(GraphFwdUid.k): k_gpu, - int(GraphFwdUid.v): v_gpu, - - int(GraphFwdUid.q_descale): q_descale_gpu, - int(GraphFwdUid.k_descale): k_descale_gpu, - int(GraphFwdUid.v_descale): v_descale_gpu, - int(GraphFwdUid.s_descale): s_descale_gpu, - int(GraphFwdUid.s_scale): s_scale_gpu, - int(GraphFwdUid.o_scale): o_scale_gpu, - - int(GraphFwdUid.o): o_gpu, - int(GraphFwdUid.stats): stats_gpu, - - int(GraphFwdUid.s_amax): s_amax_gpu, - int(GraphFwdUid.o_amax): o_amax_gpu, - } - - if is_paged_attention: - variant_pack[int(GraphFwdUid.k)] = container_k_gpu - variant_pack[int(GraphFwdUid.v)] = container_v_gpu - variant_pack[int(GraphFwdUid.kv_seq_len)] = kv_seq_len_gpu - variant_pack[int(GraphFwdUid.q_seq_len)] = q_seq_len_gpu - variant_pack[int(GraphFwdUid.k_block_table)] = k_block_table_gpu - variant_pack[int(GraphFwdUid.v_block_table)] = v_block_table_gpu - - workspace = torch.empty(graph_fwd.get_workspace_size(), dtype=torch.uint8, device="cuda") - cudnn_handle = cudnn.create_handle() - graph_fwd.execute(variant_pack, workspace, handle=cudnn_handle) - torch.cuda.synchronize() - cudnn.destroy_handle(cudnn_handle) - section_end() - - section_begin("Run Reference and Compare Output") - q_ref = q_gpu.detach().float() * get_fp8_descale_factor(q_amax, torch_itype) - k_ref = k_gpu.detach().float() * get_fp8_descale_factor(k_amax, torch_itype) - v_ref = v_gpu.detach().float() * get_fp8_descale_factor(v_amax, torch_itype) - o_ref = compute_ref(q_ref, k_ref, v_ref, attn_scale=attn_scale) - - o_ref_comp = o_ref - o_gpu_comp = o_gpu.detach().float() * get_fp8_descale_factor(o_amax, torch_otype) - - print("o_ref_comp.numel()", o_ref_comp.numel()) - print("o_gpu_comp.numel()", o_gpu_comp.numel()) - print("Number of zeros in o_ref_comp:", (o_ref_comp == 0).sum().item()) - print("Number of zeros in o_gpu_comp:", (o_gpu_comp == 0).sum().item()) - print("Number of non-finite elements in o_ref_comp:", (~torch.isfinite(o_ref_comp)).sum().item()) - print("Number of non-finite elements in o_gpu_comp:", (~torch.isfinite(o_gpu_comp)).sum().item()) - - for _ in range(3): - coord = tuple(torch.randint(0, numel, (1,)).item() for numel in o_ref_comp.size()) - print(f"o_ref_comp{coord}:", float(o_ref_comp[coord].item()).hex()) - print(f"o_gpu_comp{coord}:", float(o_gpu_comp[coord].item()).hex()) - - print(f"s_amax_gpu={s_amax_gpu.item()}, s_amax={s_amax}") - print(f"o_amax_gpu={o_amax_gpu.item()}, o_amax={o_amax}") - - failed = [] - try: - torch.testing.assert_close(o_gpu_comp, o_ref_comp, atol=config["atol"], rtol=config["rtol"]) - except Exception as e: - print("\033[91m" + f"o_gpu: {e}" + "\033[0m\n"); failed.append("o_gpu") - try: - torch.testing.assert_close(s_amax_gpu.item(), s_amax, atol=0.04, rtol=0.10) - except Exception as e: - print("\033[91m" + f"s_amax_gpu: {e}" + "\033[0m\n"); failed.append("s_amax_gpu") - try: - torch.testing.assert_close(o_amax_gpu.item(), o_amax, atol=0.04, rtol=0.10) - except Exception as e: - print("\033[91m" + f"o_amax_gpu: {e}" + "\033[0m\n"); failed.append("o_amax_gpu") - - if len(failed) > 0: - print("\033[91m" + "Failed!" + "\033[0m"); pytest.fail(f"Failed: mismatches in {', '.join(failed)}") - print("\033[92m" + "Passed!" + "\033[0m") - - # # used to debug tolerances - # x = o_ref_comp.abs() - # y = o_ref_comp - o_gpu_comp - # import plotly.express as px - # import plotly.io as pio - # fig = px.scatter( - # x=x.cpu().flatten().numpy(), - # y=y.cpu().flatten().numpy(), - # labels={"x": "Absolute value", "y": "Absolute Error"}, - # title="Absolute value vs absolute error" - # ) - # pio.write_html(fig, file=f"scatter_{name}.html", auto_open=False) - # print(f"wrote scatter_{name}.html") - - section_end() - print() - -@pytest.mark.parametrize("name", TEST_CONFIGS_BWD.keys()) -@pytest.mark.L0 -@torch_fork_set_rng(seed=0) -def test_sdpa_bwd_fp8(name): - print() - section_begin(f"Running {name} (backward)") - config = TEST_CONFIGS_BWD[name] - - if name in BLOCKED_CONFIGS_BWD: - pytest.skip("TEST WAIVED: blocked config") - - cudnn_version = LooseVersion(cudnn.backend_version_string()) - if cudnn_version < "9.14.0": - pytest.skip("TEST WAIVED: SDPA FP8 bprop testing is limited to cuDNN 9.14.0 or higher") - if torch.cuda.get_device_capability()[0] < 10: - pytest.skip("TEST WAIVED: SDPA FP8 bprop testing is limited to Blackwell or higher") - - torch_itype, cudnn_itype = get_torch_and_cudnn_type(config["itype"]) - torch_otype, cudnn_otype = get_torch_and_cudnn_type(config["otype"]) - assert torch_itype is not None and cudnn_itype is not None - assert torch_otype is not None and cudnn_otype is not None - - b = config["b"] - h_q = config["h_q"] - h_k = config["h_k"] - h_v = config["h_v"] - s_qo = config["s_qo"] - s_kv = config["s_kv"] - d_qk = config["d_qk"] - d_vo = config["d_vo"] - - attn_scale = 0.125 - - section_begin("Build Graphs") - graph_fwd = generate_graph_fwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale, 0) - graph_bwd = generate_graph_bwd(cudnn_itype, cudnn_otype, b, h_q, h_k, h_v, s_qo, s_kv, d_qk, d_vo, attn_scale) - - try: - graph_fwd.validate(); graph_fwd.build_operation_graph(); graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]); graph_fwd.check_support(); graph_fwd.build_plans() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"TEST WAIVED: unsupported fwd graph. {e}") - pytest.skip("TEST WAIVED: unsupported fwd graph.") - try: - graph_bwd.validate(); graph_bwd.build_operation_graph(); graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]); graph_bwd.check_support(); graph_bwd.build_plans() - except cudnn.cudnnGraphNotSupportedError as e: - print(f"TEST WAIVED: unsupported bwd graph. {e}") - pytest.skip("TEST WAIVED: unsupported bwd graph.") - - section_end() - - section_begin("Allocate and Generate") - q_gen = torch.clamp(torch.randn(b, s_qo, h_q, d_qk, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - k_gen = torch.clamp(torch.randn(b, s_kv, h_k, d_qk, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - v_gen = torch.clamp(torch.randn(b, s_kv, h_v, d_vo, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - dO_gen = torch.clamp(torch.randn(b, s_qo, h_q, d_vo, dtype=torch.float, device="cuda"), min=-2.0, max=2.0) - - q_amax = q_gen.abs().max().item() - k_amax = k_gen.abs().max().item() - v_amax = v_gen.abs().max().item() - s_amax, o_amax = compute_ref(q_gen, k_gen, v_gen, attn_scale, return_type="amax") - dO_amax = dO_gen.abs().max().item() - - # q_gpu = (q_gen * get_fp8_scale_factor(q_amax, torch_itype)).to(torch_itype) - # k_gpu = (k_gen * get_fp8_scale_factor(k_amax, torch_itype)).to(torch_itype) - # v_gpu = (v_gen * get_fp8_scale_factor(v_amax, torch_itype)).to(torch_itype) - q_gpu = q_gen.to(torch_itype) - k_gpu = k_gen.to(torch_itype) - v_gpu = v_gen.to(torch_itype) - - o_gpu = torch.nans(b, s_qo, h_q, d_vo, dtype=torch_otype, device="cuda") - stats_gpu = torch.nans(b, h_q, s_qo, 1, dtype=torch.float, device="cuda") - - # dO_gpu = (dO_gen * get_fp8_scale_factor(dO_amax, torch_itype)).to(torch_itype) - dO_gpu = dO_gen.to(torch_itype) - - q_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - k_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - v_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - - s_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - s_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - o_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - - s_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - o_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - - o_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - dO_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - dP_descale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - - dQ_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - dK_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - dV_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - dP_scale_gpu = torch.tensor([1.0], dtype=torch.float, device="cuda") - - dQ_gpu = torch.nans(b, s_qo, h_q, d_qk, dtype=torch_itype, device="cuda") - dK_gpu = torch.nans(b, s_kv, h_k, d_qk, dtype=torch_itype, device="cuda") - dV_gpu = torch.nans(b, s_kv, h_v, d_vo, dtype=torch_itype, device="cuda") - - dQ_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - dK_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - dV_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - dP_amax_gpu = torch.tensor([float('nan')], dtype=torch.float, device="cuda") - section_end() - - section_begin("Execute FWD") - variant_pack_fwd = { - int(GraphFwdUid.q): q_gpu, - int(GraphFwdUid.k): k_gpu, - int(GraphFwdUid.v): v_gpu, - - int(GraphFwdUid.q_descale): q_descale_gpu, - int(GraphFwdUid.k_descale): k_descale_gpu, - int(GraphFwdUid.v_descale): v_descale_gpu, - int(GraphFwdUid.s_descale): s_descale_gpu, - int(GraphFwdUid.s_scale): s_scale_gpu, - int(GraphFwdUid.o_scale): o_scale_gpu, - - int(GraphFwdUid.o): o_gpu, - int(GraphFwdUid.stats): stats_gpu, - - int(GraphFwdUid.s_amax): s_amax_gpu, - int(GraphFwdUid.o_amax): o_amax_gpu, - } - - workspace = torch.empty(graph_fwd.get_workspace_size(), dtype=torch.uint8, device="cuda") - cudnn_handle = cudnn.create_handle() - graph_fwd.execute(variant_pack_fwd, workspace, handle=cudnn_handle) - torch.cuda.synchronize() - cudnn.destroy_handle(cudnn_handle) - section_end() - - section_begin("Execute BWD") - variant_pack_bwd = { - int(GraphBwdUid.q): q_gpu, - int(GraphBwdUid.k): k_gpu, - int(GraphBwdUid.v): v_gpu, - int(GraphBwdUid.o): o_gpu, - int(GraphBwdUid.dO): dO_gpu, - int(GraphBwdUid.stats): stats_gpu, - - int(GraphBwdUid.q_descale): q_descale_gpu, - int(GraphBwdUid.k_descale): k_descale_gpu, - int(GraphBwdUid.v_descale): v_descale_gpu, - int(GraphBwdUid.o_descale): o_descale_gpu, - int(GraphBwdUid.dO_descale): dO_descale_gpu, - int(GraphBwdUid.s_descale): s_descale_gpu, - int(GraphBwdUid.s_scale): s_scale_gpu, - int(GraphBwdUid.dP_descale): dP_descale_gpu, - int(GraphBwdUid.dP_scale): dP_scale_gpu, - int(GraphBwdUid.dQ_scale): dQ_scale_gpu, - int(GraphBwdUid.dK_scale): dK_scale_gpu, - int(GraphBwdUid.dV_scale): dV_scale_gpu, - - int(GraphBwdUid.dQ): dQ_gpu, - int(GraphBwdUid.dK): dK_gpu, - int(GraphBwdUid.dV): dV_gpu, - - int(GraphBwdUid.dQ_amax): dQ_amax_gpu, - int(GraphBwdUid.dK_amax): dK_amax_gpu, - int(GraphBwdUid.dV_amax): dV_amax_gpu, - int(GraphBwdUid.dP_amax): dP_amax_gpu, - } - - workspace_b = torch.empty(graph_bwd.get_workspace_size(), dtype=torch.uint8, device="cuda") - cudnn_handle = cudnn.create_handle() - graph_bwd.execute(variant_pack_bwd, workspace_b, handle=cudnn_handle) - torch.cuda.synchronize() - cudnn.destroy_handle(cudnn_handle) - section_end() - - section_begin("Run Reference and Compare Output") - q_ref = q_gpu.detach().float() - k_ref = k_gpu.detach().float() - v_ref = v_gpu.detach().float() - # q_ref = q_gpu.detach().float() * get_fp8_descale_factor(q_amax, torch_itype) - # k_ref = k_gpu.detach().float() * get_fp8_descale_factor(k_amax, torch_itype) - # v_ref = v_gpu.detach().float() * get_fp8_descale_factor(v_amax, torch_itype) - o_ref = compute_ref(q_ref, k_ref, v_ref, attn_scale=attn_scale) - - dO_ref = dO_gpu.detach().float() - # dO_ref = dO_gpu.detach().float() * get_fp8_descale_factor(dO_amax, torch_itype) - - q_ref.requires_grad_(True) - k_ref.requires_grad_(True) - v_ref.requires_grad_(True) - o_tmp = compute_ref(q_ref, k_ref, v_ref, attn_scale=attn_scale) - dQ_ref, dK_ref, dV_ref = torch.autograd.grad(outputs=o_tmp, inputs=[q_ref, k_ref, v_ref], grad_outputs=dO_gen) - - dQ_amax_ref = dQ_ref.abs().max().item() - dK_amax_ref = dK_ref.abs().max().item() - dV_amax_ref = dV_ref.abs().max().item() - - dQ_out = dQ_gpu.detach().float() - dK_out = dK_gpu.detach().float() - dV_out = dV_gpu.detach().float() - # dQ_out = dQ_gpu.detach().float() * get_fp8_descale_factor(dQ_amax, torch_itype) - # dK_out = dK_gpu.detach().float() * get_fp8_descale_factor(dK_amax, torch_itype) - # dV_out = dV_gpu.detach().float() * get_fp8_descale_factor(dV_amax, torch_itype) - - print("dQ_out.numel()", dQ_out.numel()) - print("dK_out.numel()", dK_out.numel()) - print("dV_out.numel()", dV_out.numel()) - print("Number of zeros in dQ_out:", (dQ_out == 0).sum().item()) - print("Number of zeros in dK_out:", (dK_out == 0).sum().item()) - print("Number of zeros in dV_out:", (dV_out == 0).sum().item()) - print("Number of non-finite elements in dQ_out:", (~torch.isfinite(dQ_out)).sum().item()) - print("Number of non-finite elements in dK_out:", (~torch.isfinite(dK_out)).sum().item()) - print("Number of non-finite elements in dV_out:", (~torch.isfinite(dV_out)).sum().item()) - - print() - for _ in range(3): - coord_q = tuple(torch.randint(0, numel, (1,)).item() for numel in dQ_out.size()) - coord_k = tuple(torch.randint(0, numel, (1,)).item() for numel in dK_out.size()) - coord_v = tuple(torch.randint(0, numel, (1,)).item() for numel in dV_out.size()) - print(f"dQ_out{coord_q}:", float(dQ_out[coord_q].item()).hex()) - print(f"dQ_ref{coord_q}:", float(dQ_ref[coord_q].item()).hex()) - print(f"dK_out{coord_k}:", float(dK_out[coord_k].item()).hex()) - print(f"dK_ref{coord_k}:", float(dK_ref[coord_k].item()).hex()) - print(f"dV_out{coord_v}:", float(dV_out[coord_v].item()).hex()) - print(f"dV_ref{coord_v}:", float(dV_ref[coord_v].item()).hex()) - - print(f"dQ_amax_gpu={dQ_amax_gpu.item()}, dQ_amax_ref={dQ_amax_ref}") - print(f"dK_amax_gpu={dK_amax_gpu.item()}, dK_amax_ref={dK_amax_ref}") - print(f"dV_amax_gpu={dV_amax_gpu.item()}, dV_amax_ref={dV_amax_ref}") - print(f"dP_amax_gpu={dP_amax_gpu.item()}, dP_amax_ref=TODO") - - failed = [] - try: - torch.testing.assert_close(dQ_out, dQ_ref, atol=config["atol"], rtol=config["rtol"]) - except Exception as e: - print("\033[91m" + f"dQ: {e}" + "\033[0m\n"); failed.append("dQ") - try: - torch.testing.assert_close(dK_out, dK_ref, atol=config["atol"], rtol=config["rtol"]) - except Exception as e: - print("\033[91m" + f"dK: {e}" + "\033[0m\n"); failed.append("dK") - try: - torch.testing.assert_close(dV_out, dV_ref, atol=config["atol"], rtol=config["rtol"]) - except Exception as e: - print("\033[91m" + f"dV: {e}" + "\033[0m\n"); failed.append("dV") - - # disable amax due to NaNs currently - try: - torch.testing.assert_close(dQ_amax_gpu.item(), dQ_amax_ref, atol=0.04, rtol=0.10) - except Exception as e: - print("\033[91m" + f"amax_dQ: {e}" + "\033[0m\n"); failed.append("amax_dQ") - try: - torch.testing.assert_close(dK_amax_gpu.item(), dK_amax_ref, atol=0.04, rtol=0.10) - except Exception as e: - print("\033[91m" + f"amax_dK: {e}" + "\033[0m\n"); failed.append("amax_dK") - try: - torch.testing.assert_close(dV_amax_gpu.item(), dV_amax_ref, atol=0.04, rtol=0.10) - except Exception as e: - print("\033[91m" + f"amax_dV: {e}" + "\033[0m\n"); failed.append("amax_dV") - - if len(failed) > 0: - print("\033[91m" + "Failed!" + "\033[0m"); pytest.fail(f"Failed: mismatches in {', '.join(failed)}") - print("\033[92m" + "Passed!" + "\033[0m") - - section_end() - print() diff --git a/test/python/test_sdpa_thd.py b/test/python/test_sdpa_thd.py new file mode 100644 index 00000000..d764d130 --- /dev/null +++ b/test/python/test_sdpa_thd.py @@ -0,0 +1,1119 @@ +""" +Test for SDPA with dynamic shapes and THD (Token-Head-Dimension) layout. + +THD layout is a ragged/packed format where: +- Q: [total_q_tokens, num_heads, head_dim] - packed Q tensor +- K/V: can be BHSD or THD format +- O: [total_q_tokens, num_heads, head_dim] - packed output tensor + +This is similar to FlashInfer's cuDNN prefill implementation. + +The recommended way to run tests: +> pytest -vv -s -rA test_sdpa_dynamic_shapes.py +""" + +import cudnn +import pytest +import torch +import math +from looseversion import LooseVersion +from dataclasses import dataclass +from typing import List, Optional, Tuple +from test_utils import torch_fork_set_rng + +# ========================================= +# Helper Functions and Data Classes +# ========================================= + +from enum import Enum, auto + + +class UIDs(Enum): + Q_UID = auto() + K_UID = auto() + V_UID = auto() + O_UID = auto() + RAGGED_Q_UID = auto() + RAGGED_O_UID = auto() + ACTUAL_SEQ_LENS_Q_UID = auto() + ACTUAL_SEQ_LENS_KV_UID = auto() + + +@dataclass +class SDPAConfig: + """Configuration for SDPA test.""" + + batch_size: int + num_heads_q: int + num_heads_k: int + num_heads_v: int + head_dim_qk: int + head_dim_v: int + max_seq_len_q: int + max_seq_len_kv: int + dtype: torch.dtype = torch.bfloat16 + is_causal: bool = False + attn_scale: Optional[float] = None + + def __post_init__(self): + if self.attn_scale is None: + self.attn_scale = 1.0 / math.sqrt(self.head_dim_qk) + + +def convert_to_cudnn_type(torch_type: torch.dtype) -> cudnn.data_type: + """Convert PyTorch dtype to cuDNN data type.""" + type_map = { + torch.float16: cudnn.data_type.HALF, + torch.bfloat16: cudnn.data_type.BFLOAT16, + torch.float32: cudnn.data_type.FLOAT, + torch.int32: cudnn.data_type.INT32, + torch.int64: cudnn.data_type.INT64, + } + if torch_type not in type_map: + raise ValueError(f"Unsupported tensor data type: {torch_type}") + return type_map[torch_type] + + +def generate_variable_seq_lens( + batch_size: int, + max_seq_len_q: int, + max_seq_len_kv: int, + rng: torch.Generator, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate variable sequence lengths for Q and KV.""" + # Generate random sequence lengths, ensuring seq_len_q <= seq_len_kv + seq_len_q = torch.randint(1, max_seq_len_q + 1, (batch_size,), generator=rng, dtype=torch.int32, device="cuda") + seq_len_kv = torch.randint(1, max_seq_len_kv + 1, (batch_size,), generator=rng, dtype=torch.int32, device="cuda") + + # Ensure seq_len_q <= seq_len_kv for each batch + seq_len_q = torch.minimum(seq_len_q, seq_len_kv) + + return seq_len_q, seq_len_kv + + +def compute_ragged_offsets( + seq_lens: torch.Tensor, + num_heads: int, + head_dim: int, +) -> torch.Tensor: + """ + Compute exclusive prefix sum (ragged offsets) for THD layout. + + For THD layout, the ragged offset for batch i is the cumulative sum of + (seq_len[0:i] * num_heads * head_dim). + + Args: + seq_lens: [batch_size] - sequence lengths + num_heads: number of attention heads + head_dim: dimension per head + + Returns: + ragged_offset: [batch_size + 1, 1, 1, 1] - exclusive prefix sum + """ + batch_size = seq_lens.shape[0] + + # Compute element counts per batch + elements_per_batch = seq_lens * num_heads * head_dim + + # Exclusive prefix sum: [0, elem0, elem0+elem1, ...] + ragged_offset = torch.zeros(batch_size + 1, dtype=torch.int64, device=seq_lens.device) + ragged_offset[1:] = torch.cumsum(elements_per_batch, dim=0) + + # Reshape to [batch_size+1, 1, 1, 1] as expected by cuDNN + ragged_offset = ragged_offset.view(-1, 1, 1, 1) + + return ragged_offset + + +def create_thd_tensor( + seq_lens: torch.Tensor, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + rng: torch.Generator, + mean: float = 0.0, + std: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Create a THD (Token-Head-Dimension) layout tensor. + + THD layout: [total_tokens, num_heads, head_dim] + The tensor is packed - all sequences are concatenated. + + Args: + seq_lens: [batch_size] - sequence length per batch + num_heads: number of attention heads + head_dim: dimension per head + dtype: tensor data type + rng: random number generator + mean: mean for random initialization + std: std for random initialization + + Returns: + tensor: [total_tokens, num_heads, head_dim] + ragged_offset: [batch_size+1, 1, 1, 1] + """ + total_tokens = int(seq_lens.sum().item()) + + # Create the packed tensor + tensor = torch.empty(total_tokens, num_heads, head_dim, dtype=dtype, device="cuda") + tensor.normal_(mean=mean, std=std, generator=rng) + + # Compute ragged offsets + ragged_offset = compute_ragged_offsets(seq_lens, num_heads, head_dim) + + return tensor, ragged_offset + + +def create_bhsd_tensor( + batch_size: int, + num_heads: int, + max_seq_len: int, + head_dim: int, + dtype: torch.dtype, + rng: torch.Generator, + mean: float = 0.0, + std: float = 1.0, +) -> torch.Tensor: + """ + Create a BHSD (Batch-Head-Seq-Dim) shape tensor with BSHD strides. + + Shape: [batch_size, num_heads, max_seq_len, head_dim] + Strides: [seq*heads*dim, dim, heads*dim, 1] (BSHD stride order) + + This is an interleaved format where memory is laid out as BSHD. + + Args: + batch_size: batch size + num_heads: number of attention heads + max_seq_len: maximum sequence length + head_dim: dimension per head + dtype: tensor data type + rng: random number generator + mean: mean for random initialization + std: std for random initialization + + Returns: + tensor: [batch_size, num_heads, max_seq_len, head_dim] with BSHD strides + """ + # Allocate contiguous storage in BSHD order + total_elements = batch_size * max_seq_len * num_heads * head_dim + storage = torch.empty(total_elements, dtype=dtype, device="cuda") + storage.normal_(mean=mean, std=std, generator=rng) + + # Create view with BHSD shape but BSHD strides + # Strides: [S*H*D, D, H*D, 1] + strides = ( + max_seq_len * num_heads * head_dim, # batch stride + head_dim, # head stride + num_heads * head_dim, # seq stride + 1, # dim stride + ) + tensor = torch.as_strided(storage, (batch_size, num_heads, max_seq_len, head_dim), strides) + return tensor + + +def thd_to_bhsd( + thd_tensor: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, +) -> torch.Tensor: + """ + Convert THD layout tensor to BHSD shape with BSHD strides (for reference computation). + + THD layout: [total_tokens, num_heads, head_dim] with strides [H*D, D, 1] + BHSD with BSHD strides: [B, H, S, D] with strides [S*H*D, D, H*D, 1] + + Both layouts have the same underlying memory pattern per batch: + - THD[t, h, d] -> memory[t*H*D + h*D + d] + - BHSD[b, h, s, d] with BSHD strides -> memory[b*S*H*D + s*H*D + h*D + d] + + So within a batch, the relative offsets are identical - no transpose needed! + + Args: + thd_tensor: [total_tokens, num_heads, head_dim] + seq_lens: [batch_size] - sequence length per batch + max_seq_len: maximum sequence length for padding + + Returns: + bhsd_tensor: [batch_size, num_heads, max_seq_len, head_dim] with BSHD strides + """ + batch_size = seq_lens.shape[0] + _, num_heads, head_dim = thd_tensor.shape + + # Allocate storage in BSHD physical order: [B, S, H, D] contiguous + storage = torch.zeros(batch_size, max_seq_len, num_heads, head_dim, dtype=thd_tensor.dtype, device=thd_tensor.device) + + # Copy data batch by batch using direct copy - no transpose needed! + # THD [seq, H, D] has same memory layout as BSHD [B, S, H, D] within each batch + offset = 0 + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + # Direct copy: THD [seq_len, H, D] -> BSHD storage[i, :seq_len, H, D] + storage[i, :seq_len, :, :] = thd_tensor[offset : offset + seq_len] + offset += seq_len + + # Create BHSD view with BSHD strides from the BSHD storage + # storage is [B, S, H, D] contiguous, we want [B, H, S, D] view + bhsd_tensor = storage.permute(0, 2, 1, 3) # [B, S, H, D] -> [B, H, S, D] + + return bhsd_tensor + + +def bhsd_to_thd( + bhsd_tensor: torch.Tensor, + seq_lens: torch.Tensor, +) -> torch.Tensor: + """ + Convert BHSD shape tensor (with BSHD strides) to THD layout. + + BHSD with BSHD strides: [B, H, S, D] with strides [S*H*D, D, H*D, 1] + THD layout: [total_tokens, num_heads, head_dim] with strides [H*D, D, 1] + + Both layouts have the same underlying memory pattern per batch - no transpose needed! + + Args: + bhsd_tensor: [batch_size, num_heads, max_seq_len, head_dim] with BSHD strides + seq_lens: [batch_size] - sequence length per batch + + Returns: + thd_tensor: [total_tokens, num_heads, head_dim] + """ + batch_size = seq_lens.shape[0] + _, num_heads, _, head_dim = bhsd_tensor.shape + total_tokens = int(seq_lens.sum().item()) + + # Create output tensor + thd_tensor = torch.empty(total_tokens, num_heads, head_dim, dtype=bhsd_tensor.dtype, device=bhsd_tensor.device) + + # Convert BHSD [B, H, S, D] back to BSHD view [B, S, H, D] for direct copy + bshd_tensor = bhsd_tensor.permute(0, 2, 1, 3) # [B, H, S, D] -> [B, S, H, D] + + # Copy data batch by batch - no transpose needed! + # BSHD[b, s, h, d] has same memory layout as THD[t, h, d] + offset = 0 + for i in range(batch_size): + seq_len = int(seq_lens[i].item()) + # Direct copy: BSHD [1, seq_len, H, D] -> THD [seq_len, H, D] + thd_tensor[offset : offset + seq_len] = bshd_tensor[i, :seq_len, :, :] + offset += seq_len + + return thd_tensor + + +# ========================================= +# Reference Implementation +# ========================================= + + +def compute_sdpa_reference( + q_bhsd: torch.Tensor, + k_bhsd: torch.Tensor, + v_bhsd: torch.Tensor, + seq_len_q: torch.Tensor, + seq_len_kv: torch.Tensor, + attn_scale: float, + is_causal: bool = False, +) -> torch.Tensor: + """ + Compute SDPA reference output in float32. + + Args: + q_bhsd: [batch, heads_q, max_seq_q, head_dim_qk] + k_bhsd: [batch, heads_k, max_seq_kv, head_dim_qk] + v_bhsd: [batch, heads_v, max_seq_kv, head_dim_v] + seq_len_q: [batch] - actual sequence lengths for Q + seq_len_kv: [batch] - actual sequence lengths for K/V + attn_scale: attention scaling factor + is_causal: whether to apply causal masking + + Returns: + o_bhsd: [batch, heads_q, max_seq_q, head_dim_v] + """ + batch_size, num_heads_q, max_seq_q, head_dim_qk = q_bhsd.shape + _, num_heads_k, max_seq_kv, _ = k_bhsd.shape + _, num_heads_v, _, head_dim_v = v_bhsd.shape + + # Convert to float32 for reference computation + q = q_bhsd.to(dtype=torch.float32) + k = k_bhsd.to(dtype=torch.float32) + v = v_bhsd.to(dtype=torch.float32) + + # Handle GQA/MQA by expanding K and V + if num_heads_q != num_heads_k: + assert num_heads_q % num_heads_k == 0, "num_heads_q must be divisible by num_heads_k" + k = k.unsqueeze(2).expand(-1, -1, num_heads_q // num_heads_k, -1, -1) + k = k.reshape(batch_size, num_heads_q, max_seq_kv, head_dim_qk) + if num_heads_q != num_heads_v: + assert num_heads_q % num_heads_v == 0, "num_heads_q must be divisible by num_heads_v" + v = v.unsqueeze(2).expand(-1, -1, num_heads_q // num_heads_v, -1, -1) + v = v.reshape(batch_size, num_heads_q, max_seq_kv, head_dim_v) + + # Compute attention scores: [batch, heads, seq_q, seq_kv] + scores = torch.einsum("bhqd,bhkd->bhqk", q, k) * attn_scale + + # Create padding mask + device = q.device + q_mask = torch.zeros(batch_size, 1, max_seq_q, 1, dtype=torch.bool, device=device) + kv_mask = torch.zeros(batch_size, 1, 1, max_seq_kv, dtype=torch.bool, device=device) + for i in range(batch_size): + q_mask[i, :, seq_len_q[i] :, :] = True + kv_mask[i, :, :, seq_len_kv[i] :] = True + + # Apply padding mask + scores = scores.masked_fill(kv_mask, float("-inf")) + + # Apply causal mask if requested + if is_causal: + causal_mask = torch.ones(max_seq_q, max_seq_kv, dtype=torch.bool, device=device) + causal_mask.triu_(diagonal=1) + scores = scores.masked_fill(causal_mask, float("-inf")) + + # Softmax + attn_weights = torch.softmax(scores, dim=-1) + + # Mask out padded Q positions (set to 0) + attn_weights = attn_weights.masked_fill(q_mask, 0.0) + + # Compute output + output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, v) + + # Zero out padded positions in output + output_mask = torch.zeros(batch_size, 1, max_seq_q, 1, dtype=torch.bool, device=device) + for i in range(batch_size): + output_mask[i, :, seq_len_q[i] :, :] = True + output = output.masked_fill(output_mask, 0.0) + + return output + + +# ========================================= +# cuDNN Graph Builder +# ========================================= + +graph_cache = {} + + +def lookup_graph_from_cache(batch_size: int, h_q: int, h_k: int, h_v: int, d_qk: int, d_v: int, max_s_kv: int, causal: bool) -> cudnn.pygraph: + """ + Lookup a graph from the cuDNN graph cache. + """ + key = (batch_size, h_q, h_k, h_v, d_qk, d_v, max_s_kv, causal) + + if key in graph_cache: + return graph_cache[key] + + return None + + +def add_to_cudnn_graph_cache(batch_size: int, h_q: int, h_k: int, h_v: int, d_qk: int, d_v: int, max_s_kv: int, causal: bool, graph: cudnn.pygraph) -> None: + """ + Add a graph to the cuDNN graph cache. + """ + + key = (batch_size, h_q, h_k, h_v, d_qk, d_v, max_s_kv, causal) + + graph_cache[key] = graph + return None + + +def build_cudnn_sdpa_thd_graph( + cudnn_handle, + config: SDPAConfig, + seq_len_q: torch.Tensor, + seq_len_kv: torch.Tensor, + q_ragged_offset: torch.Tensor, + o_ragged_offset: torch.Tensor, + q_gpu: torch.Tensor, + k_gpu: torch.Tensor, + v_gpu: torch.Tensor, + o_gpu: torch.Tensor, +): + """ + Build cuDNN graph for SDPA with THD layout for Q and O. + + Q and O are in THD (Token-Head-Dimension) ragged layout. + K and V are in BHSD (Batch-Head-Seq-Dim) layout. + """ + batch_size = config.batch_size + h_q = config.num_heads_q + h_k = config.num_heads_k + h_v = config.num_heads_v + d_qk = config.head_dim_qk + d_v = config.head_dim_v + max_s_q = config.max_seq_len_q + max_s_kv = config.max_seq_len_kv + + cudnn_dtype = convert_to_cudnn_type(config.dtype) + + # Look up pre-built graph from cache + + graph = lookup_graph_from_cache(batch_size, h_q, h_k, h_v, d_qk, d_v, max_s_kv, config.is_causal) + + if graph is not None: + print("Returning existing graph since it already exists") + return graph + + # Create the graph + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=cudnn_handle, + is_dynamic_shape_enabled=True, + ) + + # Q tensor in THD layout with BHSD logical shape + # Physical shape: [total_q_tokens, h_q, d_qk] + # Logical shape for cuDNN: [batch, heads, max_seq_q, head_dim] + # Stride: THD -> [h_q * d_qk, d_qk, h_q * d_qk, 1] (bshd stride order) + q = graph.tensor( + dim=(batch_size, h_q, max_s_q, d_qk), + stride=(h_q * d_qk, d_qk, h_q * d_qk, 1), # bshd stride order for THD + data_type=cudnn_dtype, + name="Q", + uid=UIDs.Q_UID.value, + ) + + # Q ragged offset tensor + q_ragged = graph.tensor( + dim=(batch_size + 1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.INT64, + name="Q_ragged_offset", + uid=UIDs.RAGGED_Q_UID.value, + ) + q.set_ragged_offset(q_ragged) + + # K tensor in BHSD layout + k = graph.tensor( + dim=(batch_size, h_k, max_s_kv, d_qk), + stride=(h_k * max_s_kv * d_qk, d_qk, h_k * d_qk, 1), # bshd stride order + data_type=cudnn_dtype, + name="K", + uid=UIDs.K_UID.value, + ) + + # V tensor in BHSD layout + v = graph.tensor( + dim=(batch_size, h_v, max_s_kv, d_v), + stride=(h_v * max_s_kv * d_v, d_v, h_v * d_v, 1), # bshd stride order + data_type=cudnn_dtype, + name="V", + uid=UIDs.V_UID.value, + ) + + # Sequence length tensors + seq_len_q_tensor = graph.tensor( + dim=(batch_size, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.INT32, + name="seq_len_q", + uid=UIDs.ACTUAL_SEQ_LENS_Q_UID.value, + ) + + seq_len_kv_tensor = graph.tensor( + dim=(batch_size, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.INT32, + name="seq_len_kv", + uid=UIDs.ACTUAL_SEQ_LENS_KV_UID.value, + ) + + # Call SDPA + o, stats = graph.sdpa( + name="sdpa_thd", + q=q, + k=k, + v=v, + attn_scale=config.attn_scale, + use_padding_mask=True, + seq_len_q=seq_len_q_tensor, + seq_len_kv=seq_len_kv_tensor, + use_causal_mask=config.is_causal, + generate_stats=False, + ) + + # Output tensor in THD layout + o.set_output(True).set_dim((batch_size, h_q, max_s_q, d_v)).set_stride((h_q * d_v, d_v, h_q * d_v, 1)).set_data_type( # bshd stride order for THD + cudnn_dtype + ) + o.set_uid(UIDs.O_UID.value) + + # O ragged offset tensor (reuse Q's ragged offset structure for d_qk == d_v) + o_ragged = graph.tensor( + dim=(batch_size + 1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.INT64, + name="O_ragged_offset", + uid=UIDs.RAGGED_O_UID.value, + ) + o.set_ragged_offset(o_ragged) + + # Validate and build the graph + try: + graph.validate() + except cudnn.cudnnGraphNotSupportedError as e: + pytest.skip(f"Graph not supported: {e}") + except Exception as e: + pytest.fail(f"Unexpected error during graph validation: {e}") + + try: + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + except cudnn.cudnnGraphNotSupportedError as e: + pytest.skip(f"Graph not supported after validation: {e}") + except Exception as e: + pytest.fail(f"Unexpected error after graph validation: {e}") + + add_to_cudnn_graph_cache(batch_size, h_q, h_k, h_v, d_qk, d_v, max_s_kv, config.is_causal, graph) + + return graph + + +def execute_cudnn_sdpa_thd( + cudnn_handle, + config: SDPAConfig, + q_gpu: torch.Tensor, + k_gpu: torch.Tensor, + v_gpu: torch.Tensor, + seq_len_q: torch.Tensor, + seq_len_kv: torch.Tensor, + q_ragged_offset: torch.Tensor, + o_ragged_offset: torch.Tensor, +) -> torch.Tensor: + """ + Execute cuDNN SDPA with THD layout. + + Args: + cudnn_handle: cuDNN handle + config: SDPA configuration + q_gpu: [total_q_tokens, num_heads_q, head_dim_qk] - Q in THD layout + k_gpu: [batch, num_heads_k, max_seq_kv, head_dim_qk] - K in BHSD layout + v_gpu: [batch, num_heads_v, max_seq_kv, head_dim_v] - V in BHSD layout + seq_len_q: [batch] - actual Q sequence lengths + seq_len_kv: [batch] - actual KV sequence lengths + q_ragged_offset: [batch+1, 1, 1, 1] - Q ragged offsets + o_ragged_offset: [batch+1, 1, 1, 1] - O ragged offsets + + Returns: + o_gpu: [total_q_tokens, num_heads_q, head_dim_v] - output in THD layout + """ + total_q_tokens = q_gpu.shape[0] + + # Allocate output tensor + o_gpu = torch.empty(total_q_tokens, config.num_heads_q, config.head_dim_v, dtype=config.dtype, device="cuda") + + # Reshape seq_len tensors to [batch, 1, 1, 1] + seq_len_q_4d = seq_len_q.view(-1, 1, 1, 1) + seq_len_kv_4d = seq_len_kv.view(-1, 1, 1, 1) + + # Build the graph + graph = build_cudnn_sdpa_thd_graph(cudnn_handle, config, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset, q_gpu, k_gpu, v_gpu, o_gpu) + + q_shape = [config.batch_size, config.num_heads_q, config.max_seq_len_q, config.head_dim_qk] + o_shape = [config.batch_size, config.num_heads_q, config.max_seq_len_q, config.head_dim_v] + # Create variant pack + variant_pack = { + UIDs.Q_UID.value: q_gpu, + UIDs.RAGGED_Q_UID.value: q_ragged_offset, + UIDs.K_UID.value: k_gpu, + UIDs.V_UID.value: v_gpu, + UIDs.ACTUAL_SEQ_LENS_Q_UID.value: seq_len_q_4d, + UIDs.ACTUAL_SEQ_LENS_KV_UID.value: seq_len_kv_4d, + UIDs.O_UID.value: o_gpu, + UIDs.RAGGED_O_UID.value: o_ragged_offset, + } + + # Allocate workspace + workspace_size = graph.get_workspace_size() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + + # Execute + stream = torch.cuda.current_stream().cuda_stream + cudnn.set_stream(handle=cudnn_handle, stream=stream) + override_uids = [ + UIDs.Q_UID.value, + UIDs.RAGGED_Q_UID.value, + UIDs.K_UID.value, + UIDs.V_UID.value, + UIDs.ACTUAL_SEQ_LENS_Q_UID.value, + UIDs.ACTUAL_SEQ_LENS_KV_UID.value, + UIDs.O_UID.value, + UIDs.RAGGED_O_UID.value, + ] + override_shapes = [ + q_shape, + q_ragged_offset.shape, + k_gpu.shape, + v_gpu.shape, + seq_len_q_4d.shape, + seq_len_kv_4d.shape, + o_shape, + o_ragged_offset.shape, + ] + override_strides = [ + q_gpu.stride(), + q_ragged_offset.stride(), + k_gpu.stride(), + v_gpu.stride(), + seq_len_q_4d.stride(), + seq_len_kv_4d.stride(), + o_gpu.stride(), + o_ragged_offset.stride(), + ] + graph.execute(variant_pack, workspace, handle=cudnn_handle, override_uids=override_uids, override_shapes=override_shapes, override_strides=override_strides) + torch.cuda.synchronize() + + return o_gpu + + +# ========================================= +# Test Functions +# ========================================= + + +def compare_outputs( + output_gpu: torch.Tensor, + output_ref: torch.Tensor, + seq_lens: torch.Tensor, + atol: float = 0.02, + rtol: float = 0.02, + tag: str = "output", +) -> int: + """ + Compare GPU output with reference, accounting for padding. + + Returns number of mismatches. + """ + # Convert THD output to BHSD for comparison if needed + if output_gpu.dim() == 3: # THD layout + # Both should be in THD already for this comparison + actual = output_gpu.float() + expected = output_ref.float() + else: + actual = output_gpu.float() + expected = output_ref.float() + + mismatches = torch.where(torch.isclose(actual, expected, rtol=rtol, atol=atol) == False) + mismatch_cnt = mismatches[0].numel() + + if mismatch_cnt > 0: + percentage = 100 * mismatch_cnt / actual.numel() + print(f"\n{tag}: {mismatch_cnt:,} mismatches ({percentage:.2f}%)") + + # Show first few mismatches + for idx in range(min(10, mismatch_cnt)): + pos = tuple(m[idx].item() for m in mismatches) + diff = actual[pos] - expected[pos] + print(f" idx{pos}: gpu={actual[pos]:+.6e}, ref={expected[pos]:+.6e}, diff={diff:+.2e}") + else: + print(f"{tag}: All values match within tolerance (atol={atol}, rtol={rtol})") + + return mismatch_cnt + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=42) +def test_sdpa_thd_dynamic_shapes(cudnn_handle): + """Basic test for SDPA with THD layout.""" + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0 or higher") + + if torch.cuda.get_device_capability()[0] < 9: + pytest.skip("SDPA with THD layout requires SM90 or higher") + + print("\n" + "=" * 80) + print("Test: SDPA with THD layout (basic)") + print("=" * 80) + + # Configurations + configs = [ + SDPAConfig( + batch_size=4, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + max_seq_len_q=256, + max_seq_len_kv=512, + dtype=torch.bfloat16, + is_causal=False, + ), + SDPAConfig( + batch_size=4, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + max_seq_len_q=512, + max_seq_len_kv=512, + dtype=torch.bfloat16, + is_causal=False, + ), + SDPAConfig( + batch_size=4, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + max_seq_len_q=384, + max_seq_len_kv=512, + dtype=torch.bfloat16, + is_causal=False, + ), + ] + + rng = torch.Generator(device="cuda").manual_seed(42) + for config in configs: + print( + f"Config: batch={config.batch_size}, h_q={config.num_heads_q}, " + f"h_k={config.num_heads_k}, h_v={config.num_heads_v}, " + f"d_qk={config.head_dim_qk}, d_v={config.head_dim_v}, " + f"max_s_q={config.max_seq_len_q}, max_s_kv={config.max_seq_len_kv}" + ) + + # Generate variable sequence lengths + seq_len_q, seq_len_kv = generate_variable_seq_lens(config.batch_size, config.max_seq_len_q, config.max_seq_len_kv, rng) + print(f"seq_len_q: {seq_len_q.tolist()}") + print(f"seq_len_kv: {seq_len_kv.tolist()}") + + # Create Q in THD layout + q_thd, q_ragged_offset = create_thd_tensor(seq_len_q, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + print(f"Q shape (THD): {q_thd.shape}") + + # Create K, V in BHSD layout + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.max_seq_len_kv, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.max_seq_len_kv, config.head_dim_v, config.dtype, rng) + print(f"K shape (BHSD): {k_bhsd.shape}") + print(f"V shape (BHSD): {v_bhsd.shape}") + + # Compute O ragged offsets (same structure as Q for d_qk == d_v) + o_ragged_offset = compute_ragged_offsets(seq_len_q, config.num_heads_q, config.head_dim_v) + + # Execute cuDNN SDPA + print("\nExecuting cuDNN SDPA with THD layout...") + o_thd_gpu = execute_cudnn_sdpa_thd(cudnn_handle, config, q_thd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset) + print(f"Output shape (THD): {o_thd_gpu.shape}") + + # Compute reference + print("\nComputing reference output...") + # Convert Q from THD to BHSD for reference + q_bhsd_ref = thd_to_bhsd(q_thd, seq_len_q, config.max_seq_len_q) + + # print(f"Q BHSD shape: {q_bhsd_ref.shape} {q_bhsd_ref[0, :, :, 0:10]}") + # print(f"Q THD shape: {q_thd.shape} {q_thd[0, :, 0:10]}") + + o_bhsd_ref = compute_sdpa_reference(q_bhsd_ref, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, config.attn_scale, config.is_causal) + + # Convert reference output from BHSD to THD + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_len_q) + + # Compare outputs + print("\nComparing outputs...") + err_count = compare_outputs(o_thd_gpu, o_thd_ref, seq_len_q, atol=0.02, rtol=0.02) + + if err_count > 0: + pytest.fail(f"SDPA THD test failed with {err_count} mismatches") + else: + print("\n" + "=" * 80) + print("TEST PASSED: SDPA with THD layout") + print("=" * 80) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=123) +def test_sdpa_thd_gqa(cudnn_handle): + """Test SDPA with THD layout and GQA (Grouped Query Attention).""" + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0 or higher") + + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("SDPA with THD layout requires SM80 or higher") + + print("\n" + "=" * 80) + print("Test: SDPA with THD layout + GQA") + print("=" * 80) + + # Configuration with GQA (h_q > h_k = h_v) + config = SDPAConfig( + batch_size=4, + num_heads_q=8, + num_heads_k=2, # GQA: 4 Q heads per K head + num_heads_v=2, + head_dim_qk=64, + head_dim_v=64, + max_seq_len_q=128, + max_seq_len_kv=256, + dtype=torch.bfloat16, + is_causal=False, + ) + + print( + f"Config: batch={config.batch_size}, h_q={config.num_heads_q}, " + f"h_k={config.num_heads_k}, h_v={config.num_heads_v}, " + f"d_qk={config.head_dim_qk}, d_v={config.head_dim_v}" + ) + + rng = torch.Generator(device="cuda").manual_seed(123) + + # Generate variable sequence lengths + seq_len_q, seq_len_kv = generate_variable_seq_lens(config.batch_size, config.max_seq_len_q, config.max_seq_len_kv, rng) + print(f"seq_len_q: {seq_len_q.tolist()}") + print(f"seq_len_kv: {seq_len_kv.tolist()}") + + # Create Q in THD layout + q_thd, q_ragged_offset = create_thd_tensor(seq_len_q, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + + # Create K, V in BHSD layout + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.max_seq_len_kv, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.max_seq_len_kv, config.head_dim_v, config.dtype, rng) + + # Compute O ragged offsets + o_ragged_offset = compute_ragged_offsets(seq_len_q, config.num_heads_q, config.head_dim_v) + + # Execute cuDNN SDPA + print("\nExecuting cuDNN SDPA with THD + GQA...") + o_thd_gpu = execute_cudnn_sdpa_thd(cudnn_handle, config, q_thd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset) + + # Compute reference + print("\nComputing reference output...") + q_bhsd_ref = thd_to_bhsd(q_thd, seq_len_q, config.max_seq_len_q) + o_bhsd_ref = compute_sdpa_reference(q_bhsd_ref, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, config.attn_scale, config.is_causal) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_len_q) + + # Compare outputs + print("\nComparing outputs...") + err_count = compare_outputs(o_thd_gpu, o_thd_ref, seq_len_q, atol=0.02, rtol=0.02) + + if err_count > 0: + pytest.fail(f"SDPA THD + GQA test failed with {err_count} mismatches") + else: + print("\n" + "=" * 80) + print("TEST PASSED: SDPA with THD layout + GQA") + print("=" * 80) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=456) +def test_sdpa_thd_causal(cudnn_handle): + """Test SDPA with THD layout and causal masking.""" + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0 or higher") + + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("SDPA with THD layout requires SM80 or higher") + + print("\n" + "=" * 80) + print("Test: SDPA with THD layout + Causal Masking") + print("=" * 80) + + # Configuration + config = SDPAConfig( + batch_size=4, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=64, + head_dim_v=64, + max_seq_len_q=256, + max_seq_len_kv=256, + dtype=torch.bfloat16, + is_causal=True, # Enable causal masking + ) + + print(f"Config: batch={config.batch_size}, h_q={config.num_heads_q}, " f"is_causal={config.is_causal}") + + rng = torch.Generator(device="cuda").manual_seed(456) + + # Generate variable sequence lengths (ensure seq_len_q == seq_len_kv for causal) + seq_len_q, seq_len_kv = generate_variable_seq_lens(config.batch_size, config.max_seq_len_q, config.max_seq_len_kv, rng) + # For causal attention, typically s_q == s_kv + seq_len_kv = seq_len_q.clone() + + print(f"seq_len_q: {seq_len_q.tolist()}") + print(f"seq_len_kv: {seq_len_kv.tolist()}") + + # Create tensors + q_thd, q_ragged_offset = create_thd_tensor(seq_len_q, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.max_seq_len_kv, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.max_seq_len_kv, config.head_dim_v, config.dtype, rng) + o_ragged_offset = compute_ragged_offsets(seq_len_q, config.num_heads_q, config.head_dim_v) + + # Execute cuDNN SDPA + print("\nExecuting cuDNN SDPA with THD + Causal...") + o_thd_gpu = execute_cudnn_sdpa_thd(cudnn_handle, config, q_thd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset) + + # Compute reference + print("\nComputing reference output...") + q_bhsd_ref = thd_to_bhsd(q_thd, seq_len_q, config.max_seq_len_q) + o_bhsd_ref = compute_sdpa_reference(q_bhsd_ref, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, config.attn_scale, config.is_causal) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_len_q) + + # Compare outputs + print("\nComparing outputs...") + err_count = compare_outputs(o_thd_gpu, o_thd_ref, seq_len_q, atol=0.02, rtol=0.02) + + if err_count > 0: + pytest.fail(f"SDPA THD + Causal test failed with {err_count} mismatches") + else: + print("\n" + "=" * 80) + print("TEST PASSED: SDPA with THD layout + Causal Masking") + print("=" * 80) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=789) +def test_sdpa_thd_seq1(cudnn_handle): + """Test SDPA with THD layout for seq_len_q=1 (decode-like).""" + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0 or higher") + + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("SDPA with THD layout requires SM80 or higher") + + print("\n" + "=" * 80) + print("Test: SDPA with THD layout (seq_q=1, decode-like)") + print("=" * 80) + + # Configuration for decode-like scenario (s_q=1) + config = SDPAConfig( + batch_size=8, + num_heads_q=8, + num_heads_k=8, + num_heads_v=8, + head_dim_qk=128, + head_dim_v=128, + max_seq_len_q=1, # Single token query + max_seq_len_kv=512, # Long context + dtype=torch.bfloat16, + is_causal=False, + ) + + print( + f"Config: batch={config.batch_size}, h={config.num_heads_q}, " + f"d={config.head_dim_qk}, max_s_q={config.max_seq_len_q}, max_s_kv={config.max_seq_len_kv}" + ) + + rng = torch.Generator(device="cuda").manual_seed(789) + + # All batches have seq_len_q = 1 + seq_len_q = torch.ones(config.batch_size, dtype=torch.int32, device="cuda") + # Variable KV lengths + seq_len_kv = torch.randint(1, config.max_seq_len_kv + 1, (config.batch_size,), generator=rng, dtype=torch.int32, device="cuda") + print(f"seq_len_q: {seq_len_q.tolist()}") + print(f"seq_len_kv: {seq_len_kv.tolist()}") + + # Create tensors + q_thd, q_ragged_offset = create_thd_tensor(seq_len_q, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.max_seq_len_kv, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.max_seq_len_kv, config.head_dim_v, config.dtype, rng) + o_ragged_offset = compute_ragged_offsets(seq_len_q, config.num_heads_q, config.head_dim_v) + + print(f"Q shape (THD): {q_thd.shape}") + + # Execute cuDNN SDPA + print("\nExecuting cuDNN SDPA with THD (seq_q=1)...") + o_thd_gpu = execute_cudnn_sdpa_thd(cudnn_handle, config, q_thd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset) + + # Compute reference + print("\nComputing reference output...") + q_bhsd_ref = thd_to_bhsd(q_thd, seq_len_q, config.max_seq_len_q) + o_bhsd_ref = compute_sdpa_reference(q_bhsd_ref, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, config.attn_scale, config.is_causal) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_len_q) + + # Compare outputs + print("\nComparing outputs...") + err_count = compare_outputs(o_thd_gpu, o_thd_ref, seq_len_q, atol=0.02, rtol=0.02) + + if err_count > 0: + pytest.fail(f"SDPA THD seq_q=1 test failed with {err_count} mismatches") + else: + print("\n" + "=" * 80) + print("TEST PASSED: SDPA with THD layout (seq_q=1)") + print("=" * 80) + + +@pytest.mark.L0 +@torch_fork_set_rng(seed=999) +def test_sdpa_thd_large_batch(cudnn_handle): + """Test SDPA with THD layout with larger batch sizes.""" + cudnn_version = LooseVersion(cudnn.backend_version_string()) + if cudnn_version < "9.10.0": + pytest.skip("THD layout requires cuDNN 9.10.0 or higher") + + if torch.cuda.get_device_capability()[0] < 8: + pytest.skip("SDPA with THD layout requires SM80 or higher") + + print("\n" + "=" * 80) + print("Test: SDPA with THD layout (large batch)") + print("=" * 80) + + # Configuration with larger batch + config = SDPAConfig( + batch_size=32, + num_heads_q=8, + num_heads_k=2, # GQA + num_heads_v=2, + head_dim_qk=128, + head_dim_v=128, + max_seq_len_q=512, + max_seq_len_kv=512, + dtype=torch.bfloat16, + is_causal=False, + ) + + print(f"Config: batch={config.batch_size}, h_q={config.num_heads_q}, " f"h_k={config.num_heads_k}, d={config.head_dim_qk}") + + rng = torch.Generator(device="cuda").manual_seed(999) + + # Generate variable sequence lengths + seq_len_q, seq_len_kv = generate_variable_seq_lens(config.batch_size, config.max_seq_len_q, config.max_seq_len_kv, rng) + print(f"seq_len_q range: [{seq_len_q.min().item()}, {seq_len_q.max().item()}]") + print(f"seq_len_kv range: [{seq_len_kv.min().item()}, {seq_len_kv.max().item()}]") + print(f"Total Q tokens: {seq_len_q.sum().item()}") + + # Create tensors + q_thd, q_ragged_offset = create_thd_tensor(seq_len_q, config.num_heads_q, config.head_dim_qk, config.dtype, rng) + k_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_k, config.max_seq_len_kv, config.head_dim_qk, config.dtype, rng) + v_bhsd = create_bhsd_tensor(config.batch_size, config.num_heads_v, config.max_seq_len_kv, config.head_dim_v, config.dtype, rng) + o_ragged_offset = compute_ragged_offsets(seq_len_q, config.num_heads_q, config.head_dim_v) + + # Execute cuDNN SDPA + print("\nExecuting cuDNN SDPA with THD (large batch)...") + o_thd_gpu = execute_cudnn_sdpa_thd(cudnn_handle, config, q_thd, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, q_ragged_offset, o_ragged_offset) + + # Compute reference + print("\nComputing reference output...") + q_bhsd_ref = thd_to_bhsd(q_thd, seq_len_q, config.max_seq_len_q) + o_bhsd_ref = compute_sdpa_reference(q_bhsd_ref, k_bhsd, v_bhsd, seq_len_q, seq_len_kv, config.attn_scale, config.is_causal) + o_thd_ref = bhsd_to_thd(o_bhsd_ref, seq_len_q) + + # Compare outputs + print("\nComparing outputs...") + err_count = compare_outputs(o_thd_gpu, o_thd_ref, seq_len_q, atol=0.02, rtol=0.02) + + if err_count > 0: + pytest.fail(f"SDPA THD large batch test failed with {err_count} mismatches") + else: + print("\n" + "=" * 80) + print("TEST PASSED: SDPA with THD layout (large batch)") + print("=" * 80) + + +# ========================================= +# Main Entry Point +# ========================================= + +if __name__ == "__main__": + print("This is a pytest script.") + print("Run with: pytest -vv -s -rA test_sdpa_dynamic_shapes.py") diff --git a/test/python/test_sdpa_with_caching.py b/test/python/test_sdpa_with_caching.py index 51437be5..c5640ee5 100644 --- a/test/python/test_sdpa_with_caching.py +++ b/test/python/test_sdpa_with_caching.py @@ -28,9 +28,7 @@ cuda_graphs = {} H_Q = H_K = H_V = 6 -D_QK = D_VO = ( - 128 # If you are changing D_VO != D_QK, you need to change the code in create_qkv_tensors for ragged offsets of O -) +D_QK = D_VO = 128 # If you are changing D_VO != D_QK, you need to change the code in create_qkv_tensors for ragged offsets of O MAX_SEQ_LEN_Q = 1024 MAX_SEQ_LEN_KV = 1024 @@ -174,11 +172,9 @@ def lookup_or_create_sdpa_graph(handle, batch_size): compute_data_type=cudnn.data_type.FLOAT, ) - O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( - [batch_size, H_Q, MAX_SEQ_LEN_Q, D_VO] - ).set_stride([MAX_SEQ_LEN_Q * D_VO * H_Q, D_VO, D_VO * H_Q, 1]).set_data_type( - cudnn.data_type.BFLOAT16 - ) + O.set_uid(UIDs.O_UID.value).set_output(True).set_dim([batch_size, H_Q, MAX_SEQ_LEN_Q, D_VO]).set_stride( + [MAX_SEQ_LEN_Q * D_VO * H_Q, D_VO, D_VO * H_Q, 1] + ).set_data_type(cudnn.data_type.BFLOAT16) O.set_ragged_offset(ragged_q) @@ -193,9 +189,7 @@ def lookup_or_create_sdpa_graph(handle, batch_size): def pad_batch_size(batch_size, actual_seq_lens_q, actual_seq_lens_kv, ragged_offset_q): batch_buckets_keys = list(batch_buckets.keys()) - batch_size_padded = next( - (b for b in batch_buckets_keys if b >= batch_size), batch_buckets_keys[-1] - ) + batch_size_padded = next((b for b in batch_buckets_keys if b >= batch_size), batch_buckets_keys[-1]) zeros = torch.zeros( (batch_size_padded - batch_size, 1, 1, 1), dtype=actual_seq_lens_q.dtype, @@ -242,9 +236,7 @@ def test_ragged_sdpa_with_caching(cudnn_handle): # For example, you can bucket by sequence length, or masking pattern, etc. for _batch_size in batch_buckets.keys(): - batch_buckets[_batch_size] = lookup_or_create_sdpa_graph( - cudnn_handle, _batch_size - ) + batch_buckets[_batch_size] = lookup_or_create_sdpa_graph(cudnn_handle, _batch_size) logger.info(f"Buckets initialized") @@ -271,9 +263,7 @@ def test_ragged_sdpa_with_caching(cudnn_handle): device=device, ) - q_gpu, k_gpu, v_gpu, ragged_offset_q, out_gpu = create_qkv_tensors( - batch_size, actual_seq_lens_q, actual_seq_lens_kv - ) + q_gpu, k_gpu, v_gpu, ragged_offset_q, out_gpu = create_qkv_tensors(batch_size, actual_seq_lens_q, actual_seq_lens_kv) samples.append( ( @@ -313,14 +303,10 @@ def test_ragged_sdpa_with_caching(cudnn_handle): padded_actual_seq_lens_q, padded_actual_seq_lens_kv, padded_ragged_offset_q, - ) = pad_batch_size( - batch_size, actual_seq_lens_q, actual_seq_lens_kv, ragged_offset_q - ) + ) = pad_batch_size(batch_size, actual_seq_lens_q, actual_seq_lens_kv, ragged_offset_q) torch.cuda.nvtx.range_pop() - logger.info( - f"Executing the sample with actual batch_size: {batch_size} and padded_batch_size: {padded_batch_size}" - ) + logger.info(f"Executing the sample with actual batch_size: {batch_size} and padded_batch_size: {padded_batch_size}") # This will not create a new graph, it will return the graph from the bucket by the key function torch.cuda.nvtx.range_push("Look up the graph") diff --git a/test/python/test_silu_and_mul.py b/test/python/test_silu_and_mul.py index ad7bcd54..c6bc472c 100644 --- a/test/python/test_silu_and_mul.py +++ b/test/python/test_silu_and_mul.py @@ -30,15 +30,9 @@ def test_gemm_silu_and_mul(cudnn_handle): compute_data_type=cudnn.data_type.FLOAT, ) - X_gpu = torch.randint(-8, 8, (1, M, K), requires_grad=False, device="cuda").to( - dtype=torch.float8_e4m3fn - ) - W_gpu = torch.randint(-8, 8, (2, K, N), requires_grad=False, device="cuda").to( - dtype=torch.float8_e4m3fn - ) - C_gpu = torch.zeros(1, M, N, requires_grad=False, device="cuda").to( - dtype=torch.float - ) + X_gpu = torch.randint(-8, 8, (1, M, K), requires_grad=False, device="cuda").to(dtype=torch.float8_e4m3fn) + W_gpu = torch.randint(-8, 8, (2, K, N), requires_grad=False, device="cuda").to(dtype=torch.float8_e4m3fn) + C_gpu = torch.zeros(1, M, N, requires_grad=False, device="cuda").to(dtype=torch.float) scale = 0.5 X_DQ_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") @@ -89,9 +83,7 @@ def test_gemm_silu_and_mul(cudnn_handle): C_combined = graph.binary_select(C2, C3, B_mask) C = graph.reduction(C_combined, mode=cudnn.reduction_mode.MUL) - C.set_dim([1, M, N]).set_stride([M * N, N, 1]).set_output(True).set_data_type( - cudnn.data_type.FLOAT - ) + C.set_dim([1, M, N]).set_stride([M * N, N, 1]).set_output(True).set_data_type(cudnn.data_type.FLOAT) # The output of reductino operation has to be fp32. # Plus, the data is in global memory so its not possible to fuse anything now. @@ -112,9 +104,7 @@ def test_gemm_silu_and_mul(cudnn_handle): except Exception as e: pytest.fail(repr(e)) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) with profile(activities=[ProfilerActivity.CUDA]) as prof: graph.execute( @@ -154,15 +144,9 @@ def test_silu_and_mul_and_quantization(cudnn_handle): compute_data_type=cudnn.data_type.FLOAT, ) - C2a_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to( - dtype=torch.float8_e4m3fn - ) - C2b_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to( - dtype=torch.float8_e4m3fn - ) - C_gpu = torch.empty(1, M, N, requires_grad=False, device="cuda").to( - dtype=torch.float8_e4m3fn - ) + C2a_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to(dtype=torch.float8_e4m3fn) + C2b_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to(dtype=torch.float8_e4m3fn) + C_gpu = torch.empty(1, M, N, requires_grad=False, device="cuda").to(dtype=torch.float8_e4m3fn) scale = 0.5 C2_DQ_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") @@ -214,9 +198,7 @@ def test_silu_and_mul_and_quantization(cudnn_handle): graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) with profile(activities=[ProfilerActivity.CUDA]) as prof: graph.execute( diff --git a/test/python/test_slice.py b/test/python/test_slice.py index eb7d80ae..745d91ed 100644 --- a/test/python/test_slice.py +++ b/test/python/test_slice.py @@ -7,9 +7,7 @@ from test_utils import torch_fork_set_rng -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch" -) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch") @pytest.mark.L0 @torch_fork_set_rng(seed=0) def test_int8_bf16_matmul_slice(cudnn_handle): @@ -22,22 +20,10 @@ def test_int8_bf16_matmul_slice(cudnn_handle): slice_K = slice(None, None) # Initialize input tensors - A_gpu = ( - 2 - * torch.randn( - Batch, M, K, requires_grad=False, device="cuda", dtype=torch.bfloat16 - ) - - 0.25 - ) + A_gpu = 2 * torch.randn(Batch, M, K, requires_grad=False, device="cuda", dtype=torch.bfloat16) - 0.25 A_slice_gpu = A_gpu[slice_B, slice_M, :] - B_gpu = ( - 3 - * torch.randn( - Batch, K, N, requires_grad=False, device="cuda", dtype=torch.bfloat16 - ) - - 1.25 - ) + B_gpu = 3 * torch.randn(Batch, K, N, requires_grad=False, device="cuda", dtype=torch.bfloat16) - 1.25 B_slice_gpu = B_gpu[slice_B, :, slice_N] stream = torch.cuda.current_stream().cuda_stream @@ -53,9 +39,7 @@ def test_int8_bf16_matmul_slice(cudnn_handle): B = graph.tensor_like(B_gpu) B_slice = graph.slice(B, [slice_B, slice_K, slice_N], name="B_slice") - C = graph.matmul( - name="matmul", A=A_slice, B=B_slice, compute_data_type=cudnn.data_type.FLOAT - ) + C = graph.matmul(name="matmul", A=A_slice, B=B_slice, compute_data_type=cudnn.data_type.FLOAT) C.set_output(True).set_data_type(cudnn.data_type.BFLOAT16) graph.validate() @@ -71,15 +55,11 @@ def test_int8_bf16_matmul_slice(cudnn_handle): graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) # Run pyt reference - C_expected = torch.matmul( - A_slice_gpu.to(torch.bfloat16), B_slice_gpu.to(torch.bfloat16) - ) + C_expected = torch.matmul(A_slice_gpu.to(torch.bfloat16), B_slice_gpu.to(torch.bfloat16)) # Run cudnn graph C_actual = torch.zeros_like(C_expected) - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute({A: A_gpu, B: B_gpu, C: C_actual}, workspace, handle=cudnn_handle) print(A_gpu.data_ptr()) torch.cuda.synchronize() diff --git a/test/python/test_wgrads.py b/test/python/test_wgrads.py index c3f030bb..2d1a74b5 100644 --- a/test/python/test_wgrads.py +++ b/test/python/test_wgrads.py @@ -7,13 +7,13 @@ def is_ampere_arch(): - (major, minor) = torch.cuda.get_device_capability() + major, minor = torch.cuda.get_device_capability() cc = major * 10 + minor return 80 <= cc and cc < 89 def is_hopper_arch(): - (major, minor) = torch.cuda.get_device_capability() + major, minor = torch.cuda.get_device_capability() cc = major * 10 + minor return 90 <= cc @@ -39,27 +39,11 @@ def test_scale_bias_relu_wgrad(cudnn_handle): pytest.skip("SBR Wgrad is only supported on ampere and hopper.") # Reference - X_gpu = torch.randn( - n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - DY_gpu = torch.randn( - n, k, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) - scale = ( - torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - * 0.01 - ) - bias = ( - torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to( - memory_format=torch.channels_last - ) - * 0.01 - ) - DW_actual = torch.randn( - k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16 - ).to(memory_format=torch.channels_last) + X_gpu = torch.randn(n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + DY_gpu = torch.randn(n, k, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + scale = torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) * 0.01 + bias = torch.randn(1, c, 1, 1, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) * 0.01 + DW_actual = torch.randn(k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -72,18 +56,10 @@ def test_scale_bias_relu_wgrad(cudnn_handle): ) # X = graph.tensor(name = "X", dim = X_gpu.size(), stride = X_gpu.stride(), data_type = cudnn._compiled_module.data_type.DOUBLE) - X = graph.tensor( - name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype - ) - DY = graph.tensor( - name="DY", dim=DY_gpu.size(), stride=DY_gpu.stride(), data_type=DY_gpu.dtype - ) - B = graph.tensor( - name="B", dim=bias.size(), stride=bias.stride(), data_type=bias.dtype - ) - S = graph.tensor( - name="S", dim=scale.size(), stride=scale.stride(), data_type=scale.dtype - ) + X = graph.tensor(name="X", dim=X_gpu.size(), stride=X_gpu.stride(), data_type=X_gpu.dtype) + DY = graph.tensor(name="DY", dim=DY_gpu.size(), stride=DY_gpu.stride(), data_type=DY_gpu.dtype) + B = graph.tensor(name="B", dim=bias.size(), stride=bias.stride(), data_type=bias.dtype) + S = graph.tensor(name="S", dim=scale.size(), stride=scale.stride(), data_type=scale.dtype) scale_output = graph.scale(name="scale", input=X, scale=S) bias_output = graph.bias(name="bias", input=scale_output, bias=B) @@ -106,9 +82,7 @@ def test_scale_bias_relu_wgrad(cudnn_handle): graph.check_support() graph.build_plans() - workspace = torch.empty( - graph.get_workspace_size(), device="cuda", dtype=torch.uint8 - ) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) DW_actual = torch.zeros_like(X_gpu) diff --git a/tools/json_reproducer/json_parser.py b/tools/json_reproducer/json_parser.py index 6df3d411..81e568a1 100644 --- a/tools/json_reproducer/json_parser.py +++ b/tools/json_reproducer/json_parser.py @@ -31,7 +31,7 @@ graph.deserialize(data) - graph.build([cudnn.heur_mode.A]) + graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) print("Graph built successfully and can be executed.")