diff --git a/.gitmodules b/.gitmodules index 470cf466..107505e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/eigen"] path = third_party/eigen url = git@github.com:InfiniTensor/eigen-mirror.git +[submodule "third_party/flash_attention"] + path = third_party/flash_attention + url = https://github.com/Dao-AILab/flash-attention.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 74536707..a12258e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,15 +75,24 @@ if(USE_CUDA) add_compile_definitions(USE_CUDA=1) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) + + # ========== cuDNN 库 ========== + find_library(CUDNN_LIBRARY cudnn REQUIRED) + message(STATUS "Found cuDNN at: ${CUDNN_LIBRARY}") + # ======================================== + include_directories(${CUDAToolkit_INCLUDE_DIRS}) # CUDA compilation options set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") - # Only compile CUDA kernels / cuda sources here (your original used src/*.cu) + # Only compile CUDA kernels / cuda sources here file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu) add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS}) + target_include_directories(infini_train_cuda_kernels PUBLIC + ${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include + ) set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90") target_link_libraries(infini_train_cuda_kernels @@ -92,6 +101,7 @@ if(USE_CUDA) CUDA::cudart CUDA::cublas CUDA::cuda_driver + ${CUDNN_LIBRARY} ) if(USE_NCCL) @@ -116,8 +126,6 @@ target_link_libraries(infini_train ) if(USE_CUDA) - # infini_train contains cuda runtime wrappers (*.cc) like cuda_blas_handle.cc/cuda_guard.cc - # Those may need CUDA runtime/driver/cublas symbols at final link, so attach them here too. target_link_libraries(infini_train PUBLIC infini_train_cuda_kernels @@ -127,15 +135,12 @@ if(USE_CUDA) ) if(USE_NCCL) - # If your core library code also directly references NCCL symbols (not only kernels), - # keep this. Otherwise it's harmless. target_link_libraries(infini_train PUBLIC nccl) endif() endif() # ------------------------------------------------------------------------------ # Helper: link libraries in a group to fix static lib one-pass resolution -# (THIS is what fixes "undefined reference" from cuda_kernels -> core symbols) # ------------------------------------------------------------------------------ function(link_infini_train_exe target_name) if(USE_CUDA) @@ -160,7 +165,6 @@ function(link_infini_train_exe target_name) endif() endfunction() - # ------------------------------------------------------------------------------ # Examples # ------------------------------------------------------------------------------ @@ -199,4 +203,4 @@ add_executable(test_hook test/hook/test_hook.cc) target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) -target_link_libraries(test_precision_check infini_train) +target_link_libraries(test_precision_check infini_train) \ No newline at end of file diff --git "a/InfiniTrain\346\212\245\345\221\212.md" "b/InfiniTrain\346\212\245\345\221\212.md" new file mode 100644 index 00000000..17c815a8 --- /dev/null +++ "b/InfiniTrain\346\212\245\345\221\212.md" @@ -0,0 +1,88 @@ +# InfiniTrain 作业报告 + +## 1. 功能正确性验证 +gpt2_1_bfloat16 +![alt text](image-3.png) +gpt2_bfloat16_flash +![alt text](image-4.png) +llama3_1_bfloat16 +![alt text](image-2.png) +llama3_1_bfloat16_flash +![alt text](image-5.png) + + +## 2. 性能评估报告 +### 2.1 实验环境说明 + +**硬件环境** +- GPU 型号:NVIDIA A100-SXM4-80GB +- 单卡显存:81920 MiB(80GB) +- 机器总卡数:8 张(index 0~7) +- 本次测试可见设备:`CUDA_VISIBLE_DEVICES=4,5,6,7` +- 实际并行配置:日志中 `DP=1, TP=1, SP=1, PP=1`,即单进程单卡执行 + +**软件环境** +- CUDA:12.8(`nvcc` build `cuda_12.8.r12.8`) +- Driver:570.133.20 +- C++ 编译器:`c++ (Ubuntu 13.3.0) 13.3.0` +- CMake:3.31.4 +- 编译命令:`cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j` + +### 2.2 实验配置 + +基于四个日志文件: +- `gpt2_1_bfloat16.log`(baseline) +- `gpt2_1_bfloat16_fla.log`(FlashAttention) +- `llama3_1_bfloat16.log`(baseline) +- `llama3_1_bfloat16_fla.log`(FlashAttention) + +关键参数(由程序默认参数与命令行确认): +- `dtype=bfloat16` +- `batch_size=4` +- `sequence_length=64` +- `total_batch_size=256 tokens/step` +- 训练步数:10 steps +- baseline:小算子拼接版本(不加 `--flash true`) +- 实验组:FlashAttention 融合算子版本(`--flash true`) + +> 说明:为减少首步冷启动影响,下面主表采用 **step 2~10** 的均值作为稳态指标。 + +### 2.3 性能指标定义 + +- 平均时延(avg latency):每步迭代耗时均值(ms) +- 吞吐率(tokens/s):日志中的每步 tokens/s 均值 +- GPU 显存占用(MB):日志 `peak used` 的峰值(max) +- 加速比:$\text{Speedup} = \frac{\text{Latency}_{baseline}}{\text{Latency}_{flash}}$ +- 显存节省比例:$\text{MemSaving} = \frac{\text{Mem}_{baseline}-\text{Mem}_{flash}}{\text{Mem}_{baseline}} \times 100\%$ + +### 2.4 结果展示(baseline vs FlashAttention) + +| 模型 | 方案 | Avg Latency (ms) | Throughput (tok/s) | Peak Used (MB) | +|---|---|---:|---:|---:| +| GPT2 | baseline | 119.71 | 2153.67 | 1914 | +| GPT2 | FlashAttention | 63.58 | 4057.67 | 3056 | +| LLaMA3 | baseline | 768.33 | 333.78 | 24561 | +| LLaMA3 | FlashAttention | 336.90 | 765.33 | 26552 | + +**汇总指标(按模型聚合)** + +| 模型 | Speedup (baseline/flash) | 吞吐提升 (flash/baseline) | 显存节省比例 | +|---|---:|---:|---:| +| GPT2 | 1.88x | 1.88x | -59.67% | +| LLaMA3 | 2.28x | 2.29x | -8.11% | + +### 2.5 结论分析 + +1. **GPT2 上 FlashAttention 提升明显**: + - 时延从 119.71 ms 降到 63.58 ms,Speedup 为 **1.88x**; + - 吞吐从 2153.67 提升到 4057.67 tok/s(约 **1.88x**)。 + +2. **LLaMA3 上收益显著**: + - 时延从 768.33 ms 降到 336.90 ms,Speedup 为 **2.28x**; + - 吞吐从 333.78 提升到 765.33 tok/s(约 **2.29x**)。 + +3. **显存占用现象**: + - GPT2 在本次日志中 FlashAttention 的 `peak used` 更高(1914 MB -> 3056 MB,显存节省比例 -59.67%); + - LLaMA3 在本次日志中 FlashAttention 的 `peak used` 也更高(24561 MB -> 26552 MB,显存节省比例 -8.11%); + - 说明本次实验里 FlashAttention 的收益主要体现在计算效率(时延/吞吐),而非显存降低。 + diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index a007dff1..156fdcab 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -78,6 +78,7 @@ DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)") DEFINE_string( precision_check, "", "precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH"); +DEFINE_bool(flash, false, "Whether to enable flash attention"); using namespace infini_train; @@ -140,6 +141,7 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsParallel()) { device = Device(Device::DeviceType::kCUDA, rank.thread_rank()); + // auto *pg_factory = ProcessGroupFactory::Instance(device.type()); if (ddp_world_size > 1) { @@ -322,6 +324,10 @@ void Train(const nn::parallel::Rank &rank) { } for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + if (auto dist_optimizer = std::dynamic_pointer_cast(optimizer)) { + dist_optimizer->SetIsLastMicrobatch(micro_step == grad_accum_steps - 1); + } + // enable autocast for the current step infini_train::AutocastGuard autocast_guard(device.type(), dtype); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 8d497797..6a1e414f 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -12,6 +12,7 @@ #include #include "glog/logging.h" +#include "gflags/gflags.h" #include "example/common/utils.h" #include "infini_train/include/device.h" @@ -29,6 +30,7 @@ #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" + using namespace infini_train; namespace nn = infini_train::nn; @@ -78,6 +80,7 @@ CausalSelfAttention::CausalSelfAttention(const GPT2Config &config) ->View({1, 1, config_.block_size, config_.block_size}); } +DECLARE_bool(flash); std::vector> CausalSelfAttention::Forward(const std::vector> &x) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -96,7 +99,7 @@ CausalSelfAttention::Forward(const std::vectorseq_len const auto T = q->Dims()[1]; // View to multi-head: local_n_head * head_dim == local_C @@ -105,18 +108,39 @@ CausalSelfAttention::Forward(const std::vectorView({B, T, local_n_head_, head_dim})->Transpose(1, 2); v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); - // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + std::shared_ptr y; + if (FLAGS_flash) { + // cuDNN SDPA path: causal masking should be enabled by `is_causal=true`. + // Do not pass the 0/1 tril mask as additive bias (it is not -inf mask). + auto q_flash = q; + auto k_flash = k; + auto v_flash = v; + if (q->Dtype() == DataType::kFLOAT32) { + q_flash = std::make_shared(q->To(DataType::kBFLOAT16)); + k_flash = std::make_shared(k->To(DataType::kBFLOAT16)); + v_flash = std::make_shared(v->To(DataType::kBFLOAT16)); + } + y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt, + false); + if (y->Dtype() != q->Dtype()) { + y = std::make_shared(y->To(q->Dtype())); + } + // ensure expected layout: (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + } else { + // (B, h_l, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); + // (1, 1, T, T) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + // (B, h_l, T, T) + att = nn::function::Softmax(att, -1); + // (B, h_l, T, Dh) + y = att->Matmul(v); + // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); + } // Get full tensor // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 2b1e2121..eaf96b8d 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -91,6 +91,7 @@ constexpr char kDtypeBF16[] = "bfloat16"; DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_bool(flash, false, "Whether to enable flash attention"); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -298,6 +299,10 @@ void Train(const nn::parallel::Rank &rank) { } for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) { + if (auto dist_optimizer = std::dynamic_pointer_cast(optimizer)) { + dist_optimizer->SetIsLastMicrobatch(micro_step == grad_accum_steps - 1); + } + // enable autocast for the current step infini_train::AutocastGuard autocast_guard(device.type(), dtype); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index a50fb831..5285f53f 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -11,6 +11,7 @@ #include #include +#include "gflags/gflags.h" #include "glog/logging.h" #include "example/common/utils.h" @@ -138,6 +139,7 @@ std::vector> RMSNorm::Forward(const std::vector> CausalSelfAttention::Forward(const std::vec k = k->Transpose(1, 2); v = v->Transpose(1, 2); - // TODO(zbl): support flash attention later - // if (flash_) { ... } - - // manual implementation of attention - // this materializes the large (T,T) matrix for all the queries and keys - - // q: (B, H_local, T, D) - // k: (B, H_local, T, D) -> (B, H_local, D, T) - // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); - if (mask) { - // mask: (1, 1, T, T) - att = att->MaskedFill(mask, std::numeric_limits::lowest()); + std::shared_ptr y; + if (FLAGS_flash) { + // cuDNN SDPA path: causal masking should be enabled by `is_causal=true`. + // Do not pass Triu(ones, 1) mask as additive bias. + auto q_flash = q; + auto k_flash = k; + auto v_flash = v; + if (q->Dtype() == DataType::kFLOAT32) { + q_flash = std::make_shared(q->To(DataType::kBFLOAT16)); + k_flash = std::make_shared(k->To(DataType::kBFLOAT16)); + v_flash = std::make_shared(v->To(DataType::kBFLOAT16)); + } + y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt, + false); + if (y->Dtype() != q->Dtype()) { + y = std::make_shared(y->To(q->Dtype())); + } + // ensure expected layout: (B, H_local, T, D) -> (B, T, H_local, D) -> (B, T, C_local) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); + } else { + // manual implementation of attention + // this materializes the large (T,T) matrix for all the queries and keys + + // q: (B, H_local, T, D) + // k: (B, H_local, T, D) -> (B, H_local, D, T) + // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(D))); + if (mask) { + // mask: (1, 1, T, T) + att = att->MaskedFill(mask, std::numeric_limits::lowest()); + } + // (B, H_local, T, T) + att = nn::function::Softmax(att, -1); + // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) + y = att->Matmul(v); + // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); } - // (B, H_local, T, T) - att = nn::function::Softmax(att, -1); - // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D) - auto y = att->Matmul(v); - // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local}); // output projection // (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C) y = (*modules_[kCProjLayerName])({y})[0]; diff --git a/example/mnist/main.cc b/example/mnist/main.cc index e62257d7..4cd7b8f6 100644 --- a/example/mnist/main.cc +++ b/example/mnist/main.cc @@ -35,6 +35,7 @@ constexpr char kDeviceCUDA[] = "cuda"; DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_bool(flash, false, "Whether to enable flash attention"); int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/image-1.png b/image-1.png new file mode 100644 index 00000000..8cddbc4f Binary files /dev/null and b/image-1.png differ diff --git a/image-2.png b/image-2.png new file mode 100644 index 00000000..41e7a99c Binary files /dev/null and b/image-2.png differ diff --git a/image-3.png b/image-3.png new file mode 100644 index 00000000..58b53abf Binary files /dev/null and b/image-3.png differ diff --git a/image-4.png b/image-4.png new file mode 100644 index 00000000..220f11f3 Binary files /dev/null and b/image-4.png differ diff --git a/image-5.png b/image-5.png new file mode 100644 index 00000000..77f6eb92 Binary files /dev/null and b/image-5.png differ diff --git a/image.png b/image.png new file mode 100644 index 00000000..cf3fa348 Binary files /dev/null and b/image.png differ diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index 499c586f..a10a084c 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -48,7 +48,7 @@ enum class CastPolicy : uint8_t { }; // Cast-policy maps and their associated operations. The op names should match the ones defined in the op registry. -inline constexpr std::array kLowerPrecisionOps = {"Matmul", "Linear"}; +inline constexpr std::array kLowerPrecisionOps = {"Matmul"}; inline constexpr std::array kFP32Ops = {"Sin", "Cos", "Tan", "Asin", "Acos", "Atan", "Sinh", "Cosh", "Tanh", "Asinh", "Acosh", "Atanh", "Exp", "Log", @@ -59,7 +59,7 @@ inline constexpr std::array kFP32Ops // op names should match the ones defined in the op registry. inline const std::unordered_map kOpCastPolicyMap = { {"Matmul", CastPolicy::kLowerPrecision}, - {"Linear", CastPolicy::kLowerPrecision}, + {"Linear", CastPolicy::kFP32}, {"Sin", CastPolicy::kFP32}, {"Cos", CastPolicy::kFP32}, {"Tan", CastPolicy::kFP32}, diff --git a/infini_train/include/autograd/scaled_dot_product_attention.h b/infini_train/include/autograd/scaled_dot_product_attention.h new file mode 100644 index 00000000..e48f900a --- /dev/null +++ b/infini_train/include/autograd/scaled_dot_product_attention.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class ScaledDotProductAttention : public Function { +public: + static constexpr char kType[] = "ScaledDotProductAttention"; + + ScaledDotProductAttention(double dropout_p, bool is_causal, + std::optional scale, bool enable_gqa) + : Function(kType), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale), + enable_gqa_(enable_gqa) {} + + std::vector> Forward( + const std::vector> &input_tensors) override; + + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + + std::vector> Backward( + const std::vector> &grad_outputs) override; + +private: + double dropout_p_ = 0.0; + + bool is_causal_ = false; + std::optional scale_ = std::nullopt; + bool enable_gqa_ = false; + bool has_attn_mask_input_ = false; + std::shared_ptr forward_out_ = nullptr; + std::shared_ptr forward_lse_ = nullptr; + // Saved tensors for backward can be managed via Function's SaveForBackward helper +}; +} // namespace infini_train::autograd diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h index b6a02543..ea04c790 100644 --- a/infini_train/include/common/common.h +++ b/infini_train/include/common/common.h @@ -13,6 +13,8 @@ LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \ + kDataTypeToDesc.at(static_cast(dtype)))) + +//compute strides for a given shape. inline std::vector ComputeStrides(const std::vector &dims) { std::vector strides(dims.size(), 1); for (int i = dims.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * dims[i + 1]; } diff --git a/infini_train/include/nn/functional.h b/infini_train/include/nn/functional.h index e4354fd1..ba92b981 100644 --- a/infini_train/include/nn/functional.h +++ b/infini_train/include/nn/functional.h @@ -2,6 +2,7 @@ #include #include +#include #include namespace infini_train { @@ -162,6 +163,19 @@ std::shared_ptr Softmax(const std::shared_ptr &input, int64_t di std::shared_ptr Slice(const std::shared_ptr &input, const std::vector &starts, const std::vector &ends, const std::vector &steps); +// Scaled dot-product attention interface matching PyTorch's scaled_dot_product_attention. +// - query, key, value: tensors with shape (..., seq_len, head_dim) +// - attn_mask: optional additive mask (same broadcasting semantics as PyTorch) +// - dropout_p: dropout probability (0.0 disables) +// - is_causal: whether to apply causal mask +// - scale: optional scale factor; if not provided, use 1/sqrt(head_dim) +// - enable_gqa: grouped query attention flag +std::shared_ptr ScaledDotProductAttention( + const std::shared_ptr &query, const std::shared_ptr &key, + const std::shared_ptr &value, const std::shared_ptr &attn_mask = nullptr, + double dropout_p = 0.0, bool is_causal = false, + const std::optional &scale = std::nullopt, bool enable_gqa = false); + // Concatenates a sequence of tensors along a new dimension. // // Args: diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..81f640fc 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -31,6 +31,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartGradSync(); void FinishGradSync(); + // Forward microbatch boundary info to bucket groups. + void SetIsLastMicrobatch(bool is_last_microbatch); + void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); diff --git a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h index c83fe9a5..b4a2aa9d 100644 --- a/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h +++ b/infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h @@ -70,6 +70,9 @@ class ParamAndGradBucketGroup { // When all params in a bucket group are ready, will call StartGradSync() void RegisterGradReady(const std::shared_ptr ¶meter); + // Mark whether current backward corresponds to the last microbatch in a gradient accumulation window. + void SetIsLastMicrobatch(bool is_last_microbatch); + // Start grad reduce void StartGradSync(); diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index 42a95729..ff2cad58 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -18,6 +18,7 @@ namespace infini_train::autograd { std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); auto device = input_tensors[0]->GetDevice(); + // *:Switch to the device where the input tensor is located core::DeviceGuard guard(device); // Register precision check hooks if enabled (before forward) @@ -29,13 +30,15 @@ std::vector> Function::Apply(const std::vector> output_tensors; { autograd::NoGradGuard no_grad; @@ -78,6 +81,7 @@ std::vector> Function::Apply(const std::vectorset_requires_grad(output_requires_grad); output_tensor->set_grad_fn(output_requires_grad ? shared_from_this() : nullptr); + //条件二含义:需要梯度,但是它没有生父算子(即它是用户手动创建的原始参数,不是算出来的)。 output_tensor->set_is_leaf(!output_requires_grad || ((output_tensor->grad_fn() == nullptr) && output_requires_grad)); output_tensor->set_output_idx(output_idx); diff --git a/infini_train/src/autograd/scaled_dot_product_attention.cc b/infini_train/src/autograd/scaled_dot_product_attention.cc new file mode 100644 index 00000000..e15b1579 --- /dev/null +++ b/infini_train/src/autograd/scaled_dot_product_attention.cc @@ -0,0 +1,75 @@ +#include "infini_train/include/autograd/scaled_dot_product_attention.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> ScaledDotProductAttention::Forward( + const std::vector> &input_tensors) { + CHECK(input_tensors.size() == 3 || input_tensors.size() == 4); + const auto &q = input_tensors[0]; + const auto &k = input_tensors[1]; + const auto &v = input_tensors[2]; + const auto mask = input_tensors.size() == 4 ? input_tensors[3] : nullptr; + + auto device = q->GetDevice().type(); + // Call device kernel. Kernel name: ScaledDotProductAttentionForward + auto out_and_lse = Dispatcher::Instance().Call, std::shared_ptr>>( + {device, "ScaledDotProductAttentionForward"}, q, k, v, mask, dropout_p_, is_causal_, scale_, + enable_gqa_); + forward_out_ = std::get<0>(out_and_lse); + forward_lse_ = std::get<1>(out_and_lse); + auto out = forward_out_; + return {out}; +} + +void ScaledDotProductAttention::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + (void)output_tensors; + // Save q,k,v and mask (mask may be nullptr) + const auto &q = input_tensors[0]; + const auto &k = input_tensors[1]; + const auto &v = input_tensors[2]; + std::shared_ptr mask = nullptr; + has_attn_mask_input_ = (input_tensors.size() == 4); + if (input_tensors.size() == 4) { + mask = input_tensors[3]; + } + saved_tensors_ = {q, k, v, mask}; +} + +std::vector> ScaledDotProductAttention::Backward( + const std::vector> &grad_outputs) { + CHECK(saved_tensors_.size() == 4); + const auto &q = saved_tensors_[0]; + const auto &k = saved_tensors_[1]; + const auto &v = saved_tensors_[2]; + const auto &mask = saved_tensors_[3]; + + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + + CHECK(forward_out_ != nullptr); + CHECK(forward_lse_ != nullptr); + + auto grads = Dispatcher::Instance().Call, std::shared_ptr, + std::shared_ptr>>( + {device, "ScaledDotProductAttentionBackward"}, grad_output, q, k, v, mask, forward_out_, forward_lse_, + dropout_p_, is_causal_, scale_, enable_gqa_); + + forward_out_ = nullptr; + forward_lse_ = nullptr; + + if (has_attn_mask_input_) { + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads), nullptr}; + } + + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cuda/flash_attention.cu b/infini_train/src/kernels/cuda/flash_attention.cu new file mode 100644 index 00000000..d12fa6a3 --- /dev/null +++ b/infini_train/src/kernels/cuda/flash_attention.cu @@ -0,0 +1,564 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/kernel_helper.cuh" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" +//#include "infini_train/src/core/cuda/cuda_stream.h" +#include "infini_train/include/common/common.h" // ComputeStrides +#include // cudaStream_t + +// 强烈建议使用 NVIDIA 提供的 frontend 库,否则原始 API 会写到手软 +#include +namespace fe = cudnn_frontend; + + +namespace infini_train::kernels::cuda { + +namespace { +constexpr int64_t Q_UID = 101; +constexpr int64_t K_UID = 102; +constexpr int64_t V_UID = 103; +constexpr int64_t MASK_UID = 104; +constexpr int64_t O_UID = 201; +constexpr int64_t STATS_UID = 202; + +constexpr int64_t dO_UID = 301; +constexpr int64_t dQ_UID = 401; +constexpr int64_t dK_UID = 402; +constexpr int64_t dV_UID = 403; + +struct WorkspaceCache { + void *ptr = nullptr; + size_t size = 0; +}; + +static inline std::size_t hash_combine(std::size_t seed, std::size_t v) { + return seed ^ (v + 0x9e3779b97f4a7c15ULL + (seed << 6U) + (seed >> 2U)); +} + +static inline uint32_t float_to_bits(float x) { + uint32_t bits; + std::memcpy(&bits, &x, sizeof(float)); + return bits; +} + +static std::size_t hash_dims(std::vector const &dims) { + std::size_t h = 0; + for (auto d : dims) { + h = hash_combine(h, std::hash{}(d)); + } + return h; +} + +struct FwdPlanKey { + std::vector q_dims; + std::vector k_dims; + std::vector v_dims; + std::vector mask_dims; + int dtype = 0; + bool is_causal = false; + bool has_mask = false; + uint32_t attn_scale_bits = 0; + + bool operator==(FwdPlanKey const &other) const { + return q_dims == other.q_dims && + k_dims == other.k_dims && + v_dims == other.v_dims && + mask_dims == other.mask_dims && + dtype == other.dtype && + is_causal == other.is_causal && + has_mask == other.has_mask && + attn_scale_bits == other.attn_scale_bits; + } +}; + +struct FwdPlanKeyHash { + std::size_t operator()(FwdPlanKey const &k) const { + std::size_t h = 0; + h = hash_combine(h, hash_dims(k.q_dims)); + h = hash_combine(h, hash_dims(k.k_dims)); + h = hash_combine(h, hash_dims(k.v_dims)); + h = hash_combine(h, hash_dims(k.mask_dims)); + h = hash_combine(h, std::hash{}(k.dtype)); + h = hash_combine(h, std::hash{}(k.is_causal)); + h = hash_combine(h, std::hash{}(k.has_mask)); + h = hash_combine(h, std::hash{}(k.attn_scale_bits)); + return h; + } +}; + +struct BwdPlanKey { + std::vector q_dims; + std::vector k_dims; + std::vector v_dims; + std::vector o_dims; + std::vector do_dims; + std::vector lse_dims; + std::vector mask_dims; + int dtype = 0; + bool is_causal = false; + bool has_mask = false; + uint32_t attn_scale_bits = 0; + + bool operator==(BwdPlanKey const &other) const { + return q_dims == other.q_dims && + k_dims == other.k_dims && + v_dims == other.v_dims && + o_dims == other.o_dims && + do_dims == other.do_dims && + lse_dims == other.lse_dims && + mask_dims == other.mask_dims && + dtype == other.dtype && + is_causal == other.is_causal && + has_mask == other.has_mask && + attn_scale_bits == other.attn_scale_bits; + } +}; + +struct BwdPlanKeyHash { + std::size_t operator()(BwdPlanKey const &k) const { + std::size_t h = 0; + h = hash_combine(h, hash_dims(k.q_dims)); + h = hash_combine(h, hash_dims(k.k_dims)); + h = hash_combine(h, hash_dims(k.v_dims)); + h = hash_combine(h, hash_dims(k.o_dims)); + h = hash_combine(h, hash_dims(k.do_dims)); + h = hash_combine(h, hash_dims(k.lse_dims)); + h = hash_combine(h, hash_dims(k.mask_dims)); + h = hash_combine(h, std::hash{}(k.dtype)); + h = hash_combine(h, std::hash{}(k.is_causal)); + h = hash_combine(h, std::hash{}(k.has_mask)); + h = hash_combine(h, std::hash{}(k.attn_scale_bits)); + return h; + } +}; + +struct CachedPlan { + std::shared_ptr graph; + int64_t workspace_size = 0; +}; + +using FwdPlanCache = std::unordered_map; +using BwdPlanCache = std::unordered_map; +} + +// helpers for cuDNN frontend path +static cudaStream_t get_cuda_stream(const ::infini_train::Device &device) { + auto impl = ::infini_train::core::GetDeviceGuardImpl(device.type()); + auto stream_obj = impl->GetStream(device); + auto cuda_stream = dynamic_cast(stream_obj)->cuda_stream(); + return cuda_stream; +} + +static cudnnHandle_t get_cudnn_handle(const ::infini_train::Device &device) { + //用来记录现在thread正在使用哪个cuda device,cudnn handle是和device绑定的,所以需要这个信息 + int cuda_device = 0; + CUDA_CHECK(cudaGetDevice(&cuda_device)); + + static thread_local std::unordered_map handles; + auto it = handles.find(cuda_device); + if (it == handles.end()) { + cudnnHandle_t handle; + cudnnCreate(&handle); + it = handles.emplace(cuda_device, handle).first; + } + + auto cuda_stream = get_cuda_stream(device); + cudnnSetStream(it->second, cuda_stream); + + return it->second; +} + +static void *acquire_workspace(WorkspaceCache &cache, size_t requested_bytes) { + if (requested_bytes == 0) { + return nullptr; + } + if (cache.ptr == nullptr || cache.size < requested_bytes) { + if (cache.ptr != nullptr) { + CUDA_CHECK(cudaFree(cache.ptr)); + } + CUDA_CHECK(cudaMalloc(&cache.ptr, requested_bytes)); + cache.size = requested_bytes; + } + return cache.ptr; +} + +static WorkspaceCache &forward_workspace_cache() { + static thread_local WorkspaceCache cache; + return cache; +} + +static WorkspaceCache &backward_workspace_cache() { + static thread_local WorkspaceCache cache; + return cache; +} + +static FwdPlanCache &forward_plan_cache() { + static thread_local FwdPlanCache cache; + return cache; +} + +static BwdPlanCache &backward_plan_cache() { + static thread_local BwdPlanCache cache; + return cache; +} + +static fe::DataType_t get_cudnn_dtype(const ::infini_train::DataType dtype); +static std::shared_ptr make_graph_tensor( + const std::shared_ptr &graph, + const std::shared_ptr &tensor, + const std::string &name, + int64_t uid); +static void check_fe_status(fe::error_t status, const char *stage); +static CachedPlan const &get_or_create_fwd_plan(const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + bool is_causal, + float attn_scale, + cudnnHandle_t handle); +static CachedPlan const &get_or_create_bwd_plan(const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + bool is_causal, + float attn_scale, + cudnnHandle_t handle); + +static std::tuple, std::shared_ptr> ExecuteSdpaForwardWithLse( + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool /*enable_gqa*/) { + if (dropout_p > 0.0) { + throw std::runtime_error("cuDNN frontend SDPA path currently does not support dropout in this minimal kernel"); + } + + auto out = std::make_shared(q->Dims(), q->Dtype(), q->GetDevice()); + + auto q_dims = q->Dims(); + CHECK_EQ(q_dims.size(), 4) << "SDPA expects 4D Q/K/V tensor layout [B, H, S, D]"; + std::vector lse_dims = {q_dims[0], q_dims[1], q_dims[2], 1}; + //lse(Log-sum-exp) + auto lse = std::make_shared(lse_dims, DataType::kFLOAT32, q->GetDevice()); + + cudnnHandle_t handle = get_cudnn_handle(q->GetDevice()); + + float attn_scale = scale.has_value() ? static_cast(scale.value()) + : 1.0f / std::sqrt(static_cast(q->Dims().back())); + + auto const &plan = get_or_create_fwd_plan(q, k, v, attn_mask, is_causal, attn_scale, handle); + void *workspace = acquire_workspace(forward_workspace_cache(), static_cast(plan.workspace_size)); + + std::unordered_map variant_pack = { + {Q_UID, q->DataPtr()}, + {K_UID, k->DataPtr()}, + {V_UID, v->DataPtr()}, + {O_UID, out->DataPtr()}, + {STATS_UID, lse->DataPtr()}, + }; + if (attn_mask) { + variant_pack[MASK_UID] = attn_mask->DataPtr(); + } + + auto exec_status = plan.graph->execute(handle, variant_pack, workspace); + check_fe_status(exec_status, "graph->execute"); + + return {out, lse}; +} + +static fe::DataType_t get_cudnn_dtype(const ::infini_train::DataType dtype) { + switch (dtype) { + case ::infini_train::DataType::kFLOAT32: + return fe::DataType_t::FLOAT; + case ::infini_train::DataType::kFLOAT16: + return fe::DataType_t::HALF; + case ::infini_train::DataType::kBFLOAT16: + return fe::DataType_t::BFLOAT16; + default: + throw std::runtime_error("unsupported dtype for cuDNN SDP"); + } +} + +static std::shared_ptr make_graph_tensor( + const std::shared_ptr &graph, + const std::shared_ptr &tensor, + const std::string &name, + int64_t uid) { + return graph->tensor(fe::graph::Tensor_attributes() + .set_name(name) + .set_uid(uid) + .set_dim(tensor->Dims()) + .set_stride(ComputeStrides(tensor->Dims())) + .set_data_type(get_cudnn_dtype(tensor->Dtype()))); +} + +static void check_fe_status(fe::error_t status, const char *stage) { + if (status.is_bad()) { + throw std::runtime_error(std::string(stage) + ": " + status.get_message()); + } +} + +static CachedPlan const &get_or_create_fwd_plan(const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + bool is_causal, + float attn_scale, + cudnnHandle_t handle) { + FwdPlanKey key; + key.q_dims = q->Dims(); + key.k_dims = k->Dims(); + key.v_dims = v->Dims(); + key.has_mask = (attn_mask != nullptr); + if (attn_mask) { + key.mask_dims = attn_mask->Dims(); + } + key.dtype = static_cast(q->Dtype()); + key.is_causal = is_causal; + key.attn_scale_bits = float_to_bits(attn_scale); + + //cache ——FwdPlanCache::map,根据key查找是否已经存在对应的plan,如果存在就直接返回,如果不存在就创建新的plan并插入cache + auto &cache = forward_plan_cache(); + auto it = cache.find(key); + //若能直接找到就返回对应的plan,优化速度 + if (it != cache.end()) { + return it->second; + } + + auto graph = std::make_shared(); + graph->set_io_data_type(get_cudnn_dtype(q->Dtype())) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_tensor = make_graph_tensor(graph, q, "Q", Q_UID); + auto k_tensor = make_graph_tensor(graph, k, "K", K_UID); + auto v_tensor = make_graph_tensor(graph, v, "V", V_UID); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_generate_stats(true) + .set_attn_scale(attn_scale); + + if (is_causal) { + sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + if (attn_mask) { + auto mask_tensor = make_graph_tensor(graph, attn_mask, "Bias", MASK_UID); + sdpa_options.set_bias(mask_tensor); + } + + auto [out_tensor, stats_tensor] = graph->sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + out_tensor->set_output(true) + .set_uid(O_UID) + .set_dim(q->Dims()) + .set_stride(ComputeStrides(q->Dims())); + std::vector lse_dims = {q->Dims()[0], q->Dims()[1], q->Dims()[2], 1}; + stats_tensor->set_output(true) + .set_uid(STATS_UID) + .set_dim(lse_dims) + .set_stride(ComputeStrides(lse_dims)) + .set_data_type(fe::DataType_t::FLOAT); + + check_fe_status(graph->build(handle, {fe::HeurMode_t::A}), "graph->build (fwd cache build)"); + + int64_t workspace_size = 0; + check_fe_status(graph->get_workspace_size(workspace_size), "graph->get_workspace_size (fwd cache build)"); + + CachedPlan plan; + plan.graph = graph; + plan.workspace_size = workspace_size; + auto inserted = cache.emplace(std::move(key), std::move(plan)); + return inserted.first->second; +} + +static CachedPlan const &get_or_create_bwd_plan(const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + bool is_causal, + float attn_scale, + cudnnHandle_t handle) { + BwdPlanKey key; + key.q_dims = q->Dims(); + key.k_dims = k->Dims(); + key.v_dims = v->Dims(); + key.o_dims = out->Dims(); + key.do_dims = grad_out->Dims(); + key.lse_dims = lse->Dims(); + key.has_mask = (attn_mask != nullptr); + if (attn_mask) { + key.mask_dims = attn_mask->Dims(); + } + key.dtype = static_cast(q->Dtype()); + key.is_causal = is_causal; + key.attn_scale_bits = float_to_bits(attn_scale); + + auto &cache = backward_plan_cache(); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + auto graph = std::make_shared(); + graph->set_io_data_type(get_cudnn_dtype(q->Dtype())) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q_tensor = make_graph_tensor(graph, q, "Q", Q_UID); + auto k_tensor = make_graph_tensor(graph, k, "K", K_UID); + auto v_tensor = make_graph_tensor(graph, v, "V", V_UID); + auto o_tensor = make_graph_tensor(graph, out, "O", O_UID); + auto dO_tensor = make_graph_tensor(graph, grad_out, "dO", dO_UID); + auto lse_tensor = make_graph_tensor(graph, lse, "Stats", STATS_UID); + + auto sdpa_bwd_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_attn_scale(attn_scale) + .set_deterministic_algorithm(true); + + if (is_causal) { + sdpa_bwd_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT) + .set_diagonal_band_right_bound(0); + } + + if (attn_mask) { + auto mask_tensor = make_graph_tensor(graph, attn_mask, "Bias", MASK_UID); + sdpa_bwd_options.set_bias(mask_tensor); + } + + auto [dQ_tensor, dK_tensor, dV_tensor] = graph->sdpa_backward( + q_tensor, k_tensor, v_tensor, o_tensor, dO_tensor, lse_tensor, sdpa_bwd_options); + + dQ_tensor->set_output(true) + .set_uid(dQ_UID) + .set_dim(q->Dims()) + .set_stride(ComputeStrides(q->Dims())); + dK_tensor->set_output(true) + .set_uid(dK_UID) + .set_dim(k->Dims()) + .set_stride(ComputeStrides(k->Dims())); + dV_tensor->set_output(true) + .set_uid(dV_UID) + .set_dim(v->Dims()) + .set_stride(ComputeStrides(v->Dims())); + + check_fe_status(graph->build(handle, {fe::HeurMode_t::A}), "graph->build (bwd cache build)"); + + int64_t workspace_size = 0; + check_fe_status(graph->get_workspace_size(workspace_size), "graph->get_workspace_size (bwd cache build)"); + + CachedPlan plan; + plan.graph = graph; + plan.workspace_size = workspace_size; + auto inserted = cache.emplace(std::move(key), std::move(plan)); + return inserted.first->second; +} + +std::tuple, std::shared_ptr> ScaledDotProductAttentionForward( + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return ExecuteSdpaForwardWithLse(q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +std::tuple, std::shared_ptr, std::shared_ptr> +ScaledDotProductAttentionBackward( + const std::shared_ptr &grad_out, + const std::shared_ptr &q, + const std::shared_ptr &k, + const std::shared_ptr &v, + const std::shared_ptr &attn_mask, + const std::shared_ptr &out, + const std::shared_ptr &lse, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + + auto dq = std::make_shared(q->Dims(), q->Dtype(), q->GetDevice()); + auto dk = std::make_shared(k->Dims(), k->Dtype(), k->GetDevice()); + auto dv = std::make_shared(v->Dims(), v->Dtype(), v->GetDevice()); + + if (dropout_p > 0.0) { + throw std::runtime_error("cuDNN frontend SDPA path currently does not support dropout in this minimal kernel"); + } + (void)enable_gqa; + + + // ---------- cuDNN frontend implementation ---------- + cudnnHandle_t handle = get_cudnn_handle(grad_out->GetDevice()); + + float attn_scale = scale.has_value() ? static_cast(scale.value()) + : 1.0f / std::sqrt(static_cast(q->Dims().back())); + + auto const &plan = get_or_create_bwd_plan(grad_out, q, k, v, attn_mask, out, lse, is_causal, attn_scale, handle); + void *workspace = acquire_workspace(backward_workspace_cache(), static_cast(plan.workspace_size)); + + std::unordered_map variant_pack = { + {Q_UID, q->DataPtr()}, + {K_UID, k->DataPtr()}, + {V_UID, v->DataPtr()}, + {O_UID, out->DataPtr()}, + {dO_UID, grad_out->DataPtr()}, + {STATS_UID, lse->DataPtr()}, + {dQ_UID, dq->DataPtr()}, + {dK_UID, dk->DataPtr()}, + {dV_UID, dv->DataPtr()}, + }; + if (attn_mask) { + variant_pack[MASK_UID] = attn_mask->DataPtr(); + } + + auto exec_status = plan.graph->execute(handle, variant_pack, workspace); + check_fe_status(exec_status, "graph->execute (backward)"); + + return {dq, dk, dv}; +} + +} + + +#define REGISTER_CUDA_LINEAR_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_LINEAR_KERNEL(ScaledDotProductAttentionBackward) +REGISTER_CUDA_LINEAR_KERNEL(ScaledDotProductAttentionForward) + +#undef REGISTER_CUDA_LINEAR_KERNEL diff --git a/infini_train/src/kernels/cuda/no_op.cu b/infini_train/src/kernels/cuda/no_op.cu index ef2c9566..9b8b04e3 100644 --- a/infini_train/src/kernels/cuda/no_op.cu +++ b/infini_train/src/kernels/cuda/no_op.cu @@ -3,6 +3,8 @@ #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" + +//这里不用用到并行算法,直接把输入张量的视图返回即可 namespace infini_train::kernels::cuda { std::shared_ptr NoOpForward(const std::shared_ptr &input, const std::vector &dims) { const int64_t num_elements = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); diff --git a/infini_train/src/kernels/cuda/ysyx.code-workspace b/infini_train/src/kernels/cuda/ysyx.code-workspace new file mode 100644 index 00000000..afc35437 --- /dev/null +++ b/infini_train/src/kernels/cuda/ysyx.code-workspace @@ -0,0 +1,8 @@ +{ + "folders": [ + { + "path": "../../../../.." + } + ], + "settings": {} +} \ No newline at end of file diff --git a/infini_train/src/nn/functional.cc b/infini_train/src/nn/functional.cc index b02f185a..a159f6b9 100644 --- a/infini_train/src/nn/functional.cc +++ b/infini_train/src/nn/functional.cc @@ -3,6 +3,9 @@ #include #include #include +#include +#include +#include #include "infini_train/include/autograd/activations.h" #include "infini_train/include/autograd/elementwise.h" @@ -10,6 +13,7 @@ #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" #include "infini_train/include/autograd/transform.h" +#include "infini_train/include/autograd/scaled_dot_product_attention.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/tensor.h" @@ -79,4 +83,22 @@ std::shared_ptr Softmax(const std::shared_ptr &input, int64_t di std::shared_ptr Sigmoid(const std::shared_ptr &input) { return std::make_shared()->Apply({input})[0]; } -} // namespace infini_train::nn::function + +std::shared_ptr ScaledDotProductAttention( + const std::shared_ptr &query, + const std::shared_ptr &key, + const std::shared_ptr &value, + const std::shared_ptr &attn_mask, + double dropout_p, + bool is_causal, + const std::optional &scale, + bool enable_gqa) { + std::vector> inputs = {query, key, value}; + if (attn_mask) inputs.push_back(attn_mask); + auto fn = std::make_shared( + dropout_p, is_causal, scale, enable_gqa); + return fn->Apply(inputs)[0]; +} + +} +// namespace infini_train::nn::function diff --git a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc index 002fe318..c8a4ee3b 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -24,7 +25,11 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod const DistributedDataParallelConfig ddp_config) : ddp_config_(ddp_config), ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) { + std::unordered_set seen_params; for (auto ¶m : module->Parameters()) { + if (!param || !seen_params.insert(param.get()).second) { + continue; + } auto device = param->GetDevice(); CHECK_EQ(device.index(), rank.thread_rank()) << "All parameters must be on the same device as the module"; if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { @@ -130,7 +135,11 @@ void DistributedDataParallel::RegisterBackwardHooks() { }; auto &module = modules_.at(kModuleName); + std::unordered_set seen_params; for (auto ¶m : module->Parameters()) { + if (!param || !seen_params.insert(param.get()).second) { + continue; + } if (!param->requires_grad()) { continue; } diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..a96f58d7 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -91,6 +91,10 @@ void DistributedOptimizer::FinishGradSync() { for (auto &group : bucket_groups_) { group->FinishGradSync(); } } +void DistributedOptimizer::SetIsLastMicrobatch(bool is_last_microbatch) { + for (auto &group : bucket_groups_) { group->SetIsLastMicrobatch(is_last_microbatch); } +} + void DistributedOptimizer::StartParamSync(bool force_sync) { for (auto &group : bucket_groups_) { group->StartParamSync(force_sync); } } diff --git a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc index 75a21f63..e2a1eccf 100644 --- a/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc +++ b/infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc @@ -147,6 +147,10 @@ void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr &p } } +void ParamAndGradBucketGroup::SetIsLastMicrobatch(bool is_last_microbatch) { + is_last_microbatch_ = is_last_microbatch; +} + void ParamAndGradBucketGroup::StartGradSync() { if (!collective_pg_) { LOG(FATAL) << "ParamAndGradBucketGroup: StartGradSync() called with null collective_pg_."; diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 2c9b218a..890c9c73 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -25,10 +25,14 @@ void SGD::Step() { LOG(INFO) << "Skipping param with null grad."; continue; } + auto grad = param->grad(); + if (grad->Dtype() != param->Dtype()) { + grad = std::make_shared(grad->To(param->Dtype())); + } auto device = param->GetDevice(); core::DeviceGuard guard(device); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"}); - kernel.Call(param->grad(), -learning_rate_, param); + kernel.Call(grad, -learning_rate_, param); } } @@ -53,11 +57,14 @@ void Adam::Step() { for (size_t i = 0; i < params_.size(); ++i) { auto ¶m = params_[i]; - const auto &grad = param->grad(); + auto grad = param->grad(); if (!grad) { LOG(INFO) << "Skipping param with null grad."; continue; } + if (grad->Dtype() != param->Dtype()) { + grad = std::make_shared(grad->To(param->Dtype())); + } auto &m = m_[i]; auto &v = v_[i]; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 6c243fea..56da32e5 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -98,6 +98,7 @@ size_t Tensor::SizeInBytes() const { return kDataTypeToSize.at(dtype_) * num_ele const std::vector &Tensor::Dims() const { return dims_; } + size_t Tensor::NumElements() const { return num_elements_; } DataType Tensor::Dtype() const { return dtype_; } diff --git a/scripts/precision_check/precision_compare.py b/scripts/precision_check/precision_compare.py deleted file mode 100755 index 40c91308..00000000 --- a/scripts/precision_check/precision_compare.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -""" -Precision comparison tool for InfiniTrain tensor outputs. - -Usage: - python precision_compare.py --dir1 ./run1 --dir2 ./run2 [--atol 1e-5] [--rtol 1e-3] - -Compares .npy files between two directories and reports differences. -""" - -import argparse -import os -import sys -from pathlib import Path - -import numpy as np - - -def find_npy_files(directory: str) -> dict[str, Path]: - """Find all .npy files in directory (recursively).""" - files = {} - for path in Path(directory).rglob("*.npy"): - rel_path = path.relative_to(directory) - files[str(rel_path)] = path - return files - - -def compare_tensors(file1: Path, file2: Path, atol: float, rtol: float) -> dict: - """Compare two tensor files and return comparison results.""" - arr1 = np.load(file1) - arr2 = np.load(file2) - - result = { - "file": str(file1.name), - "shape1": arr1.shape, - "shape2": arr2.shape, - "dtype1": str(arr1.dtype), - "dtype2": str(arr2.dtype), - "match": False, - "error": None, - } - - if arr1.shape != arr2.shape: - result["error"] = f"Shape mismatch: {arr1.shape} vs {arr2.shape}" - return result - - if arr1.dtype != arr2.dtype: - result["error"] = f"Dtype mismatch: {arr1.dtype} vs {arr2.dtype}" - return result - - arr1_flat = arr1.astype(np.float64).flatten() - arr2_flat = arr2.astype(np.float64).flatten() - - abs_diff = np.abs(arr1_flat - arr2_flat) - max_abs_diff = np.max(abs_diff) - mean_abs_diff = np.mean(abs_diff) - - with np.errstate(divide="ignore", invalid="ignore"): - rel_diff = abs_diff / (np.abs(arr2_flat) + 1e-12) - rel_diff = np.where(np.isfinite(rel_diff), rel_diff, 0) - max_rel_diff = np.max(rel_diff) - mean_rel_diff = np.mean(rel_diff) - - result["max_abs_diff"] = float(max_abs_diff) - result["mean_abs_diff"] = float(mean_abs_diff) - result["max_rel_diff"] = float(max_rel_diff) - result["mean_rel_diff"] = float(mean_rel_diff) - result["match"] = np.allclose(arr1, arr2, atol=atol, rtol=rtol) - - return result - - -def main(): - parser = argparse.ArgumentParser(description="Compare precision check outputs") - parser.add_argument("--dir1", required=True, help="First directory") - parser.add_argument("--dir2", required=True, help="Second directory") - parser.add_argument("--atol", type=float, default=1e-5, help="Absolute tolerance") - parser.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") - parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") - args = parser.parse_args() - - if not os.path.isdir(args.dir1): - print(f"Error: {args.dir1} is not a directory") - sys.exit(1) - if not os.path.isdir(args.dir2): - print(f"Error: {args.dir2} is not a directory") - sys.exit(1) - - files1 = find_npy_files(args.dir1) - files2 = find_npy_files(args.dir2) - - print(f"Directory 1: {args.dir1} ({len(files1)} files)") - print(f"Directory 2: {args.dir2} ({len(files2)} files)") - print(f"Tolerance: atol={args.atol}, rtol={args.rtol}") - print() - - only_in_1 = set(files1.keys()) - set(files2.keys()) - only_in_2 = set(files2.keys()) - set(files1.keys()) - common = set(files1.keys()) & set(files2.keys()) - - if only_in_1: - print(f"Files only in dir1 ({len(only_in_1)}):") - for f in sorted(only_in_1): - print(f" {f}") - print() - - if only_in_2: - print(f"Files only in dir2 ({len(only_in_2)}):") - for f in sorted(only_in_2): - print(f" {f}") - print() - - if not common: - print("No common files to compare") - sys.exit(1) - - print(f"Comparing {len(common)} common files...") - print() - - passed = 0 - failed = 0 - errors = 0 - - for rel_path in sorted(common): - result = compare_tensors(files1[rel_path], files2[rel_path], args.atol, args.rtol) - - if result["error"]: - errors += 1 - print(f"ERROR: {rel_path}") - print(f" {result['error']}") - elif result["match"]: - passed += 1 - if args.verbose: - print(f"PASS: {rel_path}") - print(f" max_abs={result['max_abs_diff']:.2e} max_rel={result['max_rel_diff']:.2e}") - else: - failed += 1 - print(f"FAIL: {rel_path}") - print(f" shape={result['shape1']} dtype={result['dtype1']}") - print(f" max_abs={result['max_abs_diff']:.2e} mean_abs={result['mean_abs_diff']:.2e}") - print(f" max_rel={result['max_rel_diff']:.2e} mean_rel={result['mean_rel_diff']:.2e}") - - print() - print("=" * 50) - print(f"Summary: {passed} passed, {failed} failed, {errors} errors") - print(f"Missing: {len(only_in_1)} in dir1 only, {len(only_in_2)} in dir2 only") - - if failed > 0 or errors > 0: - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 1cf27935..f99761c9 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -17,12 +17,19 @@ read_var() { jq -r --arg k "$key" '.variables[$k] // empty' "$CONFIG_FILE" } -BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" -LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" -PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" -COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" - +BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" +LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" +PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" +COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" +FLASH="$(read_var FLASH)"; : "${FLASH:=}" + +# --- 关键修改 1: 初始化大容量分区的绝对路径临时目录 --- +# 先确保 build 目录存在,以便获取其绝对路径 mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" +# 获取绝对路径,防止 CMake 切换目录后找不到相对路径 +export CUSTOM_TMP="$(readlink -f "$BUILD_DIR")/tmp_cache" +mkdir -p "$CUSTOM_TMP" +export TMPDIR="$CUSTOM_TMP" # export custom PATHs export BUILD_DIR LOG_DIR PROFILE_LOG_DIR @@ -34,11 +41,16 @@ done < <(jq -r '.variables | to_entries[] | "\(.key)=\(.value)"' "$CONFIG_FILE") # Global variable to save the last cmake command LAST_CMAKE_CMD="" -# Clean the build directory +# --- 关键修改 2: 在清理函数中重新创建临时目录 --- clean_build_dir() { echo -e "\033[1;31m[CLEAN] Removing all contents in: ${BUILD_DIR}\033[0m" - mkdir -p "$BUILD_DIR" + # 删除 build 下所有内容(这会删掉旧的 tmp_cache) rm -rf "${BUILD_DIR:?}/"* + # 重新创建 build 目录 + mkdir -p "$BUILD_DIR" + # 核心:必须重新创建 TMPDIR 目录,否则编译器的路径会失效 + mkdir -p "$TMPDIR" + echo -e "\033[1;34m[TMP] Re-created temp space at: $TMPDIR\033[0m" } # Run a command and log output @@ -52,38 +64,27 @@ run_and_log() { echo -e "\033[1;32m============================================================\033[0m" echo -e "\033[1;36m[$timestamp] [Running] ${log_name}\033[0m" - - # Print the command being executed echo -e "\033[1;33mCommand:\033[0m $cmd" - - # Print the most recent CMake command if [[ -n "$LAST_CMAKE_CMD" ]]; then echo -e "\033[1;34mLast CMake Command:\033[0m $LAST_CMAKE_CMD" fi - echo -e "\033[1;33mLog file:\033[0m $log_path" - - # Notify if profiling mode is enabled if [[ "$is_profile" == "yes" ]]; then echo -e "\033[1;35m[PROFILE MODE ON] Profiling logs will be saved to: ${PROFILE_LOG_DIR}\033[0m" fi - echo -e "\033[1;32m============================================================\033[0m" pushd "$BUILD_DIR" > /dev/null - # Write the last cmake command into the log file if available if [[ -n "$LAST_CMAKE_CMD" ]]; then echo "[LAST_CMAKE] $LAST_CMAKE_CMD" > "$log_path" else - # If no cmake command has been run yet, clear the log > "$log_path" fi - # Write the current run command to the log echo "[COMMAND] $cmd" >> "$log_path" - # Run the command and append both stdout and stderr to the log file + # 执行命令并重定向输出 if ! eval "$cmd" >> "$log_path" 2>&1; then echo -e "\033[1;31m============================================================\033[0m" echo -e "\033[1;31m[ERROR] Command failed: ${cmd}\033[0m" @@ -97,39 +98,29 @@ run_and_log() { popd > /dev/null - # If profiling is enabled, move profiling files to the target directory if [[ "$is_profile" == "yes" ]]; then move_profile_logs "$log_name" fi } - # Move profiling output logs move_profile_logs() { local prefix="$1" - - # Move *.report.rankN files for report_file in "${BUILD_DIR}"/*.report.rank*; do if [[ -f "$report_file" ]]; then - local base_name - base_name=$(basename "$report_file") + local base_name=$(basename "$report_file") mv "$report_file" "${PROFILE_LOG_DIR}/${prefix}_${base_name}" - echo "Moved $base_name to ${PROFILE_LOG_DIR}/${prefix}_${base_name}" fi done - - # Move *.records.log.rankN files for record_file in "${BUILD_DIR}"/*.records.log.rank*; do if [[ -f "$record_file" ]]; then - local base_name - base_name=$(basename "$record_file") + local base_name=$(basename "$record_file") mv "$record_file" "${PROFILE_LOG_DIR}/${prefix}_${base_name}" - echo "Moved $base_name to ${PROFILE_LOG_DIR}/${prefix}_${base_name}" fi done } -# Build "--key value" arg string from tests[i].args (shell-escaped) +# Build args string args_string_for_test() { local idx="$1" jq -r --argjson i "$idx" ' @@ -150,11 +141,10 @@ for ((id=0; id' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: false +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +TabWidth: 4 +UseTab: Never +--- +Language: Json +# Don't format .json files. +DisableFormat: true +... + diff --git a/third_party/cudnn-frontend/.github/ISSUE_TEMPLATE/bug_report.md b/third_party/cudnn-frontend/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..cf4b0051 --- /dev/null +++ b/third_party/cudnn-frontend/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,44 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**System Environment (please complete the following information):** + - cudnn_frontend version: [e.g. v1.4.0] + - cudnn_backend version: [e.g. v9.1.0] + - GPU arch: [e.g. RTX 4090] + - cuda runtime version: [e.g. 12.4] + - cuda driver version: [e.g. 553.04] + - host compiler: [e.g. clang19] + - OS: [e.g. ubuntu22.04] + +**API logs** +Please attach API logs for both cudnn_frontend and cudnn_backend. +``` +// For cudnn_frontend +export CUDNN_FRONTEND_LOG_FILE=fe.log +export CUDNN_FRONTEND_LOG_INFO=1 + +// For cudnn_backend +export CUDNN_LOGLEVEL_DBG=3 +export CUDNN_LOGDEST_DBG=be.log +``` + +**To Reproduce** +Steps to reproduce the behavior: +1. '...' +2. '....' +3. '....' + +**Additional context** +Add any other context about the problem here. diff --git a/third_party/cudnn-frontend/CMakeLists.txt b/third_party/cudnn-frontend/CMakeLists.txt new file mode 100644 index 00000000..ff0438a6 --- /dev/null +++ b/third_party/cudnn-frontend/CMakeLists.txt @@ -0,0 +1,107 @@ +cmake_minimum_required(VERSION 3.23) + +project(cudnn_frontend VERSION 1.18.0) + +option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) +option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) +option(CUDNN_FRONTEND_BUILD_TESTS "Defines if unittests are built or not." ON) +option(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS "Defines if python bindings are built or not." OFF) + +if(MSVC OR MSYS OR MINGW) + add_compile_options(/W4 /WX) +else() + add_compile_options(-Wall -Wextra -Wpedantic -Werror -Wno-error=attributes -Wno-attributes -Wno-error=unused-function -Wno-unused-function) +endif() + +add_library(cudnn_frontend INTERFACE) + +# Add header files to library +file(GLOB_RECURSE CUDNN_FRONTEND_INCLUDE_FILES "include/*") +target_sources( + cudnn_frontend PUBLIC FILE_SET HEADERS + BASE_DIRS "$" + FILES "${CUDNN_FRONTEND_INCLUDE_FILES}" +) +unset(CUDNN_FRONTEND_INCLUDE_FILES) + +target_compile_definitions( + cudnn_frontend INTERFACE + $<$:CUDNN_FRONTEND_SKIP_JSON_LIB> +) + +target_include_directories( + cudnn_frontend INTERFACE + $ + $ +) + +# Find the cuda compiler +find_package(CUDAToolkit REQUIRED) + +target_include_directories( + cudnn_frontend INTERFACE + ${CUDAToolkit_INCLUDE_DIRS} +) + +target_compile_features(cudnn_frontend INTERFACE cxx_std_17) + +# Make PCH for targets to link against +add_library(_cudnn_frontend_pch INTERFACE) +target_precompile_headers(_cudnn_frontend_pch INTERFACE ${PROJECT_SOURCE_DIR}/include/cudnn_frontend.h) + +if (CUDNN_FRONTEND_BUILD_SAMPLES) + add_subdirectory(samples) +endif() + +if (CUDNN_FRONTEND_BUILD_TESTS) + add_subdirectory(test) +endif() + +if (CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS) + add_subdirectory(python) +endif() + +# Introduce variables: +# * CMAKE_INSTALL_LIBDIR +# * CMAKE_INSTALL_BINDIR +# * CMAKE_INSTALL_INCLUDEDIR +include(GNUInstallDirs) + +# See https://cmake.org/cmake/help/latest/module/CMakePackageConfigHelpers.html#example-generating-package-files +include(CMakePackageConfigHelpers) + +# Install the components +install( + TARGETS cudnn_frontend + EXPORT cudnn_frontend_targets FILE_SET HEADERS +) + +if (CUDNN_FRONTEND_BUILD_SAMPLES) + install(TARGETS legacy_samples samples RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +if (CUDNN_FRONTEND_BUILD_TESTS) + install(TARGETS tests RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +# Export the targets +export( + EXPORT cudnn_frontend_targets + FILE "${CMAKE_CURRENT_BINARY_DIR}/cudnn_frontend/cudnn_frontend-targets.cmake" +) +install( + EXPORT cudnn_frontend_targets + FILE cudnn_frontend-targets.cmake + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/cudnn_frontend" +) + +# Install the CMake configuration file for header discovery +configure_package_config_file( + cudnn_frontend-config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/cudnn_frontend-config.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/cudnn_frontend" +) +install( + FILES "${CMAKE_CURRENT_BINARY_DIR}/cudnn_frontend-config.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/cudnn_frontend" +) diff --git a/third_party/cudnn-frontend/CONTRIBUTING.md b/third_party/cudnn-frontend/CONTRIBUTING.md new file mode 100644 index 00000000..5ccb14ac --- /dev/null +++ b/third_party/cudnn-frontend/CONTRIBUTING.md @@ -0,0 +1,55 @@ +# Contributing to cudnn-frontend + +If you are interested in contributing to cudnn-frontend, your contributions will fall +into three categories: +1. You want to report a bug, feature request, or documentation issue + - File an [issue](https://github.com/NVIDIA/cudnn-frontend/issues) + describing what you encountered or what you want to see changed. + - The cudnn team will evaluate the issues and triage them, scheduling + them for a release. If you believe the issue needs priority attention + comment on the issue to notify the team. +2. You want to propose a new Feature and implement it + - Post about your intended feature, and we shall discuss the design and + implementation. + - Once we agree that the plan looks good, go ahead and implement it, using + the [code contributions](#code-contributions) guide below. +3. You want to implement a feature or bug-fix for an outstanding issue + - Follow the [code contributions](#code-contributions) guide below. + - If you need more context on a particular issue, please ask and we shall + provide. + +## Code contributions + +### Your first issue + +1. Read the project's [README.md](https://github.com/NVIDIA/cudnn-frontend/blob/main/README.md) + to learn how to setup the development environment. +2. Comment on the issue saying you are going to work on it and what changes you are going to make. +3. Code! Make sure to update unit tests! +4. When done, [create your pull request](https://github.com/NVIDIA/cudnn-frontend/compare). +5. Wait for other developers to review your code and update code as needed. +6. Once reviewed and approved, a cudnn-frontend developer will merge your pull request. +7. At this time, we are accepting only small fixes, changes. Once merged to main this will be an untagged version. A release tag will be assigned along with future frontend release by cudnn team. + +Remember, if you are unsure about anything, don't hesitate to comment on issues and ask for clarifications! + +## Code Formatting + +Consistent code formatting is important in the cudnn-frontend project to ensure +readability, maintainability, and thus simplifies collaboration. + +### Branches and Versions + +The cudnn-frontend repository has one main branch. Please submit a PR to this branch. We will update the doc as the policy changes. + +### Branch naming + +Branches used to create PRs should have a name of the form `-issue-` +which conforms to the following conventions: + +- Name: + - A name to convey what is being worked on + - Please use dashes or underscores between words as opposed to spaces. + +## Attribution +Portions of contribution guide adopted from [https://github.com/rapidsai/cuml/blob/branch-24.04/CONTRIBUTING.md](https://github.com/rapidsai/cuml/blob/branch-24.04/CONTRIBUTING.md) diff --git a/third_party/cudnn-frontend/LICENSE.txt b/third_party/cudnn-frontend/LICENSE.txt new file mode 100644 index 00000000..eef9c446 --- /dev/null +++ b/third_party/cudnn-frontend/LICENSE.txt @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ diff --git a/third_party/cudnn-frontend/README.md b/third_party/cudnn-frontend/README.md new file mode 100644 index 00000000..2616f208 --- /dev/null +++ b/third_party/cudnn-frontend/README.md @@ -0,0 +1,130 @@ +# cuDNN FrontEnd(FE) + +**cuDNN FE** is the modern, open-source entry point to the NVIDIA cuDNN library and high performance open-source kernels. It provides a C++ header-only library and a Python interface to access the powerful cuDNN Graph API and open-source kernels. + +## Key Features + +* **Unified Graph API:** Create reusable, persistent `cudnn_frontend::graph::Graph` objects to describe complex subgraphs. +* **Ease of Use:** Simplified C++ and Python bindings (via `pybind11`) that abstract away the boilerplate of the backend API. +* **Performance:** Built-in autotuning and support for the latest NVIDIA GPU architectures. + +## Benchmarks + +To run the sdpa benchmarks, refer to [benchmarks/sdpa](https://github.com/NVIDIA/cudnn-frontend/blob/main/benchmark/sdpa_benchmark_training/README.md) folder. Current results: + +### GB200 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB200](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB200](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB200](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB300 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB300](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB300](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB300](https://raw.githubusercontent.com/NVIDIA/cudnn-frontend/main/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + + +## Installation + +### 🐍 Python + +The easiest way to get started is via pip: + +```bash +pip install nvidia_cudnn_frontend +``` + +**Requirements:** +* Python 3.8+ +* NVIDIA driver and CUDA Toolkit + +### ⚙️ C++ (Header Only) + +Since the C++ API is header-only, integration is seamless. Simply include the header in your compilation unit: + +```cpp +#include +``` + +Ensure your include path points to the `include/` directory of this repository. + +## Building from Source + +If you want to build the Python bindings from source or run the C++ samples: + +**1. Dependencies** +* `python-dev` (e.g., `apt-get install python-dev`) +* Dependencies listed in `requirements.txt` (`pip install -r requirements.txt`) + +**2. Python Source Build** +```bash +pip install -v git+https://github.com/NVIDIA/cudnn-frontend.git +``` +*Environment variables `CUDAToolkit_ROOT` and `CUDNN_PATH` can be used to override default paths.* + +**3. C++ Samples Build** +```bash +mkdir build && cd build +cmake -DCUDNN_PATH=/path/to/cudnn -DCUDAToolkit_ROOT=/path/to/cuda ../ +cmake --build . -j16 +./bin/samples +``` + +## Documentation & Examples + +* **Developer Guide:** [Official NVIDIA Documentation](https://docs.nvidia.com/deeplearning/cudnn/frontend/v1.9.0/developer/overview.html) +* **C++ Samples:** See `samples/cpp` for comprehensive usage examples. +* **Python Samples:** See `samples/python` for pythonic implementations. + +## 🤝 Contributing + +We strictly welcome contributions! Whether you are fixing a bug, improving documentation, or optimizing one of our new OSS kernels, your help makes cuDNN better for everyone. + +1. Check the [Contribution Guide](CONTRIBUTING.md) for details. +2. Fork the repo and create your branch. +3. Submit a Pull Request. + +## Debugging + +To view the execution flow and debug issues, you can enable logging via environment variables: + +```bash +# Log to stdout +export CUDNN_FRONTEND_LOG_INFO=1 +export CUDNN_FRONTEND_LOG_FILE=stdout + +# Log to a file +export CUDNN_FRONTEND_LOG_INFO=1 +export CUDNN_FRONTEND_LOG_FILE=execution_log.txt +``` + +Alternatively, you can control logging programmatically via `cudnn_frontend::isLoggingEnabled()` + +## License + +This project is licensed under the [MIT License](LICENSE). diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/Dockerfile b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/Dockerfile new file mode 100644 index 00000000..8a0efd6d --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/Dockerfile @@ -0,0 +1,28 @@ +FROM nvcr.io/nvidia/pytorch:25.12-py3 + +# Set working directory +WORKDIR /workspace + +# Update libcudnn9-cuda-13 +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \ + dpkg -i cuda-keyring_1.1-1_all.deb && \ + apt-get remove -y *cudnn9* && \ + apt-get update && \ + apt-get -y install cudnn && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Clone cudnn_frontend and install latest cudnn +RUN git clone https://github.com/NVIDIA/cudnn-frontend.git +RUN pip install -v cudnn-frontend + +# Clone flash-attention +RUN pip uninstall -y flash-attn && \ + git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention && \ + sed -i 's/^ import flash_attn_2_cuda as flash_attn_gpu$/ pass/' /workspace/flash-attention/flash_attn/flash_attn_interface.py +RUN pip install nvidia-cutlass-dsl apache-tvm-ffi quack-kernels +ENV PYTHONPATH=/workspace/flash-attention + +# Install additional dependencies for benchmarking +RUN pip install seaborn \ No newline at end of file diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/README.md b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/README.md new file mode 100644 index 00000000..87368634 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/README.md @@ -0,0 +1,234 @@ +# Scaled Dot Product Attention Benchmark + +## Introduction + +This directory contains benchmarking tools for Scaled Dot Product Attention (SDPA) operations across various backends. The benchmarks target training use cases with support for causal masking and grouped query attention (GQA). + +## Contents + +- `Dockerfile` - Docker container setup for running benchmarks +- `benchmark_single_sdpa.py` - Single SDPA benchmark script +- `configs/` - Benchmark configuration files + - `llama.py` - Llama 3.1 GQA benchmarks (causal + non-causal) + - `dsv3.py` - DeepSeek V3 MHA benchmarks (causal only) +- `runner.py` - Configuration-based benchmark runner +- `config_types.py` - Data types for benchmark configuration +- `charts.py` - Chart generation utilities +- `../results/` - Benchmark outputs (CSV and charts) + +## Quick Start + +### 1. Build Docker Container + +```bash +docker build -t cudnn_attention_benchmark . + +docker run -it --gpus all --rm cudnn_attention_benchmark +``` + +### 2. Run Benchmarks + +```bash +# Run Llama 3.1 benchmark suite +python -m benchmark.sdpa_benchmark_training.runner --config llama + +# Run DeepSeek V3 benchmark suite +python -m benchmark.sdpa_benchmark_training.runner --config dsv3 + +# Dry run (show what would be executed) +python -m benchmark.sdpa_benchmark_training.runner --config llama --dry-run + +# Filter by backend +python -m benchmark.sdpa_benchmark_training.runner --config llama --backend cudnn + +# Filter by data type +python -m benchmark.sdpa_benchmark_training.runner --config llama --dtype bfloat16 +``` + +## Configuration-Based Benchmarking + +### Creating Custom Configurations + +1. Copy the template: + ```bash + cp configs/llama.py configs/my_config.py + ``` + +2. Edit your config: + ```python + from ..config_types import ModelPreset, BenchmarkConfig + + MY_MODEL = ModelPreset( + name="my_model", + num_q_heads=32, + num_kv_heads=8, + head_dim=128, + ) + + CONFIG = BenchmarkConfig( + name="my_benchmark", + models=[MY_MODEL], + seqlens=[(4096, 4096), (8192, 8192)], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], + profile_pass="fwd", # "fwd", "bwd", or "both" + num_iterations=10, + ) + ``` + +3. Run: + ```bash + python -m benchmark.sdpa_benchmark_training.runner --config my_config + ``` + +### Configuration Options + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `models` | List of `ModelPreset` to benchmark | Required | +| `seqlens` | List of `(q_seqlen, kv_seqlen)` tuples | Required | +| `backends` | Backends to compare | `["cudnn"]` | +| `data_types` | Data types to test | `["bfloat16"]` | +| `attn_masks` | Attention masks (`top_left`, `no_mask`, `bottom_right`) | `["top_left"]` | +| `profile_pass` | Which pass to profile (`fwd`, `bwd`, `both`) | `"fwd"` | +| `batch_size` | Batch size | `1` | +| `num_iterations` | Iterations per benchmark | `10` | +| `deterministic_bwd` | Deterministic modes for backward | `[False]` | + +### Model Presets + +Standard model: +```python +LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, +) +``` + +Asymmetric head dimensions (DeepSeek V3): +```python +DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, # Q/K head dimension + head_dim_vo=128, # V/O head dimension +) +``` + +### Output + +The runner produces (in `benchmark/results/`): +- **CSV**: `_.csv` +- **Charts**: Separate chart per mask type: + - `_top_left.png` (causal) + - `_no_mask.png` (non-causal) +- Charts show backends side-by-side with distinct colors for BF16 vs FP8 + +## Single Benchmark Script + +For running individual benchmarks: + +```bash +# cuDNN Frontend (BF16) +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend cudnn --data_type bfloat16 \ + --attn_mask top_left --fwd_bwd + +# cuDNN Frontend (FP8) +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend cudnn --data_type fp8 \ + --attn_mask top_left --fwd_bwd + +# FlashAttention 4 +python benchmark_single_sdpa.py \ + --batch_size 1 --q_seqlen 8192 --kv_seqlen 8192 \ + --num_q_heads 64 --num_kv_heads 8 --head_dim 128 \ + --sdpa_backend flash_attention_4 --data_type bfloat16 \ + --attn_mask top_left --fwd_bwd +``` + +Run `python benchmark_single_sdpa.py --help` for all options. + +## Programmatic Usage + +```python +from benchmark.sdpa_benchmark_training import ( + BenchmarkRunner, + BenchmarkConfig, + ModelPreset, + load_config, +) + +# Load existing config +config = load_config("llama") + +# Or create programmatically +config = BenchmarkConfig( + name="custom", + models=[ModelPreset("test", 64, 8, 128)], + seqlens=[(4096, 4096)], + backends=["cudnn"], +) + +runner = BenchmarkRunner() +results = runner.run_config(config) +runner.save_csv(results, config) +``` + +## Supported Backends + +| Backend | Description | +|---------|-------------| +| `cudnn` | cuDNN (native, via cuDNN Frontend) | +| `flash_attention_4` | FlashAttention 4 | +| `flash_attention_3` | FlashAttention 3 | +| `pyt_flash_attention` | PyTorch FlashAttention | +| `pyt_cudnn` | PyTorch cuDNN backend | +| `pyt_efficient_attention` | PyTorch xFormers | + +## Benchmark Results + +### GB200 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB200](results/gb200_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB200](results/gb200_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB200 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB200](results/gb200_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB200 GPU + +### GB300 - Llama 3.1 Causal (top_left) +![Llama 3.1 Causal on GB300](results/gb300_918_only_cudnn/llama3.1_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - Llama 3.1 Non-Causal (no_mask) +![Llama 3.1 Non-Causal on GB300](results/gb300_918_only_cudnn/llama3.1_no_mask.png) +- SDPA parameters: `batch=1; num_q_heads=64; num_kv_heads=8; head_dim=128; is_causal=False` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + +### GB300 - DeepSeek V3 Causal (top_left) +![DeepSeek V3 Causal on GB300](results/gb300_918_only_cudnn/dsv3_top_left_causal.png) +- SDPA parameters: `batch=1; num_q_heads=128; num_kv_heads=128; head_dim_qk=192; head_dim_vo=128; is_causal=True` +- Sequence lengths shown on x-axis +- Results obtained on NVIDIA GB300 GPU + diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/__init__.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/__init__.py new file mode 100644 index 00000000..dd665e2d --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/__init__.py @@ -0,0 +1,40 @@ +""" +SDPA Benchmark Training Package + +This package provides a flexible benchmark configuration system for +Scaled Dot Product Attention (SDPA) operations. + +Usage: + # Run benchmarks from command line + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + + # Dry run to see what would be executed + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Import and use programmatically + from benchmark.sdpa_benchmark_training import ( + BenchmarkRunner, + BenchmarkConfig, + BenchmarkResult, + ModelPreset, + load_config, + ) + + config = load_config("mlperf") + runner = BenchmarkRunner() + results = runner.run_config(config) + runner.save_csv(results, config) +""" + +from .config_types import ModelPreset, BenchmarkConfig, BenchmarkResult +from .configs import load_config, list_configs +from .runner import BenchmarkRunner + +__all__ = [ + "ModelPreset", + "BenchmarkConfig", + "BenchmarkResult", + "BenchmarkRunner", + "load_config", + "list_configs", +] diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py new file mode 100644 index 00000000..b9d83810 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py @@ -0,0 +1,1366 @@ +""" +Scaled Dot Product Attention (SDPA) benchmark + +This script benchmarks a single SDPA compute instance. +The SDPA backend can be chosen. Performance is measured using torch profiler. + +Can be used as CLI or imported as a module: + + # CLI usage + python benchmark_single_sdpa.py --batch_size 1 --q_seqlen 8192 ... + + # Module usage + from benchmark_single_sdpa import run_benchmark + result = run_benchmark(batch_size=1, q_seqlen=8192, ...) +""" + +import argparse +import torch +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.attention.bias import causal_lower_right +import os +import numpy as np +import functools +import time +import math +from typing import Optional, Dict, Any + +from torch.profiler import profile, record_function, ProfilerActivity + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--batch_size", default=1, type=int, help="Batch size to input to the layer") + parser.add_argument("--q_seqlen", default=8192, type=int, help="Sequence length to input to the layer") + parser.add_argument("--kv_seqlen", default=8192, type=int, help="Sequence length to input to the layer") + parser.add_argument( + "--num_q_heads", + default=16, + type=int, + help="Number of query heads to input to the layer", + ) + parser.add_argument( + "--num_kv_heads", + default=8, + type=int, + help="Number of key/value heads to input to the layer", + ) + parser.add_argument("--head_dim", default=128, type=int, help="Head dimension to input to the layer") + parser.add_argument( + "--head_dim_qk", + default=None, + type=int, + help="Optional: head dimension for Q/K. If set, must also set --head_dim_vo", + ) + parser.add_argument( + "--head_dim_vo", + default=None, + type=int, + help="Optional: head dimension for V/O. If set, must also set --head_dim_qk", + ) + parser.add_argument( + "--data_type", + default="bfloat16", + type=str, + help="Data type to input to the layer. Can be bfloat16, float16, or fp8", + ) + parser.add_argument( + "--num_iterations", + default=20, + type=int, + help="Number of iterations to run the layer for performance measurement", + ) + parser.add_argument( + "--num_warmup_iterations", + default=0, + type=int, + help="Number of warmup iterations to run before measuring performance", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--fwd_bwd", + action="store_true", + help="Run both forward and backward pass (fwd only by default)", + ) + parser.add_argument( + "--profile_pass", + default=None, + type=str, + choices=["fwd", "bwd", "both"], + help="Which pass to profile (default: fwd unless --fwd_bwd is set).", + ) + parser.add_argument( + "--deterministic_bwd", + action="store_true", + help="Use deterministic algorithm for backward pass where supported (cudnn FP16/BF16/FP8)", + ) + parser.add_argument( + "--attn_mask", + default="no_mask", + type=str, + help="Attn mask to use. Can be 'top_left', 'bottom_right', or 'no_mask'.", + choices=["top_left", "bottom_right", "no_mask"], + ) + parser.add_argument( + "--sdpa_backend", + default="pyt_cudnn", + type=str, + help="SDPA backend to use", + choices=[ + "pyt_math", + "pyt_cudnn", + "pyt_efficient_attention", + "pyt_flash_attention", + "flash_attention", + "flash_attention_3", + "flash_attention_4", + "cudnn", + ], + ) + parser.add_argument("--format_output", action="store_true", help="Format output to be used in benchmark") + parser.add_argument( + "--case_tag", + default="", + type=str, + help="Tag to identify the case. Not used in calculations. Only for formatted output", + ) + parser.add_argument( + "--skip_ref", + action="store_true", + help="Skip reference SDPA implementation", + ) + return parser.parse_args() + + +def run_benchmark( + batch_size: int, + q_seqlen: int, + kv_seqlen: int, + num_q_heads: int, + num_kv_heads: int, + head_dim: int = 128, + head_dim_qk: Optional[int] = None, + head_dim_vo: Optional[int] = None, + data_type: str = "bfloat16", + backend: str = "cudnn", + attn_mask: str = "no_mask", + profile_pass: str = "fwd", + num_iterations: int = 10, + num_warmup_iterations: int = 0, + skip_ref: bool = True, + deterministic_bwd: bool = False, + verbose: bool = False, +) -> Dict[str, Any]: + """ + Run a single SDPA benchmark. + + This function can be called directly when using the module as a library. + Internally uses subprocess to call this script with the appropriate arguments. + + Args: + batch_size: Batch size + q_seqlen: Query sequence length + kv_seqlen: Key/value sequence length + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads + head_dim: Head dimension (used if head_dim_qk/vo not specified) + head_dim_qk: Head dimension for Q/K (optional, for asymmetric) + head_dim_vo: Head dimension for V/O (optional, for asymmetric) + data_type: Data type ("bfloat16", "float16", "fp8") + backend: Backend name ("cudnn", "flash_attention_4", etc.) + attn_mask: Attention mask ("no_mask", "top_left", "bottom_right") + profile_pass: Which pass to profile ("fwd", "bwd", "both") + num_iterations: Number of benchmark iterations + num_warmup_iterations: Warmup iterations before measurement + skip_ref: Skip reference validation + deterministic_bwd: Use deterministic backward algorithm + verbose: Print verbose output + + Returns: + Dict with keys: + - fwd_time_ms: Median forward time in milliseconds + - bwd_time_ms: Median backward time in milliseconds (0 if not run) + - fwd_tflops: Forward TFLOPS + - bwd_tflops: Backward TFLOPS + - max_diff: Maximum difference vs reference + - gpu_name: GPU name string + - cudnn_version: cuDNN version (if available) + + Raises: + RuntimeError: If the benchmark subprocess fails + """ + import subprocess + import sys + + # Build command + script_path = os.path.abspath(__file__) + cmd = [ + sys.executable, + script_path, + "--batch_size", + str(batch_size), + "--q_seqlen", + str(q_seqlen), + "--kv_seqlen", + str(kv_seqlen), + "--num_q_heads", + str(num_q_heads), + "--num_kv_heads", + str(num_kv_heads), + "--data_type", + data_type, + "--sdpa_backend", + backend, + "--attn_mask", + attn_mask, + "--num_iterations", + str(num_iterations), + "--num_warmup_iterations", + str(num_warmup_iterations), + "--format_output", # Get CSV-formatted output for parsing + ] + + # Handle head dimensions + if head_dim_qk is not None and head_dim_vo is not None: + cmd.extend(["--head_dim_qk", str(head_dim_qk)]) + cmd.extend(["--head_dim_vo", str(head_dim_vo)]) + else: + cmd.extend(["--head_dim", str(head_dim)]) + + # Handle profile pass + if profile_pass == "both": + cmd.append("--fwd_bwd") + elif profile_pass in ("fwd", "bwd"): + cmd.extend(["--profile_pass", profile_pass]) + + # Handle flags + if skip_ref: + cmd.append("--skip_ref") + if deterministic_bwd: + cmd.append("--deterministic_bwd") + if verbose: + cmd.append("--verbose") + + # Run benchmark + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + + if result.returncode != 0: + raise RuntimeError(f"Benchmark failed with return code {result.returncode}.\n" f"stderr: {result.stderr}\n" f"stdout: {result.stdout}") + + # Parse CSV output + # Format: case_tag,backend,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim,fwd_time,bwd_time,fwd_tflops,bwd_tflops,max_diff,num_iters + output_line = result.stdout.strip().split("\n")[-1] + parts = output_line.split(",") + + if len(parts) < 12: + raise RuntimeError(f"Unexpected output format: {output_line}") + + # Get GPU name from torch + gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) if torch.cuda.is_available() else "Unknown" + + # Try to get cudnn version + cudnn_version = None + cudnn_backend_version = None + try: + import cudnn + + cudnn_version = cudnn.__version__ + cudnn_backend_version = cudnn.backend_version() + except ImportError: + pass + + return { + "fwd_time_ms": float(parts[8]), + "bwd_time_ms": float(parts[9]), + "fwd_tflops": float(parts[10]), + "bwd_tflops": float(parts[11]), + "max_diff": float(parts[12]) if len(parts) > 12 else 0.0, + "gpu_name": gpu_name, + "cudnn_version": cudnn_version, + "cudnn_backend_version": cudnn_backend_version, + } + + +# ============================================================================ +# Main benchmark implementation (runs when script is executed directly) +# ============================================================================ + +# Note: All code below this point is only executed when running as a script. +# When imported as a module, use the run_benchmark() function above. + +if __name__ != "__main__": + # Stop here when imported as module + pass +else: + # Parse command line arguments + args = parse_args() + + if args.data_type == "bfloat16": + target_dtype = torch.bfloat16 + elif args.data_type == "float16": + target_dtype = torch.float16 + elif args.data_type == "float": + target_dtype = torch.float + elif args.data_type == "fp8": + target_dtype = None + else: + raise ValueError(f"Invalid data type: {args.data_type}") + + if args.data_type == "fp8": + if args.sdpa_backend not in ["cudnn", "flash_attention_3"]: + raise ValueError(f"FP8 is only supported for cudnn and flash_attention_3 backends") + + # Parse input arguments + num_iters = args.num_iterations + dry_run_iters = args.num_warmup_iterations + batch_size = args.batch_size + q_seqlen = args.q_seqlen + kv_seqlen = args.kv_seqlen + num_q_heads = args.num_q_heads + num_kv_heads = args.num_kv_heads + if args.head_dim_qk is None and args.head_dim_vo is None: + head_dim_qk = args.head_dim + head_dim_vo = args.head_dim + elif args.head_dim_qk is not None and args.head_dim_vo is not None: + head_dim_qk = args.head_dim_qk + head_dim_vo = args.head_dim_vo + else: + raise ValueError("Both --head_dim_qk and --head_dim_vo must be provided together when using asymmetric head dims.") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + assert device.type == "cuda", "Requires CUDA device" + if args.profile_pass is not None: + run_fwd = args.profile_pass in ("fwd", "both") + run_bwd = args.profile_pass in ("bwd", "both") + elif args.fwd_bwd: + run_fwd = True + run_bwd = True + else: + run_fwd = True + run_bwd = False + enable_gqa = num_q_heads != num_kv_heads + assert args.attn_mask != "bottom_right" or q_seqlen <= kv_seqlen, "Bottom right causal mask not supported when q_seqlen > kv_seqlen" + # if args.sdpa_backend in ["flash_attention", "flash_attention_3", "pyt_flash_attention"]: + # assert args.attn_mask != "top_left", "Flash Attention does not support top left causal mask" + + l2_flush_size_mb = 256 + l2_flush_size = l2_flush_size_mb * 1024 * 1024 + l2_flush_buffer = torch.empty(l2_flush_size, device=device, dtype=torch.int8) + + ############################################################# + ########### Set up SDPA function for each backend ########### + + ## If using cuDNN FE, set up cuDNN graph. + if args.sdpa_backend == "cudnn": + is_dropout = False # Hard coded + dropout_prob = dropout_p if is_dropout else 0.0 # Hard coded to 0 + is_infer = False # Hard coded + attn_scale = head_dim_qk ** (-0.5) + + try: + import cudnn + except ImportError: + cudnn = None + assert cudnn is not None + + if args.verbose: + print(f"[INFO] cuDNN Backend Version: {cudnn.backend_version() = }") + print(f"[INFO] cuDNN Frontend Version: {cudnn.__version__ = }") + + # Helper function: Convert torch type to cuDNN type + def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + ## Will define tensors to set up cuDNN graph once. + if args.data_type == "fp8": + query = torch.randint( + 256, + (batch_size, q_seqlen, num_q_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + key = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + value = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + + descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + else: + query = torch.randn( + batch_size, + q_seqlen, + num_q_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + ).transpose(1, 2) + key = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + ).transpose(1, 2) + value = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_vo, + dtype=target_dtype, + device=device, + ).transpose(1, 2) + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=target_dtype, + device=device, + ).transpose(1, 2) + + dQuery = torch.empty_like(query) + dKey = torch.empty_like(key) + dValue = torch.empty_like(value) + if args.data_type == "fp8": + # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues + dOutput_bf16 = torch.randn(output.shape, dtype=torch.bfloat16, device=device) + dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) + dOutput = dOutput_fp8.view(torch.uint8) + else: + dOutput = torch.randn_like(output) + stats = torch.randn(batch_size, q_seqlen, num_q_heads, 1, dtype=torch.float32, device=device).transpose(1, 2) + if is_dropout: + dropout_seed = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") + dropout_offset = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") + + # cuDNN graph forward + graph_fwd = cudnn.pygraph( + io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + if is_dropout: + seed_fwd = graph_fwd.tensor_like(dropout_seed) + offset_fwd = graph_fwd.tensor_like(dropout_offset) + dropout_tuple = (dropout_prob, seed_fwd, offset_fwd) + + if args.data_type == "fp8": + q_fwd = graph_fwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) + k_fwd = graph_fwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) + v_fwd = graph_fwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) + + descale_q_fwd = graph_fwd.tensor_like(descale_q_gpu) + descale_k_fwd = graph_fwd.tensor_like(descale_k_gpu) + descale_v_fwd = graph_fwd.tensor_like(descale_v_gpu) + descale_s_fwd = graph_fwd.tensor_like(descale_s_gpu) + scale_s_fwd = graph_fwd.tensor_like(scale_s_gpu) + scale_o_fwd = graph_fwd.tensor_like(scale_o_gpu) + + o_fwd, stats_fwd, amax_s_fwd, amax_o_fwd = graph_fwd.sdpa_fp8( + q=q_fwd, + k=k_fwd, + v=v_fwd, + descale_q=descale_q_fwd, + descale_k=descale_k_fwd, + descale_v=descale_v_fwd, + descale_s=descale_s_fwd, + scale_s=scale_s_fwd, + scale_o=scale_o_fwd, + # generate_stats=not is_infer, + is_inference=is_infer, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + right_bound=None if args.attn_mask == "no_mask" else 0, + # dropout=dropout_tuple if is_dropout else None, + ) + else: + q_fwd = graph_fwd.tensor_like(query) + k_fwd = graph_fwd.tensor_like(key) + v_fwd = graph_fwd.tensor_like(value) + o_fwd, stats_fwd = graph_fwd.sdpa( + q=q_fwd, + k=k_fwd, + v=v_fwd, + # generate_stats=not is_infer, + is_inference=is_infer, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, + dropout=dropout_tuple if is_dropout else None, + ) + + if run_bwd: + if args.data_type == "fp8": + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + (stats_fwd.set_output(True).set_dim(stats.size()).set_stride(stats.stride()).set_data_type(cudnn.data_type.FLOAT) if not is_infer else None) + else: + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) + (stats_fwd.set_output(True).set_dim(stats.size()).set_stride(stats.stride()).set_data_type(cudnn.data_type.FLOAT) if not is_infer else None) + else: + if args.data_type == "fp8": + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + else: + o_fwd.set_output(True).set_dim(output.size()).set_stride(output.stride()) + + if args.data_type == "fp8": + amax_s_fwd.set_output(True).set_dim(amax_s_gpu.size()).set_stride(amax_s_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_o_fwd.set_output(True).set_dim(amax_o_gpu.size()).set_stride(amax_o_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + graph_fwd.validate() + graph_fwd.build_operation_graph() + graph_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_fwd.check_support() + graph_fwd.build_plans() + + # If backward is requested, set up backward graph. + if run_bwd: + graph_bwd = cudnn.pygraph( + io_data_type=(cudnn.data_type.FP8_E4M3 if args.data_type == "fp8" else convert_to_cudnn_type(target_dtype)), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + stats_bwd = graph_bwd.tensor_like(stats) + if is_dropout: + seed_bwd = graph_bwd.tensor_like(dropout_seed) + offset_bwd = graph_bwd.tensor_like(dropout_offset) + dropout_tuple = (dropout_prob, seed_bwd, offset_bwd) + + if args.data_type == "fp8": + q_bwd = graph_bwd.tensor_like(query).set_data_type(cudnn.data_type.FP8_E4M3) + k_bwd = graph_bwd.tensor_like(key).set_data_type(cudnn.data_type.FP8_E4M3) + v_bwd = graph_bwd.tensor_like(value).set_data_type(cudnn.data_type.FP8_E4M3) + o_bwd = graph_bwd.tensor_like(output).set_data_type(cudnn.data_type.FP8_E4M3) + dO_bwd = graph_bwd.tensor_like(dOutput).set_data_type(cudnn.data_type.FP8_E4M3) + + descale_q_bwd = graph_bwd.tensor_like(descale_q_gpu) + descale_k_bwd = graph_bwd.tensor_like(descale_k_gpu) + descale_v_bwd = graph_bwd.tensor_like(descale_v_gpu) + descale_o_bwd = graph_bwd.tensor_like(descale_o_gpu) + descale_dO_bwd = graph_bwd.tensor_like(descale_dO_gpu) + descale_s_bwd = graph_bwd.tensor_like(descale_s_gpu) + descale_dP_bwd = graph_bwd.tensor_like(descale_dP_gpu) + scale_s_bwd = graph_bwd.tensor_like(scale_s_gpu) + scale_dQ_bwd = graph_bwd.tensor_like(scale_dQ_gpu) + scale_dK_bwd = graph_bwd.tensor_like(scale_dK_gpu) + scale_dV_bwd = graph_bwd.tensor_like(scale_dV_gpu) + scale_dP_bwd = graph_bwd.tensor_like(scale_dP_gpu) + + ( + dQ_bwd, + dK_bwd, + dV_bwd, + amax_dQ_bwd, + amax_dK_bwd, + amax_dV_bwd, + amax_dP_bwd, + ) = graph_bwd.sdpa_fp8_backward( + q=q_bwd, + k=k_bwd, + v=v_bwd, + o=o_bwd, + dO=dO_bwd, + stats=stats_bwd, + descale_q=descale_q_bwd, + descale_k=descale_k_bwd, + descale_v=descale_v_bwd, + descale_o=descale_o_bwd, + descale_dO=descale_dO_bwd, + descale_s=descale_s_bwd, + descale_dP=descale_dP_bwd, + scale_s=scale_s_bwd, + scale_dQ=scale_dQ_bwd, + scale_dK=scale_dK_bwd, + scale_dV=scale_dV_bwd, + scale_dP=scale_dP_bwd, + attn_scale=attn_scale, + use_causal_mask=args.attn_mask != "no_mask" and args.attn_mask != "bottom_right", + use_causal_mask_bottom_right=args.attn_mask == "bottom_right", + dropout=dropout_tuple if is_dropout else None, + use_deterministic_algorithm=args.deterministic_bwd, + ) + else: + q_bwd = graph_bwd.tensor_like(query) + k_bwd = graph_bwd.tensor_like(key) + v_bwd = graph_bwd.tensor_like(value) + o_bwd = graph_bwd.tensor_like(output) + dO_bwd = graph_bwd.tensor_like(dOutput) + + dQ_bwd, dK_bwd, dV_bwd = graph_bwd.sdpa_backward( + q=q_bwd, + k=k_bwd, + v=v_bwd, + o=o_bwd, + dO=dO_bwd, + stats=stats_bwd, + attn_scale=attn_scale, + diagonal_alignment=(cudnn.diagonal_alignment.BOTTOM_RIGHT if args.attn_mask == "bottom_right" else cudnn.diagonal_alignment.TOP_LEFT), + diagonal_band_right_bound=None if args.attn_mask == "no_mask" else 0, + dropout=dropout_tuple if is_dropout else None, + use_deterministic_algorithm=args.deterministic_bwd, + ) + + if args.data_type == "fp8": + dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()).set_data_type(cudnn.data_type.FP8_E4M3) + amax_dQ_bwd.set_output(True).set_dim(amax_dQ_gpu.size()).set_stride(amax_dQ_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dK_bwd.set_output(True).set_dim(amax_dK_gpu.size()).set_stride(amax_dK_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dV_bwd.set_output(True).set_dim(amax_dV_gpu.size()).set_stride(amax_dV_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + amax_dP_bwd.set_output(True).set_dim(amax_dP_gpu.size()).set_stride(amax_dP_gpu.stride()).set_data_type(cudnn.data_type.FLOAT) + else: + dQ_bwd.set_output(True).set_dim(dQuery.size()).set_stride(dQuery.stride()) + dK_bwd.set_output(True).set_dim(dKey.size()).set_stride(dKey.stride()) + dV_bwd.set_output(True).set_dim(dValue.size()).set_stride(dValue.stride()) + + graph_bwd.validate() + graph_bwd.build_operation_graph() + graph_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_bwd.check_support() + graph_bwd.build_plans() + + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + dO_bwd: dOutput, + stats_bwd: stats, + descale_q_bwd: descale_q_gpu, + descale_k_bwd: descale_k_gpu, + descale_v_bwd: descale_v_gpu, + descale_o_bwd: descale_o_gpu, + descale_s_bwd: descale_s_gpu, + descale_dP_bwd: descale_dP_gpu, + descale_dO_bwd: descale_dO_gpu, + scale_s_bwd: scale_s_gpu, + scale_dQ_bwd: scale_dQ_gpu, + scale_dK_bwd: scale_dK_gpu, + scale_dV_bwd: scale_dV_gpu, + scale_dP_bwd: scale_dP_gpu, + amax_dQ_bwd: amax_dQ_gpu, + amax_dK_bwd: amax_dK_gpu, + amax_dV_bwd: amax_dV_gpu, + amax_dP_bwd: amax_dP_gpu, + } + + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) + else: + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + } + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dO_bwd: dOutput, + stats_bwd: stats, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + } + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) + else: + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) + else: + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + } + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) + if is_dropout: + variant_pack_fwd[seed_fwd] = dropout_seed + variant_pack_fwd[offset_fwd] = dropout_offset + if run_bwd: + variant_pack_bwd[seed_bwd] = dropout_seed + variant_pack_bwd[offset_bwd] = dropout_offset + ## Done setting up cuDNN graph. + + # For backends MATH, EFFICIENT_ATTENTION, CUDNN_ATTENTION, PYTORCH_FLASH_ATTENTION + def pyt_backend_sdpa(query, key, value, backend): + with sdpa_kernel(backends=[backend]): + return torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + enable_gqa=enable_gqa, + is_causal=args.attn_mask == "top_left", + attn_mask=causal_lower_right(q_seqlen, kv_seqlen) if args.attn_mask == "bottom_right" else None, + ) + + if args.sdpa_backend == "flash_attention": + import flash_attn + from flash_attn import flash_attn_func + + # Flash Attention Native + def flash_attention_sdpa(query, key, value): + return flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + + if args.sdpa_backend == "flash_attention_3": + import flash_attn_interface + + def flash_attention_3_sdpa(query, key, value): + output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + return output + + if args.sdpa_backend == "flash_attention_4" or (not args.skip_ref): + import flash_attn.cute.interface as flash_attn_interface + + def flash_attention_4_sdpa(query, key, value): + output, _ = flash_attn_interface.flash_attn_func(query, key, value, causal=args.attn_mask != "no_mask") + return output + + def get_sdpa_function(backend): + if backend == "pyt_math": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.MATH) + elif backend == "pyt_efficient_attention": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.EFFICIENT_ATTENTION) + elif backend == "pyt_flash_attention": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.FLASH_ATTENTION) + elif backend == "pyt_cudnn": + return functools.partial(pyt_backend_sdpa, backend=SDPBackend.CUDNN_ATTENTION) + elif backend == "flash_attention": + return flash_attention_sdpa + elif backend == "flash_attention_3": + return flash_attention_3_sdpa + elif backend == "flash_attention_4": + return flash_attention_4_sdpa + elif backend == "cudnn": + return None # Will be set up separately + else: + raise ValueError(f"Invalid backend: {backend}") + + # Util function for addressing different qkv formats for each backend + def preprocess_qkv(query, key, value, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return query, key, value + elif backend.startswith("flash_attention"): + query = torch.swapaxes(query, 1, 2) + key = torch.swapaxes(key, 1, 2) + value = torch.swapaxes(value, 1, 2) + return query, key, value + else: + raise ValueError(f"Invalid backend: {backend}") + + # Util function addressing different qkvo formats for each backend + def postprocess_qkvo(query, key, value, output, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return query, key, value, output + elif backend.startswith("flash_attention"): + output = torch.swapaxes(output, 1, 2) + query = torch.swapaxes(query, 1, 2) + key = torch.swapaxes(key, 1, 2) + value = torch.swapaxes(value, 1, 2) + return query, key, value, output + else: + raise ValueError(f"Invalid backend: {backend}") + + def postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, backend): + if backend.startswith("pyt_") or backend == "cudnn": + return dQuery, dKey, dValue, dOutput + elif backend.startswith("flash_attention"): + dQuery = torch.swapaxes(dQuery, 1, 2) + dKey = torch.swapaxes(dKey, 1, 2) + dValue = torch.swapaxes(dValue, 1, 2) + dOutput = torch.swapaxes(dOutput, 1, 2) + return dQuery, dKey, dValue, dOutput + else: + raise ValueError(f"Invalid backend: {backend}") + + # Util functions for calculating flops and tflops/s achieved + def flops( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + mode="fwd", + ): + assert mode in ["fwd", "bwd", "fwd_bwd"] + + if attn_mask == "no_mask": + num_nonmasked_elems = q_seqlen * kv_seqlen + elif attn_mask == "top_left": + num_nonmasked_elems = torch.tril(torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool)).sum() + elif attn_mask == "bottom_right": + diagonal_offset = kv_seqlen - q_seqlen + num_nonmasked_elems = torch.tril( + torch.ones((q_seqlen, kv_seqlen), dtype=torch.bool), + diagonal=diagonal_offset, + ).sum() + # BMM FLOPs: 2 * M * N * K. + # Here, M*N = num_nonmasked_elems per head; add batch_size * num_q_heads multiplier. + # Forward: 2 BMMs => (1 x head_dim_qk) + (1 x head_dim_vo) + # Backward: 5 BMMs => (3 x head_dim_qk) + (2 x head_dim_vo) + base = batch_size * num_q_heads * num_nonmasked_elems * 2 + if mode == "fwd": + result = base * (head_dim_qk + head_dim_vo) + elif mode == "bwd": + result = base * (3 * head_dim_qk + 2 * head_dim_vo) + else: # fwd_bwd + result = base * (4 * head_dim_qk + 3 * head_dim_vo) + return result + + def tflops_per_sec( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + time, + mode="fwd", + ): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = flops( + batch_size, + q_seqlen, + kv_seqlen, + head_dim_qk, + head_dim_vo, + num_q_heads, + attn_mask, + mode, + ) + return f / time / 1e9 if not math.isnan(time) else 0.0 # Assume time is in msec + + ###### Done setting up SDPA function for each backend ####### + ############################################################# + + ###### SDPA Benchmark -- Run ###### + ## Print System Info + if args.verbose: + print(f"[INFO] {torch.__version__ = }") + print(f"[INFO] {torch.version.cuda = }") + print(f"[INFO] {torch.cuda.is_available() = }") + print(f"[INFO] {torch.cuda.device_count() = }") + print(f"[INFO] {torch.cuda.current_device() = }") + print(f"[INFO] {torch.cuda.get_device_name(torch.cuda.current_device()) = }") + if args.sdpa_backend == "pyt_cudnn": + print(f"[INFO] {torch.backends.cudnn.version() = }") + print(f"[INFO] {torch.backends.cudnn.enabled = }") + elif args.sdpa_backend == "flash_attention": + print(f"[INFO] {flash_attn.__version__ = }") + + forward_times = [] + backward_times = [] + forward_diffs = [] + + total_iters = num_iters + dry_run_iters + + first_error = True # For suppressing error message beyond first error + sdpa_function = get_sdpa_function(args.sdpa_backend) + for i in range(total_iters): + if args.data_type == "fp8" and args.sdpa_backend == "cudnn": + query = torch.randint( + 256, + (batch_size, q_seqlen, num_q_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + key = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_qk), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + value = torch.randint( + 256, + (batch_size, kv_seqlen, num_kv_heads, head_dim_vo), + dtype=torch.uint8, + device=device, + ).transpose(1, 2) + descale_q_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_k_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_v_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dO_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + descale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_s_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_o_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dQ_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dK_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dV_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + scale_dP_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float, device=device) + amax_s_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_o_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dQ_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dK_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dV_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + amax_dP_gpu = torch.zeros(1, 1, 1, 1, dtype=torch.float, device=device) + elif args.data_type == "fp8" and args.sdpa_backend == "flash_attention_3": + query = ( + torch.randn( + batch_size, + q_seqlen, + num_q_heads, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + key = ( + torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + value = ( + torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_vo, + dtype=torch.bfloat16, + device=device, + requires_grad=True, + ) + .to(torch.float8_e4m3fn) + .transpose(1, 2) + ) + else: + query = torch.randn( + batch_size, + q_seqlen, + num_q_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + key = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_qk, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + value = torch.randn( + batch_size, + kv_seqlen, + num_kv_heads, + head_dim_vo, + dtype=target_dtype, + device=device, + requires_grad=True, + ).transpose(1, 2) + + query, key, value = preprocess_qkv(query, key, value, args.sdpa_backend) + if args.data_type == "fp8" and args.sdpa_backend == "cudnn": + # Create as bfloat16, convert to FP8, then view as uint8 to avoid DLPack issues + dOutput_bf16 = torch.randn(query.shape, dtype=torch.bfloat16, device=device) + dOutput_fp8 = dOutput_bf16.to(torch.float8_e4m3fn) + dOutput = dOutput_fp8.view(torch.uint8) + else: + dOutput = torch.randn_like(query) + + if args.sdpa_backend == "cudnn": + output = torch.empty( + batch_size, + q_seqlen, + num_q_heads, + head_dim_vo, + dtype=torch.uint8 if args.data_type == "fp8" else target_dtype, + device=device, + ).transpose(1, 2) + dQuery = torch.empty_like(query) + dKey = torch.empty_like(key) + dValue = torch.empty_like(value) + stats = torch.randn(batch_size, q_seqlen, num_q_heads, 1, dtype=torch.float32, device=device).transpose(1, 2) + if is_dropout: + dropout_seed = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") + dropout_offset = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") + + # Only variant pack and workspace need to be updated for each iteration. + if run_bwd: + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + dO_bwd: dOutput, + stats_bwd: stats, + descale_q_bwd: descale_q_gpu, + descale_k_bwd: descale_k_gpu, + descale_v_bwd: descale_v_gpu, + descale_o_bwd: descale_o_gpu, + descale_s_bwd: descale_s_gpu, + descale_dP_bwd: descale_dP_gpu, + descale_dO_bwd: descale_dO_gpu, + scale_s_bwd: scale_s_gpu, + scale_dQ_bwd: scale_dQ_gpu, + scale_dK_bwd: scale_dK_gpu, + scale_dV_bwd: scale_dV_gpu, + scale_dP_bwd: scale_dP_gpu, + amax_dQ_bwd: amax_dQ_gpu, + amax_dK_bwd: amax_dK_gpu, + amax_dV_bwd: amax_dV_gpu, + amax_dP_bwd: amax_dP_gpu, + } + else: + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + } + variant_pack_bwd = { + q_bwd: query, + k_bwd: key, + v_bwd: value, + o_bwd: output, + dO_bwd: dOutput, + stats_bwd: stats, + dQ_bwd: dQuery, + dK_bwd: dKey, + dV_bwd: dValue, + } + workspace = torch.empty( + max(graph_fwd.get_workspace_size(), graph_bwd.get_workspace_size()), + device="cuda", + dtype=torch.uint8, + ) + else: + if args.data_type == "fp8": + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + stats_fwd: stats, + descale_q_fwd: descale_q_gpu, + descale_k_fwd: descale_k_gpu, + descale_v_fwd: descale_v_gpu, + descale_s_fwd: descale_s_gpu, + scale_s_fwd: scale_s_gpu, + scale_o_fwd: scale_o_gpu, + amax_s_fwd: amax_s_gpu, + amax_o_fwd: amax_o_gpu, + } + else: + variant_pack_fwd = { + q_fwd: query, + k_fwd: key, + v_fwd: value, + o_fwd: output, + } + workspace = torch.empty(graph_fwd.get_workspace_size(), device="cuda", dtype=torch.uint8) + + if is_dropout: + variant_pack_fwd[seed_fwd] = dropout_seed + variant_pack_fwd[offset_fwd] = dropout_offset + if run_bwd: + variant_pack_bwd[seed_bwd] = dropout_seed + variant_pack_bwd[offset_bwd] = dropout_offset + + l2_flush_buffer.zero_() + + # Run kernel with profiler for forward if requested, else run unprofiled to prep for backward + if run_fwd: + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("sdpa.forward"): # Custom marker + if args.sdpa_backend == "cudnn": + graph_fwd.execute(variant_pack_fwd, workspace) + else: + output = sdpa_function(query, key, value) + torch.cuda.synchronize() # Ensure all kernels finish + + # Filter profiler results by kernel name prefix + matched_kernels = [ + item + for item in prof.key_averages() + if item.key.startswith("cudnn") + or item.key.startswith("kernel_cutlass") + or "pytorch_flash::" in item.key + or "flash::" in item.key + or "at::native::" in item.key + or "cutlass3x" in item.key + or "(anonymous namespace)::" in item.key + or item.key.startswith("fmha_") + ] + if len(matched_kernels) >= 1: + fwd_time = sum(item.device_time for item in matched_kernels) / 1000 + if i >= dry_run_iters: + forward_times.append(fwd_time) + else: + if args.sdpa_backend == "cudnn": + graph_fwd.execute(variant_pack_fwd, workspace) + else: + output = sdpa_function(query, key, value) + torch.cuda.synchronize() + + # Sleep for some time proportional to fwd_time for stable measurements + sleep_time = np.min([fwd_time / 100, 1.0]) if run_fwd and len(matched_kernels) >= 1 else 0.0 + time.sleep(sleep_time) + + if run_bwd: + # Run backward pass + + l2_flush_buffer.zero_() + + with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("sdpa.backward"): # Custom marker + if args.sdpa_backend == "cudnn": + graph_bwd.execute(variant_pack_bwd, workspace) + else: + query.retain_grad() + key.retain_grad() + value.retain_grad() + output.backward(dOutput) + + dQuery = query.grad + dKey = key.grad + dValue = value.grad + + query.grad = None + key.grad = None + value.grad = None + torch.cuda.synchronize() + + matched_kernels = [ + item + for item in prof.key_averages() + if "cudnn" in item.key + or item.key.startswith("kernel_cutlass") + or "pytorch_flash::" in item.key + or "flash::" in item.key + or "at::native::" in item.key + or "cutlass3x" in item.key + or "(anonymous namespace)::" in item.key + or item.key.startswith("fmha_") + ] + if len(matched_kernels) >= 1: + bwd_time = sum(item.device_time for item in matched_kernels) / 1000 + if i >= dry_run_iters: + backward_times.append(bwd_time) + + sleep_time = np.min([bwd_time / 100, 1.0]) if run_bwd and len(matched_kernels) >= 1 else 0.0 + time.sleep(sleep_time) + + dQuery, dKey, dValue, dOutput = postprocess_dqdkdvdo(dQuery, dKey, dValue, dOutput, args.sdpa_backend) + + ( + query, + key, + value, + output, + ) = postprocess_qkvo(query, key, value, output, args.sdpa_backend) + if args.data_type != "fp8" and not args.skip_ref and run_fwd: + try: + output_ref = flash_attention_4_sdpa(query, key, value) + if run_bwd: + query.retain_grad() + key.retain_grad() + value.retain_grad() + output_ref.backward(dOutput) + + torch.testing.assert_close(dQuery, query.grad, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(dKey, key.grad, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(dValue, value.grad, rtol=2e-2, atol=2e-2) + + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) + forward_diffs.append(torch.max(torch.abs(output.detach() - output_ref.detach())).item()) + except Exception as e: + if first_error: + print( + f"[WARN] Failed reference check. Target backend has been run, but output has not been validated. Failure may be due to incorrect output or reference function failure." + ) + print(f"[WARN] See error message: {e}") + first_error = False + forward_diffs.append(0.0) + else: + forward_diffs.append(0.0) + + time.sleep(sleep_time) + + if args.sdpa_backend == "cudnn": + del query, key, value, output, dQuery, dKey, dValue, dOutput, stats + else: + del query, key, value, output + + ## print results + fwd_median_time = ( + np.median(np.array(forward_times[5:])) if len(forward_times) > 5 else (np.median(np.array(forward_times)) if len(forward_times) > 0 else 0.0) + ) + fwd_tflops = 0.0 + if run_fwd and fwd_median_time > 0: + fwd_tflops = tflops_per_sec( + args.batch_size, + args.q_seqlen, + args.kv_seqlen, + head_dim_qk, + head_dim_vo, + args.num_q_heads, + args.attn_mask, + fwd_median_time, + "fwd", + ) + + bwd_median_time = ( + np.median(np.array(backward_times[5:])) if len(backward_times) > 5 else (np.median(np.array(backward_times)) if len(backward_times) > 0 else 0.0) + ) + bwd_tflops = 0.0 + if run_bwd and bwd_median_time > 0: + bwd_tflops = tflops_per_sec( + args.batch_size, + args.q_seqlen, + args.kv_seqlen, + head_dim_qk, + head_dim_vo, + args.num_q_heads, + args.attn_mask, + bwd_median_time, + "bwd", + ) + + if args.format_output: + print( + f"{args.case_tag},{args.sdpa_backend},{args.batch_size},{args.q_seqlen},{args.kv_seqlen},{args.num_q_heads},{args.num_kv_heads},{head_dim_qk},{fwd_median_time:.3f},{bwd_median_time:.3f},{fwd_tflops:.0f},{bwd_tflops:.0f},{(np.max(np.array(forward_diffs[5:])) if len(forward_diffs) > 5 else (np.max(np.array(forward_diffs)) if len(forward_diffs) > 0 else 0.0)):.6f},{num_iters}" + ) + else: + if run_fwd and run_bwd: + print( + f"{args.sdpa_backend}:: Median (fwd, bwd) Execution Times: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS), {bwd_median_time:.3f} ms ({bwd_tflops:.0f} TFLOPS)" + ) + elif run_fwd: + print(f"{args.sdpa_backend}:: Median (fwd) Execution Time: {fwd_median_time:.3f} ms ({fwd_tflops:.0f} TFLOPS)") + elif run_bwd: + print(f"{args.sdpa_backend}:: Median (bwd) Execution Time: {bwd_median_time:.3f} ms ({bwd_tflops:.0f} TFLOPS)") diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/charts.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/charts.py new file mode 100644 index 00000000..7c9872d5 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/charts.py @@ -0,0 +1,441 @@ +""" +Chart generation for SDPA benchmark results. + +Generates comparison bar charts showing backend performance side-by-side. +""" + +from pathlib import Path +from typing import Optional, TYPE_CHECKING +import logging + +if TYPE_CHECKING: + import pandas as pd + from .config_types import BenchmarkConfig + +logger = logging.getLogger(__name__) + +# Backend display configuration +# Each backend has a base color; FP8 variants get a darker/different shade +BACKEND_CONFIG = { + "cudnn": {"name": "cudnn", "color": "#76b900", "color_fp8": "#4a7500", "order": 0}, + "pyt_cudnn": {"name": "cuDNN (PyTorch)", "color": "#90EE90", "color_fp8": "#228B22", "order": 1}, + "pyt_flash_attention": {"name": "FAv2 (PyTorch)", "color": "#6495ED", "color_fp8": "#0000CD", "order": 2}, + "pyt_efficient_attention": {"name": "xFormers (PyTorch)", "color": "#FF00FF", "color_fp8": "#8B008B", "order": 3}, + "pyt_math": {"name": "Standard Attention", "color": "#FF8C00", "color_fp8": "#D2691E", "order": 4}, + "flash_attention": {"name": "FAv2 (Native)", "color": "#F08080", "color_fp8": "#CD5C5C", "order": 5}, + "flash_attention_3": {"name": "FAv3", "color": "#FFA500", "color_fp8": "#FF6600", "order": 6}, + "flash_attention_4": {"name": "FAv4", "color": "#FFD700", "color_fp8": "#DAA520", "order": 7}, +} + +# Font sizes for plot elements +LABEL_FONT_SIZE = 10 +LEGEND_FONT_SIZE = 8 +TITLE_FONT_SIZE = 12 +BAR_LABEL_FONT_SIZE = 6 + + +def get_backend_display_name(backend: str, data_type: str) -> str: + """ + Get display name for backend+dtype combination. + + Args: + backend: Backend name (e.g., "cudnn") + data_type: Data type (e.g., "bfloat16", "fp8") + + Returns: + Display name for legend (e.g., "cuDNN FE (FP8)") + """ + base_name = BACKEND_CONFIG.get(backend, {}).get("name", backend) + if data_type == "fp8": + return f"{base_name} (FP8)" + elif data_type == "float16": + return f"{base_name} (FP16)" + return base_name + + +def get_backend_color(backend: str, data_type: str) -> str: + """ + Get color for backend+dtype combination. + + Args: + backend: Backend name + data_type: Data type + + Returns: + Color string for matplotlib + """ + config = BACKEND_CONFIG.get(backend, {}) + if data_type == "fp8" and "color_fp8" in config: + return config["color_fp8"] + return config.get("color", "gray") + + +def generate_comparison_chart( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_path: Optional[Path] = None, +) -> Path: + """ + Generate comparison bar chart with multiple backends side-by-side. + + Creates a figure with: + - Left subplot: Forward pass TFLOPS by configuration + - Right subplot: Backward pass TFLOPS by configuration + - Each backend+dtype combo as a separate bar group + + Args: + df: DataFrame with benchmark results (from BenchmarkRunner.results_to_dataframe) + config: BenchmarkConfig used for the run + output_path: Optional path for output file. If None, uses config.output_dir + + Returns: + Path to the saved chart file + """ + import matplotlib.pyplot as plt + import seaborn as sns + import numpy as np + + # Filter to successful results only + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + # Create backend+dtype display name for legend + df["backend_display"] = df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + + # Create config label for x-axis (model/seqlen/mask) + df["config_label"] = df.apply( + lambda r: f"{r['model_name']}\n{r['q_seqlen']}x{r['kv_seqlen']}\n{r['attn_mask']}", + axis=1, + ) + + # Sort by backend order for consistent legend + df["backend_order"] = df["backend"].map(lambda b: BACKEND_CONFIG.get(b, {}).get("order", 99)) + df.sort_values(["model_name", "q_seqlen", "attn_mask", "backend_order"], inplace=True) + + # Build color palette based on unique backend+dtype combinations + # Get unique (backend, data_type, backend_display) tuples to map colors correctly + unique_combos = df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Determine if we have fwd/bwd data + has_fwd = (df["fwd_tflops"] > 0).any() + has_bwd = (df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, axes = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + ax_fwd, ax_bwd = axes + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + elif has_bwd: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + else: + raise ValueError("No forward or backward TFLOPS data to plot") + + # Calculate y-axis limit + max_tflops = max( + df["fwd_tflops"].max() if has_fwd else 0, + df["bwd_tflops"].max() if has_bwd else 0, + ) + ylim_max = max_tflops * 1.15 # Add 15% headroom for labels + + # Plot forward pass + if ax_fwd is not None: + fwd_df = df[df["fwd_tflops"] > 0] + if not fwd_df.empty: + sns.barplot( + data=fwd_df, + x="config_label", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Configuration", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title("SDPA Forward Pass", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45, labelsize=8) + ax_fwd.tick_params(axis="y", labelsize=LABEL_FONT_SIZE) + ax_fwd.set_ylim(0, ylim_max) + + # Add value labels on bars + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + # Plot backward pass + if ax_bwd is not None: + bwd_df = df[df["bwd_tflops"] > 0] + if not bwd_df.empty: + sns.barplot( + data=bwd_df, + x="config_label", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Configuration", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title("SDPA Backward Pass", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45, labelsize=8) + ax_bwd.tick_params(axis="y", labelsize=LABEL_FONT_SIZE) + ax_bwd.set_ylim(0, ylim_max) + + # Add value labels on bars + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + + # Determine output path + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{config.name}_comparison.png" + + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + logger.info(f"Chart saved to {output_path}") + return output_path + + +def generate_charts_by_mask( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_dir: Optional[Path] = None, +) -> list: + """ + Generate separate charts for each mask type. + + This creates cleaner charts when benchmarking both causal and non-causal masks. + Each chart shows seqlen on x-axis and backends as grouped bars. + + Args: + df: DataFrame with benchmark results + config: BenchmarkConfig used for the run + output_dir: Directory for output files + + Returns: + List of paths to saved chart files + """ + import matplotlib.pyplot as plt + import seaborn as sns + + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + if output_dir is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + saved_paths = [] + masks = df["attn_mask"].unique() + + for mask in masks: + mask_df = df[df["attn_mask"] == mask].copy() + + # Create display names + mask_df["backend_display"] = mask_df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + mask_df["seqlen_label"] = mask_df.apply(lambda r: f"{r['q_seqlen']}x{r['kv_seqlen']}", axis=1) + + # Build palette + unique_combos = mask_df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Sort + mask_df["backend_order"] = mask_df["backend"].map(lambda b: BACKEND_CONFIG.get(b, {}).get("order", 99)) + mask_df.sort_values(["q_seqlen", "backend_order"], inplace=True) + + has_fwd = (mask_df["fwd_tflops"] > 0).any() + has_bwd = (mask_df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, (ax_fwd, ax_bwd) = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + else: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + + mask_title = "Causal" if mask == "top_left" else "Non-Causal" if mask == "no_mask" else mask + + if ax_fwd is not None: + fwd_df = mask_df[mask_df["fwd_tflops"] > 0] + if not fwd_df.empty: + sns.barplot( + data=fwd_df, + x="seqlen_label", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title(f"{config.name} Forward ({mask_title})", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45) + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + if ax_bwd is not None: + bwd_df = mask_df[mask_df["bwd_tflops"] > 0] + if not bwd_df.empty: + sns.barplot( + data=bwd_df, + x="seqlen_label", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title(f"{config.name} Backward ({mask_title})", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45) + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + output_path = output_dir / f"{config.name}_{mask}.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + saved_paths.append(output_path) + logger.info(f"Chart saved to {output_path}") + + return saved_paths + + +def generate_seqlen_scaling_chart( + df: "pd.DataFrame", + config: "BenchmarkConfig", + output_path: Optional[Path] = None, +) -> Path: + """ + Generate a chart showing performance scaling with sequence length. + + This chart is useful when benchmarking multiple sequence lengths with + the same model configuration. + + Args: + df: DataFrame with benchmark results + config: BenchmarkConfig used for the run + output_path: Optional path for output file + + Returns: + Path to the saved chart file + """ + import matplotlib.pyplot as plt + import seaborn as sns + + # Filter to successful results only + df = df[df["success"] == True].copy() + + if df.empty: + raise ValueError("No successful results to plot") + + # Create backend+dtype display name + df["backend_display"] = df.apply(lambda r: get_backend_display_name(r["backend"], r["data_type"]), axis=1) + + # Use q_seqlen for x-axis (assuming symmetric seqlens for this chart) + df["seqlen"] = df["q_seqlen"] + + # Build color palette based on unique backend+dtype combinations + unique_combos = df[["backend", "data_type", "backend_display"]].drop_duplicates() + palette = {} + for _, row in unique_combos.iterrows(): + palette[row["backend_display"]] = get_backend_color(row["backend"], row["data_type"]) + + # Create figure + has_fwd = (df["fwd_tflops"] > 0).any() + has_bwd = (df["bwd_tflops"] > 0).any() + + if has_fwd and has_bwd: + fig, axes = plt.subplots(1, 2, figsize=(14, 6), dpi=150) + ax_fwd, ax_bwd = axes + elif has_fwd: + fig, ax_fwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_bwd = None + else: + fig, ax_bwd = plt.subplots(1, 1, figsize=(10, 6), dpi=150) + ax_fwd = None + + # Plot forward + if ax_fwd is not None and has_fwd: + fwd_df = df[df["fwd_tflops"] > 0] + sns.barplot( + data=fwd_df, + x="seqlen", + y="fwd_tflops", + hue="backend_display", + ax=ax_fwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_fwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_fwd.set_title("SDPA Forward Pass", fontsize=TITLE_FONT_SIZE) + ax_fwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_fwd.tick_params(axis="x", rotation=45) + + for container in ax_fwd.containers: + ax_fwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + # Plot backward + if ax_bwd is not None and has_bwd: + bwd_df = df[df["bwd_tflops"] > 0] + sns.barplot( + data=bwd_df, + x="seqlen", + y="bwd_tflops", + hue="backend_display", + ax=ax_bwd, + palette=palette, + edgecolor="black", + linewidth=0.5, + ) + ax_bwd.set_xlabel("Sequence Length", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_ylabel("TFLOPS", fontsize=LABEL_FONT_SIZE) + ax_bwd.set_title("SDPA Backward Pass", fontsize=TITLE_FONT_SIZE) + ax_bwd.legend(title="Backend", fontsize=LEGEND_FONT_SIZE) + ax_bwd.tick_params(axis="x", rotation=45) + + for container in ax_bwd.containers: + ax_bwd.bar_label(container, fmt="%.0f", fontsize=BAR_LABEL_FONT_SIZE) + + plt.tight_layout() + + # Determine output path + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"{config.name}_seqlen_scaling.png" + + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + logger.info(f"Chart saved to {output_path}") + return output_path diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/config_types.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/config_types.py new file mode 100644 index 00000000..ee9df6e5 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/config_types.py @@ -0,0 +1,183 @@ +""" +Core types for the SDPA benchmark configuration system. + +This module defines the dataclasses used to configure and collect results +from SDPA benchmarks. +""" + +from dataclasses import dataclass, field +from typing import Optional, List, Tuple + + +@dataclass +class ModelPreset: + """ + Represents a named model configuration preset. + + Defines the attention head configuration for a specific model architecture. + Can use either symmetric head dimensions (head_dim) or asymmetric + (head_dim_qk, head_dim_vo) for models like DeepSeek V3. + + Attributes: + name: Identifier for this preset (e.g., "llama3.1", "dsv3") + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads (differs from num_q_heads for GQA) + head_dim: Head dimension (used if head_dim_qk/vo not specified) + head_dim_qk: Head dimension for Q/K tensors (optional, for asymmetric) + head_dim_vo: Head dimension for V/O tensors (optional, for asymmetric) + + Example: + # Symmetric head dimensions (Llama 3.1) + LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, + ) + + # Asymmetric head dimensions (DeepSeek V3) + DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, + head_dim_vo=128, + ) + """ + + name: str + num_q_heads: int + num_kv_heads: int + head_dim: int = 128 + head_dim_qk: Optional[int] = None + head_dim_vo: Optional[int] = None + + def __post_init__(self): + """Resolve head dimensions after initialization.""" + if self.head_dim_qk is None: + self.head_dim_qk = self.head_dim + if self.head_dim_vo is None: + self.head_dim_vo = self.head_dim + + +@dataclass +class BenchmarkConfig: + """ + Configuration for a benchmark suite. + + Defines a set of benchmarks to run. The runner will expand this into + individual benchmark cases via cartesian product of: + models x seqlens x backends x data_types x attn_masks x deterministic_bwd + + Attributes: + name: Identifier for this config (used in output filenames) + models: List of ModelPreset to benchmark + seqlens: List of (q_seqlen, kv_seqlen) tuples + backends: List of backend names (e.g., ["cudnn", "flash_attention_4"]) + data_types: List of data types (e.g., ["bfloat16", "fp8"]) + attn_masks: List of attention masks (e.g., ["top_left", "no_mask"]) + profile_pass: Which pass to profile ("fwd", "bwd", or "both") + batch_size: Batch size for all benchmarks + num_iterations: Number of iterations per benchmark + num_warmup_iterations: Warmup iterations before measurement + skip_ref: Skip reference validation + deterministic_bwd: List of deterministic modes to test for backward pass + output_dir: Directory for output files + + Example: + CONFIG = BenchmarkConfig( + name="my_benchmark", + models=[LLAMA3_1, DSV3], + seqlens=[(4096, 4096), (8192, 8192)], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], + profile_pass="fwd", + ) + """ + + name: str + models: List[ModelPreset] + seqlens: List[Tuple[int, int]] + backends: List[str] = field(default_factory=lambda: ["cudnn"]) + data_types: List[str] = field(default_factory=lambda: ["bfloat16"]) + attn_masks: List[str] = field(default_factory=lambda: ["top_left"]) + profile_pass: str = "fwd" + batch_size: int = 1 + num_iterations: int = 10 + num_warmup_iterations: int = 0 + skip_ref: bool = True + deterministic_bwd: List[bool] = field(default_factory=lambda: [False]) + output_dir: str = "../results" + + +@dataclass +class BenchmarkResult: + """ + Result from a single benchmark execution. + + Contains both the configuration that was run and the measured results. + + Attributes: + config_name: Name of the BenchmarkConfig this result belongs to + model_name: Name of the ModelPreset used + backend: Backend that was used + data_type: Data type that was used + attn_mask: Attention mask that was used + batch_size: Batch size + q_seqlen: Query sequence length + kv_seqlen: Key/value sequence length + num_q_heads: Number of query heads + num_kv_heads: Number of key/value heads + head_dim_qk: Head dimension for Q/K + head_dim_vo: Head dimension for V/O + profile_pass: Which pass was profiled + deterministic_bwd: Whether deterministic backward was used + fwd_time_ms: Forward pass time in milliseconds + bwd_time_ms: Backward pass time in milliseconds (0 if not run) + fwd_tflops: Forward pass throughput in TFLOPS + bwd_tflops: Backward pass throughput in TFLOPS + max_diff: Maximum difference vs reference (if validated) + num_iterations: Number of iterations run + success: Whether the benchmark completed successfully + error_message: Error message if benchmark failed + gpu_name: Name of the GPU used + cudnn_version: cuDNN version string + """ + + # Config identification + config_name: str + model_name: str + backend: str + data_type: str + attn_mask: str + + # Dimensions + batch_size: int + q_seqlen: int + kv_seqlen: int + num_q_heads: int + num_kv_heads: int + head_dim_qk: int + head_dim_vo: int + + # Execution options + profile_pass: str + deterministic_bwd: bool + + # Results + fwd_time_ms: float + bwd_time_ms: float + fwd_tflops: float + bwd_tflops: float + max_diff: float + num_iterations: int + + # Status + success: bool = True + error_message: Optional[str] = None + + # Metadata + gpu_name: Optional[str] = None + cudnn_version: Optional[str] = None + cudnn_backend_version: Optional[int] = None diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/__init__.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/__init__.py new file mode 100644 index 00000000..37bf2d67 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/__init__.py @@ -0,0 +1,62 @@ +""" +Benchmark configuration loading utilities. + +This module provides functions to load benchmark configurations by name. +""" + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..config_types import BenchmarkConfig + + +def load_config(name: str) -> "BenchmarkConfig": + """ + Load a benchmark configuration by name. + + Configurations are Python modules in the configs directory. + Each module should define a CONFIG variable of type BenchmarkConfig. + + Args: + name: Name of the config (without .py extension) + + Returns: + BenchmarkConfig instance + + Raises: + ValueError: If config not found or doesn't define CONFIG + + Example: + config = load_config("mlperf") + print(config.name) # "mlperf" + """ + try: + module = importlib.import_module(f".{name}", package=__package__) + except ModuleNotFoundError: + raise ValueError(f"Config '{name}' not found. " f"Create a file at configs/{name}.py with a CONFIG variable.") + + if not hasattr(module, "CONFIG"): + raise ValueError(f"Config module '{name}' must define a CONFIG variable of type BenchmarkConfig") + + return module.CONFIG + + +def list_configs() -> list: + """ + List available config names. + + Returns: + List of config names (without .py extension) + """ + import os + from pathlib import Path + + configs_dir = Path(__file__).parent + configs = [] + + for f in configs_dir.iterdir(): + if f.suffix == ".py" and f.stem != "__init__": + configs.append(f.stem) + + return sorted(configs) diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/dsv3.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/dsv3.py new file mode 100644 index 00000000..404842c7 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/dsv3.py @@ -0,0 +1,41 @@ +""" +DeepSeek V3 SDPA Benchmark Configuration + +Benchmarks DeepSeek V3-style MHA with asymmetric head dimensions. +Only causal (top_left) mask - no non-causal benchmarks needed. +Includes forward and backward pass benchmarking with deterministic mode options. + +Usage: + python -m benchmark.sdpa_benchmark_training.runner --config dsv3 + python -m benchmark.sdpa_benchmark_training.runner --config dsv3 --dry-run +""" + +from ..config_types import ModelPreset, BenchmarkConfig + +DSV3 = ModelPreset( + name="dsv3", + num_q_heads=128, + num_kv_heads=128, + head_dim_qk=192, + head_dim_vo=128, +) + +CONFIG = BenchmarkConfig( + name="dsv3", + models=[DSV3], + seqlens=[ + (32768, 32768), + (16384, 16384), + (8192, 8192), + (4096, 4096), + (2048, 2048), + ], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left"], # Causal only + profile_pass="both", # Forward and backward + deterministic_bwd=[True], + batch_size=1, + num_iterations=10, + output_dir="results", +) diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/llama.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/llama.py new file mode 100644 index 00000000..803db837 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/configs/llama.py @@ -0,0 +1,39 @@ +""" +Llama 3.1 SDPA Benchmark Configuration + +Benchmarks Llama 3.1 405B-style GQA attention with both causal and non-causal masks. +Includes forward and backward pass benchmarking with deterministic mode options. + +Usage: + python -m benchmark.sdpa_benchmark_training.runner --config llama + python -m benchmark.sdpa_benchmark_training.runner --config llama --dry-run +""" + +from ..config_types import ModelPreset, BenchmarkConfig + +LLAMA3_1 = ModelPreset( + name="llama3.1", + num_q_heads=64, + num_kv_heads=8, + head_dim=128, +) + +CONFIG = BenchmarkConfig( + name="llama3.1", + models=[LLAMA3_1], + seqlens=[ + (32768, 32768), + (16384, 16384), + (8192, 8192), + (4096, 4096), + (2048, 2048), + ], + backends=["cudnn", "flash_attention_4"], + data_types=["bfloat16", "fp8"], + attn_masks=["top_left", "no_mask"], # Both causal and non-causal + profile_pass="both", # Forward and backward + deterministic_bwd=[False], + batch_size=1, + num_iterations=10, + output_dir="results", +) diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv new file mode 100644 index 00000000..cabd5912 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_20260126_110621.csv @@ -0,0 +1,41 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,1,32768,32768,128,128,192,128,both,True,24.538,87.230,1792.000,1311.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,32768,32768,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,16384,16384,128,128,192,128,both,True,6.476,22.025,1698.000,1298.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,16384,16384,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,8192,8192,128,128,192,128,both,True,1.831,5.875,1501.000,1217.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,8192,8192,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,4096,4096,128,128,192,128,both,True,0.519,1.650,1324.000,1083.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,4096,4096,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,2048,2048,128,128,192,128,both,True,0.163,0.520,1053.000,859.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,2048,2048,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png new file mode 100644 index 00000000..a79e4809 Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/dsv3_top_left_causal.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv new file mode 100644 index 00000000..11f9fa9f --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_20260126_110503.csv @@ -0,0 +1,21 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,32768,32768,64,8,128,128,both,False,10.436,30.513,1686.000,1441.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,32768,32768,64,8,128,128,both,False,20.041,59.879,1756.000,1469.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,32768,32768,64,8,128,128,both,False,8.317,25.675,2115.000,1713.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,32768,32768,64,8,128,128,both,False,16.521,49.482,2130.000,1778.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,16384,16384,64,8,128,128,both,False,2.672,8.018,1646.000,1371.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,16384,16384,64,8,128,128,both,False,5.037,15.384,1746.000,1429.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,16384,16384,64,8,128,128,both,False,2.182,6.730,2016.000,1634.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,16384,16384,64,8,128,128,both,False,4.240,12.707,2075.000,1731.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,8192,8192,64,8,128,128,both,False,0.704,2.150,1563.000,1279.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,8192,8192,64,8,128,128,both,False,1.313,3.980,1675.000,1381.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,8192,8192,64,8,128,128,both,False,0.591,1.851,1862.000,1485.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,8192,8192,64,8,128,128,both,False,1.133,3.385,1941.000,1624.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,4096,4096,64,8,128,128,both,False,0.212,0.622,1297.000,1105.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,4096,4096,64,8,128,128,both,False,0.350,1.090,1569.000,1261.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,4096,4096,64,8,128,128,both,False,0.172,0.555,1602.000,1239.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,4096,4096,64,8,128,128,both,False,0.299,0.941,1841.000,1461.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,2048,2048,64,8,128,128,both,False,0.067,0.209,1022.000,824.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,2048,2048,64,8,128,128,both,False,0.112,0.321,1232.000,1070.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,2048,2048,64,8,128,128,both,False,0.057,0.190,1215.000,905.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,2048,2048,64,8,128,128,both,False,0.090,0.284,1521.000,1210.000,0.000,10,True,,NVIDIA GB200,1.18.0,91801 diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png new file mode 100644 index 00000000..ac4bced3 Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_no_mask.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png new file mode 100644 index 00000000..f5f3a306 Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb200_918_only_cudnn/llama3.1_top_left_causal.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv new file mode 100644 index 00000000..9dac38b4 --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_20260126_110622.csv @@ -0,0 +1,41 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +dsv3,dsv3,cudnn,bfloat16,top_left,1,32768,32768,128,128,192,128,both,True,21.319,80.520,2063.000,1420.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,32768,32768,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,16384,16384,128,128,192,128,both,True,5.584,20.381,1969.000,1403.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,16384,16384,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,8192,8192,128,128,192,128,both,True,1.518,5.412,1811.000,1321.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,8192,8192,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,4096,4096,128,128,192,128,both,True,0.438,1.541,1570.000,1160.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,4096,4096,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, +dsv3,dsv3,cudnn,bfloat16,top_left,1,2048,2048,128,128,192,128,both,True,0.148,0.493,1158.000,906.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801.000 +dsv3,dsv3,cudnn,fp8,top_left,1,2048,2048,128,128,192,128,both,True,inf,inf,0.000,0.000,0.000,10,False,"Benchmark failed with return code 1. +stderr: Traceback (most recent call last): + File ""/workspace/cudnn_frontend/benchmark/sdpa_benchmark_training/benchmark_single_sdpa.py"", line 560, in + graph_fwd.validate() +cudnn._compiled_module.cudnnGraphNotSupportedError: hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+) + +stdout: ",,, diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png new file mode 100644 index 00000000..39cdbe44 Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/dsv3_top_left_causal.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv new file mode 100644 index 00000000..433b02ce --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_20260126_110426.csv @@ -0,0 +1,21 @@ +config_name,model_name,backend,data_type,attn_mask,batch_size,q_seqlen,kv_seqlen,num_q_heads,num_kv_heads,head_dim_qk,head_dim_vo,profile_pass,deterministic_bwd,fwd_time_ms,bwd_time_ms,fwd_tflops,bwd_tflops,max_diff,num_iterations,success,error_message,gpu_name,cudnn_version,cudnn_backend_version +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,32768,32768,64,8,128,128,both,False,8.663,28.331,2031.000,1552.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,32768,32768,64,8,128,128,both,False,17.400,56.680,2022.000,1552.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,32768,32768,64,8,128,128,both,False,5.942,23.707,2961.000,1855.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,32768,32768,64,8,128,128,both,False,11.782,45.618,2986.000,1928.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,16384,16384,64,8,128,128,both,False,2.202,7.361,1998.000,1494.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,16384,16384,64,8,128,128,both,False,4.396,14.124,2001.000,1557.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,16384,16384,64,8,128,128,both,False,1.577,6.233,2789.000,1764.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,16384,16384,64,8,128,128,both,False,3.025,11.772,2907.000,1868.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,8192,8192,64,8,128,128,both,False,0.571,1.976,1927.000,1391.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,8192,8192,64,8,128,128,both,False,1.118,3.670,1967.000,1498.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,8192,8192,64,8,128,128,both,False,0.434,1.728,2534.000,1591.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,8192,8192,64,8,128,128,both,False,0.807,3.154,2724.000,1743.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,4096,4096,64,8,128,128,both,False,0.164,0.574,1679.000,1198.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,4096,4096,64,8,128,128,both,False,0.289,1.016,1901.000,1352.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,4096,4096,64,8,128,128,both,False,0.129,0.527,2136.000,1305.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,4096,4096,64,8,128,128,both,False,0.213,0.884,2580.000,1554.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,top_left,1,2048,2048,64,8,128,128,both,False,0.054,0.191,1265.000,900.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,bfloat16,no_mask,1,2048,2048,64,8,128,128,both,False,0.088,0.299,1559.000,1151.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,top_left,1,2048,2048,64,8,128,128,both,False,0.044,0.181,1574.000,947.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 +llama3.1,llama3.1,cudnn,fp8,no_mask,1,2048,2048,64,8,128,128,both,False,0.066,0.275,2086.000,1251.000,0.000,10,True,,NVIDIA GB300,1.18.0,91801 diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png new file mode 100644 index 00000000..312bab6f Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_no_mask.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png new file mode 100644 index 00000000..dc6fd4d4 Binary files /dev/null and b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/results/gb300_918_only_cudnn/llama3.1_top_left_causal.png differ diff --git a/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/runner.py b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/runner.py new file mode 100644 index 00000000..e1201eae --- /dev/null +++ b/third_party/cudnn-frontend/benchmark/sdpa_benchmark_training/runner.py @@ -0,0 +1,505 @@ +""" +Benchmark runner with configuration expansion, execution, and result collection. + +This module provides the BenchmarkRunner class for running SDPA benchmarks +from configuration files, and a CLI entry point. + +Usage: + # Run from command line + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Import and use programmatically + from benchmark.sdpa_benchmark_training.runner import BenchmarkRunner + from benchmark.sdpa_benchmark_training.configs import load_config + + config = load_config("mlperf") + runner = BenchmarkRunner() + results = runner.run_config(config) + runner.save_csv(results, config) +""" + +import itertools +import logging +import sys +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Iterator, List, Optional + +from .config_types import BenchmarkConfig, BenchmarkResult, ModelPreset + +logger = logging.getLogger(__name__) + + +def log_environment_info(): + """Log environment information (torch, CUDA, cuDNN, flash_attn versions).""" + try: + import torch + + logger.info(f"torch.__version__ = '{torch.__version__}'") + logger.info(f"torch.version.cuda = '{torch.version.cuda}'") + logger.info(f"torch.cuda.is_available() = {torch.cuda.is_available()}") + if torch.cuda.is_available(): + logger.info(f"torch.cuda.device_count() = {torch.cuda.device_count()}") + logger.info(f"torch.cuda.current_device() = {torch.cuda.current_device()}") + logger.info(f"torch.cuda.get_device_name(torch.cuda.current_device()) = '{torch.cuda.get_device_name(torch.cuda.current_device())}'") + logger.info(f"torch.backends.cudnn.enabled = {torch.backends.cudnn.enabled}") + except ImportError: + logger.warning("torch not available") + + try: + import cudnn + + logger.info(f"cuDNN Backend Version: cudnn.backend_version() = {cudnn.backend_version()}") + logger.info(f"cuDNN Frontend Version: cudnn.__version__ = '{cudnn.__version__}'") + except ImportError: + logger.warning("cudnn not available") + + try: + import flash_attn + + logger.info(f"flash_attn.__version__ = '{flash_attn.__version__}'") + except ImportError: + pass # flash_attn is optional + + +class BenchmarkRunner: + """ + Runs benchmarks from configurations with cartesian product expansion. + + The runner takes a BenchmarkConfig and expands it into individual benchmark + cases via cartesian product of all configuration dimensions. Each case is + then executed and results are collected. + + Attributes: + verbose: Whether to print progress information + + Example: + runner = BenchmarkRunner(verbose=True) + config = load_config("mlperf") + + # Dry run to see what would be executed + for case in runner.expand_config(config): + print(case) + + # Actually run the benchmarks + results = runner.run_config(config) + runner.save_csv(results, config) + """ + + def __init__(self, verbose: bool = True): + """ + Initialize the runner. + + Args: + verbose: Whether to print progress information + """ + self.verbose = verbose + self._setup_logging() + + def _setup_logging(self): + """Configure logging based on verbosity setting.""" + level = logging.INFO if self.verbose else logging.WARNING + logging.basicConfig( + level=level, + format="[%(levelname)s] %(message)s", + stream=sys.stderr, + ) + + def expand_config(self, config: BenchmarkConfig) -> Iterator[Dict[str, Any]]: + """ + Expand a BenchmarkConfig into individual benchmark cases. + + Performs cartesian product expansion over: + models x seqlens x backends x data_types x attn_masks x deterministic_bwd + + Args: + config: BenchmarkConfig to expand + + Yields: + Dict containing all parameters for a single benchmark run + """ + for model, (q_seqlen, kv_seqlen), backend, data_type, attn_mask, det_bwd in itertools.product( + config.models, + config.seqlens, + config.backends, + config.data_types, + config.attn_masks, + config.deterministic_bwd, + ): + # Skip deterministic mode for forward-only runs + if det_bwd and config.profile_pass == "fwd": + continue + + yield { + "config_name": config.name, + "model": model, + "q_seqlen": q_seqlen, + "kv_seqlen": kv_seqlen, + "backend": backend, + "data_type": data_type, + "attn_mask": attn_mask, + "profile_pass": config.profile_pass, + "batch_size": config.batch_size, + "num_iterations": config.num_iterations, + "num_warmup_iterations": config.num_warmup_iterations, + "skip_ref": config.skip_ref, + "deterministic_bwd": det_bwd, + } + + def run_single(self, case: Dict[str, Any]) -> BenchmarkResult: + """ + Run a single benchmark case. + + Calls the run_benchmark() function from benchmark_single_sdpa.py + and wraps the result in a BenchmarkResult. + + Args: + case: Dict containing benchmark parameters (from expand_config) + + Returns: + BenchmarkResult with timing data or error information + """ + model: ModelPreset = case["model"] + + try: + # Import here to avoid circular imports and allow the module to be + # used even if torch/cudnn aren't installed (for dry-run mode) + from .benchmark_single_sdpa import run_benchmark + + result = run_benchmark( + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + data_type=case["data_type"], + backend=case["backend"], + attn_mask=case["attn_mask"], + profile_pass=case["profile_pass"], + num_iterations=case["num_iterations"], + num_warmup_iterations=case["num_warmup_iterations"], + skip_ref=case["skip_ref"], + deterministic_bwd=case["deterministic_bwd"], + ) + + return BenchmarkResult( + config_name=case["config_name"], + model_name=model.name, + backend=case["backend"], + data_type=case["data_type"], + attn_mask=case["attn_mask"], + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + profile_pass=case["profile_pass"], + deterministic_bwd=case["deterministic_bwd"], + fwd_time_ms=result["fwd_time_ms"], + bwd_time_ms=result["bwd_time_ms"], + fwd_tflops=result["fwd_tflops"], + bwd_tflops=result["bwd_tflops"], + max_diff=result["max_diff"], + num_iterations=case["num_iterations"], + success=True, + gpu_name=result.get("gpu_name"), + cudnn_version=result.get("cudnn_version"), + cudnn_backend_version=result.get("cudnn_backend_version"), + ) + + except Exception as e: + logger.error(f"Benchmark failed: {e}") + return BenchmarkResult( + config_name=case["config_name"], + model_name=model.name, + backend=case["backend"], + data_type=case["data_type"], + attn_mask=case["attn_mask"], + batch_size=case["batch_size"], + q_seqlen=case["q_seqlen"], + kv_seqlen=case["kv_seqlen"], + num_q_heads=model.num_q_heads, + num_kv_heads=model.num_kv_heads, + head_dim_qk=model.head_dim_qk, + head_dim_vo=model.head_dim_vo, + profile_pass=case["profile_pass"], + deterministic_bwd=case["deterministic_bwd"], + fwd_time_ms=float("inf"), + bwd_time_ms=float("inf"), + fwd_tflops=0.0, + bwd_tflops=0.0, + max_diff=0.0, + num_iterations=case["num_iterations"], + success=False, + error_message=str(e), + ) + + def run_config( + self, + config: BenchmarkConfig, + filter_model: Optional[str] = None, + filter_backend: Optional[str] = None, + filter_dtype: Optional[str] = None, + ) -> List[BenchmarkResult]: + """ + Run all benchmarks from a configuration. + + Args: + config: BenchmarkConfig to run + filter_model: Optional model name filter (substring match) + filter_backend: Optional backend filter (exact match) + filter_dtype: Optional data type filter (exact match) + + Returns: + List of BenchmarkResult for all executed cases + """ + # Log environment info at the start + log_environment_info() + logger.info("") # Blank line for readability + + results = [] + cases = list(self.expand_config(config)) + + # Apply filters + if filter_model: + cases = [c for c in cases if filter_model in c["model"].name] + if filter_backend: + cases = [c for c in cases if c["backend"] == filter_backend] + if filter_dtype: + cases = [c for c in cases if c["data_type"] == filter_dtype] + + if not cases: + logger.warning("No benchmark cases to run after applying filters") + return results + + logger.info(f"Running {len(cases)} benchmark cases from config '{config.name}'") + + for i, case in enumerate(cases, 1): + model = case["model"] + det_str = "det" if case["deterministic_bwd"] else "non-det" + logger.info( + f"[{i}/{len(cases)}] {model.name} | " + f"seq={case['q_seqlen']}x{case['kv_seqlen']} | " + f"{case['backend']} | {case['data_type']} | " + f"{case['attn_mask']} | {det_str}" + ) + + result = self.run_single(case) + results.append(result) + + if result.success: + fwd_info = f"fwd: {result.fwd_time_ms:.3f}ms ({result.fwd_tflops:.0f} TFLOPS)" + bwd_info = f"bwd: {result.bwd_time_ms:.3f}ms ({result.bwd_tflops:.0f} TFLOPS)" + logger.info(f" -> {fwd_info}, {bwd_info}") + else: + logger.warning(f" -> FAILED: {result.error_message}") + + return results + + def results_to_dataframe(self, results: List[BenchmarkResult]): + """ + Convert results to a pandas DataFrame. + + Args: + results: List of BenchmarkResult + + Returns: + pandas DataFrame with all result fields as columns + """ + import pandas as pd + + return pd.DataFrame([asdict(r) for r in results]) + + def save_csv( + self, + results: List[BenchmarkResult], + config: BenchmarkConfig, + output_path: Optional[Path] = None, + ) -> Path: + """ + Save results to a CSV file. + + Args: + results: List of BenchmarkResult + config: BenchmarkConfig (used for default filename) + output_path: Optional explicit output path + + Returns: + Path to the saved CSV file + """ + import pandas as pd + + df = self.results_to_dataframe(results) + + if output_path is None: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = output_dir / f"{config.name}_{timestamp}.csv" + + df.to_csv(output_path, index=False, float_format="%.3f") + logger.info(f"Results saved to {output_path}") + + return output_path + + +def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser( + description="Run SDPA benchmarks from configuration files", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run all benchmarks from mlperf config + python -m benchmark.sdpa_benchmark_training.runner --config mlperf + + # Dry run (show what would be executed) + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --dry-run + + # Filter by model name + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --filter llama3.1 + + # Filter by backend + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --backend cudnn + + # Skip chart generation + python -m benchmark.sdpa_benchmark_training.runner --config mlperf --no-chart + """, + ) + + parser.add_argument( + "--config", + required=True, + help="Config name (e.g., 'mlperf'). Must be a Python file in configs/", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print benchmark cases without executing", + ) + parser.add_argument( + "--filter", + dest="filter_model", + help="Filter by model name (substring match)", + ) + parser.add_argument( + "--backend", + dest="filter_backend", + help="Filter by backend (exact match)", + ) + parser.add_argument( + "--dtype", + dest="filter_dtype", + help="Filter by data type (exact match)", + ) + parser.add_argument( + "--output", + type=Path, + help="Output path for CSV (default: artifacts/_.csv)", + ) + parser.add_argument( + "--no-chart", + action="store_true", + help="Skip chart generation", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List available configurations and exit", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Reduce output verbosity", + ) + + args = parser.parse_args() + + # Handle --list-configs + if args.list_configs: + from .configs import list_configs + + configs = list_configs() + print("Available configurations:") + for name in configs: + print(f" {name}") + return + + # Load config + from .configs import load_config + + try: + config = load_config(args.config) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + runner = BenchmarkRunner(verbose=not args.quiet) + + # Dry run mode + if args.dry_run: + cases = list(runner.expand_config(config)) + + # Apply filters for display + if args.filter_model: + cases = [c for c in cases if args.filter_model in c["model"].name] + if args.filter_backend: + cases = [c for c in cases if c["backend"] == args.filter_backend] + if args.filter_dtype: + cases = [c for c in cases if c["data_type"] == args.filter_dtype] + + print(f"Would run {len(cases)} benchmark cases from config '{config.name}':") + print() + for i, case in enumerate(cases, 1): + model = case["model"] + det_str = "det" if case["deterministic_bwd"] else "non-det" + print( + f" [{i}] {model.name} | " + f"seq={case['q_seqlen']}x{case['kv_seqlen']} | " + f"{case['backend']} | {case['data_type']} | " + f"{case['attn_mask']} | {det_str}" + ) + return + + # Run benchmarks + results = runner.run_config( + config, + filter_model=args.filter_model, + filter_backend=args.filter_backend, + filter_dtype=args.filter_dtype, + ) + + if not results: + print("No results to save", file=sys.stderr) + sys.exit(1) + + # Save CSV + csv_path = runner.save_csv(results, config, args.output) + + # Generate charts (separate chart per mask type for clarity) + if not args.no_chart: + try: + from .charts import generate_charts_by_mask + + df = runner.results_to_dataframe(results) + chart_paths = generate_charts_by_mask(df, config) + for path in chart_paths: + print(f"Chart saved to {path}") + except ImportError as e: + logger.warning(f"Could not generate chart (missing dependency): {e}") + except Exception as e: + logger.warning(f"Could not generate chart: {e}") + + print(f"Results saved to {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/third_party/cudnn-frontend/cmake/cuDNN.cmake b/third_party/cudnn-frontend/cmake/cuDNN.cmake new file mode 100644 index 00000000..0ab86363 --- /dev/null +++ b/third_party/cudnn-frontend/cmake/cuDNN.cmake @@ -0,0 +1,115 @@ +add_library(CUDNN::cudnn_all INTERFACE IMPORTED) + +find_path( + CUDNN_INCLUDE_DIR cudnn.h + HINTS $ENV{CUDNN_INCLUDE_PATH} ${CUDNN_INCLUDE_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS} + PATH_SUFFIXES include + REQUIRED +) + +file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header) +string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}") +string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") + +function(find_cudnn_library NAME) + if(NOT "${ARGV1}" STREQUAL "OPTIONAL") + set(_cudnn_required "REQUIRED") + else() + set(_cudnn_required "") + endif() + + find_library( + ${NAME}_LIBRARY + NAMES ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" + NAMES_PER_DIR + HINTS $ENV{CUDNN_LIBRARY_PATH} ${CUDNN_LIBRARY_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR} + PATH_SUFFIXES lib64 lib/x64 lib + ${_cudnn_required} + ) + + if(${NAME}_LIBRARY) + add_library(CUDNN::${NAME} UNKNOWN IMPORTED) + set_target_properties( + CUDNN::${NAME} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} + IMPORTED_LOCATION ${${NAME}_LIBRARY} + ) + message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.") + else() + message(STATUS "${NAME} not found.") + endif() +endfunction() + +find_cudnn_library(cudnn) + +include (FindPackageHandleStandardArgs) +find_package_handle_standard_args( + LIBRARY REQUIRED_VARS + CUDNN_INCLUDE_DIR cudnn_LIBRARY +) + +if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY) + + message(STATUS "cuDNN: ${cudnn_LIBRARY}") + message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") + + set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") + +else() + + set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found") + +endif() + +target_include_directories( + CUDNN::cudnn_all + INTERFACE + $ + $ +) + +target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn +) + +if(CUDNN_MAJOR_VERSION EQUAL 8) + find_cudnn_library(cudnn_adv_infer) + find_cudnn_library(cudnn_adv_train) + find_cudnn_library(cudnn_cnn_infer) + find_cudnn_library(cudnn_cnn_train) + find_cudnn_library(cudnn_ops_infer) + find_cudnn_library(cudnn_ops_train) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + ) +elseif(CUDNN_MAJOR_VERSION EQUAL 9) + find_cudnn_library(cudnn_graph) + find_cudnn_library(cudnn_engines_runtime_compiled) + find_cudnn_library(cudnn_ops OPTIONAL) + find_cudnn_library(cudnn_cnn OPTIONAL) + find_cudnn_library(cudnn_adv OPTIONAL) + find_cudnn_library(cudnn_engines_precompiled OPTIONAL) + find_cudnn_library(cudnn_heuristic OPTIONAL) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_graph + CUDNN::cudnn_engines_runtime_compiled + CUDNN::cudnn_ops + CUDNN::cudnn_cnn + CUDNN::cudnn_adv + CUDNN::cudnn_engines_precompiled + CUDNN::cudnn_heuristic + ) +endif() diff --git a/third_party/cudnn-frontend/cudnn_frontend-config.cmake.in b/third_party/cudnn-frontend/cudnn_frontend-config.cmake.in new file mode 100644 index 00000000..8b2d8430 --- /dev/null +++ b/third_party/cudnn-frontend/cudnn_frontend-config.cmake.in @@ -0,0 +1,3 @@ +@PACKAGE_INIT@ + +include(${CMAKE_CURRENT_LIST_DIR}/cudnn_frontend-targets.cmake) diff --git a/third_party/cudnn-frontend/dlpack_version.txt b/third_party/cudnn-frontend/dlpack_version.txt new file mode 100644 index 00000000..9459d4ba --- /dev/null +++ b/third_party/cudnn-frontend/dlpack_version.txt @@ -0,0 +1 @@ +1.1 diff --git a/third_party/cudnn-frontend/include/cudnn_backend_base.h b/third_party/cudnn-frontend/include/cudnn_backend_base.h new file mode 100644 index 00000000..bae2de8a --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_backend_base.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include + +namespace cudnn_frontend { + +/// +/// OpaqueBackendPointer class +/// Holds the raws pointer to backend_descriptor +/// Usage is to wrap this into a smart pointer as +/// it helps to create and destroy the backendpointer + +class OpaqueBackendPointer { + cudnnBackendDescriptor_t m_desc = nullptr; //!< Raw void pointer + cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< status of creation of the Descriptor + + public: + OpaqueBackendPointer(const OpaqueBackendPointer&) = delete; //!< Delete the copy constructor to prevent bad copies + OpaqueBackendPointer& + operator=(const OpaqueBackendPointer&) = delete; + OpaqueBackendPointer(OpaqueBackendPointer&&) = default; + + /** + * OpaqueBackendPointer constructor. + * Calls the cudnnBackendCreateDescriptor. Allocates memory according to the type. + */ + OpaqueBackendPointer(cudnnBackendDescriptorType_t type) { status = detail::create_descriptor(type, &m_desc); } + /** + * OpaqueBackendPointer destructor. + * Calls the cudnnBackendDestroyDescriptor. Frees memory allocated in the constructor. + */ + ~OpaqueBackendPointer() { detail::destroy_descriptor(m_desc); }; + /** + * Accessor. + * Returns the const reference to raw underlying descriptor. + * Treat it like the data() function of a smart pointer. Can be freed behind the back. + */ + cudnnBackendDescriptor_t const& + get_backend_descriptor() const { + return m_desc; + } + /** + * Accessor. + * Queries the status of the descriptor after calling the cudnnCreate. + */ + cudnnStatus_t + get_status() const { + return status; + } + /** + * Accessor. + * Queries the status of the descriptor returns true if all good. + */ + bool + is_good() const { + return status == CUDNN_STATUS_SUCCESS; + } +}; + +/*! \var A shared_ptr wrapper on top of the OpaqueBackendPointer */ +using ManagedOpaqueDescriptor = std::shared_ptr; + +/*! \fn A wrapper on top of the std::make_shared for the OpaqueBackendPointer */ +static ManagedOpaqueDescriptor +make_shared_backend_pointer(cudnnBackendDescriptorType_t type) { + return std::make_shared(type); +} + +/// +/// BackendDescriptor class +/// Holds a Managed pointer to OpaqueBackendPointer class +/// Contains the status and error message if set after any operation. +/// If exception is disabled the user must query the status after +/// build operation in order to check if the cudnn construct was built +/// correctly. +class BackendDescriptor { + public: + //! Return a string describing the backend Descriptor + virtual std::string + describe() const = 0; + + //! Get a copy of the raw descriptor pointer. Ownership is reatined and + //! gets deleted when out of scope + cudnnBackendDescriptor_t + get_raw_desc() const { + return pointer->get_backend_descriptor(); + } + + //! Current status of the descriptor + cudnnStatus_t + get_status() const { + return status; + } + + //! Set status of the descriptor + void + set_status(cudnnStatus_t const status_) const { + status = status_; + } + + //! Set Diagonistic error message. + void + set_error(const char* message) const { + err_msg = message; + } + + //! Diagonistic error message if any + const char* + get_error() const { + return err_msg.c_str(); + } + + //! Returns a copy of underlying managed descriptor + ManagedOpaqueDescriptor + get_desc() const { + return pointer; + } + + //! Initializes the underlying managed descriptor + cudnnStatus_t + initialize_managed_backend_pointer(cudnnBackendDescriptorType_t type) { + pointer = make_shared_backend_pointer(type); + return pointer->get_status(); + } + + protected: + /** + * BackendDescriptor constructor. + * Initializes the member variables as passed. + */ + BackendDescriptor(ManagedOpaqueDescriptor pointer_, cudnnStatus_t status_, std::string err_msg_) + : pointer(pointer_), status(status_), err_msg(err_msg_) {} + BackendDescriptor() = default; + + virtual ~BackendDescriptor() {}; + + ManagedOpaqueDescriptor pointer; //! Shared pointer of the OpaqueBackendPointer + + mutable cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< Error code if any being set + mutable std::string err_msg; //!< Error message if any being set +}; + +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend.h b/third_party/cudnn-frontend/include/cudnn_frontend.h new file mode 100644 index 00000000..fe1a3500 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +// Suppress MSVC warning C4756 (overflow in constant arithmetic) that occurs +// in MSVC's header with certain compiler versions +#ifdef _MSC_VER +#pragma warning(disable : 4756) +#endif + +/*! \mainpage CUDNN FRONTEND API + * + * \section Introduction + * + * The cuDNN Frontend API is a C++ header-only library that demonstrates how to use the cuDNN C backend API. The cuDNN C + * backend API is documented in the cuDNN developer guide. + * + * \section Why use Frontend API + * + * Consider the following code snippet which showcases cudnnBackendTensor creation using the backend API and its + * equivalent front-end API code. Many among the backend constructs follow similar pattern. + * + * ~~~~~~~~~~~~~~~{.cpp} + * + * =========================================================================================== + * auto check_status = [](cudnnStatus_t status) { assert (status == CUDNN_STATUS_SUCCESS); }; + * =========================================================================================== + * // Backend code for Tensor Creation. + * cudnnBackendDescriptor_t tensor; + * + * check_status (cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &tensor)); + * + * check_status (cudnnBackendSetAttribute(tensor, + * CUDNN_ATTR_TENSOR_DATA_TYPE, + * CUDNN_TYPE_DATA_TYPE, + * 1, + * &data_type)); + * check_status (cudnnBackendSetAttribute(tensor, + * CUDNN_ATTR_TENSOR_DIMENSIONS, + * CUDNN_TYPE_INT64, + * tensor_dim.size(), + * tensor_dim.data())); + * check_status (cudnnBackendSetAttribute(tensor, + * CUDNN_ATTR_TENSOR_STRIDES, + * CUDNN_TYPE_INT64, + * tensor_str.size(), + * tensor_str.data())); + * check_status (cudnnBackendSetAttribute(tensor, + * CUDNN_ATTR_TENSOR_UNIQUE_ID, + * CUDNN_TYPE_INT64, + * 1, + * &id)); + * check_status (cudnnBackendSetAttribute(tensor, + * CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + * CUDNN_TYPE_INT64, + * 1, + * &alignment)); + * check_status (cudnnBackendFinalize(tensor)); + * + * check_status (cudnnBackendDestroyDescriptor(tensor)); + * =========================================================================================== + * // FrontEnd equivalent code. + * auto tensor = cudnn_frontend::TensorBuilder() + * .setDim(tensor_dim.size(), tensor_dim.data()) + * .setStrides(tensor_str.size(), tensor_str.data()) + * .setId(id) + * .setAlignment(alignment) + * .setDataType(data_type) + * .build(); + * check_status(tensor.get_status()); + * =========================================================================================== + * + * ~~~~~~~~~~~~~~~ + * + * Frontend API serves two major purpose as a companion to the backend API. + * - Functional additions: + * - Support for auto-tuning. (cudnnGet and cudnnFind) + * - Errata filters. + * - Programmatic ease: + * - Easy memory management for the cudnnBackendDescriptor_t (RAII based classes). + * - Error handling with optional exception support. Better error messages. + * - Fewer lines of code (5-10x reduction in LOC). + * - Simpler samples on how to use the new API. + */ + +#include + +#include "cudnn_frontend_ConvDesc.h" +#include "cudnn_frontend_Heuristics.h" +#include "cudnn_frontend_Engine.h" +#include "cudnn_frontend_EngineConfig.h" +#include "cudnn_frontend_EngineFallbackList.h" +#include "cudnn_frontend_Errata.h" +#include "cudnn_frontend_ExecutionPlan.h" +#include "cudnn_frontend_Filters.h" +#include "cudnn_frontend_Operation.h" +#include "cudnn_frontend_OperationGraph.h" +#include "cudnn_frontend_Tensor.h" +#include "cudnn_frontend_VariantPack.h" +#include "cudnn_frontend_PointWiseDesc.h" +#include "cudnn_frontend_MatMulDesc.h" +#include "cudnn_frontend_Logging.h" +#include "cudnn_frontend_Reorder_Tensor.h" +#include "cudnn_frontend_ExecutionPlanCache.h" +#include "cudnn_frontend_utils.h" + +#include "cudnn_frontend_Resample.h" + +#include "cudnn_frontend/graph_interface.h" +#include "cudnn_frontend/utils/serialize.h" +#include "cudnn_frontend/backend/kernel_cache.h" +#include "cudnn_frontend/utils/attn_score_modifiers.h" +#include "cudnn_frontend/backend/device_properties.h" + +#include "cudnn_frontend_version.h" + +namespace cudnn_frontend { +using ConvDesc = ConvDesc_v8; +using ConvDescBuilder = ConvDescBuilder_v8; +using ReductionDesc = ReductionDesc_v8; +using ReductionDescBuilder = ReductionDescBuilder_v8; +using EngineHeuristicsBuilder = EngineHeuristicsBuilder_v8; +using EngineHeuristics = EngineHeuristics_v8; +using EngineBuilder = EngineBuilder_v8; +using Engine = Engine_v8; +using EngineConfig = EngineConfig_v8; +using EngineConfigBuilder = EngineConfigBuilder_v8; +using EngineFallbackList = EngineFallbackList_v8; +using EngineFallbackListBuilder = EngineFallbackListBuilder_v8; +using ResampleDesc = ResampleDesc_v8; +using ResampleDescBuilder = ResampleDescBuilder_v8; +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/backend/backend_descriptor.h b/third_party/cudnn-frontend/include/cudnn_frontend/backend/backend_descriptor.h new file mode 100644 index 00000000..2cd68c0f --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/backend/backend_descriptor.h @@ -0,0 +1,138 @@ +#pragma once + +#include + +#include "../graph_helpers.h" +#include "cudnn.h" + +namespace cudnn_frontend::detail { + +/** + * @brief RAII wrapper around a `cudnnBackendDescriptor_t` object. + * + * This class provides a convenient way to manage the lifetime of a `cudnnBackendDescriptor_t` + * object using the RAII (Resource Acquisition Is Initialization) idiom. It automatically + * creates the descriptor when the object is constructed and destroys it when the object + * is destroyed, ensuring proper resource management and preventing memory leaks. + * + * @note The constructor of this class does not throw exceptions. Instead, it stores the + * status of the descriptor creation operation in the `status` member variable. Callers + * should check this status and handle any errors accordingly. + */ +class backend_descriptor { + public: + /** + * @brief Constructs a `backend_descriptor` object. + * + * @param type The type of the backend descriptor to create. + */ + backend_descriptor(cudnnBackendDescriptorType_t type) { status = detail::create_descriptor(type, &desc); } + + /** + * @brief Move constructor. + * + * Transfers the ownership of the `cudnnBackendDescriptor_t` object to the new + * `backend_descriptor` instance. + * + * @param other The source `backend_descriptor` object. + */ + backend_descriptor(backend_descriptor&& other) noexcept : desc(other.desc), status(other.status) { + other.desc = nullptr; + other.status = CUDNN_STATUS_NOT_INITIALIZED; + } + + /** + * @brief Move assignment operator. + * + * Transfers the ownership of the `cudnnBackendDescriptor_t` object to the new + * `backend_descriptor` instance. + * + * @param other The source `backend_descriptor` object. + * @return A reference to the current `backend_descriptor` object. + */ + backend_descriptor& + operator=(backend_descriptor&& other) noexcept { + if (this != &other) { + desc = other.desc; + status = other.status; + + other.desc = nullptr; + } + return *this; + } + + /** + * @brief Destructor. + * + * Destroys the `cudnnBackendDescriptor_t` object and frees the associated resources. + */ + ~backend_descriptor() { + if (desc) { + detail::destroy_descriptor(desc); + } + } + + /** + * @brief Deleted copy constructor and assignment operator. + * + * `backend_descriptor` objects are not copyable to prevent unintended resource + * sharing and potential memory leaks. + */ + backend_descriptor(backend_descriptor const&) = delete; + backend_descriptor& + operator=(backend_descriptor const&) = delete; + + /** + * @brief Initializes a `backend_descriptor` object. + * + * @param type The type of the backend descriptor to create. + */ + error_t + initialize(cudnnBackendDescriptorType_t type) { + _CUDNN_CHECK_CUDNN_ERROR(detail::create_descriptor(type, &desc)); + return {error_code_t::OK, ""}; + } + + /** + * @brief Finalizes a `backend_descriptor` object. + * + */ + error_t + finalize() { + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(desc)); + return {error_code_t::OK, ""}; + } + + /** + * @brief Accessor for the underlying `cudnnBackendDescriptor_t` object. + * + * @return A const reference to `cudnnBackendDescriptor_t`, the raw pointer to the backend descriptor. + */ + cudnnBackendDescriptor_t const& + get_ptr() const { + return desc; + } + + /** + * @brief Accessor for the status of the backend descriptor creation. + * + * @return `cudnnStatus_t` The status of the backend descriptor creation operation. + */ + cudnnStatus_t + get_status() const { + return status; + } + + /** + * @brief Constructs a default `backend_descriptor` object, but without initializing descriptor + * + * Used to return an error code to user for incorrect cuDNN version + */ + backend_descriptor() = default; + + private: + cudnnBackendDescriptor_t desc = nullptr; //!< Raw pointer to the backend descriptor. + cudnnStatus_t status = CUDNN_STATUS_NOT_INITIALIZED; //!< Status of the descriptor creation operation. +}; + +} // namespace cudnn_frontend::detail \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/backend/device_properties.h b/third_party/cudnn-frontend/include/cudnn_frontend/backend/device_properties.h new file mode 100644 index 00000000..d31c5550 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/backend/device_properties.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "../graph_helpers.h" +#include "backend_descriptor.h" + +namespace cudnn_frontend { +/// +/// DeviceProperties Class +/// Wraps the device_properties backend descriptor +/// Wraps backend utility functions for user's convenience +/// Backend accessor functions: size() +/// Contains internal utilities for device properties finalization and operation graph attributes +/// +class DeviceProperties : public detail::backend_descriptor { + public: + // Uses the default backend constructor so that we can check for initialization error during build() + DeviceProperties() = default; + + std::string + describe() const { + std::stringstream ss; + ss << "CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR : " << std::endl; + return ss.str(); + } + + inline DeviceProperties& + set_device_id(int32_t device_id) { + this->device_id = device_id; + return *this; + } + + inline DeviceProperties& + set_handle(cudnnHandle_t handle) { + this->handle = handle; + return *this; + } + + // Used to check device properties status (particularly after initialization) + error_t + status() const { + if (get_status() != CUDNN_STATUS_SUCCESS) { + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR: Check CUDNN_VERSION >= 9.8"}; + } + return {}; + } + + error_t + serialize(std::vector& serialization_buf) const { +#if (CUDNN_VERSION >= 90800) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90800, + error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION is only available starting 9.8."); + + int64_t serializationSize; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute( + get_ptr(), CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION, CUDNN_TYPE_CHAR, 0, &serializationSize, nullptr)); + serialization_buf.resize(static_cast(serializationSize)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(get_ptr(), + CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION, + CUDNN_TYPE_CHAR, + serializationSize, + &serializationSize, + serialization_buf.data())); + return {}; +#else + (void)serialization_buf; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION is only available starting 9.8."}; +#endif + } + + error_t + deserialize(const std::vector& serialized_buf) { +#if (CUDNN_VERSION >= 90800) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90800, + error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION is only available starting 9.8."); + + // Check if the device properties is already initialized + RETURN_CUDNN_FRONTEND_ERROR_IF( + get_ptr() != nullptr, error_code_t::CUDNN_BACKEND_API_FAILED, "Device properties is already initialized."); + + // Initialize the device properties descriptor + CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(get_ptr(), + CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION, + CUDNN_TYPE_CHAR, + serialized_buf.size(), + serialized_buf.data())); + + CHECK_CUDNN_FRONTEND_ERROR(finalize()); + return {}; +#else + (void)serialized_buf; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_DEVICEPROP_JSON_REPRESENTATION is only available starting 9.8."}; +#endif + } + + // Check for both compile-time and runtime cuDNN version + error_t + build() { +#if (CUDNN_VERSION >= 90800) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90800, + error_code_t::GRAPH_NOT_SUPPORTED, + "CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR is only available starting 9.8."); + if (get_ptr() == nullptr) { + CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR)); + } + + if (handle != nullptr) { + _CUDNN_CHECK_CUDNN_ERROR( + detail::set_attribute(get_ptr(), CUDNN_ATTR_DEVICEPROP_HANDLE, CUDNN_TYPE_HANDLE, 1, &handle)); + } + + if (device_id >= 0) { + _CUDNN_CHECK_CUDNN_ERROR( + detail::set_attribute(get_ptr(), CUDNN_ATTR_DEVICEPROP_DEVICE_ID, CUDNN_TYPE_INT32, 1, &device_id)); + } + + CHECK_CUDNN_FRONTEND_ERROR(finalize()); + return {}; +#else + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_DEVICEPROP_DESCRIPTOR is only available starting 9.8."}; +#endif + } + + private: + cudnnHandle_t handle = nullptr; + int32_t device_id = 0; +}; +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/backend/execution_helpers.h b/third_party/cudnn-frontend/include/cudnn_frontend/backend/execution_helpers.h new file mode 100644 index 00000000..cfc139e6 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/backend/execution_helpers.h @@ -0,0 +1,99 @@ +#pragma once + +#include + +#include "cudnn.h" + +#include "backend_descriptor.h" + +namespace cudnn_frontend::detail { +/** + * @brief Creates a CUDNN backend variant pack descriptor. + * + * This function creates a `backend_descriptor` object representing a CUDNN backend variant pack + * descriptor. The variant pack descriptor is configured with the provided device pointers, unique + * IDs, and a workspace pointer. + * + * @param[out] variant_pack The created `backend_descriptor` object representing the variant pack. + * @param device_ptrs A vector of device pointers to be associated with the variant pack. + * @param uids A vector of unique IDs to be associated with the variant pack. + * @param workspace_ptr A pointer to the workspace memory to be associated with the variant pack. + * @return `error_t` A tuple containing the error code and an optional error message. + * The error code is `error_code_t::OK` on success, or an appropriate error code on failure. + */ +inline error_t +create_variant_pack(backend_descriptor& variant_pack, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr) { + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_WORKSPACE, CUDNN_TYPE_VOID_PTR, 1, &workspace_ptr)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, + CUDNN_TYPE_VOID_PTR, + device_ptrs.size(), + device_ptrs.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, CUDNN_TYPE_INT64, uids.size(), uids.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(variant_pack.get_ptr())); + + return {error_code_t::OK, ""}; +} + +inline error_t +create_variant_pack(backend_descriptor& variant_pack, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) { + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Dynamic shapes requires cuDNN v9.18.0"}; + + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91800, cudnn_ver_error); + + CUDNN_FRONTEND_UNUSED(override_uids); + CUDNN_FRONTEND_UNUSED(override_shapes); + CUDNN_FRONTEND_UNUSED(override_strides); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_WORKSPACE, CUDNN_TYPE_VOID_PTR, 1, &workspace_ptr)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, + CUDNN_TYPE_VOID_PTR, + device_ptrs.size(), + device_ptrs.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + variant_pack.get_ptr(), CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, CUDNN_TYPE_INT64, uids.size(), uids.data())); + +#if (CUDNN_VERSION >= 91800) + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_UNIQUE_IDS, + CUDNN_TYPE_INT64, + override_uids.size(), + override_uids.data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_SHAPES, + CUDNN_TYPE_VOID_PTR, + 1, + (void*)&override_shapes)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(variant_pack.get_ptr(), + CUDNN_ATTR_VARIANT_PACK_OVERRIDE_STRIDES, + CUDNN_TYPE_VOID_PTR, + 1, + (void*)&override_strides)); +#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(variant_pack.get_ptr())); + + return {error_code_t::OK, ""}; +} + +} // namespace cudnn_frontend::detail diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/backend/kernel_cache.h b/third_party/cudnn-frontend/include/cudnn_frontend/backend/kernel_cache.h new file mode 100644 index 00000000..ef173b2f --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/backend/kernel_cache.h @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "../graph_helpers.h" +#include "backend_descriptor.h" + +namespace cudnn_frontend { +namespace graph { +class Graph; +} // namespace graph +/// +/// KernelCache Class +/// Wraps the kernel_cache backend descriptor +/// Wraps backend utility functions for user's convenience +/// Backend accessor functions: size() +/// Contains internal utilities for kernel cache finalization and operation graph attributes +/// +class KernelCache : public detail::backend_descriptor { + public: + friend class graph::Graph; + // Uses the default backend constructor so that we can check for initialization error during build() + KernelCache() : backend_descriptor() {} + + std::string + describe() const { + std::stringstream ss; + ss << "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR : " << std::endl; + return ss.str(); + } + + bool + is_finalized() { + return finalized; + } + + // Used to check kernel cache status (particularly after initialization) + error_t + status() { + if (get_status() != CUDNN_STATUS_SUCCESS) { + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR: Check CUDNN_VERSION >= 9.4"}; + } + return {}; + } + + error_t + to_json(std::string &str_json) const { + str_json.clear(); +#if (CUDNN_VERSION >= 91000) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91000, + error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."); + + int64_t serializationSize; + std::vector serialization_buf; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute( + get_ptr(), CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION, CUDNN_TYPE_CHAR, 0, &serializationSize, nullptr)); + serialization_buf.resize(static_cast(serializationSize)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(get_ptr(), + CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION, + CUDNN_TYPE_CHAR, + serializationSize, + &serializationSize, + serialization_buf.data())); + std::string json_string(serialization_buf.begin(), serialization_buf.end()); + str_json = std::move(json_string); + return {}; +#else + (void)str_json; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."}; +#endif + } + + error_t + from_json(const std::string &json_cache) { +#if (CUDNN_VERSION >= 91000) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91000, + error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."); + + // Check if the kernel cache is already initialized + RETURN_CUDNN_FRONTEND_ERROR_IF( + get_ptr() != nullptr, error_code_t::CUDNN_BACKEND_API_FAILED, "Kernel cache is already initialized."); + + // // Initialize the kernel cache descriptor + CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR)); + + std::vector serialization_buf; + serialization_buf.assign(json_cache.begin(), json_cache.end()); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(get_ptr(), + CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION, + CUDNN_TYPE_CHAR, + serialization_buf.size(), + serialization_buf.data())); + return {}; +#else + (void)json_cache; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_KERNEL_CACHE_JSON_REPRESENTATION is only available starting 9.10."}; +#endif + } + + // Responsible for initializing, setting operation graph attribute, and finalizing kernel cache + // Check for both compile-time and runtime cuDNN version + error_t + build(cudnnBackendDescriptor_t op_graph) { +#if (CUDNN_VERSION >= 90400) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90400, + error_code_t::GRAPH_NOT_SUPPORTED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4."); + if (get_ptr() == nullptr) { + CHECK_CUDNN_FRONTEND_ERROR(initialize(CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR)); + } +#if (CUDNN_VERSION >= 90500) + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500, + error_code_t::GRAPH_NOT_SUPPORTED, + "CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH is only available starting 9.5."); + if (op_graph) { + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + get_ptr(), CUDNN_ATTR_KERNEL_CACHE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); + } +#else + (void)op_graph; +#endif + CHECK_CUDNN_FRONTEND_ERROR(finalize()); + finalized = true; + return {}; +#else + (void)op_graph; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_BACKEND_KERNEL_CACHE_DESCRIPTOR is only available starting 9.4."}; +#endif + } + + private: + bool finalized = false; +}; +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/backend/plan_helpers.h b/third_party/cudnn-frontend/include/cudnn_frontend/backend/plan_helpers.h new file mode 100644 index 00000000..4ef47e22 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/backend/plan_helpers.h @@ -0,0 +1,212 @@ +#pragma once + +#include + +#include "cudnn.h" + +#include "backend_descriptor.h" +#include "../knobs.h" + +namespace cudnn_frontend::detail { +/** + * @brief Creates a CUDNN backend variant pack descriptor. + * + * This function creates a `backend_descriptor` object representing a CUDNN backend variant pack + * descriptor. The variant pack descriptor is configured with the provided device pointers, unique + * IDs, and a workspace pointer. + * + * @param[out] variant_pack The created `backend_descriptor` object representing the variant pack. + * @param device_ptrs A vector of device pointers to be associated with the variant pack. + * @param uids A vector of unique IDs to be associated with the variant pack. + * @param workspace_ptr A pointer to the workspace memory to be associated with the variant pack. + * @return `error_t` A tuple containing the error code and an optional error message. + * The error code is `error_code_t::OK` on success, or an appropriate error code on failure. + */ +inline error_t +get_workspace_size(ManagedOpaqueDescriptor& engine_config, int64_t& workspace) { +#if CUDNN_VERSION >= 90200 + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_WORKSPACE_SIZE, + CUDNN_TYPE_INT64, + 1, + nullptr, + &workspace)); + return {error_code_t::OK, ""}; +#else + (void)engine_config; + (void)workspace; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_ENGINECFG_WORKSPACE_SIZE is only available starting 9.2."}; +#endif +} + +inline error_t +get_shared_memory_size(ManagedOpaqueDescriptor& engine_config, int32_t& shared_memory_size) { +#if CUDNN_VERSION >= 90200 + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_SHARED_MEMORY_USED, + CUDNN_TYPE_INT32, + 1, + nullptr, + &shared_memory_size)); + return {error_code_t::OK, ""}; +#else + (void)engine_config; + (void)shared_memory_size; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_ENGINECFG_SHARED_MEMORY_USED is only available starting 9.2."}; +#endif +} + +inline error_t +create_engine(backend_descriptor& engine, + int64_t const engine_id, + cudnnBackendDescriptor_t op_graph, + std::shared_ptr device_properties = nullptr) { + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + engine.get_ptr(), CUDNN_ATTR_ENGINE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); + + // Validate before setting + int64_t count; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute( + op_graph, CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT, CUDNN_TYPE_INT64, 1, nullptr, &count)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + engine_id >= count || engine_id < 0, error_code_t::INVALID_VALUE, "Invalid engine id."); + + _CUDNN_CHECK_CUDNN_ERROR( + detail::set_attribute(engine.get_ptr(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &engine_id)); + + if (device_properties != nullptr) { +#if (CUDNN_VERSION >= 90800) + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(engine.get_ptr(), + CUDNN_ATTR_ENGINE_DEVICEPROP, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &device_properties->get_ptr())); +#endif + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(engine.get_ptr())); + + return {error_code_t::OK, ""}; +} + +inline error_t +query_knobs(int64_t const engine_id, cudnnBackendDescriptor_t op_graph, std::vector& knobs) { + detail::backend_descriptor engine(CUDNN_BACKEND_ENGINE_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(engine.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create engine's backend descriptor."); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_engine(engine, engine_id, op_graph)); + + // Initialize a backend descriptor for each knob type + // The size of the array should be CUDNN_KNOB_TYPE_COUNTS, as currently we dont know how many knobs the engine will + // support + std::array frontend_knobs; + for (size_t i = 0; i < CUDNN_KNOB_TYPE_COUNTS; i++) { + backend_descriptor frontend_knob(CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(frontend_knob.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create knob's backend descriptor."); + frontend_knobs[i] = std::move(frontend_knob); + } + + // Create an auxillary array to hold the raw knob descriptors + std::array backend_knobs; + for (size_t i = 0; i < CUDNN_KNOB_TYPE_COUNTS; i++) { + backend_knobs[i] = frontend_knobs[i].get_ptr(); + } + + // This is the actual number of knobs that is supported by the engine + int64_t knobs_size; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(engine.get_ptr(), + CUDNN_ATTR_ENGINE_KNOB_INFO, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + CUDNN_KNOB_TYPE_COUNTS, + &knobs_size, + backend_knobs.data())); + + for (int64_t i = 0; i < knobs_size; i++) { + cudnnBackendKnobType_t type; + int64_t elemCount; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute( + frontend_knobs[i].get_ptr(), CUDNN_ATTR_KNOB_INFO_TYPE, CUDNN_TYPE_KNOB_TYPE, 1, &elemCount, &type)); + + int64_t maxValue; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(frontend_knobs[i].get_ptr(), + CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE, + CUDNN_TYPE_INT64, + 1, + &elemCount, + &maxValue)); + + int64_t minValue; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(frontend_knobs[i].get_ptr(), + CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE, + CUDNN_TYPE_INT64, + 1, + &elemCount, + &minValue)); + + int64_t stride; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute( + frontend_knobs[i].get_ptr(), CUDNN_ATTR_KNOB_INFO_STRIDE, CUDNN_TYPE_INT64, 1, &elemCount, &stride)); + + auto frontend_knob_type = convert_from_backend_knob_type(type); + knobs.emplace_back(frontend_knob_type, maxValue, minValue, stride); + } + + return {error_code_t::OK, ""}; +} + +inline error_t +set_knob_choices(std::unordered_map const& user_choices, + std::vector& knob_choices) { + for (auto const& [type, choice] : user_choices) { + backend_descriptor knob_choice(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(knob_choice.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create knob_choice's backend descriptor."); + + cudnnBackendKnobType_t backend_type; + _CUDNN_CHECK_CUDNN_ERROR(convert_to_backend_knob_type(type, backend_type)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + knob_choice.get_ptr(), CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE, CUDNN_TYPE_KNOB_TYPE, 1, &backend_type)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + knob_choice.get_ptr(), CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE, CUDNN_TYPE_INT64, 1, &choice)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(knob_choice.get_ptr())); + + knob_choices.push_back(std::move(knob_choice)); + } + + return {error_code_t::OK, ""}; +} + +inline error_t +create_engine_config(ManagedOpaqueDescriptor& engine_config, + backend_descriptor& engine, + std::vector& knob_choices) { + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_ENGINE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &(engine.get_ptr()))); + + std::vector backend_knob_choices(CUDNN_KNOB_TYPE_COUNTS); + for (size_t i = 0; i < knob_choices.size(); i++) { + backend_knob_choices[i] = knob_choices[i].get_ptr(); + } + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_KNOB_CHOICES, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + knob_choices.size(), + backend_knob_choices.data())); + + // Finalizing the descriptor + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(engine_config->get_backend_descriptor())); + + return {error_code_t::OK, ""}; +} + +} // namespace cudnn_frontend::detail diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/context.h b/third_party/cudnn-frontend/include/cudnn_frontend/context.h new file mode 100644 index 00000000..8c894b22 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/context.h @@ -0,0 +1,110 @@ +#pragma once + +#include "../cudnn_frontend_utils.h" + +namespace cudnn_frontend::detail { + +class Context { + DataType_t compute_data_type = DataType_t::NOT_SET; + DataType_t intermediate_data_type = DataType_t::NOT_SET; + DataType_t io_data_type = DataType_t::NOT_SET; + int32_t target_sm_count = -1; + int32_t target_sm_version = -1; + bool is_dynamic_shape_enabled = false; + + std::string name = ""; + + public: + Context& + set_intermediate_data_type(DataType_t const type) { + intermediate_data_type = type; + return *this; + } + + Context& + set_io_data_type(DataType_t const type) { + io_data_type = type; + return *this; + } + + Context& + set_compute_data_type(DataType_t const type) { + compute_data_type = type; + return *this; + } + + DataType_t + get_io_data_type() const { + return io_data_type; + } + + DataType_t + get_intermediate_data_type() const { + return intermediate_data_type; + } + + DataType_t + get_compute_data_type() const { + return compute_data_type; + } + + Context& + set_name(std::string const& name_) { + name = name_; + return *this; + } + + std::string + get_name() const { + return name; + } + + Context& + set_target_sm_count(int32_t count) { + target_sm_count = count; + return *this; + } + + Context& + set_sm_version(int32_t version) { + target_sm_version = version; + return *this; + } + + Context& + set_dynamic_shape_enabled(bool is_enabled) { + is_dynamic_shape_enabled = is_enabled; + return *this; + } + + bool + get_dynamic_shape_enabled() const { + return is_dynamic_shape_enabled; + } + + int32_t + get_target_sm_count() const { + return target_sm_count; + } + + int32_t + get_sm_version() const { + return target_sm_version; + } + + Context& + fill_missing_properties(Context const& global_context) { + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(global_context.get_compute_data_type()); + } + if (get_intermediate_data_type() == DataType_t::NOT_SET) { + set_intermediate_data_type(global_context.get_intermediate_data_type()); + } + if (get_io_data_type() == DataType_t::NOT_SET) { + set_io_data_type(global_context.get_io_data_type()); + } + return *this; + } +}; + +} // namespace cudnn_frontend::detail \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/cudnn_interface.h b/third_party/cudnn-frontend/include/cudnn_frontend/cudnn_interface.h new file mode 100644 index 00000000..c142c61f --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/cudnn_interface.h @@ -0,0 +1,223 @@ +#pragma once + +#include +#include +#include + +#include "../cudnn_frontend_Tensor.h" +#include "../cudnn_frontend_Operation.h" +#include "../cudnn_frontend_OperationGraph.h" +#include "../cudnn_frontend_EngineConfig.h" +#include "../cudnn_frontend_ExecutionPlan.h" +#include "../cudnn_frontend_VariantPack.h" + +#include "graph_properties.h" +#include "graph_helpers.h" +#include "plans.h" + +namespace cudnn_frontend { + +namespace detail { +inline void +assign_uid(graph::Tensor_attributes* const tensor, + int64_t& potential_uid, + std::unordered_set const& used_uids) { + // get_next_potential_uid + while (used_uids.find(potential_uid) != used_uids.end()) { + ++potential_uid; + } + + tensor->set_uid(potential_uid); + ++potential_uid; // increment, as used its used now +} + +// TODO: Always returns OK. Can the status and error message be accessed from tensor descriptor? +inline error_t +create_cudnn_tensor( + std::shared_ptr const& props, + std::unordered_map>& tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) { + // Assign tensor a uid + if (props->has_uid() == false) { + assign_uid(props.get(), potential_uid, used_uids); + } + + // Check whether backend tensor already created + auto tensor_uid = props->get_uid(); + if (tensors.find(tensor_uid) != tensors.end()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Backend Tensor named '" << props->get_name() << "' with UID " << tensor_uid + << " already created."); + return {error_code_t::OK, ""}; + } + CUDNN_FE_LOG_LABEL_ENDL("INFO: Backend Tensor named '" << props->get_name() << "' with UID " << tensor_uid + << " being created."); + + auto&& tensor_builder = cudnn_frontend::TensorBuilder(); + + tensor_builder.setDim(props->get_dim().size(), props->get_dim().data()) + .setStrides(props->get_stride().size(), props->get_stride().data()) + .setId(tensor_uid) + .setAlignment(props->get_alignment()) + .setDataType(props->get_data_type()) + .setVirtual(props->get_is_virtual()) + .setByValue(props->get_is_pass_by_value()) + .setReorderType(props->get_reordering_type()); + + // Set vector count and dimension if they are non-default + if (props->get_vector_count() > 1 || props->get_vector_dimension() >= 0) { + tensor_builder.setVectorCountAndDimension(props->get_vector_count(), props->get_vector_dimension()); + } + + if (auto ragged_offset_props = props->get_ragged_offset()) { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids)); + tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid())); + } + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto tensor = tensor_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF( + tensor.get_status() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, tensor.get_error()); + tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); +#else + // build() can throw + // wrap in try catch + try { + auto tensor = tensor_builder.build(); + tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); + } catch (cudnn_frontend::cudnnException& e) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + } +#endif + + return {error_code_t::OK, ""}; +} +} // namespace detail + +class ICudnn { + protected: + using uid_t = int64_t; + + //// Store tensors and operations as they (probably?) need to be kept alive. + // + // The tensor mapping from fe::Tensor to be::Tensor. + // + // sub nodes share fe::Tensor. Example, in a conv-bias graph, conv output Y and bias input IN_0 are the same + // fe::Tensor. But both sub ndoes need to work together to make sure only one be::Tensor is created. Hence this + // uid_to_backend_tensors acts as the global registry for each sub node to use. + // + // Key cannot be fe::Tensor, or shared_ptr, or underlying object address of fe::Tensor. + // Hence using uid, as that uniquely identifies both types of tensors. + std::unordered_map> uid_to_tensors; + std::vector> operations; + graph::managed_backend_descriptor_t raw_operations; + + std::shared_ptr operation_graph; + std::unordered_set variant_pack_uids; + + graph::Execution_plan_list plans; + + bool is_dynamic_shape_enabled = false; + std::shared_ptr kernel_cache = nullptr; + + std::shared_ptr device_properties = nullptr; + + error_t + create_cudnn_operation_graph(cudnnHandle_t handle) { + std::vector cudnn_operations; + for (std::shared_ptr operation : operations) { + cudnn_operations.push_back(operation.get()); + } + + auto&& cudnn_operation_graph_builder = cudnn_frontend::OperationGraphBuilder(); + cudnn_operation_graph_builder.setHandle(handle) + .setOperationGraph(cudnn_operations.size(), cudnn_operations.data()) + .setIsDynamicShapeEnabled(is_dynamic_shape_enabled); + for (auto& op : raw_operations) { + cudnn_operation_graph_builder.addOperation(op); + } + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto cudnn_operation_graph = cudnn_operation_graph_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF(cudnn_operation_graph.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + cudnn_operation_graph.get_error()); + operation_graph = std::make_shared(std::move(cudnn_operation_graph)); +#else + // build() can throw + // wrap in try catch + try { + auto cudnn_operation_graph = cudnn_operation_graph_builder.build(); + operation_graph = std::make_shared(std::move(cudnn_operation_graph)); + } catch (cudnn_frontend::cudnnException& e) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + } +#endif + return {error_code_t::OK, "Successfully built Operation Graph."}; + } + + public: + error_t + get_cudnn_workspace_size_node(int64_t const plan_index, int64_t& cudnn_workspace_size) const { + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(plan_index)); + + cudnn_workspace_size = std::max(cudnn_workspace_size, plans.execution_plans[plan_index]->getWorkspaceSize()); + + return {error_code_t::OK, ""}; + } + + int64_t + get_max_cudnn_workspace_size_node() const { + return plans.get_autotune_workspace(); + } + + error_t + execute_cudnn_plan_with_uid(cudnnHandle_t handle, + std::unordered_map const& tensor_uid_to_pointer_map, + void* workspace_ptr, + int64_t plan_index, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) const { + // Make sure device pointer is provided for all uids expected for this plan + std::vector device_ptrs; + std::vector uids; + for (auto const& uid : variant_pack_uids) { + auto search = tensor_uid_to_pointer_map.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(search == tensor_uid_to_pointer_map.end(), + error_code_t::INVALID_VARIANT_PACK, + "Uid " + std::to_string(uid) + " does not exist in variant pack."); + device_ptrs.push_back(search->second); + uids.push_back(uid); + } + + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(plan_index)); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing plan at index " << plan_index + << " with override uids: " << override_uids.size()); + + if (override_uids.size() == 0) { + CHECK_CUDNN_FRONTEND_ERROR( + detail::execute(handle, plans.execution_plans[plan_index].get(), device_ptrs, uids, workspace_ptr)); + } else { + CHECK_CUDNN_FRONTEND_ERROR(detail::execute(handle, + plans.execution_plans[plan_index].get(), + device_ptrs, + uids, + workspace_ptr, + override_uids, + override_shapes, + override_strides)); + } + + return {error_code_t::OK, ""}; + } +}; + +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/graph_helpers.h b/third_party/cudnn-frontend/include/cudnn_frontend/graph_helpers.h new file mode 100644 index 00000000..9cf45c16 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/graph_helpers.h @@ -0,0 +1,599 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace cudnn_frontend { + +enum class [[nodiscard]] error_code_t { + OK, + ATTRIBUTE_NOT_SET, + SHAPE_DEDUCTION_FAILED, + INVALID_TENSOR_NAME, + INVALID_VARIANT_PACK, + GRAPH_NOT_SUPPORTED, + GRAPH_EXECUTION_PLAN_CREATION_FAILED, + GRAPH_EXECUTION_FAILED, + HEURISTIC_QUERY_FAILED, + UNSUPPORTED_GRAPH_FORMAT, + CUDA_API_FAILED, + CUDNN_BACKEND_API_FAILED, + INVALID_CUDA_DEVICE, + HANDLE_ERROR, + INVALID_VALUE +}; + +typedef struct [[nodiscard]] error_object { + error_code_t code; + std::string err_msg; + error_object() : code(error_code_t::OK), err_msg("") {}; + error_object(error_code_t err, std::string msg) : code(err), err_msg(msg) {}; + + error_code_t + get_code() { + return code; + } + + std::string + get_message() { + return err_msg; + } + + bool + is_good() const { + return code == error_code_t::OK; + } + + bool + is_bad() const { + return !is_good(); + } + + bool + operator==(error_code_t compare_code) { + return code == compare_code; + } + + bool + operator!=(error_code_t compare_code) { + return code != compare_code; + } + +} error_t; + +#ifdef WIN32 +#define CUDNN_FRONTEND_WHILE_FALSE \ + __pragma(warning(push)) __pragma(warning(disable : 4127)) while (0) __pragma(warning(pop)) +#else +#define CUDNN_FRONTEND_WHILE_FALSE while (0) +#endif + +#define CHECK_CUDNN_FRONTEND_ERROR(x) \ + do { \ + if (auto retval = x; retval.is_bad()) { \ + CUDNN_FE_LOG_LABEL_ENDL("ERROR: " << #x << " at " << __FILE__ << ":" << __LINE__); \ + return retval; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define RETURN_CUDNN_FRONTEND_ERROR_IF(cond, retval, message) \ + do { \ + if (cond) { \ + if (retval == error_code_t::OK) { \ + CUDNN_FE_LOG_LABEL("INFO: "); \ + } else { \ + CUDNN_FE_LOG_LABEL("ERROR: "); \ + } \ + CUDNN_FE_LOG(message << ". " << retval << " because (" << #cond ") at " << __FILE__ << ":" << __LINE__ \ + << "\n"); \ + return {retval, message}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define _CUDNN_CHECK_CUDNN_ERROR(x) \ + do { \ + if (auto cudnn_retval = x; cudnn_retval != CUDNN_STATUS_SUCCESS) { \ + std::stringstream error_msg; \ + error_msg << #x << " failed with message: " << detail::get_last_error_string_() \ + << ", and code: " << detail::get_error_string(cudnn_retval); \ + CUDNN_FE_LOG_LABEL_ENDL("ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__); \ + return {error_code_t::CUDNN_BACKEND_API_FAILED, error_msg.str()}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define _CUDNN_CHECK_CUDA_ERROR(x) \ + do { \ + if (auto cuda_retval = x; cuda_retval != cudaSuccess) { \ + std::stringstream error_msg; \ + error_msg << #x << " failed with " << detail::cuda_get_error_string(cuda_retval); \ + CUDNN_FE_LOG_LABEL_ENDL("ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__); \ + return {error_code_t::CUDA_API_FAILED, error_msg.str()}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define CHECK_CU_ERROR(x) \ + do { \ + if (auto cu_retval = x; cu_retval != CUDA_SUCCESS) { \ + std::stringstream error_msg; \ + const char* error_code_string; \ + detail::cu_get_error_string(cu_retval, &error_code_string); \ + error_msg << #x << " failed with " << error_code_string; \ + getLogger() << "[cudnn_frontend] ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__ \ + << std::endl; \ + return {error_code_t::CUDA_API_FAILED, error_msg.str()}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +NLOHMANN_JSON_SERIALIZE_ENUM(error_code_t, + { + {error_code_t::OK, "OK"}, + {error_code_t::ATTRIBUTE_NOT_SET, "ATTRIBUTE_NOT_SET"}, + {error_code_t::SHAPE_DEDUCTION_FAILED, "SHAPE_DEDUCTION_FAILED"}, + {error_code_t::INVALID_TENSOR_NAME, "INVALID_TENSOR_NAME"}, + {error_code_t::INVALID_VARIANT_PACK, "INVALID_VARIANT_PACK"}, + {error_code_t::GRAPH_NOT_SUPPORTED, "GRAPH_NOT_SUPPORTED"}, + {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "GRAPH_EXECUTION_PLAN_CREATION_FAILED"}, + {error_code_t::GRAPH_EXECUTION_FAILED, "GRAPH_EXECUTION_FAILED"}, + {error_code_t::HEURISTIC_QUERY_FAILED, "HEURISTIC_QUERY_FAILED"}, + {error_code_t::CUDNN_BACKEND_API_FAILED, "CUDNN_BACKEND_API_FAILED"}, + {error_code_t::CUDA_API_FAILED, "CUDA_API_FAILED"}, + {error_code_t::INVALID_CUDA_DEVICE, "INVALID_CUDA_DEVICE"}, + {error_code_t::UNSUPPORTED_GRAPH_FORMAT, "UNSUPPORTED_GRAPH_FORMAT"}, + {error_code_t::HANDLE_ERROR, "HANDLE_ERROR"}, + {error_code_t::INVALID_VALUE, "INVALID_VALUE"}, + }) + +static inline std::ostream& +operator<<(std::ostream& os, const error_code_t& mode) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + os << json{mode}; +#else + os << int(mode); +#endif + return os; +} + +static inline std::ostream& +operator<<(std::ostream& os, cudnn_frontend::error_object& err) { + os << err.get_code() << err.get_message(); + return os; +} + +static bool +allowAllConfig(cudnnBackendDescriptor_t engine_config) { + (void)engine_config; + return false; +} + +namespace detail { + +inline bool +is_activation_backward_mode(PointwiseMode_t const mode) { + return ((mode == PointwiseMode_t::RELU_BWD) || (mode == PointwiseMode_t::TANH_BWD) || + (mode == PointwiseMode_t::SIGMOID_BWD) || (mode == PointwiseMode_t::ELU_BWD) || + (mode == PointwiseMode_t::GELU_BWD) || (mode == PointwiseMode_t::GELU_APPROX_TANH_BWD) || + (mode == PointwiseMode_t::SOFTPLUS_BWD) || (mode == PointwiseMode_t::SWISH_BWD)); +} + +// Creates dense, non-overlapping strides from given dim and stride_order. +// For example, if a is a 4D tensor with dimensions labeled NCHW, then strided(a, (3, 0, 2, 1)) produces +// strides where the C dimension has a corresponding stride of one. +inline std::vector +generate_stride(std::vector const& dim, std::vector const& stride_order) { + size_t num_dims = dim.size(); + std::vector stride(num_dims); + + // Sort the dimensions according to strides from least to greatest. + // Example, dim = (2, 3, 4, 5) stride_order = (3, 1, 2, 0) + // sorted_stride_order = ((0, (3, 5)), (1, (1, 3)), (2, (2, 4)), (3, (0, 2))) + std::vector>> sorted_stride_order; + for (size_t i = 0; i < num_dims; ++i) { + sorted_stride_order.push_back({stride_order[i], {i, dim[i]}}); + } + std::sort(sorted_stride_order.begin(), sorted_stride_order.end()); + + // As dims have now been ordered starting from fastest changing, + // just fill in strides by iterating linearly over them. + int64_t product = 1; + for (size_t i = 0; i < num_dims; ++i) { + stride[sorted_stride_order[i].second.first] = product; + product *= sorted_stride_order[i].second.second; + } + + return stride; +} + +// Generate NHWC stride_order +inline std::vector +generate_NHWC_stride_order(int64_t const num_dims) { + std::vector stride_order(num_dims); + + int64_t order = 0; + stride_order[1] = order++; + for (size_t i = num_dims - 1; i > 1; --i) { + stride_order[i] = order++; + } + stride_order[0] = order; + + return stride_order; +} + +// Generate row major stride_order for matrices +// dim = (*, M, N) where * is batch dimsensions +// strides should be (..., N, 1) +inline std::vector +generate_row_major_stride_order(int64_t const num_dims) { + std::vector stride_order(num_dims); + + int64_t order = num_dims - 1; + std::generate(stride_order.begin(), stride_order.end(), [&order] { return order--; }); + + return stride_order; +} + +// Generate column major stride_order for matrices +// dim = (*, M, N) +// strides should be (*, 1, M) +inline std::vector +generate_column_major_stride_order(int64_t const num_dims) { + std::vector stride_order = generate_row_major_stride_order(num_dims); + if (num_dims > 2) { + std::swap(stride_order[num_dims - 1], stride_order[num_dims - 2]); + } + return stride_order; +} + +/** + * @brief Computes the common shape with the fewest dimensions that all input shapes can be broadcast to. + * + * This function takes a vector of shapes and calculates a common shape that all input shapes + * can be broadcast to. It follows broadcasting rules similar to those used in NumPy. + * + * @param _shapes A vector of vectors, where each inner vector represents a shape. + * Each shape is a sequence of dimension sizes. + * @param[out] common_shape The computed broadcast shape is stored in this vector. + * It will be cleared and resized as necessary. + * + * @return error_t An error code indicating the result of the operation + * + * @note + * - Shapes are processed from right to left (last dimension to first). + * - A dimension of size 1 can be broadcast to any size. + * - Non-1 dimensions must match exactly for broadcasting. + * - The resulting shape will have the maximum number of dimensions among all input shapes. + * + * @example + * std::vector> shapes = {{3, 1, 4}, {1, 2, 4}, {2, 4}}; + * std::vector result; + * error_t err = compute_broadcast_shape(shapes, result); + * // If err == error_code_t::OK, result will be {3, 2, 4} + */ +inline error_t +compute_broadcast_shape(const std::vector>& _shapes, std::vector& common_shape) { + // Filter out empty shapes + std::vector> shapes; + std::copy_if(_shapes.begin(), _shapes.end(), std::back_inserter(shapes), [](const std::vector& shape) { + return !shape.empty(); + }); + + // Short-circuits if there are no input shapes + RETURN_CUDNN_FRONTEND_ERROR_IF( + shapes.empty(), error_code_t::SHAPE_DEDUCTION_FAILED, "All input shapes provided are empty."); + + // Find the maximum dimension + int64_t max_dim = std::max_element(shapes.begin(), + shapes.end(), + [](const std::vector& a, const std::vector& b) { + return a.size() < b.size(); + }) + ->size(); + + // Initialize common_shape with 1s + common_shape.assign(max_dim, 1); + + for (const auto& shape : shapes) { + for (int idx = -1; idx >= -static_cast(shape.size()); --idx) { + int64_t common_idx = common_shape.size() + idx; + int64_t shape_idx = shape.size() + idx; + + if (common_shape[common_idx] == 1) { + common_shape[common_idx] = shape[shape_idx]; + } + + RETURN_CUDNN_FRONTEND_ERROR_IF((shape[shape_idx] != 1) && (common_shape[common_idx] != shape[shape_idx]), + error_code_t::SHAPE_DEDUCTION_FAILED, + "dimensions mismatch as broadcasting 2 non-one dimension sizes."); + } + } + + return {error_code_t::OK, ""}; +} +/** + * @brief Generates a stride order preserving the format of the input tensor. + * + * This function derives the exact stride order from the input tensor's strides. + * It returns the indices of the strides in ascending order of stride values. + * + * @param input_stride The stride of the input tensor + * @param output_dim_size The number of dimensions in the output tensor + * @return std::vector The generated stride order + */ +inline error_t +generate_stride_order_preserving_format(const std::vector& input_stride, + size_t output_dim_size, + std::vector& stride_order) { + std::vector indices(input_stride.size()); + std::iota(indices.begin(), indices.end(), 0); + + // Sort indices based on stride values in descending order + std::sort(indices.begin(), indices.end(), [&input_stride](int64_t i, int64_t j) { + return input_stride[i] < input_stride[j]; + }); + + // Enable this after further debug + // std::set stride_set(input_stride.begin(), input_stride.end()); + // RETURN_CUDNN_FRONTEND_ERROR_IF((stride_set.size() != input_stride.size()), + // error_code_t::SHAPE_DEDUCTION_FAILED, + // "Have multiple stride with same value. Cant determine stride order"); + + // Create the stride order + stride_order.resize(input_stride.size()); + for (size_t i = 0; i < indices.size(); ++i) { + stride_order[indices[i]] = i; + } + + // If output_dim_size is larger, pad with remaining dimensions + if (output_dim_size > input_stride.size()) { + size_t start = stride_order.size(); + stride_order.resize(output_dim_size); + std::iota(stride_order.begin() + start, stride_order.end(), start); + } + + return {error_code_t::OK, ""}; +} + +/** + * @brief Infers the output dimensions for a matrix multiplication operation. + * + * This function calculates the output dimensions of a matrix multiplication + * based on the input dimensions of tensors A and B. It uses compute_broadcast_shape + * for batch dimensions and ensures the last two dimensions are correct for matrix multiplication. + * + * @param a_dim Dimensions of the first input tensor (A). + * @param b_dim Dimensions of the second input tensor (B). + * @param output_dim Reference to the vector where the output dimensions will be stored. + * @return error_t An error code indicating the result of the operation. + */ +inline error_t +generate_matmul_output_dim(const std::vector& a_dim, + const std::vector& b_dim, + std::vector& output_dim) { + // Ensure a_dim and b_dim have at least 2 dimensions + if (a_dim.size() < 2 || b_dim.size() < 2) { + return {error_code_t::SHAPE_DEDUCTION_FAILED, "Input tensors must have at least 2 dimensions for matmul."}; + } + + // Check if inner dimensions are compatible + if (a_dim[a_dim.size() - 1] != b_dim[b_dim.size() - 2]) { + return {error_code_t::SHAPE_DEDUCTION_FAILED, + "Inner dimensions of input tensors are not compatible for matmul."}; + } + + // Prepare shapes for broadcasting + std::vector a_batch_dim(a_dim.begin(), a_dim.end() - 2); + std::vector b_batch_dim(b_dim.begin(), b_dim.end() - 2); + + // Compute broadcast shape for batch dimensions + std::vector broadcasted_batch; + CHECK_CUDNN_FRONTEND_ERROR(detail::compute_broadcast_shape({a_batch_dim, b_batch_dim}, broadcasted_batch)); + + // Construct final output shape + output_dim = broadcasted_batch; + output_dim.push_back(a_dim[a_dim.size() - 2]); // M from A + output_dim.push_back(b_dim[b_dim.size() - 1]); // N from B + + return {error_code_t::OK, ""}; +} + +inline std::string +to_hex(const void* data, size_t num_elements, size_t elem_size) { + const auto* bytes = static_cast(data); + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < num_elements; ++i) { + if (i > 0) ss << ", "; + ss << "0x" << std::hex << std::uppercase; + switch (elem_size) { + case 1: + ss << static_cast(bytes[i]); + break; + case 2: + ss << *reinterpret_cast(&bytes[i * 2]); + break; + case 4: + ss << *reinterpret_cast(&bytes[i * 4]); + break; + case 8: + ss << *reinterpret_cast(&bytes[i * 8]); + break; + default: + ss << "?"; + } + } + ss << "]"; + return ss.str(); +} + +inline std::string +to_decimal(const void* data, size_t num_elements, size_t elem_size) { + const auto* bytes = static_cast(data); + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < num_elements; ++i) { + if (i > 0) ss << ", "; + switch (elem_size) { + case 1: + ss << static_cast(bytes[i]); + break; + case 2: + ss << *reinterpret_cast(&bytes[i * 2]); + break; + case 4: + ss << *reinterpret_cast(&bytes[i * 4]); + break; + case 8: + ss << *reinterpret_cast(&bytes[i * 8]); + break; + default: + ss << "?"; + } + } + ss << "]"; + return ss.str(); +} + +inline std::string +to_base64(const void* data, size_t total_bytes) { + static const char table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + const auto* bytes = static_cast(data); + std::string result; + result.reserve(((total_bytes + 2) / 3) * 4); + for (size_t i = 0; i < total_bytes; i += 3) { + uint32_t n = static_cast(bytes[i]) << 16; + if (i + 1 < total_bytes) n |= static_cast(bytes[i + 1]) << 8; + if (i + 2 < total_bytes) n |= static_cast(bytes[i + 2]); + result.push_back(table[(n >> 18) & 0x3F]); + result.push_back(table[(n >> 12) & 0x3F]); + result.push_back((i + 1 < total_bytes) ? table[(n >> 6) & 0x3F] : '='); + result.push_back((i + 2 < total_bytes) ? table[n & 0x3F] : '='); + } + return result; +} + +inline error_t +log_dump_tensor_content(int64_t uid, + std::string const& name, + void* ptr, + size_t num_elements, + size_t elem_size, + char fmt, + cudaStream_t stream) { + if (!isLoggingEnabled()) return {error_code_t::OK, ""}; + + size_t total_bytes = num_elements * elem_size; + + cudaPointerAttributes attr; + _CUDNN_CHECK_CUDA_ERROR(cuda_pointer_get_attributes(&attr, ptr)); + + std::vector host_buf(total_bytes); + if (attr.type == cudaMemoryTypeDevice || attr.type == cudaMemoryTypeManaged) { + _CUDNN_CHECK_CUDA_ERROR(cuda_mem_cpy_async(host_buf.data(), ptr, total_bytes, cudaMemcpyDeviceToHost, stream)); + _CUDNN_CHECK_CUDA_ERROR(cuda_stream_synchronize(stream)); + } else { + std::memcpy(host_buf.data(), ptr, total_bytes); + } + + std::string data_str; + switch (fmt) { + case 'x': + data_str = to_hex(host_buf.data(), num_elements, elem_size); + break; + case 'd': + data_str = to_decimal(host_buf.data(), num_elements, elem_size); + break; + case 'b': + data_str = to_base64(host_buf.data(), total_bytes); + break; + default: + data_str = to_hex(host_buf.data(), num_elements, elem_size); + } + CUDNN_FE_LOG_LABEL_ENDL("Tensor Dump Uid: " << uid << " Name: " << name << " Data: " << data_str); + return {error_code_t::OK, ""}; +} + +inline error_t +log_variant_pack_memory_type(int64_t uid, void* ptr) { + if (!isLoggingEnabled()) return {error_code_t::OK, ""}; + + cudaPointerAttributes attributes; + _CUDNN_CHECK_CUDA_ERROR(cuda_pointer_get_attributes(&attributes, ptr)); + + auto memory_type_to_string = [](cudaMemoryType type) { + switch (type) { + case cudaMemoryTypeHost: + return std::string("Host"); + case cudaMemoryTypeDevice: + return std::string("Device"); + case cudaMemoryTypeManaged: + return std::string("Managed"); + case cudaMemoryTypeUnregistered: + return std::string("Unregistered"); + default: + return "UNKNOWN cudaMemoryType (" + std::to_string(type) + ")"; + } + }; + + auto ptr_to_string = [](void* p) { + std::stringstream ss; + ss << "0x" << std::hex << std::setw(sizeof(void*) * 2) << std::setfill('0') << reinterpret_cast(p); + return ss.str(); + }; + + // clang-format off + CUDNN_FE_LOG_LABEL_ENDL("Variant Pack" << std::setw(0) << " Uid: " << std::setw(20) << uid + << std::setw(0) << " MemoryType: " << std::setw(12) << memory_type_to_string(attributes.type) + << std::setw(0) << " Device: " << std::setw(4) << attributes.device + << std::setw(0) << " UnifiedPtr: " << std::setw(20) << ptr_to_string(ptr) + << std::setw(0) << " DevicePtr: " << std::setw(20) << ptr_to_string(attributes.devicePointer) + << std::setw(0) << " HostPtr: " << std::setw(20) << ptr_to_string(attributes.hostPointer)); + // clang-format on + return {error_code_t::OK, ""}; +} + +} // namespace detail + +class cudnnGraphNotSupportedException : public std::runtime_error { + public: + cudnnGraphNotSupportedException(const char* message) throw() : std::runtime_error(message) {} + + virtual const char* + what() const throw() { + return std::runtime_error::what(); + } +}; + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/graph_interface.h b/third_party/cudnn-frontend/include/cudnn_frontend/graph_interface.h new file mode 100644 index 00000000..74f699c4 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/graph_interface.h @@ -0,0 +1,2708 @@ +#pragma once + +#include +#include +#include + +#include "../cudnn_frontend_version.h" +#include "node/batchnorm.h" +#include "node/batchnorm_inference.h" +#include "node/bn_finalize.h" +#include "node/conv_fprop.h" +#include "node/conv_dgrad.h" +#include "node/conv_wgrad.h" +#include "node/dbn.h" +#include "node/dln.h" +#include "node/dbn_weight.h" +#include "node/genstats.h" +#include "node/layernorm.h" +#include "node/adaptive_layernorm.h" +#include "node/instancenorm.h" +#include "node/rmsnorm.h" +#include "node/resample.h" +#include "node/reshape.h" +#include "node/slice.h" +#include "node/scaled_dot_product_flash_attention.h" +#include "node/sdpa_fp8_bwd.h" +#include "node/block_scale_quantize.h" +#include "node/block_scale_dequantize.h" +#include "node/concatenate.h" +#include "node/moe_grouped_matmul.h" + +#include "backend/backend_descriptor.h" +#include "plans.h" +#include "knobs.h" +#include "graph_helpers.h" +#include "backend/kernel_cache.h" + +namespace cudnn_frontend::graph { + +class Graph : public ICudnn, public INode { + private: + std::unordered_set> full_graph_inputs; + std::unordered_set used_uids; + int64_t fe_workspace_size = 0; + + std::unordered_set> deserialized_tensor_properties; + std::unordered_map deserialized_pass_by_value; + std::unordered_map>> deserialized_workspace_modifications; + + // Cached values computed during build/deserialize, used during execute to avoid repeated collection. + // These are mutable because execute() is const but needs non-const access for pointer extraction. + mutable std::unordered_map cached_pass_by_value; + mutable std::unordered_map>> cached_workspace_modifications; + + // char: 'x'=hex, 'd'=decimal, 'b'=base64 + std::vector, char>> tensors_to_dump; + + error_t + get_pre_assigned_uids(std::unordered_set &used_uids) { + for (auto const &input : full_graph_inputs) { + if (input->has_uid()) { + auto uid = input->get_uid(); + auto iter = used_uids.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(iter != used_uids.end(), + error_code_t::INVALID_VALUE, + "uid " + std::to_string(uid) + " for tensor named " + input->get_name() + + " has been already assigned to another tensor."); + used_uids.insert(uid); + } + } + for (auto const &output : full_graph_outputs) { + if (output->has_uid()) { + auto uid = output->get_uid(); + auto iter = used_uids.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(iter != used_uids.end(), + error_code_t::INVALID_VALUE, + "uid " + std::to_string(uid) + " for tensor named " + + output->get_name() + + " has been already assigned to another tensor."); + used_uids.insert(uid); + } + } + + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + RETURN_CUDNN_FRONTEND_ERROR_IF( + (context.get_dynamic_shape_enabled() || kernel_cache != nullptr) && detail::get_backend_version() < 90400, + error_code_t::GRAPH_NOT_SUPPORTED, + "Dynamic shapes or kernel caching enabled, but cuDNN version < 9.4!"); + RETURN_CUDNN_FRONTEND_ERROR_IF(((context.get_dynamic_shape_enabled() == false) && (kernel_cache != nullptr)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Kernel caching enabled but dynamic shapes is disabled"); + if (detail::get_backend_version() != detail::get_compiled_version()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: The cuDNN version used at compilation (" + << detail::get_compiled_version() << ") and the one used at runtime (" + << detail::get_backend_version() << ") differ."); + } + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { + return {error_code_t::OK, ""}; + } + + virtual error_t + collect_pass_by_value_tensors_node( + std::unordered_map &pass_by_values) const override final { + for (auto [uid, value] : deserialized_pass_by_value) { + pass_by_values.emplace(uid, value); + } + return {error_code_t::OK, ""}; + } + + virtual error_t + collect_tensors_in_workspace_node( + std::unordered_map>> + &worskspace_modifications, + int64_t &) const override { + for (auto [uid, value] : deserialized_workspace_modifications) { + worskspace_modifications.emplace(uid, value); + } + return {error_code_t::OK, ""}; + } + + virtual error_t + create_cudnn_tensors_node(std::unordered_map> &, + int64_t &, + std::unordered_set const &) const override final { + return {error_code_t::OK, ""}; + } + + error_t + extend_tensor_map_with_workspace_tensors_( + std::unordered_map &tensor_to_pointer_map, + void *workspace, + std::unordered_map>> const &worskspace_modifications) + const { + for (auto const &[uid, data] : worskspace_modifications) { + tensor_to_pointer_map.emplace(uid, static_cast(workspace) + std::get<1>(data)); + } + return {error_code_t::OK, ""}; + } + + error_t + extend_tensor_map_with_pass_by_value_tensors_( + std::unordered_map &tensor_to_pointer_map, + std::unordered_map &tensor_to_pass_by_value) const { + for (auto &[uid, value] : tensor_to_pass_by_value) { + if (half *half_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, half_value_ptr); + } else if (nv_bfloat16 *nv_bfloat16_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, nv_bfloat16_value_ptr); + } else if (int32_t *int32_t_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, int32_t_value_ptr); + } else if (int64_t *int64_t_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, int64_t_value_ptr); + } else if (float *float_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, float_value_ptr); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF( + true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor."); + } + } + return {error_code_t::OK, ""}; + } + + error_t + make_variant_pack_replacements( + std::unordered_map &tensor_to_pointer_map, + std::unordered_map> replacements) const { + for (auto &[from_uid, value] : replacements) { + const auto &[to_uid, start_offset] = value; + + // Check if from_uid exists in the map + auto it = tensor_to_pointer_map.find(from_uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(it == tensor_to_pointer_map.end(), + error_code_t::INVALID_VARIANT_PACK, + "Variant pack expected uid " + std::to_string(from_uid) + " but not found."); + + // Perform pointer arithmetic + tensor_to_pointer_map[to_uid] = static_cast(static_cast(it->second) + start_offset); + } + return {error_code_t::OK, ""}; + } + + int64_t + get_max_cudnn_workspace_size() const { + return get_max_cudnn_workspace_size_node(); + } + + // Key: uid to replace in variant pack + // Value: uid to replace with, start offset to add to pointer + std::unordered_map> + variant_pack_replacements; + + error_t + run_auxiliary_kernels( + cudnnHandle_t handle, + void *fe_workspace, + std::unordered_map>> &workspace_modifications) const { + cudaStream_t stream; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_stream(handle, &stream)); + char *workspace = static_cast(fe_workspace); + + for (auto [uid, data] : workspace_modifications) { + (void)uid; + if (std::get<0>(data) == 0) { + auto &vec_data = std::get<2>(data); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_mem_cpy_async(workspace + std::get<1>(data), + vec_data.data(), + vec_data.size() * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + } else if (std::get<0>(data) == 1) { + int64_t memset_size = (int64_t)std::get<2>(data)[0]; + _CUDNN_CHECK_CUDA_ERROR( + detail::cuda_mem_set_async(workspace + std::get<1>(data), 0, memset_size, stream)); + } + } + return {error_code_t::OK, ""}; + } + + size_t + key(bool remove_shape) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + json j; + serialize(j); + if (remove_shape) { + for (auto &tensor : j["tensors"]) { + tensor["dim"].clear(); + tensor["stride"].clear(); + } + } + return std::hash{}(j); +#else + CUDNN_FRONTEND_UNUSED(remove_shape); + return 1; +#endif + } + + // Private unified sdpa method - internal implementation for both FP16 and FP8 modes + inline SDPA_attributes::SDPA_outputs + sdpa_internal(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + SDPA_attributes &&attributes) { + // Set inputs + attributes.inputs[SDPA_attributes::input_names::Q] = q; + attributes.inputs[SDPA_attributes::input_names::K] = k; + attributes.inputs[SDPA_attributes::input_names::V] = v; + + // Make required output tensors + SDPA_attributes::SDPA_outputs sdpa_outputs; + + sdpa_outputs.O = attributes.outputs[SDPA_attributes::output_names::O] = output_tensor(attributes.name + "::O"); + + if (attributes.generate_stats == true) { + sdpa_outputs.Stats = attributes.outputs[SDPA_attributes::output_names::Stats] = + output_tensor(attributes.name + "::Stats"); + } + + // Dropout mask dump (created conditionally based on dropout parameters) + if (attributes.outputs.find(SDPA_attributes::output_names::RNG_DUMP) != attributes.outputs.end() && + attributes.outputs.at(SDPA_attributes::output_names::RNG_DUMP) != nullptr) { + sdpa_outputs.RNG_DUMP = attributes.outputs[SDPA_attributes::output_names::RNG_DUMP]; + } + + // FP8-specific outputs (created conditionally based on FP8 scaling parameters) + if (attributes.inputs.find(SDPA_attributes::input_names::Descale_S) != attributes.inputs.end() && + attributes.inputs.at(SDPA_attributes::input_names::Descale_S) != nullptr) { + sdpa_outputs.Amax_S = attributes.outputs[SDPA_attributes::output_names::Amax_S] = + output_tensor(attributes.name + "::Amax_S"); + } + if (attributes.inputs.find(SDPA_attributes::input_names::Scale_O) != attributes.inputs.end() && + attributes.inputs.at(SDPA_attributes::input_names::Scale_O) != nullptr) { + sdpa_outputs.Amax_O = attributes.outputs[SDPA_attributes::output_names::Amax_O] = + output_tensor(attributes.name + "::Amax_O"); + } + + auto seq_len_q_it = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_Q); + auto seq_len_kv_it = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_KV); + if (seq_len_q_it != attributes.inputs.end() && seq_len_q_it->second != nullptr) { + tensors_to_dump.emplace_back(seq_len_q_it->second, 'd'); + } + if (seq_len_kv_it != attributes.inputs.end() && seq_len_kv_it->second != nullptr) { + tensors_to_dump.emplace_back(seq_len_kv_it->second, 'd'); + } + + for (auto t : {q, k, v, sdpa_outputs.O}) { + if (auto ragged = t->get_ragged_offset()) { + tensors_to_dump.emplace_back(ragged, 'd'); + } + } + + if (attributes.implementation == AttentionImplementation_t::AUTO) { + // Sets attributes.implementation to a supporting implementation, + // or leaves as AUTO if none found + attributes._auto_select_implementation(context); + } + + switch (attributes.implementation) { + case AttentionImplementation_t::AUTO: + throw std::runtime_error("No suitable implementation for given SDPA_attributes"); + break; + case AttentionImplementation_t::COMPOSITE: + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + break; + case AttentionImplementation_t::UNIFIED: + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + break; + } + + return sdpa_outputs; + } + + public: + Graph() : INode(detail::Context{}) {} + + error_t + update_cuda_graph(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + tensor_uid_to_pointer_map.reserve(tensor_to_pointer_map.size()); + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return update_cuda_graph(handle, tensor_uid_to_pointer_map, workspace, cudnn_cuda_graph); + } + + error_t + update_cuda_graph(cudnnHandle_t handle, + std::unordered_map &uid_to_device_ptrs, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // Initializes this cudnn graph + RETURN_CUDNN_FRONTEND_ERROR_IF( + cudnn_cuda_graph == nullptr, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should not be a nullptr"); + + size_t num_root_nodes; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_get_root_nodes(cudnn_cuda_graph, nullptr, &num_root_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + num_root_nodes != 1, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should have exactly 1 root node."); + + cudaGraphNode_t current_node = nullptr; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_get_root_nodes(cudnn_cuda_graph, ¤t_node, &num_root_nodes)); + + /////////////////////////////////////// + //// PASS BY VALUE TENSOR HANDLING //// + /////////////////////////////////////// + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // cuda graph will keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, cached values being deallocated does not affect these cpu values. + // No cuda graph nodes are required for handling fe owned pass by value tensors. + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, cached_pass_by_value)); + + //////////////////////////// + //// WORKSPACE HANDLING //// + //////////////////////////// + // Using cached workspace modifications to avoid repeated tree traversal. + for (auto const &[uid, data] : cached_workspace_modifications) { + const auto &[operation_type, offset, vec_data] = data; + uid_to_device_ptrs[uid] = static_cast(workspace) + offset; + + // 0 means memcpy + if (operation_type == 0) { + _CUDNN_CHECK_CUDA_ERROR( + detail::cuda_graph_add_memcpy_node_set_params_1D(current_node, + static_cast(workspace) + offset, + vec_data.data(), + vec_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + // 1 means memset + else if (operation_type == 1) { + // offset from workspace + void *device_ptr = static_cast(workspace) + offset; + int64_t memset_size = static_cast(vec_data[0]); + + cudaMemsetParams params; + params.dst = device_ptr; + params.elementSize = sizeof(char); + params.value = 0x0; + params.width = memset_size; + params.height = 1; // 1D memset currently + params.pitch = 0; // unused + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node_set_params(current_node, ¶ms)); + } + // Other values do not correspond to CUDA graph nodes + else { + continue; + } + + size_t num_dependent_nodes; + _CUDNN_CHECK_CUDA_ERROR( + detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_dependent_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + num_dependent_nodes != 1, + error_code_t::INVALID_VALUE, + "Each node of cudnn_cuda_graph before the backend graph node should have exactly 1 dependent node."); + _CUDNN_CHECK_CUDA_ERROR( + detail::cuda_graph_node_get_dependent_nodes(current_node, ¤t_node, &num_dependent_nodes)); + } + + // Make sure device pointer is provided for all uids expected for this plan + std::vector device_ptrs; + std::vector uids; + + device_ptrs.reserve(variant_pack_uids.size()); + uids.reserve(variant_pack_uids.size()); + + for (auto const &uid : variant_pack_uids) { + auto search = uid_to_device_ptrs.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(search == uid_to_device_ptrs.end(), + error_code_t::INVALID_VARIANT_PACK, + "Uid " + std::to_string(uid) + " does not exist in variant pack."); + device_ptrs.push_back(search->second); + uids.push_back(uid); + } + + /////////////////// + //// BE GRAPH //// + /////////////////// + cudaGraph_t backend_cuda_graph; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_child_graph_node_get_graph(current_node, &backend_cuda_graph)); + + detail::backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, cudnn_workspace)); + + int64_t candidate = plans.candidate; + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(candidate)); + _CUDNN_CHECK_CUDNN_ERROR(detail::update_cuda_graph(handle, + plans.execution_plans[candidate]->get_raw_desc(), + variant_pack_descriptor.get_ptr(), + backend_cuda_graph)); + + // There should be nothing after the backend graph + size_t num_dependent_nodes; + _CUDNN_CHECK_CUDA_ERROR( + detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_dependent_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF(num_dependent_nodes != 0, + error_code_t::INVALID_VALUE, + "cudnn_cuda_graph should have no graph nodes after the backend graph node."); + + return {error_code_t::OK, ""}; + } + + error_t + populate_cuda_graph(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + tensor_uid_to_pointer_map.reserve(tensor_to_pointer_map.size()); + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return populate_cuda_graph(handle, tensor_uid_to_pointer_map, workspace, cudnn_cuda_graph); + } + + error_t + populate_cuda_graph(cudnnHandle_t handle, + std::unordered_map &uid_to_device_ptrs, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // Check if the cuda graph is empty + size_t numNodes = 0; + CHECK_CU_ERROR(detail::cu_graph_get_nodes(cudnn_cuda_graph, nullptr, &numNodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF(numNodes != 0, + error_code_t::INVALID_VALUE, + "cuda graph provided to populate is not empty. cuDNN requires it to be empty " + "for the corresponding update APIs to work correctly."); + + // This function makes linear cuda graphs. And that makes it easy to walk + // the graph when updating it. + // So just keeping track of the last node in the cuda graph is sufficient. + cudaGraphNode_t last_node = nullptr; + + /////////////////////////////////////// + //// PASS BY VALUE TENSOR HANDLING //// + /////////////////////////////////////// + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // cuda graph will keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, cached values being deallocated does not affect these cpu values. + // No cuda graph nodes are required for handling fe owned pass by value tensors. + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, cached_pass_by_value)); + + ///////////////////////////////// + //// WORKSPACE HANDLING //// + ///////////////////////////////// + // Using cached workspace modifications to avoid repeated tree traversal. + for (auto const &[uid, data] : cached_workspace_modifications) { + const auto &[operation_type, offset, vec_data] = data; + uid_to_device_ptrs[uid] = static_cast(workspace) + offset; + + cudaGraphNode_t node = nullptr; + + // 0 means memcpy + if (operation_type == 0) { + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_add_memcpy_node_1D(&node, + cudnn_cuda_graph, + &last_node, + last_node != nullptr, + static_cast(workspace) + offset, + vec_data.data(), + vec_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + // 1 means memset + else if (operation_type == 1) { + // offset from workspace + void *device_ptr = static_cast(workspace) + offset; + int64_t memset_size = static_cast(vec_data[0]); + + cudaMemsetParams params; + params.dst = device_ptr; + params.elementSize = sizeof(char); + params.value = 0x0; + params.width = memset_size; + params.height = 1; // 1D memset currently + params.pitch = 0; // unused + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node( + &node, cudnn_cuda_graph, &last_node, last_node != nullptr, ¶ms)); + } + // Other values do not correspond to CUDA graph nodes + else { + continue; + } + + last_node = node; + } + + ////////////// + // BE graph // + ////////////// + + // Get the BE's cuda graph + + // Make sure device pointer is provided for all uids expected for this plan + std::vector device_ptrs; + device_ptrs.reserve(variant_pack_uids.size()); + std::vector uids; + uids.reserve(variant_pack_uids.size()); + for (auto const &uid : variant_pack_uids) { + auto search = uid_to_device_ptrs.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(search == uid_to_device_ptrs.end(), + error_code_t::INVALID_VARIANT_PACK, + "Uid " + std::to_string(uid) + " does not exist in variant pack."); + device_ptrs.push_back(search->second); + uids.push_back(uid); + } + + // Create the variant pack to pass to backend + detail::backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, cudnn_workspace)); + + // Get the plan candidate. It only makes to sense to make cuda graph after execution plan has been built. + // And in that case the candidate would have been set. + int64_t candidate = plans.candidate; + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(candidate)); + + // Finally get the backend cuda graph. + cudaGraph_t backend_cuda_graph; + // Initialize the cudnn cuda graph. + // The responsibility to destroy is on the user. + detail::cu_graph_create(&backend_cuda_graph, 0); // 0 is just what the API says to pass + + _CUDNN_CHECK_CUDNN_ERROR(detail::populate_cuda_graph(handle, + plans.execution_plans[candidate]->get_raw_desc(), + variant_pack_descriptor.get_ptr(), + backend_cuda_graph)); + + // Clone BE graph into a graph_node + // This same call also places the newly created into FE's graph + // TODO: BE graph is at the end, so put in appropriate dependencies + cudaGraphNode_t backend_cuda_graph_node; + detail::cuda_graph_add_child_graph_node( + &backend_cuda_graph_node, cudnn_cuda_graph, &last_node, last_node != nullptr, backend_cuda_graph); + + // Destroy the BE graph as it now has been cloned into a node + // It was initialized by internals of backend, but the responsibility to destroy it is on FE. + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_destroy(backend_cuda_graph)); + + return {error_code_t::OK, ""}; + } + + error_t + validate() { + CUDNN_FE_LOG_BANNER(" VALIDATING GRAPH "); + CUDNN_FE_LOG(*this << std::endl;); + + // First validate all inputs that the user set. + for (auto const &input : full_graph_inputs) { + CHECK_CUDNN_FRONTEND_ERROR(input->validate()); + } + + // Validate the nodes, which in turn also infers missing tensor attributes. + CHECK_CUDNN_FRONTEND_ERROR(validate_subtree()); + // Validate all outputs, which should now have everything set to be lowered to backend. + for (auto const &output : full_graph_outputs) { + CHECK_CUDNN_FRONTEND_ERROR(output->validate()); + } + + // Get all the pre assigned uids + CHECK_CUDNN_FRONTEND_ERROR(get_pre_assigned_uids(used_uids)); + // Clear state + used_uids.clear(); + + CUDNN_FE_LOG_BANNER(" VALIDATED ALL OK "); + + return {error_code_t::OK, ""}; + } + + // overload for deviceless AoT compilation + error_t + build_operation_graph() { + CUDNN_FE_LOG_BANNER(" BUILD OP GRAPH WITHOUT HANDLE "); + + if (device_properties == nullptr) { + return {error_code_t::ATTRIBUTE_NOT_SET, "Device properties are not set."}; + } + CUDNN_FE_LOG_BANNER(" BUILT OP GRAPH WITHOUT HANDLE "); + return build_operation_graph(nullptr); + } + + error_t + build_operation_graph(cudnnHandle_t handle) { + CUDNN_FE_LOG_BANNER(" BUILD OP GRAPH "); + + CUDNN_FE_LOG_BANNER(" 1/4 INFER PROPERTIES OF NODES "); + + // expand composite nodes + CHECK_CUDNN_FRONTEND_ERROR(expand_subtree()); + + // Get all the pre assigned uids + CHECK_CUDNN_FRONTEND_ERROR(get_pre_assigned_uids(used_uids)); + + CUDNN_FE_LOG_BANNER(" 2/4 CREATE TENSORS "); + + Tensor_attributes::uid_t start_uid = 1; + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_subtree(uid_to_tensors, start_uid, used_uids)); + + CUDNN_FE_LOG_BANNER(" 3/4 CREATE OPERATIONS "); + // INode keeps track of all uids that an operation graph uses. + // This helps to return errors to user during execution, without relying on backend to do so. + // Also, as uid in a variant pack have to be unique, keep a set of them. + CHECK_CUDNN_FRONTEND_ERROR( + create_cudnn_operations(variant_pack_uids, operations, raw_operations, uid_to_tensors)); + + // Collect variant pack modifiers when lowering to backend. + // The collected map is used everytime when execute is called. + CHECK_CUDNN_FRONTEND_ERROR(collect_variant_pack_replacements_subtree(variant_pack_replacements)); + + fe_workspace_size = get_fe_workspace_size_subtree(); + + // Cache pass_by_value tensors and workspace modifications for fast execution. + // These are collected once here and reused in every execute() call to avoid + // repeated tree traversal and map allocation overhead. + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(cached_pass_by_value)); + { + int64_t temp_offset = 0; + CHECK_CUDNN_FRONTEND_ERROR( + collect_tensors_in_workspace_subtree(cached_workspace_modifications, temp_offset)); + } + + CUDNN_FE_LOG_BANNER(" 4/4 LOWERING TO BACKEND OPERATION GRAPH "); + + // The method here fuses all operations. There will be 1 operation graph in total. + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_operation_graph(handle)); + + if (context.get_dynamic_shape_enabled() && kernel_cache && !kernel_cache->is_finalized()) { + CUDNN_FE_LOG_BANNER(" BUILD KERNEL CACHE "); + CHECK_CUDNN_FRONTEND_ERROR(kernel_cache->build(operation_graph->get_raw_desc())); + } + + CUDNN_FE_LOG_BANNER(" BUILD OP GRAPH ALL OK === "); + + return {error_code_t::OK, ""}; + } + + error_t + get_plan_name(std::string &name) const { + return get_plan_name_at_index(plans.candidate, name); + } + + error_t + get_plan_name_at_index(int64_t plan_index, std::string &name) const { + auto ret_val = plans.get_name_at_index(plan_index, name); + CUDNN_FE_LOG_LABEL_ENDL("INFO: get_plan_name_at_index(" << plan_index << ") is " + name); + return ret_val; + } + + error_t + get_workspace_size(int64_t &cudnn_workspace_size) const { + return get_workspace_size_plan_at_index(plans.candidate, cudnn_workspace_size); + } + + error_t + get_workspace_size_plan_at_index(int64_t plan_index, int64_t &cudnn_workspace_size) const { + // There are two workspaces: + // - cudnn execution plan workspace + // - FE node workspace (example: alibiSlope for fmha) + int64_t cudnn_ws = 0; + CHECK_CUDNN_FRONTEND_ERROR(get_cudnn_workspace_size_node(plan_index, cudnn_ws)); + cudnn_workspace_size = cudnn_ws + fe_workspace_size; + CUDNN_FE_LOG_LABEL_ENDL("INFO: get_workspace_size() is " << cudnn_workspace_size); + return {error_code_t::OK, ""}; + } + + int64_t + get_workspace_size() const { + return get_workspace_size_plan_at_index(plans.candidate); + } + + int64_t + get_workspace_size_plan_at_index(int64_t plan_index) const { + int64_t cudnn_workspace = 0; + auto status = get_workspace_size_plan_at_index(plan_index, cudnn_workspace); + if (status.is_bad()) { + CUDNN_FE_LOG_LABEL_ENDL("ERROR: Querying workspace failed."); + } + return cudnn_workspace; + } + + int64_t + get_autotune_workspace_size() const { + // There are two workspaces: + // - cudnn execution plan workspace + // - FE node workspace (example: alibiSlope for fmha) + return fe_workspace_size + get_max_cudnn_workspace_size(); + } + + error_t + autotune(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace, + void *user_impl = nullptr) { + // Add pass_by_value data pointers to tensor_uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); + + CHECK_CUDNN_FRONTEND_ERROR( + make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); + + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); + + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); + + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + + CHECK_CUDNN_FRONTEND_ERROR(plans.autotune(handle, tensor_uid_to_pointer_map, cudnn_workspace, user_impl)); + return {error_code_t::OK, ""}; + } + + error_t + autotune(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + void *user_impl = nullptr) { + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return autotune(handle, tensor_uid_to_pointer_map, workspace, user_impl); + } + + error_t + execute_plan_at_index(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + int64_t plan_index) const { + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN AT INDEX for plan index (with Tensor keys) " << plan_index << " "); + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return execute_plan_at_index(handle, tensor_uid_to_pointer_map, workspace, plan_index); + } + + error_t + execute(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace) const { + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN (with Tensor keys) "); + + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return execute(handle, tensor_uid_to_pointer_map, workspace); + } + error_t + execute_plan_at_index(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace, + int64_t plan_index, + std::vector const &override_uids, + std::vector> const &override_shapes, + std::vector> const &override_strides) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + // Object lifetime is controlled by cached_pass_by_value which persists for the Graph's lifetime. + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN AT INDEX for plan index " << plan_index << " "); + + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); + + CHECK_CUDNN_FRONTEND_ERROR( + make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); + + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); + + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + + if (isLoggingEnabled()) { + cudaStream_t stream; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_stream(handle, &stream)); + for (auto const &[uid, ptr] : tensor_uid_to_pointer_map) { + CHECK_CUDNN_FRONTEND_ERROR(detail::log_variant_pack_memory_type(uid, ptr)); + } + for (auto const &[tensor, fmt] : tensors_to_dump) { + auto it = tensor_uid_to_pointer_map.find(tensor->get_uid()); + if (it != tensor_uid_to_pointer_map.end()) { + auto const &dims = tensor->get_dim(); + size_t num_elements = 1; + for (auto d : dims) num_elements *= static_cast(d); + size_t elem_size = detail::get_data_type_size(tensor->get_data_type()); + CHECK_CUDNN_FRONTEND_ERROR(detail::log_dump_tensor_content( + it->first, tensor->get_name(), it->second, num_elements, elem_size, fmt, stream)); + } + } + } + + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid(handle, + tensor_uid_to_pointer_map, + cudnn_workspace, + plan_index, + override_uids, + override_shapes, + override_strides)); + + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN AT INDEX ALL OK for plan index " << plan_index << " "); + return {error_code_t::OK, ""}; + } + + error_t + execute(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace, + std::vector const &override_uids, + std::vector> const &override_shapes, + std::vector> const &override_strides) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN "); + + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); + CHECK_CUDNN_FRONTEND_ERROR( + make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); + + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); + + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + + if (isLoggingEnabled()) { + cudaStream_t stream; + _CUDNN_CHECK_CUDNN_ERROR(detail::get_stream(handle, &stream)); + for (auto const &[uid, ptr] : tensor_uid_to_pointer_map) { + CHECK_CUDNN_FRONTEND_ERROR(detail::log_variant_pack_memory_type(uid, ptr)); + } + for (auto const &[tensor, fmt] : tensors_to_dump) { + auto it = tensor_uid_to_pointer_map.find(tensor->get_uid()); + if (it != tensor_uid_to_pointer_map.end()) { + auto const &dims = tensor->get_dim(); + size_t num_elements = 1; + for (auto d : dims) num_elements *= static_cast(d); + size_t elem_size = detail::get_data_type_size(tensor->get_data_type()); + CHECK_CUDNN_FRONTEND_ERROR(detail::log_dump_tensor_content( + it->first, tensor->get_name(), it->second, num_elements, elem_size, fmt, stream)); + } + } + } + + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid(handle, + tensor_uid_to_pointer_map, + cudnn_workspace, + plans.candidate, + override_uids, + override_shapes, + override_strides)); + + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN ALL OK "); + return {error_code_t::OK, ""}; + } + + error_t + execute_plan_at_index(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace, + int64_t plan_index) const { + // Add pass_by_value data pointers to uid_to_pointer map + // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during + // execute. + CHECK_CUDNN_FRONTEND_ERROR( + execute_plan_at_index(handle, tensor_uid_to_pointer_map, workspace, plan_index, {}, {}, {})); + return {error_code_t::OK, ""}; + } + + error_t + execute(cudnnHandle_t handle, + std::unordered_map &tensor_uid_to_pointer_map, + void *workspace) const { + // Add pass_by_value data pointers to uid_to_pointer map. + // Using cached values to avoid repeated tree traversal overhead. + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN "); + + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, cached_pass_by_value)); + CHECK_CUDNN_FRONTEND_ERROR( + make_variant_pack_replacements(tensor_uid_to_pointer_map, variant_pack_replacements)); + + CHECK_CUDNN_FRONTEND_ERROR(run_auxiliary_kernels(handle, workspace, cached_workspace_modifications)); + + //逻辑注册 把刚才 Workspace 里那些新变量的地址正式告诉执行器。 + CHECK_CUDNN_FRONTEND_ERROR(extend_tensor_map_with_workspace_tensors_( + tensor_uid_to_pointer_map, workspace, cached_workspace_modifications)); + // offset workspace by the already used fe graph workspace + // this is where cudnn backend can start using workspace for its execution plans + void *cudnn_workspace = static_cast(workspace) + fe_workspace_size; + + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid( + handle, tensor_uid_to_pointer_map, cudnn_workspace, plans.candidate, {}, {}, {})); + + CUDNN_FE_LOG_BANNER(" EXECUTE PLAN ALL OK "); + return {error_code_t::OK, ""}; + } + + error_t + warmup(cudnnHandle_t handle) { + cudaStream_t fake_stream; + + cudaStream_t original_stream; + + _CUDNN_CHECK_CUDNN_ERROR(detail::get_stream(handle, &original_stream)); + + CUDNN_FE_LOG_BANNER("WARMUP (BEGIN FAKE GRAPH CAPTURE) "); + + if (original_stream == nullptr) { + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_stream_create(&fake_stream)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_stream(handle, fake_stream)); + } else { + fake_stream = original_stream; + } + + cudaGraph_t graph_obj; + + cudaStreamCaptureStatus capture_status; + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_stream_is_capturing(fake_stream, &capture_status)); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: capture_status " + << capture_status << " original_stream " + << ((original_stream == nullptr) ? "DEFAULT (NULL) Stream" : "NON-DEFAULT Stream")); + + if (capture_status != cudaStreamCaptureStatusNone) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: cuda graph capture active, aborting warmup"); + return {error_code_t::OK, "cuda graph capture active, aborting warmup"}; + } + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_begin_capture(fake_stream, cudaStreamCaptureModeRelaxed)); + + std::unordered_map tensor_uid_to_pointer_map; + + void *tmp_pointer = reinterpret_cast(0x7f0000000000llu); + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_malloc((void **)&tmp_pointer, 1024 * 1024)); + + float tmp_double = 1.0f; + void *cpu_pointer = reinterpret_cast(&tmp_double); + + for (auto const &tensor : deserialized_tensor_properties) { + if (tensor->get_is_virtual() == false) { + if (tensor->get_is_pass_by_value() == false) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), tmp_pointer); + } else { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), cpu_pointer); + } + } + } + + CUDNN_FE_LOG_LABEL_ENDL("INFO: full_graph_inputs: " << full_graph_inputs.size() << " elements"); + for (auto const &tensor : full_graph_inputs) { + CUDNN_FE_LOG_LABEL_ENDL("\tuid: " << tensor->get_uid() + << ", is_pass_by_value = " << tensor->get_is_pass_by_value()); + if (tensor->get_is_pass_by_value() == false) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), tmp_pointer); + } else { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), cpu_pointer); + } + } + CUDNN_FE_LOG_LABEL_ENDL("INFO: full_graph_outputs: " << full_graph_outputs.size() << " elements"); + for (auto const &tensor : full_graph_outputs) { + CUDNN_FE_LOG_LABEL_ENDL("\tuid: " << tensor->get_uid()); + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), tmp_pointer); + } + + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(tensor_uid_to_pointer_map, deserialized_pass_by_value)); + + auto cudnn_status = execute(handle, tensor_uid_to_pointer_map, tmp_pointer); + (void)cudnn_status; // No need to check bad executes + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_end_capture(fake_stream, &graph_obj)); + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_graph_destroy(graph_obj)); + + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_free(tmp_pointer)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_stream(handle, original_stream)); + + if (original_stream == nullptr) { + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_stream_destroy(fake_stream)); + } + + CUDNN_FE_LOG_BANNER("WARMUP (END FAKE GRAPH CAPTURE) "); + + return {error_code_t::OK, ""}; + } + + error_t + serialize(std::vector &data) const { + CUDNN_FE_LOG_BANNER(" SERIALIZE PLAN "); +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + json j; + serialize(j); + + auto const candidate = plans.candidate; + auto execution_plan = plans.execution_plans[candidate]; + if (execution_plan != nullptr) { + auto serialized_plan = execution_plan->getJsonRepresentation(); + j["cudnn_backend_data"] = serialized_plan; + j["variant_pack_uids"] = variant_pack_uids; + } + + j["behavior_notes"] = plans.behavior_notes; + + std::unordered_map tensor_to_pass_by_value; + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); + j["pass_by_values"] = tensor_to_pass_by_value; + + std::unordered_map>> workspace_modifications; + int64_t workspace_offset = 0; + CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); + j["workspace_modifications"] = workspace_modifications; + + j["variant_pack_replacements"] = variant_pack_replacements; + + j["fe_workspace_size"] = fe_workspace_size; + + std::vector> tensors_to_dump_uids; + for (auto const &[tensor, fmt] : tensors_to_dump) { + tensors_to_dump_uids.emplace_back(tensor->get_uid(), fmt); + } + j["tensors_to_dump"] = tensors_to_dump_uids; + + data = json::to_ubjson(j); + CUDNN_FE_LOG_BANNER(" SERIALIZE PLAN (ALL OK) "); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(data); + return {error_code_t::GRAPH_NOT_SUPPORTED, "unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"}; +#endif + } + + error_t + deserialize(cudnnHandle_t handle, std::vector const &data) { + CUDNN_FE_LOG_BANNER(" DESERIALIZE PLAN WITH HANDLE "); + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + json j = json::from_ubjson(data); + + if (j.contains("tensors")) { + auto tensor_map = j["tensors"].get>(); + for (const auto &tensor_info : tensor_map) { + auto tensor_attributes = std::make_shared(); + from_json(tensor_info.second, *tensor_attributes); + deserialized_tensor_properties.insert(tensor_attributes); + } + } + + auto serialized_plan = j["cudnn_backend_data"]; + + CHECK_CUDNN_FRONTEND_ERROR(plans.build_plans(handle, serialized_plan)); + + plans.behavior_notes = j["behavior_notes"].get>>(); + + variant_pack_uids = j["variant_pack_uids"].get>(); + + deserialized_pass_by_value = j["pass_by_values"]; + + deserialized_workspace_modifications = j["workspace_modifications"]; + + variant_pack_replacements = j["variant_pack_replacements"]; + + fe_workspace_size = j["fe_workspace_size"]; + + // Initialize the execution caches from deserialized data + cached_pass_by_value = deserialized_pass_by_value; + cached_workspace_modifications = deserialized_workspace_modifications; + + if (j.contains("tensors_to_dump")) { + auto dump_uids = j["tensors_to_dump"].get>>(); + for (auto const &[uid, fmt] : dump_uids) { + for (auto const &tensor : deserialized_tensor_properties) { + if (tensor->get_uid() == uid) { + tensors_to_dump.emplace_back(tensor, fmt); + break; + } + } + } + } + + CHECK_CUDNN_FRONTEND_ERROR(warmup(handle)); + + CUDNN_FE_LOG_BANNER(" DESERIALIZE PLAN WITH HANDLE (ALL OK) "); + + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(handle); + CUDNN_FRONTEND_UNUSED(data); + return {error_code_t::GRAPH_NOT_SUPPORTED, "unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"}; +#endif + } + + Type + getType() override { + return Type::COMPOSITE; + } + + Graph & + set_intermediate_data_type(DataType_t type); + Graph & + set_io_data_type(DataType_t type); + Graph & + set_compute_data_type(DataType_t type); + Graph & + set_dynamic_shape_enabled(bool is_enabled); + Graph & + set_sm_count(int32_t type); + Graph & + set_sm_version(int32_t version); + Graph & + set_kernel_cache(std::shared_ptr cache); + Graph & + set_device_properties(std::shared_ptr device_prop); + + Graph & + set_name(std::string const &name) { + context.set_name(name); + return *this; + } + + error_t + query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const; + + std::shared_ptr + tensor(Tensor_attributes const &tensor); + + std::shared_ptr + tensor_like(std::shared_ptr const &tensor, std::string const &name = std::string{}); + + std::array, 3> layernorm(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Layernorm_attributes); + + std::array, 3> adalayernorm(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + AdaLayernorm_attributes); + + std::array, 3> instancenorm(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Instancenorm_attributes); + + std::array, 5> batchnorm(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Batchnorm_attributes); + + std::shared_ptr batchnorm_inference(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Batchnorm_inference_attributes); + + std::array, 6> bn_finalize(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + BN_finalize_attributes); + + std::shared_ptr conv_fprop(std::shared_ptr, + std::shared_ptr, + Conv_fprop_attributes); + + std::shared_ptr conv_dgrad(std::shared_ptr, + std::shared_ptr, + Conv_dgrad_attributes); + + std::shared_ptr conv_wgrad(std::shared_ptr, + std::shared_ptr, + Conv_wgrad_attributes); + + std::array, 5> dbn_weight(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + DBN_weight_attributes); + + std::array, 3> batchnorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Batchnorm_backward_attributes); + + std::array, 3> layernorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Layernorm_backward_attributes); + + std::array, 3> adalayernorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + AdaLayernorm_backward_attributes); + + std::array, 3> instancenorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Instancenorm_backward_attributes); + std::array, 2> genstats(std::shared_ptr, Genstats_attributes); + + std::array, 2> rmsnorm(std::shared_ptr, + std::shared_ptr, + Rmsnorm_attributes); + + std::array, 3> rmsnorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Rmsnorm_backward_attributes); + + std::array, 2> sdpa(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_attributes); + + std::array, 4> sdpa_fp8(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_fp8_attributes); + + inline std::array, 7> sdpa_fp8_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_fp8_backward_attributes); + + std::array, 3> sdpa_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_backward_attributes); + + std::shared_ptr slice(std::shared_ptr, Slice_attributes); + + std::array, 2> block_scale_quantize(std::shared_ptr, + Block_scale_quantize_attributes); + + std::shared_ptr block_scale_dequantize(std::shared_ptr, + std::shared_ptr, + Block_scale_dequantize_attributes); + + std::shared_ptr concatenate(std::vector>, + Concatenate_attributes); + + std::shared_ptr moe_grouped_matmul(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Moe_grouped_matmul_attributes); + + [[deprecated]] std::array, 2> + scaled_dot_product_flash_attention(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + SDPA_attributes attributes) { + return sdpa(q, k, v, attributes); + } + [[deprecated]] std::array, 3> + scaled_dot_product_flash_attention_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr stats, + SDPA_backward_attributes attributes) { + return sdpa_backward(q, k, v, o, dO, stats, attributes); + } + + error_t + create_execution_plans(std::vector const &mode); + + error_t + create_execution_plan(int64_t const engine_id, std::unordered_map const &knobs); + + int64_t + get_execution_plan_count() const; + + inline error_t + get_engine_count(int64_t &count); + + inline error_t + get_knobs_for_engine(int64_t const engine, std::vector &); + + error_t + check_support(cudnnHandle_t h) { + // handle not required anymore + // TODO: remove this function in next release + (void)h; + return check_support(); + } + + // overload for deviceless AoT compilation + error_t + check_support() { + CHECK_CUDNN_FRONTEND_ERROR(plans.check_support()); + return {error_code_t::OK, ""}; + } + + // TODO: remove this function in next release + error_t + build(cudnnHandle_t const &handle, + std::vector const &mode, + BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, + bool const do_multithreaded_builds = false); + + // overload for deviceless AoT compilation + error_t + build(std::vector const &mode, + BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, + bool const do_multithreaded_builds = false); + + error_t + build_plans(cudnnHandle_t const &handle, + BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, + bool const do_multithreaded_builds = false) { + // handle not required anymore + // TODO: remove this function in next release + (void)handle; + return build_plans(policy, do_multithreaded_builds); + } + + // overload for deviceless AoT compilation + error_t + build_plans(BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, + bool const do_multithreaded_builds = false); + + error_t + build_plan_at_index(cudnnHandle_t const &handle, int64_t index) { + // handle not required anymore + // TODO: remove this function in next release + (void)handle; + return build_plan_at_index(index); + } + + // overload for deviceless AoT compilation + error_t + build_plan_at_index(int64_t index); + + Graph & + deselect_workspace_greater_than(int64_t const workspace) { + plans.set_max_workspace_allowed(workspace); + return *this; + } + + Graph & + deselect_shared_mem_greater_than(int64_t const workspace) { + plans.set_max_shared_mem_allowed(workspace); + return *this; + } + + Graph & + deselect_engines(std::vector const &engine_names) { + plans.set_barred_names(engine_names); + return *this; + } + + Graph & + select_behavior_notes(std::vector const ¬es) { + auto status = plans.filter_behavior_notes(notes, true); + if (status.is_bad()) { + CUDNN_FE_LOG(status.get_message() << std::endl); + } + return *this; + } + + Graph & + select_numeric_notes(std::vector const ¬es) { + auto status = plans.filter_numeric_notes(notes, true); + if (status.is_bad()) { + CUDNN_FE_LOG(status.get_message() << std::endl); + } + return *this; + } + + Graph & + deselect_behavior_notes(std::vector const ¬es) { + auto status = plans.filter_behavior_notes(notes, false); + if (status.is_bad()) { + CUDNN_FE_LOG(status.get_message() << std::endl); + } + return *this; + } + + Graph & + deselect_numeric_notes(std::vector const ¬es) { + auto status = plans.filter_numeric_notes(notes, false); + if (status.is_bad()) { + CUDNN_FE_LOG(status.get_message() << std::endl); + } + return *this; + } + + error_t + get_behavior_notes_for_plan_at_index(int64_t const index, std::vector ¬es) const; + + error_t + get_behavior_notes(std::vector ¬es) const; + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json &j) const override final { + // Different from serialization of other INodes. + // Go over each subnode and serialize them. + json full_json; + + full_json["context"]["name"] = context.get_name(); + full_json["context"]["compute_data_type"] = context.get_compute_data_type(); + full_json["context"]["intermediate_data_type"] = context.get_intermediate_data_type(); + full_json["context"]["io_data_type"] = context.get_io_data_type(); + full_json["context"]["sm_count"] = context.get_target_sm_count(); + full_json["context"]["is_dynamic_shape_enabled"] = context.get_dynamic_shape_enabled(); + + full_json.update(R"( {"tag": "GRAPH"})"_json); + full_json["nodes"]; + for (auto const &sub_node : sub_nodes) { + json j_sub_node; + sub_node->serialize(j_sub_node); + full_json["nodes"].push_back(j_sub_node); + } + + j["context"] = full_json["context"]; + + j["json_version"] = "1.0"; + j["cudnn_backend_version"] = detail::get_backend_version_string(); + j["cudnn_frontend_version"] = CUDNN_FRONTEND_VERSION; + j["nodes"]; + j["tensors"]; + std::unordered_set tensors; + for (const auto &sub_node : full_json["nodes"]) { + // Create a short version of the node + auto short_node = sub_node; + short_node["inputs"] = {}; + short_node["outputs"] = {}; + + auto node_name = sub_node["tag"].get(); + auto i = 0; + // Process node inputs + for (const auto &input : sub_node["inputs"]) { + std::string port_name; + json tensor_info; + + if (node_name == "CONCATENATE") { + // Extract port_name and tensor_name + port_name = std::to_string(i); + tensor_info = input; + i++; + } else { + // Extract port_name and tensor_name + port_name = input[0].get(); + tensor_info = input[1]; + } + + if (tensor_info.is_null()) { + continue; + } + + std::string tensor_name = tensor_info["name"].get(); + // Update short_node inputs + short_node["inputs"][port_name] = tensor_name; + + // Check if the tensor is already in the tensors map + if (tensors.find(tensor_name) == tensors.end()) { + // If not, add it to the j["tensors"] + j["tensors"][tensor_name] = tensor_info; + } + } + + // Process node outputs + for (const auto &output : sub_node["outputs"]) { + // Extract port_name and tensor_name + auto port_name = output[0].get(); + auto tensor_info = output[1]; + + if (tensor_info.is_null()) { + continue; + } + + std::string tensor_name = tensor_info["name"].get(); + + // Update short_node outputs + short_node["outputs"][port_name] = tensor_name; + + // Check if the tensor is already in the tensors map + if (tensors.find(tensor_name) == tensors.end()) { + // If not, add it to the j["tensors"] + j["tensors"][tensor_name] = tensor_info; + } + } + + // Add the short_node to j["nodes"] + j["nodes"].push_back(short_node); + } + }; +#endif + + size_t + key() override final { + return key(context.get_dynamic_shape_enabled()); + } + + // TODO: temparorily placed in graphs class. This function needs to be a free standing function. +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + error_t + deserialize(const json &j) { + if (j.contains("context")) { + const auto &j_context = j["context"]; + if (j_context.contains("compute_data_type") && !j_context["compute_data_type"].is_null()) { + context.set_compute_data_type(j_context["compute_data_type"].get()); + } + if (j_context.contains("intermediate_data_type") && !j_context["intermediate_data_type"].is_null()) { + context.set_intermediate_data_type(j_context["intermediate_data_type"].get()); + } + if (j_context.contains("io_data_type") && !j_context["io_data_type"].is_null()) { + context.set_io_data_type(j_context["io_data_type"].get()); + } + if (j_context.contains("name") && !j_context["name"].is_null()) { + context.set_name(j_context["name"].get()); + } + if (j_context.contains("sm_count") && !j_context["sm_count"].is_null()) { + context.set_target_sm_count(j_context["sm_count"].get()); + } + if (j_context.contains("is_dynamic_shape_enabled") && !j_context["is_dynamic_shape_enabled"].is_null()) { + context.set_dynamic_shape_enabled(j_context["is_dynamic_shape_enabled"].get()); + } + } + + std::map> created_tensors; + // Iterate through each sub-node in the full JSON + if (j.contains("nodes") && j["nodes"].is_array()) { + for (auto j_sub_node : j["nodes"]) { + // Create a JSON object for inputs + json inputs; + + // Iterate through each input of the sub-node + if (j_sub_node.contains("inputs") && j_sub_node["inputs"].is_object()) { + for (auto &[port_name, tensor_name] : j_sub_node["inputs"].items()) { + if (j.contains("tensors") && j["tensors"].contains(tensor_name)) { + // Add the input to the inputs JSON object + inputs.push_back({port_name, j["tensors"][tensor_name]}); + } + } + } + + // Create a JSON object for outputs + json outputs; + + // Iterate through each output of the sub-node + if (j_sub_node.contains("outputs") && j_sub_node["outputs"].is_object()) { + for (auto &[port_name, tensor_name] : j_sub_node["outputs"].items()) { + if (j.contains("tensors") && j["tensors"].contains(tensor_name)) { + // Add the output to the outputs JSON object + outputs.push_back({port_name, j["tensors"][tensor_name]}); + } + } + } + + // Replace the original inputs and outputs of the sub-node with the new JSON objects + j_sub_node["inputs"] = inputs; + j_sub_node["outputs"] = outputs; + + auto check_if_pre_created_tensor = [&created_tensors](std::shared_ptr t) { + if (t == nullptr) { + return t; + } + + if (created_tensors.find(t->get_name()) == created_tensors.end()) { + created_tensors.insert({t->get_name(), t}); + return t; + } else { + return created_tensors[t->get_name()]; + } + }; + +#define CHECK_TENSORS(attributes) \ + for (const auto &[key, tensor] : attributes.inputs) { \ + attributes.inputs[key] = check_if_pre_created_tensor(tensor); \ + } \ + for (const auto &[key, tensor] : attributes.outputs) { \ + attributes.outputs[key] = check_if_pre_created_tensor(tensor); \ + } + +#define FILL_GLOBAL_IO_TENSOR_MAP(attributes) \ + for (auto input_name_to_attr_pair : attributes.inputs) { \ + if (input_name_to_attr_pair.second != nullptr && \ + (input_name_to_attr_pair.second->get_is_virtual() == false)) { \ + full_graph_inputs.emplace(input_name_to_attr_pair.second); \ + } \ + } \ + for (auto output_name_to_attr_pair : attributes.outputs) { \ + if (output_name_to_attr_pair.second != nullptr) { \ + full_graph_outputs.emplace(output_name_to_attr_pair.second); \ + } \ + } + if (j_sub_node.contains("tag") && j_sub_node["tag"].is_string()) { + auto tag = j_sub_node["tag"].get(); + if (tag == "CONV_FPROP") { + auto conv_fprop_attributes = j_sub_node.get(); + CHECK_TENSORS(conv_fprop_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(conv_fprop_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(conv_fprop_attributes), context)); + } else if (tag == "POINTWISE") { + auto pointwise_attributes = j_sub_node.get(); + CHECK_TENSORS(pointwise_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(pointwise_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(pointwise_attributes), context)); + } else if (tag == "REDUCTION") { + auto reduction_attributes = j_sub_node.get(); + CHECK_TENSORS(reduction_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(reduction_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(reduction_attributes), context)); + } else if (tag == "SDPA_FWD") { + auto sdpa_attributes = j_sub_node.get(); + CHECK_TENSORS(sdpa_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(sdpa_attributes); + switch (sdpa_attributes.implementation) { + case AttentionImplementation_t::AUTO: + return {error_code_t::INVALID_VALUE, + "Implementation cannot be AUTO in serialized form"}; + case AttentionImplementation_t::COMPOSITE: + sub_nodes.emplace_back( + std::make_unique(std::move(sdpa_attributes), context)); + break; + case AttentionImplementation_t::UNIFIED: + sub_nodes.emplace_back( + std::make_unique(std::move(sdpa_attributes), context)); + } + } else if (tag == "SDPA_BWD") { + auto sdpa_bwd_attributes = j_sub_node.get(); + CHECK_TENSORS(sdpa_bwd_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(sdpa_bwd_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(sdpa_bwd_attributes), context)); + } else if (tag == "MATMUL") { + auto matmul_attributes = j_sub_node.get(); + CHECK_TENSORS(matmul_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(matmul_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(matmul_attributes), context)); + } else if (tag == "SLICE") { + auto slice_attributes = j_sub_node.get(); + CHECK_TENSORS(slice_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(slice_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(slice_attributes), context)); + } else if (tag == "RESAMPLE") { + auto resample_attributes = j_sub_node.get(); + CHECK_TENSORS(resample_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(resample_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(resample_attributes), context)); + } else if (tag == "CONV_DGRAD") { + auto dgrad_attributes = j_sub_node.get(); + CHECK_TENSORS(dgrad_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(dgrad_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(dgrad_attributes), context)); + } else if (tag == "CONV_WGRAD") { + auto wgrad_attributes = j_sub_node.get(); + CHECK_TENSORS(wgrad_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(wgrad_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(wgrad_attributes), context)); + } else if (tag == "MOE_GROUPED_MATMUL") { + auto moe_grouped_matmul_attributes = j_sub_node.get(); + CHECK_TENSORS(moe_grouped_matmul_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(moe_grouped_matmul_attributes); + sub_nodes.emplace_back( + std::make_unique(std::move(moe_grouped_matmul_attributes), context)); + } + } +#undef CHECK_TENSORS + } + } + + return {error_code_t::OK, ""}; + } +#endif + + std::string + print(void) const { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + std::stringstream ss; + json j = *this; + ss << j; + return ss.str(); +#else + return "print is unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"; +#endif + } +}; + +inline error_t +Graph::get_behavior_notes_for_plan_at_index(int64_t const index, std::vector ¬es) const { + CHECK_CUDNN_FRONTEND_ERROR(plans.get_behavior_notes_at_index(index, notes)); + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::get_behavior_notes(std::vector ¬es) const { + int64_t const candidate = plans.candidate; + RETURN_CUDNN_FRONTEND_ERROR_IF( + candidate == -1, + error_code_t::INVALID_VALUE, + "No candiate plan set for the graph. You can set one by building a plan, which in turn sets the " + "candidate internally. Do note that you also query behaviour notes for a created-but-not-built plan by using " + "get_behavior_notes_for_plan_at_index API."); + + CHECK_CUDNN_FRONTEND_ERROR(get_behavior_notes_for_plan_at_index(candidate, notes)); + return {error_code_t::OK, ""}; +} + +inline int64_t +Graph::get_execution_plan_count() const { + return plans.execution_plans.size(); +} + +inline error_t +Graph::get_engine_count(int64_t &count) { + _CUDNN_CHECK_CUDNN_ERROR(detail::get_attribute(operation_graph->get_raw_desc(), + CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT, + CUDNN_TYPE_INT64, + 1, + nullptr, + &count)); + + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::get_knobs_for_engine(int64_t const engine, std::vector &knobs) { + CHECK_CUDNN_FRONTEND_ERROR(detail::query_knobs(engine, operation_graph->get_raw_desc(), knobs)); + + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::create_execution_plans(std::vector const &mode) { + CUDNN_FE_LOG_BANNER(" CREATE EXECUTION PLANS (HEURISTICS QUERY) "); + + // CHECK IF NEED TO OVERRIDE HEURISTICS QUERY + for (auto &sub_node : sub_nodes) { + if (auto [engine_id, user_knobs] = sub_node->override_heuristics_query(); engine_id != -1) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + CUDNN_FE_LOG_LABEL_ENDL("INFO: Overriding heuristics query with engine ID " + << engine_id << " and user knobs " << nlohmann::json(user_knobs).dump()); +#else + CUDNN_FE_LOG_LABEL_ENDL("INFO: Overriding heuristics query with engine ID " + << engine_id << " and user knobs " << static_cast(user_knobs.size())); +#endif + CHECK_CUDNN_FRONTEND_ERROR(create_execution_plan(engine_id, user_knobs)); + return {error_code_t::OK, ""}; + } + } + + EngineConfigList op_graph_to_configs; + CHECK_CUDNN_FRONTEND_ERROR(detail::query_cudnn_heuristics_impl( + operation_graph, op_graph_to_configs, mode, context.get_target_sm_count(), device_properties)); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Extracting engine configs."); + + plans.set_tag(operation_graph->getTag()); + plans.enqueue_engine_configs(op_graph_to_configs); + plans.set_kernel_cache(kernel_cache); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Querying engine config properties."); + CHECK_CUDNN_FRONTEND_ERROR(plans.query_properties()); + + CUDNN_FE_LOG_BANNER(" HEURISTICS QUERY ALL OK "); + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::create_execution_plan(int64_t const engine_id, std::unordered_map const &user_knobs) { + // first create the engine + // this just uses the global engine id and operation graph + CUDNN_FE_LOG_BANNER(" CREATE EXECUTION PLAN for engine id " << engine_id << " "); + detail::backend_descriptor engine(CUDNN_BACKEND_ENGINE_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(engine.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create engine's backend descriptor."); + CHECK_CUDNN_FRONTEND_ERROR( + detail::create_engine(engine, engine_id, operation_graph->get_raw_desc(), device_properties)); + + // Create an array of knob choices + std::vector knob_choices; + CHECK_CUDNN_FRONTEND_ERROR(detail::set_knob_choices(user_knobs, knob_choices)); + + auto engine_config = make_shared_backend_pointer((cudnnBackendDescriptorType_t)CUDNN_BACKEND_ENGINECFG_DESCRIPTOR); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_engine_config(engine_config, engine, knob_choices)); + plans.enqueue_engine_configs({engine_config}); + CHECK_CUDNN_FRONTEND_ERROR(plans.query_properties()); + + CUDNN_FE_LOG_BANNER(" CREATE EXECUTION PLAN ALL OK "); + + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::build_plan_at_index(int64_t plan_index) { + CHECK_CUDNN_FRONTEND_ERROR(plans.build_plan_at_index(plan_index)); + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::build_plans(BuildPlanPolicy_t const policy, bool const do_multithreaded_builds) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + CUDNN_FE_LOG_BANNER(" BUILD PLANS for policy " << nlohmann::json(policy).dump() << " "); +#else + CUDNN_FE_LOG_BANNER(" BUILD PLANS for policy " << static_cast(policy) << " "); +#endif + CHECK_CUDNN_FRONTEND_ERROR(plans.build_plans(policy, do_multithreaded_builds)); + CUDNN_FE_LOG_BANNER(" BUILD PLANS ALL OK "); + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::build(cudnnHandle_t const &handle, + std::vector const &modes, + BuildPlanPolicy_t const policy, + bool const do_multithreaded_builds) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + CUDNN_FE_LOG_BANNER(" BUILD with handle " << nlohmann::json(policy).dump()); +#else + CUDNN_FE_LOG_BANNER(" BUILD with handle " << static_cast(policy) << " "); +#endif + CHECK_CUDNN_FRONTEND_ERROR(this->validate()); + CHECK_CUDNN_FRONTEND_ERROR(this->build_operation_graph(handle)); + CHECK_CUDNN_FRONTEND_ERROR(this->create_execution_plans(modes)); + CHECK_CUDNN_FRONTEND_ERROR(this->check_support()); + CHECK_CUDNN_FRONTEND_ERROR(this->build_plans(policy, do_multithreaded_builds)); + CUDNN_FE_LOG_BANNER(" BUILD ALL OK (with handle) "); + return {error_code_t::OK, ""}; +} + +inline error_t +Graph::build(std::vector const &modes, BuildPlanPolicy_t const policy, bool const do_multithreaded_builds) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + CUDNN_FE_LOG_BANNER(" BUILD PLANS without handle " << nlohmann::json(policy).dump() << " "); +#else + CUDNN_FE_LOG_BANNER(" BUILD PLANS without handle " << static_cast(policy) << " "); +#endif + CHECK_CUDNN_FRONTEND_ERROR(this->validate()); + CHECK_CUDNN_FRONTEND_ERROR(this->build_operation_graph()); + CHECK_CUDNN_FRONTEND_ERROR(this->create_execution_plans(modes)); + CHECK_CUDNN_FRONTEND_ERROR(this->check_support()); + CHECK_CUDNN_FRONTEND_ERROR(this->build_plans(policy, do_multithreaded_builds)); + CUDNN_FE_LOG_BANNER(" BUILD PLANS ALL OK (no handle) "); + return {error_code_t::OK, ""}; +} + +inline Graph & +Graph::set_intermediate_data_type(DataType_t const type) { + context.set_intermediate_data_type(type); + return *this; +} + +inline Graph & +Graph::set_io_data_type(DataType_t const type) { + context.set_io_data_type(type); + return *this; +} + +inline Graph & +Graph::set_compute_data_type(DataType_t const type) { + context.set_compute_data_type(type); + return *this; +} + +inline Graph & +Graph::set_dynamic_shape_enabled(bool is_enabled) { + context.set_dynamic_shape_enabled(is_enabled); + this->is_dynamic_shape_enabled = is_enabled; + return *this; +} + +inline Graph & +Graph::set_kernel_cache(std::shared_ptr cache) { + kernel_cache = cache; + return *this; +} + +inline Graph & +Graph::set_device_properties(std::shared_ptr device_prop) { + device_properties = device_prop; + return *this; +} + +inline Graph & +Graph::set_sm_count(int32_t count) { + context.set_target_sm_count(count); + return *this; +} + +inline Graph & +Graph::set_sm_version(int32_t version) { + context.set_sm_version(version); + return *this; +} + +inline std::shared_ptr +Graph::tensor(Tensor_attributes const &tensor) { + auto tensor_ptr = std::make_shared(tensor); + full_graph_inputs.emplace(tensor_ptr); + return tensor_ptr; +} + +inline error_t +Graph::query_tensor_attributes_of_uid(int64_t const uid, Tensor_attributes &tensor) const { + for (auto const &o_tensor : full_graph_outputs) { + if (uid == o_tensor->get_uid()) { + tensor = *o_tensor; + return {error_code_t::OK, ""}; + } + } + + for (auto const &i_tensor : full_graph_inputs) { + if (uid == i_tensor->get_uid()) { + tensor = *i_tensor; + return {error_code_t::OK, ""}; + } + } + + for (auto const &d_tensor : deserialized_tensor_properties) { + if (uid == d_tensor->get_uid()) { + tensor = *d_tensor; + return {error_code_t::OK, ""}; + } + } + + return {error_code_t::INVALID_VALUE, "No matching tensor for this UID"}; +} + +// tensor_like is meant to create "useable" copies of a tensor. +// By usable, it means not copying over the uids, as uids are FE-level(internal) detail. +// It also means not copying over names, which are user-level(external) detail. But user is given option to provide a +// new name. +inline std::shared_ptr +Graph::tensor_like(std::shared_ptr const &tensor, std::string const &name) { + auto tensor_ptr = std::make_shared(*tensor); + + // reset the uid of the cloned tensor + // uids are not meant to be copied by tensor_like + // When lowering to cudnn backend, both tensors involved here will get unique uids. + tensor_ptr->clear_uid(); + + // reset the name too. Defaults to empty string. + tensor_ptr->set_name(name); + full_graph_inputs.emplace(tensor_ptr); + + return tensor_ptr; +} + +inline std::array, 6> +Graph::bn_finalize(std::shared_ptr sum, + std::shared_ptr sq_sum, + std::shared_ptr scale, + std::shared_ptr bias, + std::shared_ptr epsilon, + std::shared_ptr accum_count, + BN_finalize_attributes attributes) { + // Set outputs + auto EQ_SCALE = attributes.outputs[BN_finalize_attributes::output_names::EQ_SCALE] = + output_tensor(attributes.name + "::EQ_SCALE"); + auto EQ_BIAS = attributes.outputs[BN_finalize_attributes::output_names::EQ_BIAS] = + output_tensor(attributes.name + "::EQ_BIAS"); + auto MEAN = attributes.outputs[BN_finalize_attributes::output_names::MEAN] = + output_tensor(attributes.name + "::MEAN"); + auto INV_VARIANCE = attributes.outputs[BN_finalize_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + std::shared_ptr NEXT_RUNNING_MEAN = nullptr; + std::shared_ptr NEXT_RUNNING_VAR = nullptr; + if (attributes.inputs[BN_finalize_attributes::input_names::PREV_RUNNING_MEAN] && + attributes.inputs[BN_finalize_attributes::input_names::PREV_RUNNING_VAR] && + attributes.inputs[BN_finalize_attributes::input_names::MOMENTUM]) { + NEXT_RUNNING_MEAN = output_tensor(attributes.name + "::NEXT_RUNNING_MEAN"); + NEXT_RUNNING_VAR = output_tensor(attributes.name + "::NEXT_RUNNING_VAR"); + } + attributes.outputs[BN_finalize_attributes::output_names::NEXT_RUNNING_MEAN] = NEXT_RUNNING_MEAN; + attributes.outputs[BN_finalize_attributes::output_names::NEXT_RUNNING_VAR] = NEXT_RUNNING_VAR; + + // Set inputs + attributes.inputs[BN_finalize_attributes::input_names::SUM] = sum; + attributes.inputs[BN_finalize_attributes::input_names::SQ_SUM] = sq_sum; + attributes.inputs[BN_finalize_attributes::input_names::SCALE] = scale; + attributes.inputs[BN_finalize_attributes::input_names::BIAS] = bias; + attributes.inputs[BN_finalize_attributes::input_names::EPSILON] = epsilon; + attributes.inputs[BN_finalize_attributes::input_names::ACCUM_COUNT] = accum_count; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {EQ_SCALE, EQ_BIAS, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR}; +} + +inline std::array, 3> +Graph::layernorm(std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr bias, + Layernorm_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Layernorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr MEAN = nullptr; + std::shared_ptr INV_VARIANCE = nullptr; + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + MEAN = attributes.outputs[Layernorm_attributes::output_names::MEAN] = output_tensor(attributes.name + "::MEAN"); + INV_VARIANCE = attributes.outputs[Layernorm_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + } + // Set inputs + attributes.inputs[Layernorm_attributes::input_names::X] = x; + attributes.inputs[Layernorm_attributes::input_names::SCALE] = scale; + attributes.inputs[Layernorm_attributes::input_names::BIAS] = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {Y, MEAN, INV_VARIANCE}; +} + +inline std::array, 3> +Graph::adalayernorm(std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr bias, + AdaLayernorm_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[AdaLayernorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr MEAN = nullptr; + std::shared_ptr INV_VARIANCE = nullptr; + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + MEAN = attributes.outputs[AdaLayernorm_attributes::output_names::MEAN] = + output_tensor(attributes.name + "::MEAN"); + INV_VARIANCE = attributes.outputs[AdaLayernorm_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + } + // Set inputs + attributes.inputs[AdaLayernorm_attributes::input_names::X] = x; + attributes.inputs[AdaLayernorm_attributes::input_names::SCALE] = scale; + attributes.inputs[AdaLayernorm_attributes::input_names::BIAS] = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {std::move(Y), std::move(MEAN), std::move(INV_VARIANCE)}; +} + +inline std::array, 3> +Graph::instancenorm(std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr bias, + Instancenorm_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Instancenorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr MEAN = nullptr; + std::shared_ptr INV_VARIANCE = nullptr; + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + MEAN = attributes.outputs[Instancenorm_attributes::output_names::MEAN] = + output_tensor(attributes.name + "::MEAN"); + INV_VARIANCE = attributes.outputs[Instancenorm_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + } + // Set inputs + attributes.inputs[Instancenorm_attributes::input_names::X] = x; + attributes.inputs[Instancenorm_attributes::input_names::SCALE] = scale; + attributes.inputs[Instancenorm_attributes::input_names::BIAS] = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {Y, MEAN, INV_VARIANCE}; +} + +inline std::array, 5> +Graph::batchnorm(std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr bias, + Batchnorm_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Batchnorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + auto MEAN = attributes.outputs[Batchnorm_attributes::output_names::MEAN] = + output_tensor(attributes.name + "::MEAN"); + auto INV_VARIANCE = attributes.outputs[Batchnorm_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + std::shared_ptr NEXT_RUNNING_MEAN = nullptr; + std::shared_ptr NEXT_RUNNING_VAR = nullptr; + if (attributes.inputs[Batchnorm_attributes::input_names::PREV_RUNNING_MEAN] && + attributes.inputs[Batchnorm_attributes::input_names::PREV_RUNNING_VAR] && + attributes.inputs[Batchnorm_attributes::input_names::MOMENTUM]) { + NEXT_RUNNING_MEAN = output_tensor(attributes.name + "::NEXT_RUNNING_MEAN"); + NEXT_RUNNING_VAR = output_tensor(attributes.name + "::NEXT_RUNNING_VAR"); + } + attributes.outputs[Batchnorm_attributes::output_names::NEXT_RUNNING_MEAN] = NEXT_RUNNING_MEAN; + attributes.outputs[Batchnorm_attributes::output_names::NEXT_RUNNING_VAR] = NEXT_RUNNING_VAR; + + // Set inputs + attributes.inputs[Batchnorm_attributes::input_names::X] = x; + attributes.inputs[Batchnorm_attributes::input_names::SCALE] = scale; + attributes.inputs[Batchnorm_attributes::input_names::BIAS] = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {Y, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR}; +} + +inline std::shared_ptr +Graph::batchnorm_inference(std::shared_ptr x, + std::shared_ptr mean, + std::shared_ptr inv_variance, + std::shared_ptr scale, + std::shared_ptr bias, + Batchnorm_inference_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Batchnorm_inference_attributes::output_names::Y] = + output_tensor(attributes.name + "::Y"); + + // Set inputs + attributes.inputs[Batchnorm_inference_attributes::input_names::X] = x; + attributes.inputs[Batchnorm_inference_attributes::input_names::MEAN] = mean; + attributes.inputs[Batchnorm_inference_attributes::input_names::INV_VARIANCE] = inv_variance; + attributes.inputs[Batchnorm_inference_attributes::input_names::SCALE] = scale; + attributes.inputs[Batchnorm_inference_attributes::input_names::BIAS] = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return Y; +} + +inline std::array, 3> +Graph::batchnorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + Batchnorm_backward_attributes attributes) { + // Set outputs + auto DX = attributes.outputs[Batchnorm_backward_attributes::output_names::DX] = + output_tensor(attributes.name + "::DX"); + auto DSCALE = attributes.outputs[Batchnorm_backward_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::DSCALE"); + auto DBIAS = attributes.outputs[Batchnorm_backward_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::DBIAS"); + + // Set inputs + attributes.inputs[Batchnorm_backward_attributes::input_names::DY] = dy; + attributes.inputs[Batchnorm_backward_attributes::input_names::X] = x; + attributes.inputs[Batchnorm_backward_attributes::input_names::SCALE] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {DX, DSCALE, DBIAS}; +} + +inline std::array, 3> +Graph::instancenorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + Instancenorm_backward_attributes attributes) { + // Set outputs + auto DX = attributes.outputs[Instancenorm_backward_attributes::output_names::DX] = + output_tensor(attributes.name + "::DX"); + auto DSCALE = attributes.outputs[Instancenorm_backward_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::DSCALE"); + auto DBIAS = attributes.outputs[Instancenorm_backward_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::DBIAS"); + + // Set inputs + attributes.inputs[Instancenorm_backward_attributes::input_names::DY] = dy; + attributes.inputs[Instancenorm_backward_attributes::input_names::X] = x; + attributes.inputs[Instancenorm_backward_attributes::input_names::SCALE] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {DX, DSCALE, DBIAS}; +} + +inline std::array, 3> +Graph::layernorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + Layernorm_backward_attributes attributes) { + // Set outputs + auto DX = attributes.outputs[Layernorm_backward_attributes::output_names::DX] = + output_tensor(attributes.name + "::DX"); + auto DSCALE = attributes.outputs[Layernorm_backward_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::DSCALE"); + auto DBIAS = attributes.outputs[Layernorm_backward_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::DBIAS"); + + // Set inputs + attributes.inputs[Layernorm_backward_attributes::input_names::DY] = dy; + attributes.inputs[Layernorm_backward_attributes::input_names::X] = x; + attributes.inputs[Layernorm_backward_attributes::input_names::SCALE] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {DX, DSCALE, DBIAS}; +} + +inline std::array, 3> +Graph::adalayernorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + AdaLayernorm_backward_attributes attributes) { + // Set outputs + auto DX = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DX] = + output_tensor(attributes.name + "::DX"); + auto DSCALE = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::DSCALE"); + auto DBIAS = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::DBIAS"); + // Set inputs + attributes.inputs[AdaLayernorm_backward_attributes::input_names::DY] = dy; + attributes.inputs[AdaLayernorm_backward_attributes::input_names::X] = x; + attributes.inputs[AdaLayernorm_backward_attributes::input_names::SCALE] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {std::move(DX), std::move(DSCALE), std::move(DBIAS)}; +} + +inline std::shared_ptr +Graph::conv_fprop(std::shared_ptr x, + std::shared_ptr w, + Conv_fprop_attributes attributes) { + // Make required output tensors + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + auto Y = output_tensor(attributes.name + "::Y"); + attributes.outputs[Conv_fprop_attributes::output_names::Y] = Y; + + // Set inputs + attributes.inputs[Conv_fprop_attributes::input_names::X] = x; + attributes.inputs[Conv_fprop_attributes::input_names::W] = w; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return Y; +} + +inline std::array, 5> +Graph::dbn_weight(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr mean, + std::shared_ptr inv_variance, + std::shared_ptr scale, + DBN_weight_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + // Make required output tensors + auto DBIAS = attributes.outputs[DBN_weight_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::DBIAS"); + auto DSCALE = attributes.outputs[DBN_weight_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::DSCALE"); + auto EQ_BIAS = attributes.outputs[DBN_weight_attributes::output_names::EQ_BIAS] = + output_tensor(attributes.name + "::EQ_BIAS"); + auto EQ_SCALE_DY = attributes.outputs[DBN_weight_attributes::output_names::EQ_SCALE_DY] = + output_tensor(attributes.name + "::EQ_SCALE_DY"); + auto EQ_SCALE_X = attributes.outputs[DBN_weight_attributes::output_names::EQ_SCALE_X] = + output_tensor(attributes.name + "::EQ_SCALE_X"); + + // Set inputs + attributes.inputs[DBN_weight_attributes::input_names::DY] = dy; + attributes.inputs[DBN_weight_attributes::input_names::X] = x; + attributes.inputs[DBN_weight_attributes::input_names::SCALE] = scale; + attributes.inputs[DBN_weight_attributes::input_names::MEAN] = mean; + attributes.inputs[DBN_weight_attributes::input_names::INV_VARIANCE] = inv_variance; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {DSCALE, DBIAS, EQ_SCALE_DY, EQ_SCALE_X, EQ_BIAS}; +} + +inline std::shared_ptr +Graph::conv_dgrad(std::shared_ptr dy, + std::shared_ptr w, + Conv_dgrad_attributes attributes) { + // Make required output tensors + auto DX = attributes.outputs[Conv_dgrad_attributes::output_names::DX] = output_tensor(attributes.name + "::DX"); + + // Set inputs + attributes.inputs[Conv_dgrad_attributes::input_names::DY] = dy; + attributes.inputs[Conv_dgrad_attributes::input_names::W] = w; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return DX; +} + +inline std::array, 2> +Graph::genstats(std::shared_ptr x, Genstats_attributes attributes) { + // Set outputs + auto SUM = attributes.outputs[Genstats_attributes::output_names::SUM] = + output_tensor(attributes.name + "_sum_output"); + auto SQ_SUM = attributes.outputs[Genstats_attributes::output_names::SQ_SUM] = + output_tensor(attributes.name + "_sq_sum_output"); + + // Set inputs + attributes.inputs[Genstats_attributes::input_names::X] = x; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {SUM, SQ_SUM}; +} + +inline std::shared_ptr +Graph::conv_wgrad(std::shared_ptr dy, + std::shared_ptr x, + Conv_wgrad_attributes attributes) { + // Make required output tensors + auto DW = attributes.outputs[Conv_wgrad_attributes::output_names::DW] = output_tensor(attributes.name + "::DW"); + + // Set inputs + attributes.inputs[Conv_wgrad_attributes::input_names::X] = x; + attributes.inputs[Conv_wgrad_attributes::input_names::DY] = dy; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return DW; +} + +inline std::array, 2> +Graph::rmsnorm(std::shared_ptr x, + std::shared_ptr scale, + Rmsnorm_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Rmsnorm_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr INV_VARIANCE = nullptr; + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + INV_VARIANCE = attributes.outputs[Rmsnorm_attributes::output_names::INV_VARIANCE] = + output_tensor(attributes.name + "::INV_VARIANCE"); + } + // Set inputs + attributes.inputs[Rmsnorm_attributes::input_names::X] = x; + attributes.inputs[Rmsnorm_attributes::input_names::SCALE] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {Y, INV_VARIANCE}; +} + +inline std::array, 3> +Graph::rmsnorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr inv_variance, + Rmsnorm_backward_attributes attributes) { + // Set outputs + auto DX = attributes.outputs[Rmsnorm_backward_attributes::output_names::DX] = + output_tensor(attributes.name + "::DX"); + auto DScale = attributes.outputs[Rmsnorm_backward_attributes::output_names::DSCALE] = + output_tensor(attributes.name + "::Dscale"); + std::shared_ptr DBias = nullptr; + if (attributes.use_dbias.value_or(true)) { + DBias = attributes.outputs[Rmsnorm_backward_attributes::output_names::DBIAS] = + output_tensor(attributes.name + "::Dbias"); + } + + // Set inputs + attributes.inputs[Rmsnorm_backward_attributes::input_names::DY] = dy; + attributes.inputs[Rmsnorm_backward_attributes::input_names::X] = x; + attributes.inputs[Rmsnorm_backward_attributes::input_names::SCALE] = scale; + attributes.inputs[Rmsnorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {DX, DScale, DBias}; +} + +inline std::array, 2> +Graph::sdpa(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + SDPA_attributes attributes) { + //优化性能 + if (attributes.mma_core_mode == DataType_t::NOT_SET) { + attributes._set_mma_core_mode(DataType_t::HALF); + } + + // Call internal implementation and return only the O and Stats outputs for backward compatibility + auto internal_result = sdpa_internal(q, k, v, std::move(attributes)); + return {internal_result.O, internal_result.Stats}; +} + +inline std::array, 4> +Graph::sdpa_fp8(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_s, + std::shared_ptr scale_s, + std::shared_ptr scale_o, + SDPA_fp8_attributes attributes) { + if (attributes.mma_core_mode == DataType_t::NOT_SET) { + attributes._set_mma_core_mode(DataType_t::FP8_E4M3); + } + + // Set FP8 scaling inputs + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_Q] = descale_q; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_K] = descale_k; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_V] = descale_v; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_S] = descale_s; + attributes.inputs[SDPA_fp8_attributes::input_names::Scale_S] = scale_s; + attributes.inputs[SDPA_fp8_attributes::input_names::Scale_O] = scale_o; + + // Call internal implementation and return {Output, Stats, Amax_S, Amax_O} as array for backward compatibility + auto internal_result = sdpa_internal(q, k, v, std::move(attributes)); + return {internal_result.O, internal_result.Stats, internal_result.Amax_S, internal_result.Amax_O}; +} + +inline std::array, 7> +Graph::sdpa_fp8_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr Stats, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_o, + std::shared_ptr descale_do, + std::shared_ptr descale_s, + std::shared_ptr descale_dp, + std::shared_ptr scale_s, + std::shared_ptr scale_dq, + std::shared_ptr scale_dk, + std::shared_ptr scale_dv, + std::shared_ptr scale_dp, + SDPA_fp8_backward_attributes attributes) { + // Make required output tensors + auto dQ = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dQ] = + output_tensor(attributes.name + "::dQ"); + auto dK = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dK] = + output_tensor(attributes.name + "::dK"); + auto dV = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dV] = + output_tensor(attributes.name + "::dV"); + auto Amax_dQ = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dQ] = + output_tensor(attributes.name + "::Amax_dQ"); + auto Amax_dK = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dK] = + output_tensor(attributes.name + "::Amax_dK"); + auto Amax_dV = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dV] = + output_tensor(attributes.name + "::Amax_dV"); + auto Amax_dP = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dP] = + output_tensor(attributes.name + "::Amax_dP"); + + // Set inputs + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Q] = q; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::K] = k; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::V] = v; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::O] = o; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Stats] = Stats; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::dO] = dO; + + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_Q] = descale_q; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_K] = descale_k; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_V] = descale_v; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_S] = descale_s; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_O] = descale_o; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_dO] = descale_do; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_dP] = descale_dp; + + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dQ] = scale_dq; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dK] = scale_dk; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dV] = scale_dv; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_S] = scale_s; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dP] = scale_dp; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP}; +} + +inline std::array, 3> +Graph::sdpa_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr stats, + SDPA_backward_attributes attributes) { + // Set inputs + attributes.inputs[SDPA_backward_attributes::input_names::Q] = q; + attributes.inputs[SDPA_backward_attributes::input_names::K] = k; + attributes.inputs[SDPA_backward_attributes::input_names::V] = v; + attributes.inputs[SDPA_backward_attributes::input_names::O] = o; + attributes.inputs[SDPA_backward_attributes::input_names::dO] = dO; + attributes.inputs[SDPA_backward_attributes::input_names::Stats] = stats; + + // Make required output tensors + auto dQ = attributes.outputs[SDPA_backward_attributes::output_names::dQ] = output_tensor(attributes.name + "::dQ"); + auto dK = attributes.outputs[SDPA_backward_attributes::output_names::dK] = output_tensor(attributes.name + "::dK"); + auto dV = attributes.outputs[SDPA_backward_attributes::output_names::dV] = output_tensor(attributes.name + "::dV"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {dQ, dK, dV}; +} + +inline std::shared_ptr +Graph::slice(std::shared_ptr input, Slice_attributes attributes) { + attributes.inputs[Slice_attributes::input_names::X] = input; + auto Y = attributes.outputs[Slice_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} + +inline std::array, 2> +Graph::block_scale_quantize(std::shared_ptr x, Block_scale_quantize_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Block_scale_quantize_attributes::output_names::Y] = + output_tensor(attributes.name + "::Y"); + auto scale = attributes.outputs[Block_scale_quantize_attributes::output_names::scale] = + output_tensor(attributes.name + "::scale"); + + // Set inputs + attributes.inputs[Block_scale_quantize_attributes::input_names::X] = x; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {Y, scale}; +} + +inline std::shared_ptr +Graph::block_scale_dequantize(std::shared_ptr x, + std::shared_ptr scale, + Block_scale_dequantize_attributes attributes) { + // Set outputs + auto Y = attributes.outputs[Block_scale_dequantize_attributes::output_names::Y] = + output_tensor(attributes.name + "::Y"); + + // Set inputs + attributes.inputs[Block_scale_dequantize_attributes::input_names::X] = x; + attributes.inputs[Block_scale_dequantize_attributes::input_names::scale] = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return Y; +} + +inline std::shared_ptr +Graph::concatenate(std::vector> x, Concatenate_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + + // Set outputs + auto Y = attributes.outputs[Concatenate_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + // Set inputs + for (auto &element : x) { + attributes.inputs.push_back(element); + } + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return Y; +} + +inline std::shared_ptr +Graph::moe_grouped_matmul(std::shared_ptr token, + std::shared_ptr weight, + std::shared_ptr first_token_offset, + std::shared_ptr token_index, + std::shared_ptr token_ks, + Moe_grouped_matmul_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + + auto output = attributes.outputs[Moe_grouped_matmul_attributes::output_names::Output] = + output_tensor(attributes.name + "::Output"); + + attributes.inputs[Moe_grouped_matmul_attributes::input_names::Token] = token; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::Weight] = weight; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::FirstTokenOffset] = first_token_offset; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::TokenIndex] = token_index; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::TokenKs] = token_ks; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return output; +} + +static inline std::ostream & +operator<<(std::ostream &os, Graph const &graph) { + os << graph.print(); + return os; +} + +} // namespace cudnn_frontend::graph diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/graph_properties.h b/third_party/cudnn-frontend/include/cudnn_frontend/graph_properties.h new file mode 100644 index 00000000..03b31b56 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/graph_properties.h @@ -0,0 +1,2655 @@ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "context.h" + +#include "../cudnn_frontend_utils.h" + +namespace cudnn_frontend { + +namespace graph { + +using managed_backend_descriptor_t = std::vector; + +// simple structure to hold all properties of a tensor. +// Each property has a getter setter. +class Tensor_attributes { + public: + using uid_t = int64_t; + + // There are two usecases of pass by value tensors: + // 1. Fused scalar constants + // 2. Scalar passed during execution + // In approach 1, users provide a value to embed into the graph. + // In approach 2, users set is_pass_by_value boolean and then pass a pointer to scalar value with execute() API. + // A closed set of types that are allowed to be passed by value. + using pass_by_values_t = std::variant; + + error_t + validate() const { + RETURN_CUDNN_FRONTEND_ERROR_IF( + dim.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Tensor '" + name + "' dims not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + stride.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Tensor '" + name + "' strides not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF(dim.size() != stride.size(), + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' does not equal dimensionality in dim and stride."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + is_virtual && is_pass_by_value, + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' can't be both virutal and pass_by_value at the same time."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + pass_by_value.has_value() & (!is_pass_by_value), + error_code_t::ATTRIBUTE_NOT_SET, + "Tensor '" + name + "' can't be a fused scalar and not a pass_by_value tensor at the same time."); + + return {error_code_t::OK, ""}; + } + + private: + template + friend class Attributes; + + std::string name; + DataType_t data_type = DataType_t::NOT_SET; + std::vector dim = {}; + std::vector stride = {}; + bool is_virtual = false; + + std::optional pass_by_value = std::nullopt; + bool is_pass_by_value = false; + + TensorReordering_t reordering_type = TensorReordering_t::NONE; + uid_t uid = 0; + bool uid_assigned = false; + + std::shared_ptr ragged_offset; + int64_t alignment = 16; // Default to 16 bytes + int64_t vector_count = 1; // Default to 1 (no vectorization) + int64_t vector_dimension = -1; // Default to -1 (not set) + + auto + fill_from_context(detail::Context const& context) -> Tensor_attributes& { + if (get_data_type() == DataType_t::NOT_SET) { + if (get_is_virtual()) { + set_data_type(context.get_intermediate_data_type()); + } else { + set_data_type(context.get_io_data_type()); + } + } + return *this; + } + + public: + // Serialization functions +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + friend void + to_json(nlohmann::json& j, const Tensor_attributes& ta); + friend void + from_json(const nlohmann::json& j, Tensor_attributes& ta); +#endif + + Tensor_attributes() = default; + + Tensor_attributes(float const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::FLOAT; + } + + Tensor_attributes(half const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::HALF; + } + + Tensor_attributes(nv_bfloat16 const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::BFLOAT16; + } + + Tensor_attributes(int32_t const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::INT32; + } + + Tensor_attributes(int64_t const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::INT64; + } + + std::string + get_name() const { + return name; + } + + auto + set_name(std::string const& value) -> Tensor_attributes& { + name = value; + return *this; + } + + DataType_t + get_data_type() const { + return data_type; + } + + auto + set_data_type(DataType_t const value) -> Tensor_attributes& { + data_type = value; + return *this; + } + + std::vector + get_dim() const { + return dim; + } + + auto + set_dim(std::vector const& value) -> Tensor_attributes& { + dim = value; + return *this; + } + + int64_t + get_volume() const { + int64_t volume = 1ul; + for (int64_t d : dim) { + volume *= d; + } + return volume; + } + + std::vector + get_stride() const { + return stride; + } + + auto + set_stride(std::vector const& value) -> Tensor_attributes& { + stride = value; + return *this; + } + + bool + get_is_virtual() const { + return is_virtual; + } + + std::shared_ptr + get_ragged_offset() { + return ragged_offset; + } + + auto + set_is_virtual(bool const value) -> Tensor_attributes& { + is_virtual = value; + return *this; + } + + auto + set_output(bool const value) -> Tensor_attributes& { + return set_is_virtual(!value); + } + + std::optional + get_pass_by_value() const { + return pass_by_value; + } + + bool + get_is_pass_by_value() const { + return is_pass_by_value; + } + + auto + set_is_pass_by_value(bool const value) -> Tensor_attributes& { + is_pass_by_value = value; + return *this; + } + + TensorReordering_t + get_reordering_type() const { + return reordering_type; + } + + auto + set_reordering_type(TensorReordering_t const value) -> Tensor_attributes& { + reordering_type = value; + return *this; + } + + int64_t + get_alignment() const { + return alignment; + } + + auto + set_alignment(int64_t const value) -> Tensor_attributes& { + alignment = value; + return *this; + } + + int64_t + get_vector_count() const { + return vector_count; + } + + int64_t + get_vector_dimension() const { + return vector_dimension; + } + + auto + set_vector_count_and_dimension(int64_t const count, int64_t const dimension) -> Tensor_attributes& { + vector_count = count; + vector_dimension = dimension; + return *this; + } + + uid_t + get_uid() const { + return uid; + } + + uid_t + has_uid() const { + return uid_assigned; + } + + auto + clear_uid(void) -> Tensor_attributes& { + uid = 0; + uid_assigned = false; + return *this; + } + + auto + set_uid(uid_t value) -> Tensor_attributes& { + uid = value; + uid_assigned = true; + return *this; + } + + auto + set_ragged_offset(std::shared_ptr const& value) -> Tensor_attributes& { + ragged_offset = value; + return *this; + } +}; + +class Batchnorm_attributes; +class Batchnorm_backward_attributes; +class Concatenate_attributes; + +template +class Attributes { + DerivedT& + self() { + return *static_cast(this); + } + DerivedT const& + self() const { + return *static_cast(this); + } + + protected: + std::vector + get_non_virtual_uids() const { + std::vector non_virtual_uids; + auto derived = static_cast(this); + if constexpr (std::is_same_v) { + for (auto tensor : derived->inputs) { + if (tensor && tensor->get_is_virtual() == false) { + non_virtual_uids.push_back(tensor->get_uid()); + if (auto ragged_offset = tensor->get_ragged_offset()) { + non_virtual_uids.push_back(ragged_offset->get_uid()); + } + } + } + } else { + for (auto& [name, tensor] : derived->inputs) { + (void)name; + if (tensor && tensor->get_is_virtual() == false) { + non_virtual_uids.push_back(tensor->get_uid()); + if (auto ragged_offset = tensor->get_ragged_offset()) { + non_virtual_uids.push_back(ragged_offset->get_uid()); + } + } + } + } + + for (auto& [name, tensor] : derived->outputs) { + (void)name; + if (tensor && tensor->get_is_virtual() == false) { + non_virtual_uids.push_back(tensor->get_uid()); + if (auto ragged_offset = tensor->get_ragged_offset()) { + non_virtual_uids.push_back(ragged_offset->get_uid()); + } + } + } + + // Handle special case of BN where peer_stats is also an input + if constexpr (std::is_same_v || + std::is_same_v) { + for (auto& tensor : derived->peer_stats) { + if (tensor && tensor->get_is_virtual() == false) { + non_virtual_uids.push_back(tensor->get_uid()); + if (auto ragged_offset = tensor->get_ragged_offset()) { + non_virtual_uids.push_back(ragged_offset->get_uid()); + } + } + } + } + + return non_virtual_uids; + } + + public: + error_t + fill_pass_by_value(std::unordered_map& + tensor_to_pass_by_value) const { + auto derived = static_cast(this); + if constexpr (std::is_same_v) { + for (auto& tensor : derived->inputs) { + if (tensor && tensor->get_pass_by_value().has_value()) { + tensor_to_pass_by_value.emplace(tensor->get_uid(), tensor->get_pass_by_value().value()); + } + } + } else { + for (auto& [name, tensor] : derived->inputs) { + (void)name; + if (tensor && tensor->get_pass_by_value().has_value()) { + tensor_to_pass_by_value.emplace(tensor->get_uid(), tensor->get_pass_by_value().value()); + } + } + } + + return {error_code_t::OK, ""}; + } + + void + fill_from_context(detail::Context const& context) { + auto derived = static_cast(this); + + if constexpr (std::is_same_v) { + for (auto& tensor : derived->inputs) { + if (tensor) { + tensor->fill_from_context(context); + } + } + } else { + for (auto& [name, tensor] : derived->inputs) { + (void)name; + if (tensor) { + tensor->fill_from_context(context); + } + } + } + for (auto& [name, tensor] : derived->outputs) { + (void)name; + if (tensor) { + tensor->fill_from_context(context); + } + } + // Handle special case of BN where peer_stats is also an input + if constexpr (std::is_same_v || + std::is_same_v) { + for (auto& tensor : derived->peer_stats) { + if (tensor) { + tensor->fill_from_context(context); + } + } + } + + if (compute_data_type == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + + // Handle shape and stride inferencing for fused scalars. + // Pick number of dimensions from anyone of non-fused-scalar input/output tensors + // In case, all tensors are fused scalars, just keep them 1D. + int64_t number_of_dims = 1; + if constexpr (std::is_same_v) { + for (auto tensor : derived->inputs) { + if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + number_of_dims = tensor->get_dim().size(); + break; + } + } + } else { + for (auto [name, tensor] : derived->inputs) { + (void)name; + if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + number_of_dims = tensor->get_dim().size(); + break; + } + } + } + + // If number of dims is still 1, try to see if user set output dims. + if (number_of_dims == 1) { + for (auto [name, tensor] : derived->outputs) { + (void)name; + if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + number_of_dims = tensor->get_dim().size(); + break; + } + } + } + + if constexpr (std::is_same_v) { + for (auto tensor : derived->inputs) { + if (tensor && tensor->get_pass_by_value().has_value()) { + tensor->set_dim(std::vector(number_of_dims, 1)); + tensor->set_stride(std::vector(number_of_dims, 1)); + } + } + } else { + for (auto [name, tensor] : derived->inputs) { + (void)name; + if (tensor && tensor->get_pass_by_value().has_value()) { + tensor->set_dim(std::vector(number_of_dims, 1)); + tensor->set_stride(std::vector(number_of_dims, 1)); + } + } + } + } + + std::string name; + DataType_t compute_data_type = DataType_t::NOT_SET; + + DerivedT& + set_name(std::string const& value) { + name = value; + return self(); + } + + DerivedT& + set_compute_data_type(DataType_t value) { + compute_data_type = value; + return self(); + } +}; + +class BN_finalize_attributes : public Attributes { + friend class Attributes; + friend class BatchNormFinalizeNode; + friend class Graph; + + public: + enum class input_names { + SUM, + SQ_SUM, + SCALE, + BIAS, + EPSILON, + ACCUM_COUNT, + PREV_RUNNING_MEAN, + PREV_RUNNING_VAR, + MOMENTUM + }; + std::unordered_map> inputs; + enum class output_names { EQ_SCALE, EQ_BIAS, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR }; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(BN_finalize_attributes, name, compute_data_type, inputs, outputs) + std::unordered_map> outputs; + + BN_finalize_attributes& + set_previous_running_stats(std::shared_ptr& mean, + std::shared_ptr& variance, + std::shared_ptr& momentum) { + inputs[BN_finalize_attributes::input_names::PREV_RUNNING_MEAN] = mean; + inputs[BN_finalize_attributes::input_names::PREV_RUNNING_VAR] = variance; + inputs[BN_finalize_attributes::input_names::MOMENTUM] = momentum; + return *this; + } +}; + +class Genstats_attributes : public Attributes { + friend class Attributes; + friend class GenstatsNode; + friend class Graph; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + + enum class output_names { SUM, SQ_SUM }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Genstats_attributes, name, compute_data_type, inputs, outputs) +}; + +class Conv_fprop_attributes : public Attributes { + friend class Attributes; + friend class ConvolutionNode; + friend class Graph; + + std::vector pre_padding; + std::vector post_padding; + std::vector stride; + std::vector dilation; + + ConvolutionMode_t math_mode = ConvolutionMode_t::CROSS_CORRELATION; + + public: + enum class input_names { X, W }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_fprop_attributes, + name, + compute_data_type, + inputs, + outputs, + pre_padding, + post_padding, + stride, + dilation, + math_mode) + + ConvolutionMode_t + get_convolution_mode() const { + return math_mode; + } + + std::vector + get_pre_padding() const { + return pre_padding; + } + + std::vector + get_post_padding() const { + return post_padding; + } + + Conv_fprop_attributes& + set_padding(std::vector value) { + pre_padding = value; + post_padding = value; + return *this; + } + + Conv_fprop_attributes& + set_pre_padding(std::vector value) { + pre_padding = value; + return *this; + } + + Conv_fprop_attributes& + set_post_padding(std::vector value) { + post_padding = value; + return *this; + } + + Conv_fprop_attributes& + set_convolution_mode(ConvolutionMode_t mode_) { + math_mode = mode_; + return *this; + } + + std::vector + get_stride() const { + return stride; + } + + Conv_fprop_attributes& + set_stride(std::vector value) { + stride = value; + return *this; + } + + std::vector + get_dilation() const { + return dilation; + } + + Conv_fprop_attributes& + set_dilation(std::vector value) { + dilation = value; + return *this; + } +}; + +class Batchnorm_backward_attributes : public Attributes { + friend class Attributes; + friend class DBNNode; + friend class Graph; + + public: + enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE }; + std::unordered_map> inputs; + // Only special case where one of the inputs is a vector. + std::vector> peer_stats; + enum class output_names { DX, DSCALE, DBIAS }; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_backward_attributes, name, compute_data_type, inputs, peer_stats, outputs) + std::unordered_map> outputs; + + Batchnorm_backward_attributes& + set_saved_mean_and_inv_variance(std::shared_ptr mean, + std::shared_ptr inv_variance) { + inputs[Batchnorm_backward_attributes::input_names::MEAN] = mean; + inputs[Batchnorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance; + return *this; + } + + Batchnorm_backward_attributes& + set_peer_stats(std::vector> const& input_peer_stats) { + peer_stats = input_peer_stats; + return *this; + } +}; + +class DBN_weight_attributes : public Attributes { + friend class Attributes; + friend class DBNWeightNode; + friend class Graph; + + public: + enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE }; + std::unordered_map> inputs; + enum class output_names { DSCALE, DBIAS, EQ_BIAS, EQ_SCALE_DY, EQ_SCALE_X }; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(DBN_weight_attributes, name, compute_data_type, inputs, outputs) + std::unordered_map> outputs; +}; + +class Conv_dgrad_attributes : public Attributes { + friend class Attributes; + friend class DgradNode; + friend class Graph; + + std::vector pre_padding; + std::vector post_padding; + std::vector stride; + std::vector dilation; + + ConvolutionMode_t math_mode = ConvolutionMode_t::CROSS_CORRELATION; + + public: + enum class input_names { DY, W }; + std::unordered_map> inputs; + enum class output_names { DX }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_dgrad_attributes, + name, + compute_data_type, + inputs, + outputs, + pre_padding, + post_padding, + stride, + dilation, + math_mode) + + ConvolutionMode_t + get_convolution_mode() const { + return math_mode; + } + + std::vector + get_pre_padding() const { + return pre_padding; + } + + std::vector + get_post_padding() const { + return post_padding; + } + + Conv_dgrad_attributes& + set_padding(std::vector value) { + pre_padding = value; + post_padding = value; + return *this; + } + + Conv_dgrad_attributes& + set_pre_padding(std::vector value) { + pre_padding = value; + return *this; + } + + Conv_dgrad_attributes& + set_post_padding(std::vector value) { + post_padding = value; + return *this; + } + + std::vector + get_stride() const { + return stride; + } + + Conv_dgrad_attributes& + set_convolution_mode(ConvolutionMode_t mode_) { + math_mode = mode_; + ; + return *this; + } + + Conv_dgrad_attributes& + set_stride(std::vector value) { + stride = value; + return *this; + } + + std::vector + get_dilation() const { + return dilation; + } + + Conv_dgrad_attributes& + set_dilation(std::vector value) { + dilation = value; + return *this; + } +}; + +class Matmul_fp8_attributes : public Attributes { + friend class Attributes; + friend class MatmulFP8Node; + friend class INode; + + double padding_value = 0.0; + + public: + enum class input_names { Descale_A, Descale_B, A, B, M_override, N_override, K_override, Scale_C }; + std::unordered_map> inputs; + enum class output_names { C, Amax_C }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_fp8_attributes, name, compute_data_type, inputs, outputs) + + Matmul_fp8_attributes& + set_m_override(std::shared_ptr const& value) { + inputs[input_names::M_override] = value; + return *this; + } + + Matmul_fp8_attributes& + set_n_override(std::shared_ptr const& value) { + inputs[input_names::N_override] = value; + return *this; + } + + Matmul_fp8_attributes& + set_k_override(std::shared_ptr const& value) { + inputs[input_names::K_override] = value; + return *this; + } + + Matmul_fp8_attributes& + set_padding(double const padding_val) { + padding_value = padding_val; + return *this; + } + + double + get_padding() const { + return padding_value; + } +}; + +class Matmul_attributes : public Attributes { + friend class Attributes; + friend class MatmulNode; + friend class INode; + + double padding_value = 0.0; + + public: + enum class input_names { A, B, M_override, N_override, K_override }; + std::unordered_map> inputs; + enum class output_names { C }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_attributes, name, compute_data_type, inputs, outputs, padding_value) + + Matmul_attributes& + clone_fp8_attributes(Matmul_fp8_attributes const& attributes) { + auto m_override = attributes.inputs.find(Matmul_fp8_attributes::input_names::M_override); + if (m_override != attributes.inputs.end()) { + set_m_override(m_override->second); + } + auto n_override = attributes.inputs.find(Matmul_fp8_attributes::input_names::N_override); + if (n_override != attributes.inputs.end()) { + set_n_override(n_override->second); + } + auto k_override = attributes.inputs.find(Matmul_fp8_attributes::input_names::K_override); + if (k_override != attributes.inputs.end()) { + set_k_override(k_override->second); + } + + set_padding(attributes.get_padding()); + + return *this; + } + + Matmul_attributes& + set_m_override(std::shared_ptr const& value) { + inputs[input_names::M_override] = value; + return *this; + } + + Matmul_attributes& + set_n_override(std::shared_ptr const& value) { + inputs[input_names::N_override] = value; + return *this; + } + + Matmul_attributes& + set_k_override(std::shared_ptr const& value) { + inputs[input_names::K_override] = value; + return *this; + } + + Matmul_attributes& + set_padding(double const padding_val) { + padding_value = padding_val; + return *this; + } +}; + +class Pointwise_attributes : public Attributes { + friend class Attributes; + friend class PointwiseNode; + friend class SoftmaxNode; + friend class INode; + + PointwiseMode_t mode = PointwiseMode_t::NOT_SET; + + std::optional axis; + + std::optional relu_lower_clip; + std::optional relu_upper_clip; + std::optional relu_lower_clip_slope; + + std::optional swish_beta; + std::optional elu_alpha; + std::optional softplus_beta; + + public: + enum class input_names { IN_0, IN_1, IN_2 }; + std::unordered_map> inputs; + enum class output_names { OUT_0 }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Pointwise_attributes, + name, + compute_data_type, + inputs, + outputs, + mode, + axis, + relu_lower_clip, + relu_upper_clip, + relu_lower_clip_slope, + swish_beta, + elu_alpha, + softplus_beta) + + Pointwise_attributes& + set_mode(PointwiseMode_t const value) { + mode = value; + return *this; + } + + std::optional + get_axis() const { + return axis; + } + + Pointwise_attributes& + set_axis(int64_t const axis) { + this->axis = axis; + return *this; + } + + Pointwise_attributes& + set_relu_lower_clip_slope(float const negative_slope) { + this->relu_lower_clip_slope = negative_slope; + return *this; + } + + Pointwise_attributes& + set_relu_lower_clip(float const value) { + this->relu_lower_clip = value; + return *this; + } + + Pointwise_attributes& + set_relu_upper_clip(float const value) { + this->relu_upper_clip = value; + return *this; + } + + Pointwise_attributes& + set_swish_beta(float const value) { + this->swish_beta = value; + return *this; + } + + Pointwise_attributes& + set_elu_alpha(float const value) { + this->elu_alpha = value; + return *this; + } + + Pointwise_attributes& + set_softplus_beta(float const value) { + this->softplus_beta = value; + return *this; + } +}; + +class Instancenorm_backward_attributes : public Attributes { + friend class Attributes; + friend class DINNode; + friend class Graph; + + public: + enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE }; + std::unordered_map> inputs; + enum class output_names { DX, DSCALE, DBIAS }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_backward_attributes, name, compute_data_type, inputs, outputs) + + Instancenorm_backward_attributes& + set_saved_mean_and_inv_variance(std::shared_ptr mean, + std::shared_ptr inv_variance) { + inputs[Instancenorm_backward_attributes::input_names::MEAN] = mean; + inputs[Instancenorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance; + return *this; + } +}; + +class Layernorm_backward_attributes : public Attributes { + friend class Attributes; + friend class DLNNode; + friend class Graph; + + public: + enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE, EPSILON }; + std::unordered_map> inputs; + enum class output_names { DX, DSCALE, DBIAS }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_backward_attributes, name, compute_data_type, inputs, outputs) + + Layernorm_backward_attributes& + set_saved_mean_and_inv_variance(std::shared_ptr mean, + std::shared_ptr inv_variance) { + inputs[Layernorm_backward_attributes::input_names::MEAN] = mean; + inputs[Layernorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance; + return *this; + } +}; + +class Layernorm_attributes : public Attributes { + friend class Attributes; + friend class LayerNormNode; + friend class Graph; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + public: + enum class input_names { X, SCALE, BIAS, EPSILON }; + std::unordered_map> inputs; + enum class output_names { Y, MEAN, INV_VARIANCE }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) + + Layernorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + Layernorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs[Layernorm_attributes::input_names::EPSILON] = value; + return *this; + } +}; + +class AdaLayernorm_attributes : public Attributes { + friend class Attributes; + friend class AdaLayerNormNode; + friend class Graph; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + public: + enum class input_names { X, SCALE, BIAS, EPSILON }; + std::unordered_map> inputs; + enum class output_names { Y, MEAN, INV_VARIANCE }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(AdaLayernorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) + + AdaLayernorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + AdaLayernorm_attributes& + set_epsilon(std::shared_ptr value) { + inputs[AdaLayernorm_attributes::input_names::EPSILON] = std::move(value); + return *this; + } +}; + +class AdaLayernorm_backward_attributes : public Attributes { + friend class Attributes; + friend class DAdaLayerNormNode; + friend class Graph; + + public: + enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE, EPSILON }; + std::unordered_map> inputs; + enum class output_names { DX, DSCALE, DBIAS }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(AdaLayernorm_backward_attributes, name, compute_data_type, inputs, outputs) + + AdaLayernorm_backward_attributes& + set_saved_mean_and_inv_variance(std::shared_ptr mean, + std::shared_ptr inv_variance) { + inputs[AdaLayernorm_backward_attributes::input_names::MEAN] = mean; + inputs[AdaLayernorm_backward_attributes::input_names::INV_VARIANCE] = inv_variance; + return *this; + } +}; + +class Instancenorm_attributes : public Attributes { + friend class Attributes; + friend class InstanceNormNode; + friend class Graph; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + public: + enum class input_names { X, SCALE, BIAS, EPSILON }; + std::unordered_map> inputs; + enum class output_names { Y, MEAN, INV_VARIANCE }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) + + Instancenorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + Instancenorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs[Instancenorm_attributes::input_names::EPSILON] = value; + return *this; + } +}; + +class Batchnorm_attributes : public Attributes { + friend class Attributes; + friend class BatchNormNode; + friend class Graph; + + public: + enum class input_names { X, SCALE, BIAS, PREV_RUNNING_MEAN, PREV_RUNNING_VAR, EPSILON, MOMENTUM }; + std::unordered_map> inputs; + // Only special case where one of the inputs is a vector. + std::vector> peer_stats; + enum class output_names { Y, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_attributes, name, compute_data_type, inputs, peer_stats, outputs) + + Batchnorm_attributes& + set_previous_running_stats(std::shared_ptr& mean, + std::shared_ptr& variance, + std::shared_ptr& momentum) { + inputs[input_names::PREV_RUNNING_MEAN] = mean; + inputs[input_names::PREV_RUNNING_VAR] = variance; + inputs[input_names::MOMENTUM] = momentum; + return *this; + } + + Batchnorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs[input_names::EPSILON] = value; + return *this; + } + + Batchnorm_attributes& + set_peer_stats(std::vector> const& input_peer_stats) { + peer_stats = input_peer_stats; + return *this; + } +}; + +class Batchnorm_inference_attributes : public Attributes { + friend class Attributes; + friend class BatchnormInferenceNode; + friend class Graph; + + public: + enum class input_names { X, MEAN, INV_VARIANCE, SCALE, BIAS }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_inference_attributes, name, compute_data_type, inputs, outputs) +}; + +class Reduction_attributes : public Attributes { + friend class Attributes; + friend class ReductionNode; + friend class INode; + + std::optional mode; + bool is_deterministic = false; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reduction_attributes, + name, + compute_data_type, + inputs, + outputs, + mode, + is_deterministic) + + std::optional + get_mode() const { + return mode; + } + + Reduction_attributes& + set_mode(ReductionMode_t value) { + mode = value; + return *this; + } + + bool + get_is_deterministic() const { + return is_deterministic; + } + + Reduction_attributes& + set_is_deterministic(bool value) { + is_deterministic = value; + return *this; + } +}; + +class Rng_attributes : public Attributes { + friend class Attributes; + friend class RngNode; + friend class INode; + + RngDistribution_t distribution = RngDistribution_t::NOT_SET; + std::vector dim = {}; + std::vector stride = {}; + std::optional seed; + std::optional bernoulli_probability; + + public: + enum class input_names { Seed, Offset }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rng_attributes, + name, + inputs, + outputs, + distribution, + dim, + stride, + seed, + bernoulli_probability) + + std::vector + get_dim() const { + return dim; + } + + auto + set_dim(std::vector const& value) -> Rng_attributes& { + dim = value; + return *this; + } + + std::vector + get_stride() const { + return stride; + } + + auto + set_stride(std::vector const& value) -> Rng_attributes& { + stride = value; + return *this; + } + + RngDistribution_t + get_distribution() const { + return distribution; + } + + Rng_attributes& + set_distribution(RngDistribution_t value) { + distribution = value; + return *this; + } + + std::optional + get_seed() const { + return seed; + } + + Rng_attributes& + set_seed(std::optional value) { + seed = value; + return *this; + } + + std::optional + get_bernoulli_probability() const { + return bernoulli_probability; + } + + Rng_attributes& + set_bernoulli_probability(std::optional value) { + bernoulli_probability = value; + return *this; + } +}; + +class Resample_attributes : public Attributes { + friend class Attributes; + friend class ResampleNode; + friend class INode; + + std::optional generate_index; + ResampleMode_t resample_mode; + PaddingMode_t padding_mode; + std::vector pre_padding; + std::vector post_padding; + std::vector stride; + std::vector window; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + + enum class output_names { Y, Index }; + std::unordered_map> outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Resample_attributes, + name, + inputs, + outputs, + generate_index, + resample_mode, + padding_mode, + pre_padding, + post_padding, + stride, + window) + + auto + set_resampling_mode(ResampleMode_t const& value) -> Resample_attributes& { + resample_mode = value; + return *this; + } + + auto + set_padding_mode(PaddingMode_t const& value) -> Resample_attributes& { + padding_mode = value; + return *this; + } + + auto + set_window(std::vector const& value) -> Resample_attributes& { + window.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + window[i].numerator = value[i]; + window[i].denominator = 1; + } + return *this; + } + + auto + set_window(std::vector const& value) -> Resample_attributes& { + window = value; + return *this; + } + + auto + set_stride(std::vector const& value) -> Resample_attributes& { + stride.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + stride[i].numerator = value[i]; + stride[i].denominator = 1; + } + return *this; + } + + auto + set_stride(std::vector const& value) -> Resample_attributes& { + stride = value; + return *this; + } + + auto + set_pre_padding(std::vector const& value) -> Resample_attributes& { + pre_padding.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + pre_padding[i].numerator = value[i]; + pre_padding[i].denominator = 1; + } + return *this; + } + + auto + set_pre_padding(std::vector const& value) -> Resample_attributes& { + pre_padding = value; + return *this; + } + + auto + set_post_padding(std::vector const& value) -> Resample_attributes& { + post_padding.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + post_padding[i].numerator = value[i]; + post_padding[i].denominator = 1; + } + return *this; + } + + auto + set_post_padding(std::vector const& value) -> Resample_attributes& { + post_padding = value; + return *this; + } + + auto + set_generate_index(bool const value) -> Resample_attributes& { + generate_index = value; + return *this; + } + + [[deprecated]] auto + set_is_inference(bool const value) -> Resample_attributes& { + return set_generate_index(!value); + } +}; + +class Reshape_attributes : public Attributes { + friend class Attributes; + friend class ReshapeNode; + friend class INode; + + std::vector dim = {}; + std::vector stride = {}; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, name, compute_data_type, inputs, outputs, dim, stride) + + std::vector + get_dim() const { + return dim; + } + + auto + set_dim(std::vector const& value) -> Reshape_attributes& { + dim = value; + return *this; + } + + std::vector + get_stride() const { + return stride; + } + + auto + set_stride(std::vector const& value) -> Reshape_attributes& { + stride = value; + return *this; + } +}; + +class Rmsnorm_attributes : public Attributes { + friend class Attributes; + friend class RMSNormNode; + friend class Graph; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + public: + enum class input_names { X, SCALE, BIAS, EPSILON }; + std::unordered_map> inputs; + enum class output_names { Y, INV_VARIANCE }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) + + Rmsnorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + Rmsnorm_attributes& + set_bias(std::shared_ptr& value) { + inputs[Rmsnorm_attributes::input_names::BIAS] = value; + return *this; + } + + Rmsnorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs[Rmsnorm_attributes::input_names::EPSILON] = value; + return *this; + } +}; + +class Rmsnorm_backward_attributes : public Attributes { + friend class Attributes; + friend class DRMSNormNode; + friend class Graph; + + std::optional use_dbias; + + public: + enum class input_names { DY, X, SCALE, INV_VARIANCE }; + std::unordered_map> inputs; + enum class output_names { DX, DSCALE, DBIAS }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_backward_attributes, name, compute_data_type, inputs, outputs) + + Rmsnorm_backward_attributes& + has_dbias(bool value) { + use_dbias = value; + return *this; + } +}; + +// class Scaled_dot_product_attention_attributes : public Operation { +// public: +// struct Inputs { +// std::shared_ptr Q; +// std::shared_ptr K; +// std::shared_ptr Attn_scale; +// std::shared_ptr Bias; // Optional bias after bmm1 +// std::shared_ptr V; +// std::shared_ptr SEQ_LEN_Q; +// std::shared_ptr SEQ_LEN_KV; +// std::shared_ptr Mask; +// std::shared_ptr Dropout_mask; +// std::shared_ptr Dropout_scale; +// } inputs; + +// struct Outputs { +// std::shared_ptr O; +// std::shared_ptr +// S; // softmax output dumped when is_inference false. Users first need to check whether its nullptr. +// } outputs; + +// std::optional is_inference; +// bool padding_mask = false; +// bool causal_mask = false; +// std::optional dropout_probability; +// int64_t seed; +// float dropout_scale = 1.f; + +// public: +// Scaled_dot_product_attention_attributes() : Operation(Tag::Scaled_dot_product_attention), is_inference(false) {} + +// Scaled_dot_product_attention_attributes& +// set_is_inference(bool const value) { +// is_inference = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_seq_len_q(std::shared_ptr value) { +// inputs.SEQ_LEN_Q = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_seq_len_kv(std::shared_ptr value) { +// inputs.SEQ_LEN_KV = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_padding_mask(bool const value) { +// padding_mask = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_causal_mask(bool const value) { +// causal_mask = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_attn_scale(std::shared_ptr value) { +// inputs.Attn_scale = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_bias(std::shared_ptr bias) { +// inputs.Bias = bias; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_dropout(float const probability, int64_t const seed_) { +// dropout_probability = probability; +// seed = seed_; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_dropout(std::shared_ptr mask, std::shared_ptr scale) { +// inputs.Dropout_mask = mask; +// inputs.Dropout_scale = scale; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_compute_data_type(DataType_t const value) { +// compute_data_type = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// set_name(std::string const& value) { +// name = value; +// return *this; +// } + +// Scaled_dot_product_attention_attributes& +// fill_from_context(detail::Context const& context) { +// // Fill node's tensors +// inputs.Q->fill_from_context(context); +// inputs.K->fill_from_context(context); +// inputs.V->fill_from_context(context); +// inputs.SEQ_LEN_Q->fill_from_context(context); +// inputs.SEQ_LEN_KV->fill_from_context(context); +// outputs.O->fill_from_context(context); + +// // Fill this node +// if (get_compute_data_type() == DataType_t::NOT_SET) { +// set_compute_data_type(context.get_compute_data_type()); +// } +// return *this; +// } +// }; +template +class SDPANodeBase; +class CompositeSDPANode; +class UnifiedSDPANode; + +class SDPA_attributes : public Attributes { + friend class Attributes; + friend class SDPANodeBase; + friend class CompositeSDPANode; + friend class SDPANodeBase; + friend class UnifiedSDPANode; + friend class Graph; + + using Tensor_t = std::shared_ptr; + using Graph_t = std::shared_ptr; + + using AttentionScoreModifier_t = + std::function, std::shared_ptr)>; + + std::optional generate_stats; + bool alibi_mask = false; + bool padding_mask = false; + std::optional left_bound; + std::optional right_bound; + DiagonalAlignment_t diagonal_alignment = DiagonalAlignment_t::TOP_LEFT; + std::optional dropout_probability; + std::optional attn_scale_value; + std::optional max_seq_len_kv; + AttentionScoreModifier_t attention_score_modifier = nullptr; + DataType_t mma_core_mode = DataType_t::NOT_SET; + + // Deprecated fields for backward compatibility with SDPA_fp8_attributes + bool causal_mask = false; + bool causal_mask_bottom_right = false; + + AttentionImplementation_t implementation = AttentionImplementation_t::AUTO; + + bool + has_causal_like_masking() const { + return right_bound.has_value(); + } + + bool + has_causal_mask_bottom_right() const { + return right_bound.has_value() && diagonal_alignment == DiagonalAlignment_t::BOTTOM_RIGHT; + } + + public: + enum class input_names { + Q, + K, + V, + Attn_scale, + Bias, + SEQ_LEN_Q, + SEQ_LEN_KV, + Seed, + Offset, + Dropout_mask, + Dropout_scale, + Page_table_K, + Page_table_V, + Block_mask, + // FP8-specific scaling inputs + Descale_Q, + Descale_K, + Descale_V, + Descale_S, + Scale_S, + Scale_O, + SINK_TOKEN, + }; + std::unordered_map> inputs; + enum class output_names { O, Stats, Max, Sum_exp, RNG_DUMP, Amax_S, Amax_O }; + std::unordered_map> outputs; + // Convenience struct for named access to SDPA outputs + struct SDPA_outputs { + std::shared_ptr O; ///< Main attention output tensor + std::shared_ptr Stats; ///< Statistics/softmax output (when generate_stats=true) + std::shared_ptr Max; ///< Max output tensor + std::shared_ptr Sum_exp; ///< Sum_exp output tensor + std::shared_ptr RNG_DUMP; ///< Random number generator dump for dropout + ///< check why we don't return RNG_DUMP this way + std::shared_ptr Amax_S; ///< FP8 absolute maximum for attention scores + std::shared_ptr Amax_O; ///< FP8 absolute maximum for output tensor + }; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_attributes, + name, + inputs, + outputs, + generate_stats, + alibi_mask, + padding_mask, + dropout_probability, + attn_scale_value, + max_seq_len_kv, + mma_core_mode, + left_bound, + right_bound, + diagonal_alignment, + causal_mask, + causal_mask_bottom_right, + implementation) + + SDPA_attributes& + set_generate_stats(bool const value) { + generate_stats = value; + return *this; + } + + SDPA_attributes& + set_logit_max(std::shared_ptr value) { + outputs[SDPA_attributes::output_names::Max] = std::move(value); + return *this; + } + + SDPA_attributes& + set_score_sum_exp(std::shared_ptr value) { + outputs[SDPA_attributes::output_names::Sum_exp] = std::move(value); + return *this; + } + + [[deprecated]] SDPA_attributes& + set_is_inference(bool const value) { + return set_generate_stats(!value); + } + + SDPA_attributes& + set_attn_scale(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Attn_scale] = std::move(value); + return *this; + } + + SDPA_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; + return *this; + } + + SDPA_attributes& + set_bias(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Bias] = std::move(value); + return *this; + } + + SDPA_attributes& + set_block_mask(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Block_mask] = std::move(value); + return *this; + } + + SDPA_attributes& + set_alibi_mask(bool const value) { + alibi_mask = value; + return *this; + } + + SDPA_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + // Internal function - do not use directly in application code + SDPA_attributes& + _set_mma_core_mode(DataType_t const value) { + mma_core_mode = value; + return *this; + } + + SDPA_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::SEQ_LEN_Q] = std::move(value); + return *this; + } + + SDPA_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::SEQ_LEN_KV] = std::move(value); + return *this; + } + + SDPA_attributes& + set_diagonal_alignment(DiagonalAlignment_t const alignment) { + diagonal_alignment = alignment; + return *this; + } + + // Sets the diagonal position to top left and + // calls set_diagonal_band_right_bound(0) if no right_bound was specified + // TODO: Deprecate + SDPA_attributes& + set_causal_mask(bool const value) { + if (value) { + set_diagonal_alignment(DiagonalAlignment_t::TOP_LEFT); + if (!right_bound.has_value()) { + set_diagonal_band_right_bound(0); + } + } + causal_mask = value; + return *this; + } + + // Sets the diagonal position to the bottom right (on a per-sequence basis) + // and calls set_diagonal_band_right_bound(0) if no right_bound was specified + // TODO: Deprecate + SDPA_attributes& + set_causal_mask_bottom_right(bool const value) { + if (value) { + set_diagonal_alignment(DiagonalAlignment_t::BOTTOM_RIGHT); + if (!right_bound.has_value()) { + set_diagonal_band_right_bound(0); + } + } + causal_mask_bottom_right = value; + return *this; + } + + SDPA_attributes& + set_score_mod(AttentionScoreModifier_t fn) { + attention_score_modifier = std::move(fn); + return *this; + } + + // calls set_diagonal_band_left_bound(value) + // TODO: Deprecate + SDPA_attributes& + set_sliding_window_length(int const value) { + return set_diagonal_band_left_bound(value); + } + + SDPA_attributes& + set_diagonal_band_left_bound(int const value) { + left_bound = value; + return *this; + } + + SDPA_attributes& + set_diagonal_band_right_bound(int const value) { + right_bound = value; + return *this; + } + + SDPA_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs[SDPA_attributes::input_names::Seed] = std::move(seed); + inputs[SDPA_attributes::input_names::Offset] = std::move(offset); + return *this; + } + + SDPA_attributes& + set_dropout(std::shared_ptr mask, std::shared_ptr scale) { + inputs[SDPA_attributes::input_names::Dropout_mask] = std::move(mask); + inputs[SDPA_attributes::input_names::Dropout_scale] = std::move(scale); + return *this; + } + + // For debugging purposes only. + SDPA_attributes& + set_rng_dump(std::shared_ptr value) { + outputs[SDPA_attributes::output_names::RNG_DUMP] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_k_table(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Page_table_K] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_v_table(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Page_table_V] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_max_seq_len_kv(int const value) { + max_seq_len_kv = value; + return *this; + } + + SDPA_attributes& + set_sink_token(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::SINK_TOKEN] = std::move(value); + return *this; + } + + SDPA_attributes& + set_implementation(AttentionImplementation_t value) { + implementation = value; + return *this; + } + + // Implementation is in sdpa_support_surface.h + error_t + validate_sdpa_support_surface(const detail::Context& context, int64_t s_kv, bool is_paged_k, bool is_paged_v) const; + + // Internal function - do not use directly in application code + void + _auto_select_implementation(const detail::Context& context) { + if (verify_sdpa_support_surface_for_implementation(context, AttentionImplementation_t::UNIFIED).is_good()) { + implementation = AttentionImplementation_t::UNIFIED; + CUDNN_FE_LOG_LABEL_ENDL("INFO: Auto-selected SDPA implementation UNIFIED"); + } else if (verify_sdpa_support_surface_for_implementation(context, AttentionImplementation_t::COMPOSITE) + .is_good()) { + implementation = AttentionImplementation_t::COMPOSITE; + CUDNN_FE_LOG_LABEL_ENDL("INFO: Auto-selected SDPA implementation COMPOSITE"); + } else { + // Leave `implementation` with its previous value (usually AUTO). + CUDNN_FE_LOG_LABEL_ENDL("ERROR: No suitable SDPA implementation for given SDPA_attributes"); + } + } + + private: + // Check whether implementation `impl` supports the requested features. `impl` must not be AUTO. + // (The `implementation` member variable is ignored.) + error_t + verify_sdpa_support_surface_for_implementation(const detail::Context& context, + AttentionImplementation_t impl) const; +}; + +// Type alias for backward compatibility - SDPA_fp8_attributes is now an alias to SDPA_attributes +// All FP8 functionality is unified in SDPA_attributes with the mma_core_mode field +using SDPA_fp8_attributes = SDPA_attributes; + +class SDPA_backward_attributes : public Attributes { + friend class Attributes; + friend class CompositeSDPABackwardNode; + friend class Graph; + using Tensor_t = std::shared_ptr; + using Graph_t = std::shared_ptr; + + using AttentionScoreModifier_t = + std::function, std::shared_ptr)>; + + bool alibi_mask = false; + bool padding_mask = false; + std::optional left_bound; + std::optional right_bound; + DiagonalAlignment_t diagonal_alignment = DiagonalAlignment_t::TOP_LEFT; + + std::optional dropout_probability; + std::optional attn_scale_value; + + std::optional max_total_seq_len_q; + std::optional max_total_seq_len_kv; + + bool is_deterministic_algorithm = false; + AttentionScoreModifier_t attention_score_modifier = nullptr; + AttentionScoreModifier_t attention_score_modifier_bprop = nullptr; + + bool + has_causal_like_masking() const { + return right_bound.has_value(); + } + + bool + has_causal_mask_bottom_right() const { + return right_bound.has_value() && diagonal_alignment == DiagonalAlignment_t::BOTTOM_RIGHT; + } + + public: + enum class input_names { + Q, + K, + V, + O, + dO, + Stats, + Attn_scale, + Bias, + SEQ_LEN_Q, + SEQ_LEN_KV, + Seed, + Offset, + Dropout_mask, + Dropout_scale, + Dropout_scale_inv, + SINK_TOKEN, + }; + std::unordered_map> inputs; + enum class output_names { dQ, dK, dV, dBias, RNG_DUMP, DSINK_TOKEN }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_backward_attributes, + name, + inputs, + outputs, + alibi_mask, + padding_mask, + dropout_probability, + attn_scale_value, + left_bound, + right_bound, + diagonal_alignment, + max_total_seq_len_q, + max_total_seq_len_kv, + is_deterministic_algorithm) + + SDPA_backward_attributes& + set_attn_scale(std::shared_ptr value) { + inputs[SDPA_backward_attributes::input_names::Attn_scale] = value; + return *this; + } + + SDPA_backward_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; + return *this; + } + + SDPA_backward_attributes& + set_bias(std::shared_ptr value) { + inputs[SDPA_backward_attributes::input_names::Bias] = value; + return *this; + } + + SDPA_backward_attributes& + set_dbias(std::shared_ptr value) { + outputs[SDPA_backward_attributes::output_names::dBias] = value; + return *this; + } + + SDPA_backward_attributes& + set_alibi_mask(bool const value) { + alibi_mask = value; + return *this; + } + + SDPA_backward_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + SDPA_backward_attributes& + set_score_mod(AttentionScoreModifier_t fn) { + attention_score_modifier = std::move(fn); + return *this; + } + + SDPA_backward_attributes& + set_score_mod_bprop(AttentionScoreModifier_t fn) { + attention_score_modifier_bprop = std::move(fn); + return *this; + } + + SDPA_backward_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs[SDPA_backward_attributes::input_names::SEQ_LEN_Q] = value; + return *this; + } + + SDPA_backward_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs[SDPA_backward_attributes::input_names::SEQ_LEN_KV] = value; + return *this; + } + + SDPA_backward_attributes& + set_max_total_seq_len_q(int64_t const value) { + max_total_seq_len_q = value; + return *this; + } + + SDPA_backward_attributes& + set_max_total_seq_len_kv(int64_t const value) { + max_total_seq_len_kv = value; + return *this; + } + + SDPA_backward_attributes& + set_diagonal_alignment(DiagonalAlignment_t const alignment) { + diagonal_alignment = alignment; + return *this; + } + + // Sets the diagonal position to top left and + // calls set_diagonal_band_right_bound(0) if no right_bound was specified + // TODO: Deprecate + SDPA_backward_attributes& + set_causal_mask(bool const value) { + if (value) { + set_diagonal_alignment(DiagonalAlignment_t::TOP_LEFT); + if (!right_bound.has_value()) { + set_diagonal_band_right_bound(0); + } + } + return *this; + } + + // Sets the diagonal position to the bottom right (on a per-sequence basis) + // and calls set_diagonal_band_right_bound(0) if no right_bound was specified + // TODO: Deprecate + SDPA_backward_attributes& + set_causal_mask_bottom_right(bool const value) { + if (value) { + set_diagonal_alignment(DiagonalAlignment_t::BOTTOM_RIGHT); + if (!right_bound.has_value()) { + set_diagonal_band_right_bound(0); + } + } + return *this; + } + + // calls set_diagonal_band_left_bound(value) + // TODO: Deprecate + SDPA_backward_attributes& + set_sliding_window_length(int const value) { + return set_diagonal_band_left_bound(value); + } + + SDPA_backward_attributes& + set_diagonal_band_left_bound(int const value) { + left_bound = value; + return *this; + } + + SDPA_backward_attributes& + set_diagonal_band_right_bound(int const value) { + right_bound = value; + return *this; + } + + SDPA_backward_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs[SDPA_backward_attributes::input_names::Seed] = seed; + inputs[SDPA_backward_attributes::input_names::Offset] = offset; + return *this; + } + + SDPA_backward_attributes& + set_dropout(std::shared_ptr mask, + std::shared_ptr scale, + std::shared_ptr scale_inv) { + inputs[SDPA_backward_attributes::input_names::Dropout_mask] = mask; + inputs[SDPA_backward_attributes::input_names::Dropout_scale] = scale; + inputs[SDPA_backward_attributes::input_names::Dropout_scale_inv] = scale_inv; + return *this; + } + + // For debugging purposes only. + SDPA_backward_attributes& + set_rng_dump(std::shared_ptr value) { + outputs[SDPA_backward_attributes::output_names::RNG_DUMP] = value; + return *this; + } + + SDPA_backward_attributes& + set_deterministic_algorithm(bool const value) { + is_deterministic_algorithm = value; + return *this; + } + + SDPA_backward_attributes& + set_sink_token(std::shared_ptr value) { + inputs[SDPA_backward_attributes::input_names::SINK_TOKEN] = value; + return *this; + } + + SDPA_backward_attributes& + set_dsink_token(std::shared_ptr value) { + outputs[SDPA_backward_attributes::output_names::DSINK_TOKEN] = value; + return *this; + } +}; + +class SDPA_fp8_backward_attributes : public Attributes { + friend class Attributes; + friend class SDPAFP8BackwardNode; + friend class Graph; + + bool padding_mask = false; + bool causal_mask = false; + bool causal_mask_bottom_right = false; + bool is_deterministic_algorithm = false; + + std::optional dropout_probability; + std::optional attn_scale_value; + + public: + enum class input_names { + Q, + K, + V, + O, + dO, + Stats, + Attn_scale, + Bias, + SEQ_LEN_Q, + SEQ_LEN_KV, + Seed, + Offset, + Dropout_mask, + Dropout_scale, + Dropout_scale_inv, + + Descale_Q, + Descale_K, + Descale_V, + Descale_O, + Descale_dO, + Descale_S, + Descale_dP, + Scale_dQ, + Scale_dK, + Scale_dV, + Scale_S, + Scale_dP, + }; + std::unordered_map> inputs; + + enum class output_names { dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP }; + std::unordered_map> outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_backward_attributes, + name, + compute_data_type, + inputs, + outputs, + padding_mask, + causal_mask, + dropout_probability, + causal_mask_bottom_right, + attn_scale_value, + is_deterministic_algorithm) + + SDPA_fp8_backward_attributes& + set_attn_scale(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::Attn_scale] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_bias(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::Bias] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::SEQ_LEN_Q] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::SEQ_LEN_KV] = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_causal_mask(bool const value) { + causal_mask = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_causal_mask_bottom_right(bool const value) { + causal_mask_bottom_right = value; + return *this; + } + + SDPA_fp8_backward_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs[SDPA_fp8_backward_attributes::input_names::Seed] = seed; + inputs[SDPA_fp8_backward_attributes::input_names::Offset] = offset; + return *this; + } + + SDPA_fp8_backward_attributes& + set_dropout(std::shared_ptr mask, + std::shared_ptr scale, + std::shared_ptr scale_inv) { + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_mask] = mask; + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_scale] = scale; + inputs[SDPA_fp8_backward_attributes::input_names::Dropout_scale_inv] = scale_inv; + return *this; + } + + SDPA_fp8_backward_attributes& + set_deterministic_algorithm(bool const value) { + is_deterministic_algorithm = value; + return *this; + } +}; + +using Scaled_dot_product_flash_attention_attributes [[deprecated]] = SDPA_attributes; +using Scaled_dot_product_flash_attention_backward_attributes [[deprecated]] = SDPA_backward_attributes; + +class Softmax_attributes : public Attributes { + friend class Attributes; + friend class SoftmaxNode; + friend class INode; + + public: + enum class input_names { P, SINK }; + std::unordered_map> inputs; + enum class output_names { S, Stats, Max, Sum_exp }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Softmax_attributes, name, compute_data_type, inputs, outputs) + + Softmax_attributes& + set_sink(std::shared_ptr value) { + inputs[Softmax_attributes::input_names::SINK] = value; + return *this; + } +}; + +class Conv_wgrad_attributes : public Attributes { + friend class Attributes; + friend class WgradNode; + friend class Graph; + + std::vector pre_padding; + std::vector post_padding; + std::vector stride; + std::vector dilation; + ConvolutionMode_t math_mode = ConvolutionMode_t::CROSS_CORRELATION; + + public: + enum class input_names { DY, X }; + std::unordered_map> inputs; + + enum class output_names { DW }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_wgrad_attributes, + name, + compute_data_type, + inputs, + outputs, + pre_padding, + post_padding, + stride, + dilation, + math_mode) + + ConvolutionMode_t + get_convolution_mode() const { + return math_mode; + } + + std::vector + get_pre_padding() const { + return pre_padding; + } + + std::vector + get_post_padding() const { + return post_padding; + } + + Conv_wgrad_attributes& + set_convolution_mode(ConvolutionMode_t mode_) { + math_mode = mode_; + ; + return *this; + } + + Conv_wgrad_attributes& + set_padding(std::vector value) { + pre_padding = value; + post_padding = value; + return *this; + } + + Conv_wgrad_attributes& + set_pre_padding(std::vector value) { + pre_padding = value; + return *this; + } + + Conv_wgrad_attributes& + set_post_padding(std::vector value) { + post_padding = value; + return *this; + } + + std::vector + get_stride() const { + return stride; + } + + Conv_wgrad_attributes& + set_stride(std::vector value) { + stride = value; + return *this; + } + + std::vector + get_dilation() const { + return dilation; + } + + Conv_wgrad_attributes& + set_dilation(std::vector value) { + dilation = value; + return *this; + } +}; + +class Slice_attributes : public Attributes { + friend class Attributes; + friend class SliceNode; + friend class INode; + + std::vector> slices; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Slice_attributes, name, compute_data_type, inputs, outputs, slices) + + Slice_attributes& + set_slices(std::vector> const value) { + slices = value; + return *this; + } + + int64_t + get_offset() const { + auto& input = inputs.at(input_names::X); + auto const input_stride = input->get_stride(); + + int64_t offset = 0; + + // Get number of elements to skip + for (size_t i = 0; i < slices.size(); ++i) { + offset += slices[i].first * input_stride[i]; + } + + // multiply by element size to get offset in bytes + offset *= detail::get_data_type_size(input->get_data_type()); + return offset; + } +}; + +class PagedCacheLoad_attributes : public Attributes { + friend class Attributes; + friend class PagedCacheLoadNode; + friend class INode; + + public: + enum class input_names { container, seqLen, pageTable }; + std::unordered_map> inputs; + enum class output_names { yOut }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(PagedCacheLoad_attributes, name, compute_data_type, inputs, outputs) +}; + +class Block_scale_quantize_attributes : public Attributes { + friend class Attributes; + friend class BlockScaleQuantizeNode; + friend class Graph; + + std::optional block_size; + std::optional axis; + bool transpose = false; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + enum class output_names { Y, scale }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Block_scale_quantize_attributes, + name, + compute_data_type, + inputs, + outputs, + block_size, + axis) + + Block_scale_quantize_attributes& + set_block_size(int32_t const value) { + block_size = value; + return *this; + } + + Block_scale_quantize_attributes& + set_axis(int64_t const value) { + axis = value; + return *this; + } + + Block_scale_quantize_attributes& + set_transpose(bool const value) { + transpose = value; + return *this; + } +}; + +class Block_scale_dequantize_attributes : public Attributes { + friend class Attributes; + friend class BlockScaleDequantizeNode; + friend class Graph; + + std::vector block_size; + bool is_negative_scale; + + public: + enum class input_names { X, scale }; + std::unordered_map> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Block_scale_dequantize_attributes, + name, + compute_data_type, + inputs, + outputs, + block_size, + is_negative_scale) + + Block_scale_dequantize_attributes& + set_block_size(int32_t const value, int32_t idx = 0) { + if (idx < 0) { + return *this; + } + if (static_cast(block_size.size()) < idx + 1) { + block_size.resize(idx + 1, 1); + } + block_size[idx] = value; + return *this; + } + + Block_scale_dequantize_attributes& + set_block_size(const int32_t* values, int32_t len = 1) { + if (len < 1) { + return *this; + } + if (static_cast(block_size.size()) < len) { + block_size.resize(len); + } + std::copy(values, values + len, block_size.begin()); + return *this; + } + + Block_scale_dequantize_attributes& + set_block_size(const std::vector& values) { + block_size = values; + return *this; + } + + bool + get_is_negative_scale() const { + return is_negative_scale; + } + + Block_scale_dequantize_attributes& + set_is_negative_scale(bool value) { + is_negative_scale = value; + return *this; + } +}; + +#if 0 +class Concatenate_string { + friend class Attributes; + friend class ConcatenateNode; + friend class Graph; +public: +std::string str; +NLOHMANN_DEFINE_TYPE_INTRUSIVE(Concatenate_string, str) +}; +#endif + +class Concatenate_attributes : public Attributes { + friend class Attributes; + friend class ConcatenateNode; + friend class Graph; + + std::optional axis; + std::optional in_place_index; + + public: + std::vector> inputs; + enum class output_names { Y }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Concatenate_attributes, name, inputs, outputs, axis, in_place_index) + + Concatenate_attributes& + set_axis(int64_t const value) { + axis = value; + return *this; + } + + Concatenate_attributes& + set_in_place_index(int64_t const value) { + in_place_index = value; + return *this; + } +}; + +class Moe_grouped_matmul_attributes : public Attributes { + friend class Attributes; + friend class MoeGroupedMatmulNode; + friend class Graph; + + MoeGroupedMatmulMode_t mode = MoeGroupedMatmulMode_t::NONE; + + int32_t top_k = 0; + + public: + enum class input_names { Token, Weight, FirstTokenOffset, TokenIndex, TokenKs }; + std::unordered_map> inputs; + enum class output_names { Output }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Moe_grouped_matmul_attributes, name, inputs, outputs, mode, top_k) + + Moe_grouped_matmul_attributes& + set_mode(MoeGroupedMatmulMode_t mode) { + this->mode = mode; + return *this; + } + + Moe_grouped_matmul_attributes& + set_top_k(int32_t top_k) { + this->top_k = top_k; + return *this; + } +}; + +} // namespace graph + +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/knobs.h b/third_party/cudnn-frontend/include/cudnn_frontend/knobs.h new file mode 100644 index 00000000..ee4ee4c4 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/knobs.h @@ -0,0 +1,228 @@ +#pragma once + +namespace cudnn_frontend { + +enum class KnobType_t { + NOT_SET, + + SWIZZLE, + TILE_SIZE, + EDGE, + MULTIPLY, + SPLIT_K_BUF, + TILEK, + STAGES, + REDUCTION_MODE, + SPLIT_K_SLC, + IDX_MODE, + SPECFILT, + KERNEL_CFG, + WORKSPACE, + TILE_CGA_M, + TILE_CGA_N, + BLOCK_SIZE, + OCCUPANCY, + ARRAY_SIZE_PER_THREAD, + SPLIT_COLS, + TILE_ROWS, + TILE_COLS, + LOAD_SIZE, + CTA_COUNT, + STREAM_K, + SPLIT_P_SLC, + TILE_M, + TILE_N, + WARP_SPEC_CFG, +}; + +class Knob { + public: + KnobType_t type = KnobType_t::NOT_SET; + int64_t maxValue = 0; + int64_t minValue = 0; + int64_t stride = 0; + + Knob(KnobType_t type, int64_t max, int64_t min, int64_t str) + : type(type), maxValue(max), minValue(min), stride(str) {} +}; + +static inline cudnnStatus_t +convert_to_backend_knob_type(KnobType_t const knob_type, cudnnBackendKnobType_t& cudnn_knob_type) { + switch (knob_type) { + case KnobType_t::SWIZZLE: + cudnn_knob_type = CUDNN_KNOB_TYPE_SWIZZLE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_SIZE: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_SIZE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::EDGE: + cudnn_knob_type = CUDNN_KNOB_TYPE_EDGE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::MULTIPLY: + cudnn_knob_type = CUDNN_KNOB_TYPE_MULTIPLY; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::SPLIT_K_BUF: + cudnn_knob_type = CUDNN_KNOB_TYPE_SPLIT_K_BUF; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILEK: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILEK; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::STAGES: + cudnn_knob_type = CUDNN_KNOB_TYPE_STAGES; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::REDUCTION_MODE: + cudnn_knob_type = CUDNN_KNOB_TYPE_REDUCTION_MODE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::SPLIT_K_SLC: + cudnn_knob_type = CUDNN_KNOB_TYPE_SPLIT_K_SLC; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::IDX_MODE: + cudnn_knob_type = CUDNN_KNOB_TYPE_IDX_MODE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::SPECFILT: + cudnn_knob_type = CUDNN_KNOB_TYPE_SPECFILT; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::KERNEL_CFG: + cudnn_knob_type = CUDNN_KNOB_TYPE_KERNEL_CFG; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::WORKSPACE: + cudnn_knob_type = CUDNN_KNOB_TYPE_WORKSPACE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#if (CUDNN_VERSION >= 8600) + case KnobType_t::TILE_CGA_M: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_CGA_M; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_CGA_N: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_CGA_N; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif +#if (CUDNN_VERSION >= 8800) + case KnobType_t::BLOCK_SIZE: + cudnn_knob_type = CUDNN_KNOB_TYPE_BLOCK_SIZE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif +#if (CUDNN_VERSION >= 8900) + case KnobType_t::OCCUPANCY: + cudnn_knob_type = CUDNN_KNOB_TYPE_OCCUPANCY; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::ARRAY_SIZE_PER_THREAD: + cudnn_knob_type = CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif +#if (CUDNN_VERSION >= 8905) + case KnobType_t::SPLIT_COLS: + cudnn_knob_type = CUDNN_KNOB_TYPE_SPLIT_COLS; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_ROWS: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_ROWS; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_COLS: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_COLS; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::LOAD_SIZE: + cudnn_knob_type = CUDNN_KNOB_TYPE_LOAD_SIZE; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif +#if (CUDNN_VERSION >= 90700) + case KnobType_t::CTA_COUNT: + cudnn_knob_type = CUDNN_KNOB_TYPE_CTA_COUNT; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::STREAM_K: + cudnn_knob_type = CUDNN_KNOB_TYPE_STREAM_K; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::SPLIT_P_SLC: + cudnn_knob_type = CUDNN_KNOB_TYPE_SPLIT_P_SLC; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_M: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_M; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::TILE_N: + cudnn_knob_type = CUDNN_KNOB_TYPE_TILE_N; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case KnobType_t::WARP_SPEC_CFG: + cudnn_knob_type = CUDNN_KNOB_TYPE_WARP_SPEC_CFG; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif +#ifndef NO_DEFAULT_IN_SWITCH + default: + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif + } + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +} + +inline KnobType_t +convert_from_backend_knob_type(cudnnBackendKnobType_t cudnn_knob_type) { + switch (cudnn_knob_type) { + case CUDNN_KNOB_TYPE_SWIZZLE: + return KnobType_t::SWIZZLE; + case CUDNN_KNOB_TYPE_TILE_SIZE: + return KnobType_t::TILE_SIZE; + case CUDNN_KNOB_TYPE_EDGE: + return KnobType_t::EDGE; + case CUDNN_KNOB_TYPE_MULTIPLY: + return KnobType_t::MULTIPLY; + case CUDNN_KNOB_TYPE_SPLIT_K_BUF: + return KnobType_t::SPLIT_K_BUF; + case CUDNN_KNOB_TYPE_TILEK: + return KnobType_t::TILEK; + case CUDNN_KNOB_TYPE_STAGES: + return KnobType_t::STAGES; + case CUDNN_KNOB_TYPE_REDUCTION_MODE: + return KnobType_t::REDUCTION_MODE; + case CUDNN_KNOB_TYPE_SPLIT_K_SLC: + return KnobType_t::SPLIT_K_SLC; + case CUDNN_KNOB_TYPE_IDX_MODE: + return KnobType_t::IDX_MODE; + case CUDNN_KNOB_TYPE_SPECFILT: + return KnobType_t::SPECFILT; + case CUDNN_KNOB_TYPE_KERNEL_CFG: + return KnobType_t::KERNEL_CFG; + case CUDNN_KNOB_TYPE_WORKSPACE: + return KnobType_t::WORKSPACE; +#if (CUDNN_VERSION >= 8600) + case CUDNN_KNOB_TYPE_TILE_CGA_M: + return KnobType_t::TILE_CGA_M; + case CUDNN_KNOB_TYPE_TILE_CGA_N: + return KnobType_t::TILE_CGA_N; +#endif +#if (CUDNN_VERSION >= 8800) + case CUDNN_KNOB_TYPE_BLOCK_SIZE: + return KnobType_t::BLOCK_SIZE; +#endif +#if (CUDNN_VERSION >= 8900) + case CUDNN_KNOB_TYPE_OCCUPANCY: + return KnobType_t::OCCUPANCY; + case CUDNN_KNOB_TYPE_ARRAY_SIZE_PER_THREAD: + return KnobType_t::ARRAY_SIZE_PER_THREAD; +#endif +#if (CUDNN_VERSION >= 8905) + case CUDNN_KNOB_TYPE_SPLIT_COLS: + return KnobType_t::SPLIT_COLS; + case CUDNN_KNOB_TYPE_TILE_ROWS: + return KnobType_t::TILE_ROWS; + case CUDNN_KNOB_TYPE_TILE_COLS: + return KnobType_t::TILE_COLS; + case CUDNN_KNOB_TYPE_LOAD_SIZE: + return KnobType_t::LOAD_SIZE; +#endif +#if (CUDNN_VERSION >= 90700) + case CUDNN_KNOB_TYPE_CTA_COUNT: + return KnobType_t::CTA_COUNT; + case CUDNN_KNOB_TYPE_STREAM_K: + return KnobType_t::STREAM_K; + case CUDNN_KNOB_TYPE_SPLIT_P_SLC: + return KnobType_t::SPLIT_P_SLC; + case CUDNN_KNOB_TYPE_TILE_M: + return KnobType_t::TILE_M; + case CUDNN_KNOB_TYPE_TILE_N: + return KnobType_t::TILE_N; + case CUDNN_KNOB_TYPE_WARP_SPEC_CFG: + return KnobType_t::WARP_SPEC_CFG; +#endif + default: + return KnobType_t::NOT_SET; + } +} + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/adaptive_layernorm.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/adaptive_layernorm.h new file mode 100644 index 00000000..86eaa6d1 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/adaptive_layernorm.h @@ -0,0 +1,454 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class AdaLayerNormNode : public NodeCRTP { + public: + AdaLayernorm_attributes attributes; + + AdaLayerNormNode(AdaLayernorm_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::ADALAYERNORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for adalayernorm node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[AdaLayernorm_attributes::input_names::X]; + auto Y = attributes.outputs[AdaLayernorm_attributes::output_names::Y]; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + // scale_bias dim is 1,c,h,w + auto scale_bias_dim = X->get_dim(); + scale_bias_dim[0] = 1; + + auto scale = attributes.inputs[AdaLayernorm_attributes::input_names::SCALE]; + // Only infer dims and strides if user did not set them + if (scale->get_dim().empty()) { + scale->set_dim(scale_bias_dim); + } + if (scale->get_stride().empty()) { + auto const& scale_dim = scale->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), scale_dim.size(), stride_order)); + scale->set_stride(detail::generate_stride(scale_dim, stride_order)); + } + + auto bias = attributes.inputs[AdaLayernorm_attributes::input_names::BIAS]; + // Only infer dims and strides if user did not set them + if (bias->get_dim().empty()) { + bias->set_dim(scale_bias_dim); + } + if (bias->get_stride().empty()) { + auto const& bias_dim = bias->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), bias_dim.size(), stride_order)); + bias->set_stride(detail::generate_stride(bias_dim, stride_order)); + } + + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + // stats dim is x where scale == 1 else 1 + auto stats_dim = X->get_dim(); + for (size_t i = 1; i < stats_dim.size(); i++) { + if (scale->get_dim()[i] != 1) { + stats_dim[i] = 1; + } + } + + auto mean = attributes.outputs[AdaLayernorm_attributes::output_names::MEAN]; + // Only infer dims and strides if user did not set them + if (mean->get_dim().empty()) { + mean->set_dim(stats_dim); + } + if (mean->get_stride().empty()) { + auto const& mean_dim = mean->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), mean_dim.size(), stride_order)); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); + } + + auto inv_var = attributes.outputs[AdaLayernorm_attributes::output_names::INV_VARIANCE]; + // Only infer dims and strides if user did not set them + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), inv_var_dim.size(), stride_order)); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + + // Set scalar tensors + std::vector ones(X->get_dim().size(), 1); + auto infer_scalar_tensors = [&ones](std::shared_ptr& T) { + // Only infer dims and strides if user did not set them + if (T->get_dim().empty()) { + T->set_dim(ones); + } + if (T->get_stride().empty()) { + T->set_stride(ones); + } + }; + infer_scalar_tensors(attributes.inputs[AdaLayernorm_attributes::input_names::EPSILON]); + + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << "Validating AdaLayerNormNode " << attributes.name); + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of adalayernorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building AdaLayernorm operations " << attributes.name << std::endl; + + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "AdaLN fwd requires cuDNN v9.9.0"}; +#if (CUDNN_VERSION >= 90900) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90900, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + auto adalayernorm_operation = + make_shared_backend_pointer((cudnnBackendDescriptorType_t)CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR); + + cudnnBackendNormMode_t cudnn_norm_mode; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::ADA_LAYER_NORM, cudnn_norm_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + auto X = attributes.inputs.find(AdaLayernorm_attributes::input_names::X)->second; + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + auto Scale = attributes.inputs.find(AdaLayernorm_attributes::input_names::SCALE)->second; + auto backend_scale = tensors[Scale->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_scale)); + + auto Bias_iter = attributes.inputs.find(AdaLayernorm_attributes::input_names::BIAS); + if (Bias_iter != attributes.inputs.end() && Bias_iter->second->get_is_virtual() == false) { + auto backend_bias = tensors[Bias_iter->second->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_bias)); + } + + auto Epsilon = attributes.inputs.find(AdaLayernorm_attributes::input_names::EPSILON)->second; + auto backend_epsilon = tensors[Epsilon->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_epsilon)); + + auto Y = attributes.outputs.find(AdaLayernorm_attributes::output_names::Y)->second; + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + auto Mean = attributes.outputs.find(AdaLayernorm_attributes::output_names::MEAN)->second; + auto backend_mean = tensors[Mean->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_mean)); + + auto Inv_variance = attributes.outputs.find(AdaLayernorm_attributes::output_names::INV_VARIANCE)->second; + auto backend_inv_variance = tensors[Inv_variance->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_inv_variance)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(adalayernorm_operation->get_backend_descriptor())); + + raw_operations.push_back(adalayernorm_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "ADA_LAYER_NORM"})"_json); + } +#endif +}; + +/*******/ + +class DAdaLayerNormNode : public NodeCRTP { + public: + AdaLayernorm_backward_attributes attributes; + + DAdaLayerNormNode(AdaLayernorm_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DADALAYERNORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for DAdaLayerNorm node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = attributes.inputs[AdaLayernorm_backward_attributes::input_names::X]; + auto const x_tensor_dim = X->get_dim(); + + auto DY = attributes.inputs[AdaLayernorm_backward_attributes::input_names::DY]; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NCHW + auto const& stride_order = detail::generate_row_major_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = attributes.outputs[AdaLayernorm_backward_attributes::output_names::DX]; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NCHW + auto const& stride_order = detail::generate_row_major_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + auto SCALE = attributes.inputs[AdaLayernorm_backward_attributes::input_names::SCALE]; + auto scale_bias_dim = SCALE->get_dim(); + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NCHW + auto const& stride_order = detail::generate_row_major_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(attributes.outputs[AdaLayernorm_backward_attributes::output_names::DSCALE]); + auto DBIAS = attributes.outputs.at(AdaLayernorm_backward_attributes::output_names::DBIAS); + if (DBIAS->get_is_virtual() == false) { + infer_scale_bias_tensors(DBIAS); + } + + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating DAdaLayerNormNode node " << attributes.name); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building DAdaLayerNormNode operations " << attributes.name + << std::endl; + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "AdaLN bwd requires cuDNN v9.9.0"}; +#if (CUDNN_VERSION >= 90900) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90900, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + auto adalayernorm_operation = + make_shared_backend_pointer((cudnnBackendDescriptorType_t)CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR); + + cudnnBackendNormMode_t cudnn_norm_mode; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::ADA_LAYER_NORM, cudnn_norm_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + auto X = attributes.inputs.find(AdaLayernorm_backward_attributes::input_names::X)->second; + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + auto Mean = attributes.inputs.find(AdaLayernorm_backward_attributes::input_names::MEAN)->second; + auto backend_mean = tensors[Mean->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_mean)); + + auto Inv_variance = attributes.inputs.find(AdaLayernorm_backward_attributes::input_names::INV_VARIANCE)->second; + auto backend_inv_variance = tensors[Inv_variance->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_inv_variance)); + + auto Dy = attributes.inputs.find(AdaLayernorm_backward_attributes::input_names::DY)->second; + auto backend_dy = tensors[Dy->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_dy)); + + auto Scale = attributes.inputs.find(AdaLayernorm_backward_attributes::input_names::SCALE)->second; + auto backend_scale = tensors[Scale->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_scale)); + + auto Dx = attributes.outputs.find(AdaLayernorm_backward_attributes::output_names::DX)->second; + auto backend_dx = tensors[Dx->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_dx)); + + auto Dscale = attributes.outputs.find(AdaLayernorm_backward_attributes::output_names::DSCALE)->second; + auto backend_dscale = tensors[Dscale->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_dscale)); + + auto Dbias_iter = attributes.outputs.find(AdaLayernorm_backward_attributes::output_names::DBIAS); + if (Dbias_iter != attributes.outputs.end() && Dbias_iter->second->get_is_virtual() == false) { + auto backend_dbias = tensors[Dbias_iter->second->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(adalayernorm_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_dbias)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(adalayernorm_operation->get_backend_descriptor())); + + raw_operations.push_back(adalayernorm_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "ADA_LAYER_NORM_BPROP"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm.h new file mode 100644 index 00000000..69d232be --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm.h @@ -0,0 +1,268 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class BatchNormNode : public NodeCRTP { + public: + Batchnorm_attributes attributes; + + BatchNormNode(Batchnorm_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::BATCHNORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for batchnorm node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Batchnorm_attributes::input_names::X]; + auto Y = attributes.outputs[Batchnorm_attributes::output_names::Y]; + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + // Set channel length tensors + auto const x_tensor_dim = X->get_dim(); + auto infer_per_channel_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim.resize(x_tensor_dim.size(), 1); + tensor_dim[1] = x_tensor_dim[1]; + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_per_channel_tensors(attributes.outputs[Batchnorm_attributes::output_names::MEAN]); + infer_per_channel_tensors(attributes.outputs[Batchnorm_attributes::output_names::INV_VARIANCE]); + + auto has_running_stats = attributes.inputs[Batchnorm_attributes::input_names::PREV_RUNNING_MEAN] || + attributes.inputs[Batchnorm_attributes::input_names::PREV_RUNNING_VAR]; + + if (has_running_stats) { + infer_per_channel_tensors(attributes.outputs[Batchnorm_attributes::output_names::NEXT_RUNNING_MEAN]); + infer_per_channel_tensors(attributes.outputs[Batchnorm_attributes::output_names::NEXT_RUNNING_VAR]); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building BatchNormNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 batchnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + batchnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase to TRAINING + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormFwdPhase_t::TRAINING, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set saved mean and inv_variance + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Batchnorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Batchnorm_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set scale and bias tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Batchnorm_attributes::input_names::BIAS); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Batchnorm_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Check for running stats + bool has_running_stats = true; + auto it = attributes.inputs.find(Batchnorm_attributes::input_names::PREV_RUNNING_MEAN); + if (it == attributes.inputs.end() || it->second == nullptr) { + has_running_stats = false; + } + + if (has_running_stats) { + // Set momentum (exp decay factor) + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MOMENTUM, Batchnorm_attributes::input_names::MOMENTUM); + auto momentum_desc = tensors.at(MOMENTUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EXP_AVG_FACTOR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &momentum_desc)); + + // Set prev running mean and var + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_MEAN, + Batchnorm_attributes::input_names::PREV_RUNNING_MEAN); + auto prev_mean_desc = tensors.at(PREV_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_VAR, + Batchnorm_attributes::input_names::PREV_RUNNING_VAR); + auto prev_var_desc = tensors.at(PREV_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INPUT_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_var_desc)); + + // Set next running mean and var + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_MEAN, + Batchnorm_attributes::output_names::NEXT_RUNNING_MEAN); + auto next_mean_desc = tensors.at(NEXT_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_VAR, + Batchnorm_attributes::output_names::NEXT_RUNNING_VAR); + auto next_var_desc = tensors.at(NEXT_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_OUTPUT_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_var_desc)); + } + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Batchnorm_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set peer stat tensors if any + if (!attributes.peer_stats.empty()) { + std::vector peer_stat_descs; + for (auto const& peer_stat : attributes.peer_stats) { + peer_stat_descs.push_back(tensors.at(peer_stat->get_uid())->get_raw_desc()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PEER_STAT_DESCS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + peer_stat_descs.size(), + peer_stat_descs.data())); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(batchnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(batchnorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "BATCHNORM"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm_inference.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm_inference.h new file mode 100644 index 00000000..9fb433f9 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/batchnorm_inference.h @@ -0,0 +1,156 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class BatchnormInferenceNode : public NodeCRTP { + public: + Batchnorm_inference_attributes attributes; + + BatchnormInferenceNode(Batchnorm_inference_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::BATCHNORM_INFERENCE; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for batchnorm inference node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Batchnorm_inference_attributes::input_names::X]; + auto Y = attributes.outputs[Batchnorm_inference_attributes::output_names::Y]; + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building BatchnormInferenceNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 batchnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + batchnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase to INFERENCE + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormFwdPhase_t::INFERENCE, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_inference_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set mean and inv_variance (as inputs for inference) + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Batchnorm_inference_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, + Batchnorm_inference_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set scale and bias tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_inference_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Batchnorm_inference_attributes::input_names::BIAS); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Batchnorm_inference_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(batchnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(batchnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(batchnorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "BATCHNORM_INFERENCE"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_dequantize.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_dequantize.h new file mode 100644 index 00000000..bbaa2bf0 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_dequantize.h @@ -0,0 +1,169 @@ +#pragma once + +#include "../../cudnn_frontend_Logging.h" +#include "../../cudnn_frontend_shim.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class BlockScaleDequantizeNode : public NodeCRTP { + public: + Block_scale_dequantize_attributes attributes; + + BlockScaleDequantizeNode(Block_scale_dequantize_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::BLOCK_SCALE_DEQUANTIZE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating BlockScaleDequantizeNode " << attributes.name << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.block_size.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Block size not set\n"); + + auto Y = attributes.outputs.at(Block_scale_dequantize_attributes::output_names::Y); + + RETURN_CUDNN_FRONTEND_ERROR_IF(!(Y->get_is_virtual()), + error_code_t::INVALID_VALUE, + "Output tensor of dequantize node should be virtual\n"); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for BlockScaleDequantizeNode " << attributes.name + << std::endl; + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Block_scale_dequantize_attributes::input_names::X]; + auto scale = attributes.inputs[Block_scale_dequantize_attributes::input_names::scale]; + auto Y = attributes.outputs[Block_scale_dequantize_attributes::output_names::Y]; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building BlockScaleDequantizeNode operations " << attributes.name << std::endl; + auto cudnn_ver_error = + error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Block scale dequantize requires cuDNN v9.7.0"}; + +#if (CUDNN_VERSION >= 90700) // TODO: v9.99 is new feature branch; switch to release branch when ready + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90700, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + auto block_scale_dequantize_operation = make_shared_backend_pointer( + (cudnnBackendDescriptorType_t)CUDNN_BACKEND_OPERATION_BLOCK_SCALE_DEQUANTIZE_DESCRIPTOR); + + auto X = attributes.inputs.find(Block_scale_dequantize_attributes::input_names::X)->second; + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + auto scale = attributes.inputs.find(Block_scale_dequantize_attributes::input_names::scale)->second; + auto backend_scale = tensors[scale->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_scale)); + + auto Y = attributes.outputs.find(Block_scale_dequantize_attributes::output_names::Y)->second; + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + cudnnDataType_t cudnn_data_type; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + const int32_t* block_size = attributes.block_size.data(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_BLOCK_SIZE, + CUDNN_TYPE_INT32, + attributes.block_size.size(), + block_size)); + +#if (CUDNN_VERSION >= 91400) + if (detail::get_backend_version() >= 91400) { + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_dequantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_DEQUANTIZE_NEG_SCALE, + CUDNN_TYPE_BOOLEAN, + 1, + &attributes.is_negative_scale)); + } +#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(block_scale_dequantize_operation->get_backend_descriptor())); + + raw_operations.push_back(block_scale_dequantize_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "BLOCK_SCALE_DEQUANTIZE"})"_json); + } +#endif +}; + +inline void +INode::block_scale_dequantize(std::shared_ptr x, + std::shared_ptr scale, + Block_scale_dequantize_attributes attributes, + std::shared_ptr y) { + attributes.inputs[Block_scale_dequantize_attributes::input_names::X] = x; + attributes.inputs[Block_scale_dequantize_attributes::input_names::scale] = scale; + attributes.outputs[Block_scale_dequantize_attributes::output_names::Y] = y; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_quantize.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_quantize.h new file mode 100644 index 00000000..aba536fc --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/block_scale_quantize.h @@ -0,0 +1,205 @@ +#pragma once + +#include "../../cudnn_frontend_Logging.h" +#include "../../cudnn_frontend_shim.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class BlockScaleQuantizeNode : public NodeCRTP { + public: + Block_scale_quantize_attributes attributes; + + BlockScaleQuantizeNode(Block_scale_quantize_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::BLOCK_SCALE_QUANTIZE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Validating BlockScaleQuantizeNode " << attributes.name + << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF( + !attributes.block_size.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "Block size not set."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for BlockScaleQuantizeNode " << attributes.name + << std::endl; + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Block_scale_quantize_attributes::input_names::X]; + auto Y = attributes.outputs[Block_scale_quantize_attributes::output_names::Y]; + auto scale = attributes.outputs[Block_scale_quantize_attributes::output_names::scale]; + + // Block scale quantize requires the block scale axis to be packed + auto infer_strides_transposed = [&X](std::shared_ptr& T, + std::optional const& axis) { + auto const& dim = T->get_dim(); + auto const& X_dim = X->get_dim(); + auto const& X_stride = X->get_stride(); + + std::vector indices(X_stride.size()); + std::iota(indices.begin(), indices.end(), 0); + // Sort indices based on stride values in descending order + std::sort(indices.begin(), indices.end(), [&X_dim, &X_stride](int64_t i, int64_t j) { + // Prioritize singleton dimensions + if (X_stride[i] == X_stride[j]) { + return (X_dim[i] == 1) || (X_dim[j] != 1); + } + return X_stride[i] < X_stride[j]; + }); + if (axis) { + // Rotate left until the axis is the packed dim + std::rotate(indices.begin(), std::find(indices.begin(), indices.end(), axis.value()), indices.end()); + } + std::vector stride_order(X_stride.size()); + for (size_t i = 0; i < indices.size(); ++i) { + stride_order[indices[i]] = i; + } + T->set_stride(detail::generate_stride(dim, stride_order)); + }; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + if (attributes.transpose) { + infer_strides_transposed(Y, attributes.axis); + } else { + Y->set_stride(X->get_stride()); + } + } + + // Only infer dims and strides if user did not set them + if (scale->get_dim().empty()) { + auto scale_dim = X->get_dim(); + if (attributes.axis) { + scale_dim[attributes.axis.value()] /= attributes.block_size.value(); + } else { + scale_dim.back() /= attributes.block_size.value(); + } + scale->set_dim(scale_dim); + } + if (scale->get_stride().empty()) { + if (attributes.transpose) { + infer_strides_transposed(scale, attributes.axis); + } else { + auto const& scale_dim = scale->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), scale_dim.size(), stride_order)); + scale->set_stride(detail::generate_stride(scale_dim, stride_order)); + } + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building BlockScaleQuantizeNode operations " << attributes.name + << std::endl; + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Block scale quantize requires cuDNN v9.7.0"}; + +#if (CUDNN_VERSION >= 90700) // TODO: v9.99 is new feature branch; switch to release branch when ready + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90700, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + auto block_scale_quantize_operation = make_shared_backend_pointer( + (cudnnBackendDescriptorType_t)CUDNN_BACKEND_OPERATION_BLOCK_SCALE_QUANTIZE_DESCRIPTOR); + + auto X = attributes.inputs.find(Block_scale_quantize_attributes::input_names::X)->second; + auto backend_x = tensors[X->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_quantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_QUANTIZE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_x)); + + auto Y = attributes.outputs.find(Block_scale_quantize_attributes::output_names::Y)->second; + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_quantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_QUANTIZE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + auto scale = attributes.outputs.find(Block_scale_quantize_attributes::output_names::scale)->second; + auto backend_scale = tensors[scale->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_quantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_QUANTIZE_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_scale)); + + cudnnDataType_t cudnn_data_type; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_quantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_QUANTIZE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + int32_t block_size = attributes.block_size.value(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(block_scale_quantize_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_BLOCK_SCALE_QUANTIZE_BLOCK_SIZE, + CUDNN_TYPE_INT32, + 1, + &block_size)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(block_scale_quantize_operation->get_backend_descriptor())); + + raw_operations.push_back(block_scale_quantize_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "BLOCK_SCALE_QUANTIZE"})"_json); + } +#endif +}; + +inline void +INode::block_scale_quantize(std::shared_ptr x, + Block_scale_quantize_attributes attributes, + std::shared_ptr y, + std::shared_ptr scale) { + attributes.inputs[Block_scale_quantize_attributes::input_names::X] = x; + attributes.outputs[Block_scale_quantize_attributes::output_names::Y] = y; + attributes.outputs[Block_scale_quantize_attributes::output_names::scale] = scale; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/bn_finalize.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/bn_finalize.h new file mode 100644 index 00000000..99d3ec0c --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/bn_finalize.h @@ -0,0 +1,263 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class BatchNormFinalizeNode : public NodeCRTP { + public: + BN_finalize_attributes attributes; + + BatchNormFinalizeNode(BN_finalize_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::BN_FINALIZE; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO:Inferencing properties for batchnorm finalize node " << attributes.name); + + attributes.fill_from_context(context); + + auto SUM = attributes.inputs[BN_finalize_attributes::input_names::SUM]; + auto const sum_tensor_dim = SUM->get_dim(); + + // Set channel length tensors + auto infer_per_channel_tensors = [&sum_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim = sum_tensor_dim; + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::EQ_BIAS]); + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::EQ_SCALE]); + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::MEAN]); + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::INV_VARIANCE]); + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::NEXT_RUNNING_MEAN]); + infer_per_channel_tensors(attributes.outputs[BN_finalize_attributes::output_names::NEXT_RUNNING_VAR]); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building BatchNormFinalizeNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 bn_finalize_operation; + + _CUDNN_CHECK_CUDNN_ERROR(bn_finalize_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR)); + + // Set BN finalize mode + cudnnBnFinalizeStatsMode_t bn_finalize_mode = CUDNN_BN_FINALIZE_STATISTICS_TRAINING; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE, + CUDNN_TYPE_BN_FINALIZE_STATS_MODE, + 1, + &bn_finalize_mode)); + + // Set compute type (math precision) + cudnnDataType_t compute_type = CUDNN_DATA_FLOAT; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &compute_type)); + + // Set SUM input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SUM, BN_finalize_attributes::input_names::SUM); + auto sum_desc = tensors.at(SUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sum_desc)); + + // Set SQ_SUM input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SQ_SUM, BN_finalize_attributes::input_names::SQ_SUM); + auto sq_sum_desc = tensors.at(SQ_SUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sq_sum_desc)); + + // Set SCALE input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, BN_finalize_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set BIAS input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, BN_finalize_attributes::input_names::BIAS); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set EQ_SCALE output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE, BN_finalize_attributes::output_names::EQ_SCALE); + auto eq_scale_desc = tensors.at(EQ_SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_desc)); + + // Set EQ_BIAS output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, BN_finalize_attributes::output_names::EQ_BIAS); + auto eq_bias_desc = tensors.at(EQ_BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_bias_desc)); + + // Set PREV_RUNNING_MEAN input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_MEAN, + BN_finalize_attributes::input_names::PREV_RUNNING_MEAN); + auto prev_running_mean_desc = tensors.at(PREV_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_running_mean_desc)); + + // Set PREV_RUNNING_VAR input tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(PREV_RUNNING_VAR, + BN_finalize_attributes::input_names::PREV_RUNNING_VAR); + auto prev_running_var_desc = tensors.at(PREV_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &prev_running_var_desc)); + + // Set NEXT_RUNNING_MEAN output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_MEAN, + BN_finalize_attributes::output_names::NEXT_RUNNING_MEAN); + auto next_running_mean_desc = tensors.at(NEXT_RUNNING_MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_running_mean_desc)); + + // Set NEXT_RUNNING_VAR output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(NEXT_RUNNING_VAR, + BN_finalize_attributes::output_names::NEXT_RUNNING_VAR); + auto next_running_var_desc = tensors.at(NEXT_RUNNING_VAR->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &next_running_var_desc)); + + // Set MEAN output tensor (saved mean) + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, BN_finalize_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + // Set INV_VARIANCE output tensor (saved inv std) + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, BN_finalize_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set EPSILON tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, BN_finalize_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Set MOMENTUM tensor (exp average factor) + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MOMENTUM, BN_finalize_attributes::input_names::MOMENTUM); + auto momentum_desc = tensors.at(MOMENTUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &momentum_desc)); + + // Set ACCUM_COUNT tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(ACCUM_COUNT, BN_finalize_attributes::input_names::ACCUM_COUNT); + auto accum_count_desc = tensors.at(ACCUM_COUNT->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_finalize_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &accum_count_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(bn_finalize_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(bn_finalize_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "BN_FINALIZE"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/concatenate.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/concatenate.h new file mode 100644 index 00000000..0c051186 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/concatenate.h @@ -0,0 +1,164 @@ +#pragma once + +#include "../../cudnn_frontend_Logging.h" +#include "../../cudnn_frontend_shim.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" +#include + +namespace cudnn_frontend { + +namespace graph { + +class ConcatenateNode : public NodeCRTP { + public: + Concatenate_attributes attributes; + + ConcatenateNode(Concatenate_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::CONCATENATE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Validating ConcatenateNode " << attributes.name << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF(!attributes.axis.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "Axis not set\n"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + !attributes.in_place_index.has_value(), error_code_t::ATTRIBUTE_NOT_SET, "In-place index not set\n"); + + auto X = attributes.inputs; + + RETURN_CUDNN_FRONTEND_ERROR_IF( + (X.size() == 0), error_code_t::INVALID_VALUE, "Input size of the concatenate node cannot be zero\n"); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferring properties for ConcatenateNode " << attributes.name + << std::endl; + + attributes.fill_from_context(context); + + auto Y = attributes.outputs[Concatenate_attributes::output_names::Y]; + + // Infer dims and strides only if user did not set them + int64_t dim_sum = 0; + for (const auto& input : attributes.inputs) { + dim_sum += input->get_dim()[attributes.axis.value()]; + } + + auto X = attributes.inputs[0]; + auto dims = X->get_dim(); + dims[attributes.axis.value()] = dim_sum; + + if (Y->get_dim().empty()) { + Y->set_dim(dims); + Y->set_dim(dims); + } + + if (Y->get_stride().empty()) { + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), dims.size(), stride_order)); + Y->set_stride(detail::generate_stride(dims, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " << "Building ConcatenateNode operations " << attributes.name + << std::endl; + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Concatenate requires cuDNN v9.7.0"}; + +#if (CUDNN_VERSION >= 90700) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90700, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + auto concatenate_operation = make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_CONCAT_DESCRIPTOR); + + std::vector backend_x(attributes.inputs.size()); + size_t index = 0; + for (const auto& input : attributes.inputs) { + backend_x[index] = tensors[input->get_uid()]->get_desc()->get_backend_descriptor(); + index++; + } + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_CONCAT_INPUT_DESCS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + attributes.inputs.size(), + backend_x.data())); + + auto Y = attributes.outputs.find(Concatenate_attributes::output_names::Y)->second; + auto backend_y = tensors[Y->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_CONCAT_OUTPUT_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_y)); + + int64_t axis = attributes.axis.value(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_CONCAT_AXIS, + CUDNN_TYPE_INT64, + 1, + &axis)); + + int64_t in_place_index = attributes.in_place_index.value(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(concatenate_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_CONCAT_INPLACE_INDEX, + CUDNN_TYPE_INT64, + 1, + &in_place_index)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(concatenate_operation->get_backend_descriptor())); + + raw_operations.push_back(concatenate_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "CONCATENATE"})"_json); + } +#endif +}; + +inline void +INode::concatenate(std::vector> x, + Concatenate_attributes attributes, + std::shared_ptr y) { + for (auto& element : x) { + attributes.inputs.push_back(element); + } + attributes.outputs[Concatenate_attributes::output_names::Y] = y; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +} // namespace graph + +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_dgrad.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_dgrad.h new file mode 100644 index 00000000..a2e2e4e8 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_dgrad.h @@ -0,0 +1,205 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class DgradNode : public NodeCRTP { + public: + Conv_dgrad_attributes attributes; + + DgradNode(Conv_dgrad_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DGRAD; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating Node Type::DGRAD " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_pre_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Pre padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_post_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Post padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_stride().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv strides not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_dilation().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv dilation not set."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for dgrad node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferrencing from (X, DY) -> DW works today. + auto DX = attributes.outputs.find(Conv_dgrad_attributes::output_names::DX)->second; + auto W = attributes.inputs.find(Conv_dgrad_attributes::input_names::W)->second; + auto DY = attributes.inputs.find(Conv_dgrad_attributes::input_names::DY)->second; + + auto const w_tensor_dim = W->get_dim(); + auto const dy_tensor_dim = DY->get_dim(); + auto dx_tensor_dim = DX->get_dim(); + + RETURN_CUDNN_FRONTEND_ERROR_IF(DX->get_dim().empty(), + error_code_t::ATTRIBUTE_NOT_SET, + "For dgrad node, output dimension inferencing is not possible."); + + // No dim inferencing as inverse mapping from DY, W to DX is not unique. + // Only infer strides if user did not set them + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building DgradNode operations " << attributes.name << " "); + + // Create dgrad descriptor by directly calling cuDNN backend API + ConvDesc_v8 dgrad_descriptor; + int64_t const spatial_dim_count = attributes.get_pre_padding().size(); + + _CUDNN_CHECK_CUDNN_ERROR( + dgrad_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + dgrad_descriptor.get_raw_desc(), CUDNN_ATTR_CONVOLUTION_CONV_MODE, CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dgrad_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(dgrad_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 dgrad_operation; + + _CUDNN_CHECK_CUDNN_ERROR(dgrad_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Conv_dgrad_attributes::output_names::DX); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(W, Conv_dgrad_attributes::input_names::W); + auto w_desc = tensors.at(W->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &w_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Conv_dgrad_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + auto conv_desc_ptr = dgrad_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dgrad_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dgrad_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "CONV_DGRAD"})"_json); + } +#endif +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_fprop.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_fprop.h new file mode 100644 index 00000000..f8f22ec7 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_fprop.h @@ -0,0 +1,237 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { +class ConvolutionNode : public NodeCRTP { + public: + Conv_fprop_attributes attributes; + + ConvolutionNode(Conv_fprop_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::CONVOLUTION; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating Node Type::CONVOLUTION " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_pre_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Pre padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_post_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Post padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_stride().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv strides not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_dilation().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv dilation not set."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for conv node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferrencing from (X, W) -> Y works today. + auto& X = attributes.inputs.find(Conv_fprop_attributes::input_names::X)->second; + auto& W = attributes.inputs.find(Conv_fprop_attributes::input_names::W)->second; + auto& Y = attributes.outputs.find(Conv_fprop_attributes::output_names::Y)->second; + + auto const x_tensor_dim = X->get_dim(); + auto const w_tensor_dim = W->get_dim(); + auto y_tensor_dim = Y->get_dim(); + + // Only infer dims and strides if user did not set them + if (y_tensor_dim.empty()) { + y_tensor_dim.resize(x_tensor_dim.size()); + auto const& pre_padding = attributes.get_pre_padding(); + auto const& post_padding = attributes.get_post_padding(); + auto const& stride = attributes.get_stride(); + auto const& dilation = attributes.get_dilation(); + // N + y_tensor_dim[0] = x_tensor_dim[0]; + // PQ + for (size_t dim = 2; dim < x_tensor_dim.size(); ++dim) { + y_tensor_dim[dim] = 1 + (x_tensor_dim[dim] - dilation[dim - 2] * (w_tensor_dim[dim] - 1) - 1 + + pre_padding[dim - 2] + post_padding[dim - 2]) / + stride[dim - 2]; + } + // K + y_tensor_dim[1] = w_tensor_dim[0]; + Y->set_dim(y_tensor_dim); + } + if (Y->get_stride().empty()) { + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building ConvolutionNode operations " << attributes.name << " "); + + // Create convolution descriptor by directly calling cuDNN backend API + ConvDesc_v8 convolution_descriptor; + int64_t const spatial_dim_count = attributes.get_pre_padding().size(); + + _CUDNN_CHECK_CUDNN_ERROR( + convolution_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set convolution mode + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_CONV_MODE, + CUDNN_TYPE_CONVOLUTION_MODE, + 1, + &mode)); + + // Set spatial dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + // Set pre-padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + // Set post-padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + // Set dilation + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + // Set strides + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(convolution_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(convolution_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 convolution_operation; + + _CUDNN_CHECK_CUDNN_ERROR(convolution_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Conv_fprop_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set weight tensor W + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(W, Conv_fprop_attributes::input_names::W); + auto w_desc = tensors.at(W->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &w_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Conv_fprop_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set convolution descriptor + auto conv_desc_ptr = convolution_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + // Set alpha and beta + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(convolution_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(convolution_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(convolution_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "CONV_FPROP"})"_json); + } +#endif +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_wgrad.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_wgrad.h new file mode 100644 index 00000000..2f9b478c --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/conv_wgrad.h @@ -0,0 +1,201 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class WgradNode : public NodeCRTP { + public: + Conv_wgrad_attributes attributes; + + WgradNode(Conv_wgrad_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::WGRAD; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating Node Type::WGRAD " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_pre_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Pre padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_post_padding().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Post padding not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_stride().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv strides not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.get_dilation().empty(), error_code_t::ATTRIBUTE_NOT_SET, "Conv dilation not set."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for conv node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferrencing from (X, DY) -> DW works today. + auto X = attributes.inputs[Conv_wgrad_attributes::input_names::X]; + auto DW = attributes.outputs[Conv_wgrad_attributes::output_names::DW]; + auto DY = attributes.inputs[Conv_wgrad_attributes::input_names::DY]; + + auto const x_tensor_dim = X->get_dim(); + auto const dy_tensor_dim = DY->get_dim(); + auto dw_tensor_dim = DW->get_dim(); + + // No dim inferencing as inverse mapping from DY, X to DX is not unique. + // Only infer strides if user did not set them + if (DW->get_stride().empty()) { + auto const& DW_dim = DW->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DW_dim.size()); + DW->set_stride(detail::generate_stride(DW_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building WgradNode operations " << attributes.name << " "); + + // Create wgrad descriptor by directly calling cuDNN backend API + ConvDesc_v8 wgrad_descriptor; + int64_t const spatial_dim_count = attributes.get_pre_padding().size(); + + _CUDNN_CHECK_CUDNN_ERROR( + wgrad_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR)); + + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + cudnnConvolutionMode_t mode = detail::convert_to_cudnn_type(attributes.math_mode); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + wgrad_descriptor.get_raw_desc(), CUDNN_ATTR_CONVOLUTION_CONV_MODE, CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &spatial_dim_count)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_pre_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_post_padding().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_dilation().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_descriptor.get_raw_desc(), + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, + spatial_dim_count, + attributes.get_stride().data())); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(wgrad_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(wgrad_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 wgrad_operation; + + _CUDNN_CHECK_CUDNN_ERROR(wgrad_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Conv_wgrad_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Conv_wgrad_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DW, Conv_wgrad_attributes::output_names::DW); + auto dw_desc = tensors.at(DW->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dw_desc)); + + auto conv_desc_ptr = wgrad_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &conv_desc_ptr)); + + float alpha = 1.0f; + float beta = 0.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, + CUDNN_TYPE_FLOAT, + 1, + &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(wgrad_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, + CUDNN_TYPE_FLOAT, + 1, + &beta)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(wgrad_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(wgrad_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "CONV_WGRAD"})"_json); + } +#endif +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn.h new file mode 100644 index 00000000..9f2a23e7 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn.h @@ -0,0 +1,206 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class DBNNode : public NodeCRTP { + public: + Batchnorm_backward_attributes attributes; + + DBNNode(Batchnorm_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DBN; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for DBN node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = attributes.inputs[Batchnorm_backward_attributes::input_names::X]; + auto const x_tensor_dim = X->get_dim(); + + auto DX = attributes.outputs[Batchnorm_backward_attributes::output_names::DX]; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + // Set channel length tensors + auto infer_per_channel_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim.resize(x_tensor_dim.size(), 1); + tensor_dim[1] = x_tensor_dim[1]; + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_per_channel_tensors(attributes.outputs[Batchnorm_backward_attributes::output_names::DSCALE]); + infer_per_channel_tensors(attributes.outputs[Batchnorm_backward_attributes::output_names::DBIAS]); + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building DBNNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 dbn_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + dbn_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to BATCH_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::BATCH_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Batchnorm_backward_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Batchnorm_backward_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Batchnorm_backward_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set mean and inv_variance tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Batchnorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, + Batchnorm_backward_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE and DBIAS output tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Batchnorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Batchnorm_backward_attributes::output_names::DBIAS); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Batchnorm_backward_attributes::output_names::DX); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + // Set peer stat tensors if any + if (!attributes.peer_stats.empty()) { + std::vector peer_stat_descs; + for (auto const& peer_stat : attributes.peer_stats) { + peer_stat_descs.push_back(tensors.at(peer_stat->get_uid())->get_raw_desc()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dbn_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_PEER_STAT_DESCS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + peer_stat_descs.size(), + peer_stat_descs.data())); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dbn_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dbn_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "DBN"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn_weight.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn_weight.h new file mode 100644 index 00000000..bfad0dfb --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/dbn_weight.h @@ -0,0 +1,215 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class DBNWeightNode : public NodeCRTP { + public: + DBN_weight_attributes attributes; + + DBNWeightNode(DBN_weight_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DBN_WEIGHT; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for batchnorm finalize node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferencing from DY works today. + auto DY = attributes.inputs[DBN_weight_attributes::input_names::DY]; + auto const dy_tensor_dim = DY->get_dim(); + + auto X = attributes.inputs[DBN_weight_attributes::input_names::X]; + auto x_tensor_dim = X->get_dim(); + // Only infer dims and strides if user did not set them + if (x_tensor_dim.empty()) { + x_tensor_dim.resize(dy_tensor_dim.size()); + X->set_dim(dy_tensor_dim); + } + if (X->get_stride().empty()) { + auto const& X_dim = X->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(X_dim.size()); + X->set_stride(detail::generate_stride(X_dim, stride_order)); + } + + // Set channel length tensors + auto infer_per_channel_tensors = [&dy_tensor_dim](std::shared_ptr const& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (T->get_dim().empty()) { + tensor_dim.resize(dy_tensor_dim.size(), 1); + tensor_dim[1] = dy_tensor_dim[1]; + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_per_channel_tensors(attributes.outputs[DBN_weight_attributes::output_names::DBIAS]); + infer_per_channel_tensors(attributes.outputs[DBN_weight_attributes::output_names::DSCALE]); + infer_per_channel_tensors(attributes.outputs[DBN_weight_attributes::output_names::EQ_BIAS]); + infer_per_channel_tensors(attributes.outputs[DBN_weight_attributes::output_names::EQ_SCALE_DY]); + infer_per_channel_tensors(attributes.outputs[DBN_weight_attributes::output_names::EQ_SCALE_X]); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building DBNWeightNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 bn_bwd_weight_operation; + + _CUDNN_CHECK_CUDNN_ERROR(bn_bwd_weight_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR)); + + // Set compute type (math precision) + cudnnDataType_t compute_type = CUDNN_DATA_FLOAT; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &compute_type)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, DBN_weight_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, DBN_weight_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set mean tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, DBN_weight_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + // Set inv_variance tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, DBN_weight_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, DBN_weight_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set DSCALE output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, DBN_weight_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + // Set DBIAS output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, DBN_weight_attributes::output_names::DBIAS); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set EQ_SCALE_DY output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_DY, DBN_weight_attributes::output_names::EQ_SCALE_DY); + auto eq_scale_dy_desc = tensors.at(EQ_SCALE_DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_dy_desc)); + + // Set EQ_SCALE_X output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_SCALE_X, DBN_weight_attributes::output_names::EQ_SCALE_X); + auto eq_scale_x_desc = tensors.at(EQ_SCALE_X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_scale_x_desc)); + + // Set EQ_BIAS output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(EQ_BIAS, DBN_weight_attributes::output_names::EQ_BIAS); + auto eq_bias_desc = tensors.at(EQ_BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(bn_bwd_weight_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &eq_bias_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(bn_bwd_weight_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(bn_bwd_weight_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "DBN_WEIGHT"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/dln.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/dln.h new file mode 100644 index 00000000..4dd508a5 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/dln.h @@ -0,0 +1,227 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class DLNNode : public NodeCRTP { + public: + Layernorm_backward_attributes attributes; + + DLNNode(Layernorm_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DLN; + } + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for DLN node " << attributes.name); + + // WAR as epsilon was required in previous versions + if (detail::get_backend_version() < 8906) { + attributes.inputs[Layernorm_backward_attributes::input_names::EPSILON] = + std::make_shared(0.0f); + } + + attributes.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = attributes.inputs[Layernorm_backward_attributes::input_names::X]; + auto const x_tensor_dim = X->get_dim(); + + auto DY = attributes.inputs[Layernorm_backward_attributes::input_names::DY]; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = attributes.outputs[Layernorm_backward_attributes::output_names::DX]; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + auto scale_bias_dim = X->get_dim(); + scale_bias_dim[0] = 1; + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(attributes.outputs[Layernorm_backward_attributes::output_names::DSCALE]); + infer_scale_bias_tensors(attributes.outputs[Layernorm_backward_attributes::output_names::DBIAS]); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building DLNNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 dln_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + dln_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to LAYER_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::LAYER_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Layernorm_backward_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Layernorm_backward_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Layernorm_backward_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set mean and inv_variance tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Layernorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, + Layernorm_backward_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE and DBIAS output tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Layernorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Layernorm_backward_attributes::output_names::DBIAS); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Layernorm_backward_attributes::output_names::DX); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + // Set epsilon tensor for older backend versions + if (detail::get_backend_version() < 8906) { + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Layernorm_backward_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(dln_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(dln_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(dln_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "LAYER_NORM_BPROP"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/genstats.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/genstats.h new file mode 100644 index 00000000..8f918975 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/genstats.h @@ -0,0 +1,147 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { + +class GenstatsNode : public NodeCRTP { + public: + Genstats_attributes attributes; + + GenstatsNode(Genstats_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::GENSTATS; + } + + error_t + infer_properties_node() override final { + attributes.fill_from_context(context); + + // Only inferrencing from X works today. + auto X = attributes.inputs[Genstats_attributes::input_names::X]; + auto SUM = attributes.outputs[Genstats_attributes::output_names::SUM]; + auto SQ_SUM = attributes.outputs[Genstats_attributes::output_names::SQ_SUM]; + + auto const x_tensor_dim = X->get_dim(); + auto sum_tensor_dim = SUM->get_dim(); + auto sq_sum_tensor_dim = SQ_SUM->get_dim(); + + // Only infer dims and strides if user did not set them + if (sum_tensor_dim.empty()) { + sum_tensor_dim.resize(x_tensor_dim.size(), 1); + sum_tensor_dim[1] = x_tensor_dim[1]; + SUM->set_dim(sum_tensor_dim); + } + if (SUM->get_stride().empty()) { + auto const& SUM_dim = SUM->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(SUM_dim.size()); + SUM->set_stride(detail::generate_stride(SUM_dim, stride_order)); + } + + // Only infer dims and strides if user did not set them + if (sq_sum_tensor_dim.empty()) { + sq_sum_tensor_dim.resize(x_tensor_dim.size(), 1); + sq_sum_tensor_dim[1] = x_tensor_dim[1]; + SQ_SUM->set_dim(sq_sum_tensor_dim); + } + if (SQ_SUM->get_stride().empty()) { + auto const& SQ_SUM_dim = SQ_SUM->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(SQ_SUM_dim.size()); + SQ_SUM->set_stride(detail::generate_stride(SQ_SUM_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building GenstatsNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 genstats_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + genstats_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Genstats_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set gen stats mode + cudnnGenStatsMode_t genstats_mode = CUDNN_GENSTATS_SUM_SQSUM; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_MODE, + CUDNN_TYPE_GENSTATS_MODE, + 1, + &genstats_mode)); + + // Set math precision based on X tensor data type + cudnnDataType_t math_prec = static_cast(tensors.at(X->second->get_uid())->getDataType()); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &math_prec)); + + // Set SUM output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(SUM, Genstats_attributes::output_names::SUM); + auto sum_desc = tensors.at(SUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sum_desc)); + + // Set SQ_SUM output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(SQ_SUM, Genstats_attributes::output_names::SQ_SUM); + auto sq_sum_desc = tensors.at(SQ_SUM->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(genstats_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &sq_sum_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(genstats_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(genstats_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "GENSTATS"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/instancenorm.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/instancenorm.h new file mode 100644 index 00000000..1b71f4ab --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/instancenorm.h @@ -0,0 +1,414 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class InstanceNormNode : public NodeCRTP { + public: + Instancenorm_attributes attributes; + + InstanceNormNode(Instancenorm_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::INSTANCENORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for instancenorm node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Instancenorm_attributes::input_names::X]; + auto Y = attributes.outputs[Instancenorm_attributes::output_names::Y]; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + // mean inv_var dim is n,c,1,1 + auto stats_dim = X->get_dim(); + for (size_t i = 2; i < stats_dim.size(); i++) { + stats_dim[i] = 1; + } + + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + auto mean = attributes.outputs[Instancenorm_attributes::output_names::MEAN]; + if (mean->get_dim().empty()) { + mean->set_dim(stats_dim); + } + if (mean->get_stride().empty()) { + auto const& mean_dim = mean->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(mean_dim.size()); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); + } + + auto inv_var = attributes.outputs[Instancenorm_attributes::output_names::INV_VARIANCE]; + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating InstanceNormNode " << attributes.name); + + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of instancenorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building InstanceNormNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 instancenorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + instancenorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to INSTANCE_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::INSTANCE_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Instancenorm_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set scale and bias tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Instancenorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Instancenorm_attributes::input_names::BIAS); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Instancenorm_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Instancenorm_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set mean and inv_variance for training phase + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Instancenorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, + Instancenorm_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(instancenorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(instancenorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(instancenorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "INSTANCE_NORM"})"_json); + } +#endif +}; + +class DINNode : public NodeCRTP { + public: + Instancenorm_backward_attributes attributes; + + DINNode(Instancenorm_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DIN; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for DIN node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = attributes.inputs[Instancenorm_backward_attributes::input_names::X]; + auto const x_tensor_dim = X->get_dim(); + + auto DY = attributes.inputs[Instancenorm_backward_attributes::input_names::DY]; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = attributes.outputs[Instancenorm_backward_attributes::output_names::DX]; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + // scale_bias dim is 1,c,1,1 + // mean inv_var dim is n,c,1,1 + auto scale_bias_dim = X->get_dim(); + for (size_t i = 0; i < scale_bias_dim.size(); i++) { + if (i != 1) { + scale_bias_dim[i] = 1; + } + } + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(attributes.outputs[Instancenorm_backward_attributes::output_names::DSCALE]); + infer_scale_bias_tensors(attributes.outputs[Instancenorm_backward_attributes::output_names::DBIAS]); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building DINNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 din_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + din_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to INSTANCE_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::INSTANCE_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Instancenorm_backward_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Instancenorm_backward_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Instancenorm_backward_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set mean and inv_variance tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(MEAN, Instancenorm_backward_attributes::input_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, + Instancenorm_backward_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE and DBIAS output tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Instancenorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Instancenorm_backward_attributes::output_names::DBIAS); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + + // Set DX output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Instancenorm_backward_attributes::output_names::DX); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(din_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(din_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(din_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "INSTANCE_NORM_BPROP"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/layernorm.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/layernorm.h new file mode 100644 index 00000000..46420aa2 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/layernorm.h @@ -0,0 +1,259 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class LayerNormNode : public NodeCRTP { + public: + Layernorm_attributes attributes; + + LayerNormNode(Layernorm_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::LAYERNORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for layernorm node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Layernorm_attributes::input_names::X]; + auto Y = attributes.outputs[Layernorm_attributes::output_names::Y]; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + // scale_bias dim is 1,c,h,w + auto scale_bias_dim = X->get_dim(); + scale_bias_dim[0] = 1; + + auto scale = attributes.inputs[Layernorm_attributes::input_names::SCALE]; + // Only infer dims and strides if user did not set them + if (scale->get_dim().empty()) { + scale->set_dim(scale_bias_dim); + } + if (scale->get_stride().empty()) { + auto const& scale_dim = scale->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), scale_dim.size(), stride_order)); + scale->set_stride(detail::generate_stride(scale_dim, stride_order)); + } + + auto bias = attributes.inputs[Layernorm_attributes::input_names::BIAS]; + // Only infer dims and strides if user did not set them + if (bias->get_dim().empty()) { + bias->set_dim(scale_bias_dim); + } + if (bias->get_stride().empty()) { + auto const& bias_dim = bias->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), bias_dim.size(), stride_order)); + bias->set_stride(detail::generate_stride(bias_dim, stride_order)); + } + + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + // stats dim is x where scale == 1 else 1 + auto stats_dim = X->get_dim(); + for (size_t i = 0; i < stats_dim.size(); i++) { + if (scale->get_dim()[i] != 1) { + stats_dim[i] = 1; + } + } + + auto mean = attributes.outputs[Layernorm_attributes::output_names::MEAN]; + // Only infer dims and strides if user did not set them + if (mean->get_dim().empty()) { + mean->set_dim(stats_dim); + } + if (mean->get_stride().empty()) { + auto const& mean_dim = mean->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), mean_dim.size(), stride_order)); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); + } + + auto inv_var = attributes.outputs[Layernorm_attributes::output_names::INV_VARIANCE]; + // Only infer dims and strides if user did not set them + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), inv_var_dim.size(), stride_order)); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + + // Set scalar tensors + std::vector ones(X->get_dim().size(), 1); + auto infer_scalar_tensors = [&ones](std::shared_ptr& T) { + // Only infer dims and strides if user did not set them + if (T->get_dim().empty()) { + T->set_dim(ones); + } + if (T->get_stride().empty()) { + T->set_stride(ones); + } + }; + infer_scalar_tensors(attributes.inputs[Layernorm_attributes::input_names::EPSILON]); + + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << "Validating LayerNormNode " << attributes.name); + + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of layernorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building LayerNormNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 layernorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + layernorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::LAYER_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Layernorm_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set scale and bias tensors + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Layernorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(BIAS, Layernorm_attributes::input_names::BIAS); + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + + // Set epsilon tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Layernorm_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Layernorm_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set mean and inv_variance for training phase + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(MEAN, Layernorm_attributes::output_names::MEAN); + auto mean_desc = tensors.at(MEAN->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MEAN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &mean_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Layernorm_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(layernorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(layernorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(layernorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "LAYER_NORM"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul.h new file mode 100644 index 00000000..f09d3415 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul.h @@ -0,0 +1,253 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class MatmulNode : public NodeCRTP { + public: + Matmul_attributes attributes; + + MatmulNode(Matmul_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::MATMUL; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for matmul node " << attributes.name); + + attributes.fill_from_context(context); + + // Only inferrencing from (A, B) -> C works today. + auto a_tensor = attributes.inputs[Matmul_attributes::input_names::A]; + auto b_tensor = attributes.inputs[Matmul_attributes::input_names::B]; + auto c_tensor = attributes.outputs[Matmul_attributes::output_names::C]; + + auto const a_tensor_dim = a_tensor->get_dim(); + auto const b_tensor_dim = b_tensor->get_dim(); + auto c_tensor_dim = c_tensor->get_dim(); + + // Only infer dims and strides if user did not set them + if (c_tensor_dim.empty()) { + // CHECK_CUDNN_FRONTEND_ERROR(detail::generate_matmul_output_dim(a_tensor_dim, b_tensor_dim, c_tensor_dim)); + + c_tensor_dim.resize(a_tensor_dim.size()); + int64_t gemm_start_dim = a_tensor_dim.size() - 2; + c_tensor_dim[gemm_start_dim] = a_tensor_dim[gemm_start_dim]; // M + c_tensor_dim[gemm_start_dim + 1] = b_tensor_dim[gemm_start_dim + 1]; // N + + // Broadcast the batches + for (int64_t i = 0; i < gemm_start_dim; ++i) { + c_tensor_dim[i] = std::max(a_tensor_dim[i], b_tensor_dim[i]); + } + + c_tensor->set_dim(c_tensor_dim); + } + if (c_tensor->get_stride().empty()) { + auto const& c_dim = c_tensor->get_dim(); + // Default to Col major + auto const& stride_order = detail::generate_row_major_stride_order(c_dim.size()); + c_tensor->set_stride(detail::generate_stride(c_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building MatmulNode operations " << attributes.name << " "); + + // Create matmul descriptor by directly calling cuDNN backend API + MatMulDesc_v8 matmul_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR(matmul_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_MATMUL_DESCRIPTOR)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + matmul_descriptor.get_raw_desc(), CUDNN_ATTR_MATMUL_COMP_TYPE, CUDNN_TYPE_DATA_TYPE, 1, &cudnn_data_type)); + + // Set padding value if specified +#if (CUDNN_VERSION >= 8900) + if (attributes.padding_value != 0.0) { + double padding_value = attributes.padding_value; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_descriptor.get_raw_desc(), + CUDNN_ATTR_MATMUL_PADDING_VALUE, + CUDNN_TYPE_DOUBLE, + 1, + &padding_value)); + } +#endif + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(matmul_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(matmul_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 matmul_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + matmul_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR)); + + // Set input tensor A + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(A, Matmul_attributes::input_names::A); + auto a_desc = tensors.at(A->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_ADESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &a_desc)); + + // Set input tensor B + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(B, Matmul_attributes::input_names::B); + auto b_desc = tensors.at(B->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_BDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &b_desc)); + + // Set output tensor C + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(C, Matmul_attributes::output_names::C); + auto c_desc = tensors.at(C->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_CDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &c_desc)); + + // Set matmul descriptor + auto matmul_desc_ptr = matmul_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &matmul_desc_ptr)); + + // Set optional override tensors + auto M_override = attributes.inputs.find(Matmul_attributes::input_names::M_override); + if ((M_override != attributes.inputs.end()) && (M_override->second != nullptr)) { + auto m_override_desc = tensors.at(M_override->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_M_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &m_override_desc)); + } + + auto N_override = attributes.inputs.find(Matmul_attributes::input_names::N_override); + if ((N_override != attributes.inputs.end()) && (N_override->second != nullptr)) { + auto n_override_desc = tensors.at(N_override->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_N_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &n_override_desc)); + } + + auto K_override = attributes.inputs.find(Matmul_attributes::input_names::K_override); + if ((K_override != attributes.inputs.end()) && (K_override->second != nullptr)) { + auto k_override_desc = tensors.at(K_override->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(matmul_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_MATMUL_GEMM_K_OVERRIDE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &k_override_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(matmul_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(matmul_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "MATMUL"})"_json); + } +#endif +}; + +inline void +INode::matmul(std::shared_ptr a, + std::shared_ptr b, + Matmul_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Matmul_attributes::input_names::A] = a; + attributes.inputs[Matmul_attributes::input_names::B] = b; + attributes.outputs[Matmul_attributes::output_names::C] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +inline std::shared_ptr +INode::matmul(std::shared_ptr a, + std::shared_ptr b, + Matmul_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + attributes.inputs[Matmul_attributes::input_names::A] = a; + attributes.inputs[Matmul_attributes::input_names::B] = b; + + if (a->get_name().empty()) { + a->set_name(attributes.name + "::A"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::B"); + }; + + auto m_override = attributes.inputs.find(Matmul_attributes::input_names::M_override); + auto n_override = attributes.inputs.find(Matmul_attributes::input_names::N_override); + auto k_override = attributes.inputs.find(Matmul_attributes::input_names::K_override); + + if (m_override != attributes.inputs.end()) { + auto tensor = m_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::M_override"); + } + } + if (n_override != attributes.inputs.end()) { + auto tensor = n_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::N_override"); + } + } + if (k_override != attributes.inputs.end()) { + auto tensor = k_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::K_override"); + } + } + + auto C = attributes.outputs[Matmul_attributes::output_names::C] = output_tensor(attributes.name + "::C"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return C; +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul_fp8.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul_fp8.h new file mode 100644 index 00000000..d3fe5e58 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/matmul_fp8.h @@ -0,0 +1,104 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class MatmulFP8Node : public NodeCRTP { + public: + Matmul_fp8_attributes attributes; + + MatmulFP8Node(Matmul_fp8_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::MATMUL; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for matmul fp8 node " << attributes.name); + + attributes.fill_from_context(context); + + auto const& a_dim = attributes.inputs.at(Matmul_fp8_attributes::input_names::A)->get_dim(); + auto const& b_dim = attributes.inputs.at(Matmul_fp8_attributes::input_names::B)->get_dim(); + auto const& c_dim = attributes.outputs.at(Matmul_fp8_attributes::output_names::C)->get_dim(); + + std::shared_ptr last_output; + + // Matmul + + auto matmul_attributes = Matmul_attributes(); + matmul_attributes.clone_fp8_attributes(attributes); + matmul_attributes.set_name("matmul"); + + last_output = matmul(attributes.inputs.at(Matmul_fp8_attributes::input_names::A), + attributes.inputs.at(Matmul_fp8_attributes::input_names::B), + matmul_attributes); + + // Reduction if GQA for head dimension + if (a_dim.size() == 4 && b_dim.size() == 4 && c_dim.size() == 4 && a_dim[1] == b_dim[1] && + a_dim[1] != c_dim[1] && (a_dim[1] % c_dim[1] == 0)) { + auto gqa_attributes = Reduction_attributes().set_name("gqa_c").set_mode(ReductionMode_t::ADD); + last_output = reduction(last_output, gqa_attributes); + last_output->set_dim(c_dim); + } + + //// Scale Descales + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + // Descale A + mul_attributes.set_name("descale_a"); + last_output = + pointwise(last_output, attributes.inputs.at(Matmul_fp8_attributes::input_names::Descale_A), mul_attributes); + + // Descale B + mul_attributes.set_name("descale_b"); + last_output = + pointwise(last_output, attributes.inputs.at(Matmul_fp8_attributes::input_names::Descale_B), mul_attributes); + + // Scale C + mul_attributes.set_name("scale_c"); + // Special non-functional-style call. Needed because output already created and provided to user. + pointwise(last_output, + attributes.inputs.at(Matmul_fp8_attributes::input_names::Scale_C), + mul_attributes, + attributes.outputs.at(Matmul_fp8_attributes::output_names::C)); + + // Amax C + auto amax_attributes = Reduction_attributes().set_name("amax_c").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(last_output, amax_attributes, attributes.outputs.at(Matmul_fp8_attributes::output_names::Amax_C)); + + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "MATMUL_FP8"})"_json); + } +#endif +}; +inline void +INode::matmul_fp8(std::shared_ptr a, + std::shared_ptr b, + std::shared_ptr descale_a, + std::shared_ptr descale_b, + std::shared_ptr scale_c, + Matmul_fp8_attributes attributes, + std::shared_ptr c, + std::shared_ptr amax_c) { + attributes.inputs[Matmul_fp8_attributes::input_names::A] = a; + attributes.inputs[Matmul_fp8_attributes::input_names::B] = b; + attributes.inputs[Matmul_fp8_attributes::input_names::Descale_A] = descale_a; + attributes.inputs[Matmul_fp8_attributes::input_names::Descale_B] = descale_b; + attributes.inputs[Matmul_fp8_attributes::input_names::Scale_C] = scale_c; + attributes.outputs[Matmul_fp8_attributes::output_names::C] = c; + attributes.outputs[Matmul_fp8_attributes::output_names::Amax_C] = amax_c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/moe_grouped_matmul.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/moe_grouped_matmul.h new file mode 100644 index 00000000..e12acde0 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/moe_grouped_matmul.h @@ -0,0 +1,192 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class MoeGroupedMatmulNode : public NodeCRTP { + public: + Moe_grouped_matmul_attributes attributes; + + MoeGroupedMatmulNode(Moe_grouped_matmul_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::MOE_GROUPED_MATMUL; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for moe grouped matmul node " << attributes.name); + + attributes.fill_from_context(context); + + auto token_tensor = attributes.inputs[Moe_grouped_matmul_attributes::input_names::Token]; + auto weight_tensor = attributes.inputs[Moe_grouped_matmul_attributes::input_names::Weight]; + auto token_index_tensor = attributes.inputs[Moe_grouped_matmul_attributes::input_names::TokenIndex]; + auto output_tensor = attributes.outputs[Moe_grouped_matmul_attributes::output_names::Output]; + + auto const token_tensor_dim = token_tensor->get_dim(); + auto const weight_tensor_dim = weight_tensor->get_dim(); + auto output_tensor_dim = output_tensor->get_dim(); + + if (output_tensor_dim.empty()) { + output_tensor_dim.resize(3); + output_tensor_dim[0] = 1; + output_tensor_dim[2] = weight_tensor_dim[2]; + if (attributes.mode == MoeGroupedMatmulMode_t::GATHER) { + output_tensor_dim[1] = token_index_tensor->get_dim()[1]; + } else { + output_tensor_dim[1] = token_tensor_dim[1]; + } + output_tensor_dim.resize(3); + + output_tensor->set_dim(output_tensor_dim); + } + + if (output_tensor->get_stride().empty()) { + auto const& output_dim = output_tensor->get_dim(); + auto const& stride_order = detail::generate_row_major_stride_order(output_dim.size()); + output_tensor->set_stride(detail::generate_stride(output_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building MoeGroupedMatmulNode operations " << attributes.name << std::endl; + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Moe grouped matmul requires cuDNN v9.15.0"}; + +#if (CUDNN_VERSION >= 91500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91500, cudnn_ver_error); + CUDNN_FRONTEND_UNUSED(operations); + + auto moe_grouped_matmul_operation = + make_shared_backend_pointer(CUDNN_BACKEND_OPERATION_MOE_GROUPED_MATMUL_DESCRIPTOR); + + cudnnMoeGroupedMatmulMode_t moe_grouped_matmul_mode; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.mode, moe_grouped_matmul_mode)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_MODE, + CUDNN_TYPE_MOE_GROUPED_MATMUL_MODE, + 1, + &moe_grouped_matmul_mode)); + + cudnnDataType_t cudnn_data_type; + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + auto token = attributes.inputs.find(Moe_grouped_matmul_attributes::input_names::Token)->second; + auto backend_token = tensors[token->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_TOKEN_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_token)); + + auto weight = attributes.inputs.find(Moe_grouped_matmul_attributes::input_names::Weight)->second; + auto backend_weight = tensors[weight->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_WEIGHT_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_weight)); + + auto first_token_offset = + attributes.inputs.find(Moe_grouped_matmul_attributes::input_names::FirstTokenOffset)->second; + auto backend_first_token_offset = tensors[first_token_offset->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_FIRST_TOKEN_OFFSET_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_first_token_offset)); + + auto output = attributes.outputs.find(Moe_grouped_matmul_attributes::output_names::Output)->second; + auto backend_output = tensors[output->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_OUTPUT_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_output)); + + if (attributes.mode == MoeGroupedMatmulMode_t::GATHER || attributes.mode == MoeGroupedMatmulMode_t::SCATTER) { + auto token_index = attributes.inputs.find(Moe_grouped_matmul_attributes::input_names::TokenIndex)->second; + auto backend_token_index = tensors[token_index->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_TOKEN_INDEX_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_token_index)); + } + + if (attributes.mode == MoeGroupedMatmulMode_t::SCATTER) { + auto token_ks = attributes.inputs.find(Moe_grouped_matmul_attributes::input_names::TokenKs)->second; + auto backend_token_ks = tensors[token_ks->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_TOKEN_KS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_token_ks)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(moe_grouped_matmul_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_MOE_GROUPED_MATMUL_TOP_K, + CUDNN_TYPE_INT32, + 1, + &(attributes.top_k))); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(moe_grouped_matmul_operation->get_backend_descriptor())); + + raw_operations.push_back(moe_grouped_matmul_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "MOE_GROUPED_MATMUL"})"_json); + } +#endif +}; + +inline void +INode::moe_grouped_matmul(std::shared_ptr token, + std::shared_ptr weight, + std::shared_ptr first_token_offset, + std::shared_ptr token_index, + std::shared_ptr token_ks, + Moe_grouped_matmul_attributes attributes, + std::shared_ptr output) { + attributes.inputs[Moe_grouped_matmul_attributes::input_names::Token] = token; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::Weight] = weight; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::FirstTokenOffset] = first_token_offset; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::TokenIndex] = token_index; + attributes.inputs[Moe_grouped_matmul_attributes::input_names::TokenKs] = token_ks; + attributes.outputs[Moe_grouped_matmul_attributes::output_names::Output] = output; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/paged_cache_load.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/paged_cache_load.h new file mode 100644 index 00000000..9fc86109 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/paged_cache_load.h @@ -0,0 +1,153 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "pointwise.h" +#include "reduction.h" + +namespace cudnn_frontend::graph { + +class PagedCacheLoadNode : public NodeCRTP { + public: + PagedCacheLoad_attributes attributes; + + PagedCacheLoadNode(PagedCacheLoad_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::PAGED_CACHE_LOAD; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building PagedCacheLoadNode operations " << attributes.name << " "); + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Paged cache load requires cuDNN v9.5.0"}; + +#if (CUDNN_VERSION >= 90500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90500, cudnn_ver_error); + + // Create operation by directly calling cuDNN backend API + Operation_v8 paged_cache_load_operation; + + _CUDNN_CHECK_CUDNN_ERROR(paged_cache_load_operation.initialize_managed_backend_pointer( + CUDNN_BACKEND_OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR)); + + // Set container tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(container, PagedCacheLoad_attributes::input_names::container); + auto container_desc = tensors.at(container->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_CONTAINER_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &container_desc)); + + // Set page table tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(pageTable, PagedCacheLoad_attributes::input_names::pageTable); + auto page_table_desc = tensors.at(pageTable->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_PAGE_TABLE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &page_table_desc)); + + // Set sequence length tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(seqLen, PagedCacheLoad_attributes::input_names::seqLen); + auto seq_len_desc = tensors.at(seqLen->second->get_uid())->get_raw_desc(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_SEQUENCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &seq_len_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(yOut, PagedCacheLoad_attributes::output_names::yOut); + auto y_desc = tensors.at(yOut->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(paged_cache_load_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(paged_cache_load_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(paged_cache_load_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating PagedCacheLoadNode " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 || detail::get_compiled_version() < 90500, + error_code_t::CUDNN_BACKEND_API_FAILED, + "The cuDNN backend version must be at least 9.5.0 at compile time and runtime " + "in order to use PagedCacheLoadNode."); + + auto const yOut_dims = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_dim(); + auto const yOut_strides = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_stride(); + auto const container_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::container)->get_dim(); + auto const blockTable_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::pageTable)->get_dim(); + + // In the backend, the k-cache is passed as K^T and has dims [B,H,D,S], while v-cache has dims [B,H,S,D] + // Use the strides to distinguish. + auto yIsTransposed = yOut_strides[2] == 1; + auto s_kv = !yIsTransposed ? yOut_dims[2] : yOut_dims[3]; + + auto block_size = container_dims[2]; + auto block_table_size = blockTable_dims[2]; + bool is_block_table_packed = + attributes.inputs.at(PagedCacheLoad_attributes::input_names::pageTable)->get_ragged_offset() != nullptr; + + RETURN_CUDNN_FRONTEND_ERROR_IF( + !is_block_table_packed && (s_kv + (block_size - 1)) / block_size != block_table_size, + error_code_t::INVALID_VALUE, + "Paged cache load: block table size must equal ceil(s_kv/block_size), except when using packed block " + "tables"); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + } +#endif +}; + +inline void +INode::paged_cache_load(std::shared_ptr container, + std::shared_ptr seqLen, + std::shared_ptr pageTable, + PagedCacheLoad_attributes attributes, + std::shared_ptr yOut) { + attributes.inputs[PagedCacheLoad_attributes::input_names::container] = std::move(container); + attributes.inputs[PagedCacheLoad_attributes::input_names::seqLen] = std::move(seqLen); + attributes.inputs[PagedCacheLoad_attributes::input_names::pageTable] = std::move(pageTable); + attributes.outputs[PagedCacheLoad_attributes::output_names::yOut] = std::move(yOut); + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/pointwise.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/pointwise.h new file mode 100644 index 00000000..d67ab6a3 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/pointwise.h @@ -0,0 +1,377 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class PointwiseNode : public NodeCRTP { + public: + Pointwise_attributes attributes; + + PointwiseNode(Pointwise_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::POINTWISE; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for pointwise node " << attributes.name); + + attributes.fill_from_context(context); + + auto out_0_tensor = attributes.outputs.at(Pointwise_attributes::output_names::OUT_0); + + auto output_dim = out_0_tensor->get_dim(); + // Only infer dims and strides if user did not set them + if (output_dim.empty()) { + std::vector> input_shapes; + for (const auto& [input_name, input_tensor] : attributes.inputs) { + if (!input_tensor) { + continue; + } + input_shapes.push_back(input_tensor->get_dim()); + } + + CHECK_CUDNN_FRONTEND_ERROR(detail::compute_broadcast_shape(input_shapes, output_dim)); + out_0_tensor->set_dim(output_dim); + } + + if (out_0_tensor->get_stride().empty()) { + for (const auto& [input_name, input_tensor] : attributes.inputs) { + if (input_tensor == nullptr) { + continue; + } + if (input_tensor->get_dim() == out_0_tensor->get_dim()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO:" << " " << out_0_tensor->get_name() + << " stride computed from " << input_tensor->get_name()); + out_0_tensor->set_stride(input_tensor->get_stride()); + break; + } + } + if (out_0_tensor->get_stride().empty() && out_0_tensor->get_is_virtual()) { + // If the tensor is virtual the strides are immaterial + auto input_stride = attributes.inputs.at(Pointwise_attributes::input_names::IN_0)->get_stride(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(input_stride, output_dim.size(), stride_order)); + out_0_tensor->set_stride(detail::generate_stride(output_dim, stride_order)); + } + RETURN_CUDNN_FRONTEND_ERROR_IF(out_0_tensor->get_stride().empty(), + error_code_t::SHAPE_DEDUCTION_FAILED, + "Pointwise output strides could not be computed"); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building PointwiseNode operations " << attributes.name << " "); + + // Create pointwise descriptor by directly calling cuDNN backend API + PointWiseDesc_v8 pointwise_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR( + pointwise_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_POINTWISE_DESCRIPTOR)); + + // Set pointwise mode + cudnnPointwiseMode_t cudnn_pointwise_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.mode, cudnn_pointwise_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_MODE, + CUDNN_TYPE_POINTWISE_MODE, + 1, + &cudnn_pointwise_mode)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_MATH_PREC, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set mode-specific attributes + if (attributes.mode == PointwiseMode_t::RELU_FWD || attributes.mode == PointwiseMode_t::RELU_BWD) { + cudnnNanPropagation_t nan_propagation = CUDNN_PROPAGATE_NAN; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_NAN_PROPAGATION, + CUDNN_TYPE_NAN_PROPOGATION, + 1, + &nan_propagation)); + + double lower_clip = attributes.relu_lower_clip.value_or(0.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP, + CUDNN_TYPE_DOUBLE, + 1, + &lower_clip)); + + double upper_clip = attributes.relu_upper_clip.value_or(std::numeric_limits::max()); + if (attributes.compute_data_type == DataType_t::FLOAT) { + upper_clip = std::min(upper_clip, std::numeric_limits::max()); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP, + CUDNN_TYPE_DOUBLE, + 1, + &upper_clip)); + + double lower_clip_slope = attributes.relu_lower_clip_slope.value_or(0.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE, + CUDNN_TYPE_DOUBLE, + 1, + &lower_clip_slope)); + } else if (attributes.mode == PointwiseMode_t::ELU_FWD || attributes.mode == PointwiseMode_t::ELU_BWD) { + double elu_alpha = attributes.elu_alpha.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_descriptor.get_raw_desc(), CUDNN_ATTR_POINTWISE_ELU_ALPHA, CUDNN_TYPE_DOUBLE, 1, &elu_alpha)); + } else if (attributes.mode == PointwiseMode_t::SOFTPLUS_FWD || + attributes.mode == PointwiseMode_t::SOFTPLUS_BWD) { + double softplus_beta = attributes.softplus_beta.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA, + CUDNN_TYPE_DOUBLE, + 1, + &softplus_beta)); + } else if (attributes.mode == PointwiseMode_t::SWISH_FWD || attributes.mode == PointwiseMode_t::SWISH_BWD) { + double swish_beta = attributes.swish_beta.value_or(1.0); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_descriptor.get_raw_desc(), + CUDNN_ATTR_POINTWISE_SWISH_BETA, + CUDNN_TYPE_DOUBLE, + 1, + &swish_beta)); + } else if (attributes.mode == PointwiseMode_t::GEN_INDEX) { + int64_t axis = attributes.get_axis().value_or(-1); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_descriptor.get_raw_desc(), CUDNN_ATTR_POINTWISE_AXIS, CUDNN_TYPE_INT64, 1, &axis)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(pointwise_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(pointwise_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 pointwise_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + pointwise_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)); + + // Set the pointwise descriptor + auto pw_desc_ptr = pointwise_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &pw_desc_ptr)); + + auto const port_count = get_pointwise_mode_port_count(attributes.mode); + bool const is_activation_bwd = detail::is_activation_backward_mode(attributes.mode); + + if (is_activation_bwd) { + // Backward mode: IN_0 is dy, IN_1 is x, OUT_0 is dx + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_0, Pointwise_attributes::input_names::IN_0); + auto dy_desc = tensors.at(IN_0->second->get_uid())->get_raw_desc(); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_1, Pointwise_attributes::input_names::IN_1); + auto x_desc = tensors.at(IN_1->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(OUT_0, Pointwise_attributes::output_names::OUT_0); + auto dx_desc = tensors.at(OUT_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + } else { + // Forward mode + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_0, Pointwise_attributes::input_names::IN_0); + auto x_desc = tensors.at(IN_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(OUT_0, Pointwise_attributes::output_names::OUT_0); + auto y_desc = tensors.at(OUT_0->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + if (port_count >= 3) { + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_1, Pointwise_attributes::input_names::IN_1); + auto b_desc = tensors.at(IN_1->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_BDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &b_desc)); + } + + if (port_count >= 4) { + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(IN_2, Pointwise_attributes::input_names::IN_2); + auto t_desc = tensors.at(IN_2->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(pointwise_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_POINTWISE_TDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &t_desc)); + } + } + + // Set alpha scaling factors (always set to 1.0) + float alpha1 = 1.0f; + float alpha2 = 1.0f; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1, CUDNN_TYPE_FLOAT, 1, &alpha1)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + pointwise_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2, CUDNN_TYPE_FLOAT, 1, &alpha2)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(pointwise_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(pointwise_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "POINTWISE"})"_json); + } +#endif +}; + +inline void +INode::pointwise(std::shared_ptr a, + Pointwise_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +inline void +INode::pointwise(std::shared_ptr a, + std::shared_ptr b, + Pointwise_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; + attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +inline std::shared_ptr +INode::pointwise(std::shared_ptr a, Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; + auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = + output_tensor(attributes.name + "::OUT_0"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return OUT_0; +} + +inline std::shared_ptr +INode::pointwise(std::shared_ptr a, + std::shared_ptr b, + Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::IN_1"); + }; + auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = + output_tensor(attributes.name + "::OUT_0"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return OUT_0; +} + +inline std::shared_ptr +INode::pointwise(std::shared_ptr a, + std::shared_ptr b, + std::shared_ptr c, + Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; + attributes.inputs[Pointwise_attributes::input_names::IN_2] = c; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::IN_1"); + }; + if (c->get_name().empty()) { + c->set_name(attributes.name + "::IN_2"); + }; + auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = + output_tensor(attributes.name + "::OUT_0"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return OUT_0; +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/reduction.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/reduction.h new file mode 100644 index 00000000..193b9d37 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/reduction.h @@ -0,0 +1,189 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class ReductionNode : public NodeCRTP { + public: + Reduction_attributes attributes; + + ReductionNode(Reduction_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::REDUCTION; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating ReductionNode " << attributes.name); + + if (attributes.get_is_deterministic() && detail::get_backend_version() < 91100) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "DETERMINISTIC mode is not supported in cudnn version < 9.11.0"}; + } + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for reduction node " << attributes.name); + + attributes.fill_from_context(context); + + // Only inferrencing from IN_0 to OUT_0 works today. + auto x_tensor = attributes.inputs[Reduction_attributes::input_names::X]; + auto y_tensor = attributes.outputs[Reduction_attributes::output_names::Y]; + + auto const& x_tensor_dim = x_tensor->get_dim(); + auto y_tensor_dim = y_tensor->get_dim(); + // Only infer dims and strides if user did not set them + if (y_tensor_dim.empty()) { + y_tensor->set_dim(x_tensor_dim); + } + if (y_tensor->get_stride().empty()) { + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building ReductionNode operations " << attributes.name << " "); + + // Create reduction descriptor by directly calling cuDNN backend API + ReductionDesc_v8 reduction_descriptor; + + // 1. Create the backend descriptor + + _CUDNN_CHECK_CUDNN_ERROR( + reduction_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_REDUCTION_DESCRIPTOR)); + + // 2. Set compute type attribute + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // 3. Set reduction operator attribute + cudnnReduceTensorOp_t cudnn_reduction_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_mode().value(), cudnn_reduction_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_OPERATOR, + CUDNN_TYPE_REDUCTION_OPERATOR_TYPE, + 1, + &cudnn_reduction_mode)); + + // 4. Set deterministic mode if supported +#if (CUDNN_VERSION >= 91100) + if (detail::get_backend_version() >= 91100) { + bool is_deterministic = attributes.get_is_deterministic(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_descriptor.get_raw_desc(), + CUDNN_ATTR_REDUCTION_IS_DETERMINISTIC, + CUDNN_TYPE_BOOLEAN, + 1, + &is_deterministic)); + } +#endif + + // 5. Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reduction_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(reduction_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 reduction_operation; + + // Validate input tensors are set + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reduction_attributes::input_names::X); + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reduction_attributes::output_names::Y); + + // 1. Create the backend operation descriptor + + _CUDNN_CHECK_CUDNN_ERROR( + reduction_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)); + + // 2. Set the reduction descriptor attribute + auto reduction_desc_ptr = reduction_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &reduction_desc_ptr)); + + // 3. Set the input tensor (X) descriptor attribute + auto x_backend_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_backend_desc)); + + // 4. Set the output tensor (Y) descriptor attribute + auto y_backend_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reduction_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_REDUCTION_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_backend_desc)); + + // 5. Finalize the operation descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reduction_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(reduction_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "REDUCTION"})"_json); + } +#endif +}; + +inline void +INode::reduction(std::shared_ptr a, + Reduction_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Reduction_attributes::input_names::X] = a; + attributes.outputs[Reduction_attributes::output_names::Y] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +inline std::shared_ptr +INode::reduction(std::shared_ptr input, Reduction_attributes attributes) { + attributes.inputs[Reduction_attributes::input_names::X] = input; + auto Y = attributes.outputs[Reduction_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} +} // namespace cudnn_frontend::graph diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/resample.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/resample.h new file mode 100644 index 00000000..34e6031f --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/resample.h @@ -0,0 +1,291 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class ResampleNode : public NodeCRTP { + public: + Resample_attributes attributes; + + ResampleNode(Resample_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::RESAMPLE; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << "Validating ResampleNode " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.generate_index.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "generate_index attribute not set"); + + if (attributes.generate_index.value() == true && attributes.resample_mode == ResampleMode_t::MAXPOOL) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(Resample_attributes::output_names::Index); + } + + // Make sure that the mode can be lowered to BE + cudnnResampleMode_t dummy; + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::convert_to_cudnn_type(attributes.resample_mode, dummy) != CUDNN_STATUS_SUCCESS, + error_code_t::ATTRIBUTE_NOT_SET, + "Invalid resample mode."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for resample node " << attributes.name); + + auto y_tensor = attributes.outputs[Resample_attributes::output_names::Y]; + auto x_tensor = attributes.inputs[Resample_attributes::input_names::X]; + + attributes.fill_from_context(context); + + // If user does not set shape and layout of the output tensor, + // Get it from node attributes + if (y_tensor->get_dim().empty()) { + auto const x_dim = x_tensor->get_dim(); + auto y_dim = y_tensor->get_dim(); + y_dim = x_dim; + + // 2 cause first two dimensions are batch and channels + for (auto dim = 2u; dim < x_dim.size(); ++dim) { + auto spatial_dim = dim - 2u; + y_dim[dim] = + 1 + (x_dim[dim] + attributes.pre_padding[spatial_dim].numerator + + attributes.post_padding[spatial_dim].numerator - attributes.window[spatial_dim].numerator) / + attributes.stride[spatial_dim].numerator; + } + + y_tensor->set_dim(y_dim); + } + + // If layout is not set, generate the strides from layout + if (y_tensor->get_stride().empty()) { + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); + } + + if (attributes.outputs[Resample_attributes::output_names::Index]) { + auto index_tensor = attributes.outputs[Resample_attributes::output_names::Index]; + index_tensor->set_dim(y_tensor->get_dim()); + + // If layout is not set, generate the strides from layout + if (index_tensor->get_stride().empty()) { + auto const& index_dim = index_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(index_dim.size()); + index_tensor->set_stride(detail::generate_stride(index_dim, stride_order)); + } + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building ResampleNode operations " << attributes.name << " "); + + auto number_of_spatial_dim = static_cast(attributes.window.size()); + + // Create resample descriptor by directly calling cuDNN backend API + ResampleDesc_v8 resample_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR( + resample_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_RESAMPLE_DESCRIPTOR)); + + // Set resample mode + cudnnResampleMode_t cudnn_resample_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.resample_mode, cudnn_resample_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_MODE, + CUDNN_TYPE_RESAMPLE_MODE, + 1, + &cudnn_resample_mode)); + + // Set compute type + cudnnDataType_t cudnn_data_type; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.compute_data_type, cudnn_data_type)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, + 1, + &cudnn_data_type)); + + // Set nan propagation + cudnnNanPropagation_t nan_opt = CUDNN_PROPAGATE_NAN; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_NAN_PROPAGATION, + CUDNN_TYPE_NAN_PROPOGATION, + 1, + &nan_opt)); + + // Set padding mode + cudnnPaddingMode_t cudnn_padding_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.padding_mode, cudnn_padding_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_PADDING_MODE, + CUDNN_TYPE_PADDING_MODE, + 1, + &cudnn_padding_mode)); + + // Set spatial dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_SPATIAL_DIMS, + CUDNN_TYPE_INT64, + 1, + &number_of_spatial_dim)); + + // Set window dimensions + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_WINDOW_DIMS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.window.data())); + + // Set pre padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_PRE_PADDINGS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.pre_padding.data())); + + // Set post padding + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_POST_PADDINGS, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.post_padding.data())); + + // Set strides + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_descriptor.get_raw_desc(), + CUDNN_ATTR_RESAMPLE_STRIDES, + CUDNN_TYPE_FRACTION, + number_of_spatial_dim, + attributes.stride.data())); + + // Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(resample_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(resample_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 resample_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + resample_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Resample_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Resample_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set alpha and beta + double alpha = 1.0; + double beta = 0.0; + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + resample_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RESAMPLE_FWD_ALPHA, CUDNN_TYPE_DOUBLE, 1, &alpha)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + resample_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RESAMPLE_FWD_BETA, CUDNN_TYPE_DOUBLE, 1, &beta)); + + // Set resample descriptor + auto resample_raw_desc = resample_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &resample_raw_desc)); + + // Set index tensor if available + auto index = attributes.outputs.find(Resample_attributes::output_names::Index); + if ((index != attributes.outputs.end()) && (index->second != nullptr)) { + auto idx_desc = tensors.at(index->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(resample_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESAMPLE_FWD_IDXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &idx_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(resample_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(resample_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RESAMPLE"})"_json); + } +#endif +}; + +inline std::array, 2> +INode::resample(std::shared_ptr input, Resample_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } + attributes.inputs[Resample_attributes::input_names::X] = input; + auto Y = attributes.outputs[Resample_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr Index = nullptr; + if (attributes.generate_index.has_value() && attributes.generate_index.value() == true && + attributes.resample_mode == ResampleMode_t::MAXPOOL) { + Index = attributes.outputs[Resample_attributes::output_names::Index] = + output_tensor(attributes.name + "::Index"); + } + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return {Y, Index}; +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/reshape.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/reshape.h new file mode 100644 index 00000000..39b08c79 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/reshape.h @@ -0,0 +1,116 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class ReshapeNode : public NodeCRTP { + public: + Reshape_attributes attributes; + + ReshapeNode(Reshape_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::RESHAPE; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for reshape node " << attributes.name); + + auto y_tensor = attributes.outputs[Reshape_attributes::output_names::Y]; + + attributes.fill_from_context(context); + + // If user does not set shape and layout of the output tensor, + // Get it from node attributes + // If layout is not set, generate the strides from layout + + if (y_tensor->get_dim().empty() && attributes.get_dim().size()) { + y_tensor->set_dim(attributes.dim); + } + + if (y_tensor->get_stride().empty()) { + if (attributes.get_stride().size()) { + y_tensor->set_stride(attributes.get_stride()); + } else { + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); + } + } + + if (y_tensor->get_dim().empty() || y_tensor->get_stride().empty()) { + return {error_code_t::SHAPE_DEDUCTION_FAILED, "Reshape node output shape deduction failed"}; + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building ReshapeNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 reshape_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + reshape_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Reshape_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESHAPE_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Reshape_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(reshape_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RESHAPE_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(reshape_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(reshape_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RESHAPE"})"_json); + } +#endif +}; + +inline std::shared_ptr +INode::reshape(std::shared_ptr input, Reshape_attributes attributes) { + attributes.inputs[Reshape_attributes::input_names::X] = input; + auto Y = attributes.outputs[Reshape_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/rmsnorm.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/rmsnorm.h new file mode 100644 index 00000000..bc1f37d2 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/rmsnorm.h @@ -0,0 +1,406 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class RMSNormNode : public NodeCRTP { + public: + Rmsnorm_attributes attributes; + + RMSNormNode(Rmsnorm_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::RMSNORM; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for rmsnorm node " << attributes.name); + + attributes.fill_from_context(context); + + auto X = attributes.inputs[Rmsnorm_attributes::input_names::X]; + auto Y = attributes.outputs[Rmsnorm_attributes::output_names::Y]; + + // Only infer dims and strides if user did not set them + if (Y->get_dim().empty()) { + Y->set_dim(X->get_dim()); + } + if (Y->get_stride().empty()) { + Y->set_stride(X->get_stride()); + } + + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + auto inv_var = attributes.outputs[Rmsnorm_attributes::output_names::INV_VARIANCE]; + // Only infer dims and strides if user did not set them + if (inv_var->get_dim().empty()) { + auto inv_var_dim = X->get_dim(); + auto scale = attributes.inputs[Rmsnorm_attributes::input_names::SCALE]; + if (scale->get_dim().empty()) { + // mean inv_var dim is n,1,1,1 + for (size_t i = 1; i < inv_var_dim.size(); i++) { + inv_var_dim[i] = 1; + } + } else { + for (size_t i = 0; i < inv_var_dim.size(); i++) { + if (scale->get_dim()[i] != 1) { + inv_var_dim[i] = 1; + } + } + } + inv_var->set_dim(inv_var_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(X->get_stride(), inv_var_dim.size(), stride_order)); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating RMSNormNode " << attributes.name); + + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of rmsnorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building RMSNormNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 rmsnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + rmsnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_FORWARD_DESCRIPTOR)); + + // Set norm mode to RMS_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::RMS_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set forward phase + cudnnBackendNormFwdPhase_t cudnn_norm_fwd_phase; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.forward_phase, cudnn_norm_fwd_phase)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_PHASE, + CUDNN_TYPE_NORM_FWD_PHASE, + 1, + &cudnn_norm_fwd_phase)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Rmsnorm_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Rmsnorm_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set epsilon tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(EPSILON, Rmsnorm_attributes::input_names::EPSILON); + auto epsilon_desc = tensors.at(EPSILON->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_EPSILON_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &epsilon_desc)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Rmsnorm_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_YDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &y_desc)); + + // Set inv_variance for training phase + if (attributes.forward_phase == NormFwdPhase_t::TRAINING) { + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(INV_VARIANCE, Rmsnorm_attributes::output_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + } + + // Set optional bias tensor + auto BIAS = attributes.inputs.find(Rmsnorm_attributes::input_names::BIAS); + if ((BIAS != attributes.inputs.end()) && (BIAS->second != nullptr)) { + auto bias_desc = tensors.at(BIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_FWD_BIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &bias_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rmsnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(rmsnorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RMS_NORM"})"_json); + } +#endif +}; + +class DRMSNormNode : public NodeCRTP { + public: + Rmsnorm_backward_attributes attributes; + + DRMSNormNode(Rmsnorm_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::DRMSNorm; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating DRMSNormNode node " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.use_dbias.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "DRMSNormNode node needs has_bias(bool) to be called."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferencing properties for DRMSNorm node " << attributes.name); + + attributes.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = attributes.inputs[Rmsnorm_backward_attributes::input_names::X]; + auto const x_tensor_dim = X->get_dim(); + + auto DY = attributes.inputs[Rmsnorm_backward_attributes::input_names::DY]; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = attributes.outputs[Rmsnorm_backward_attributes::output_names::DX]; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + auto scale_bias_dim = X->get_dim(); + scale_bias_dim[0] = 1; + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(attributes.outputs[Rmsnorm_backward_attributes::output_names::DSCALE]); + if (attributes.use_dbias.value()) { + infer_scale_bias_tensors(attributes.outputs[Rmsnorm_backward_attributes::output_names::DBIAS]); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: Building DRMSNormNode operations " << attributes.name << " "); + + // Create operation by directly calling cuDNN backend API + Operation_v8 drmsnorm_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + drmsnorm_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_NORM_BACKWARD_DESCRIPTOR)); + + // Set norm mode to RMS_NORM + cudnnBackendNormMode_t cudnn_norm_mode; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(NormMode_t::RMS_NORM, cudnn_norm_mode)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_MODE, + CUDNN_TYPE_NORM_MODE, + 1, + &cudnn_norm_mode)); + + // Set input tensor X + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Rmsnorm_backward_attributes::input_names::X); + auto x_desc = tensors.at(X->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_XDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &x_desc)); + + // Set DY tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(DY, Rmsnorm_backward_attributes::input_names::DY); + auto dy_desc = tensors.at(DY->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DYDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dy_desc)); + + // Set scale tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(SCALE, Rmsnorm_backward_attributes::input_names::SCALE); + auto scale_desc = tensors.at(SCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_SCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &scale_desc)); + + // Set inv_variance tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(INV_VARIANCE, Rmsnorm_backward_attributes::input_names::INV_VARIANCE); + auto inv_var_desc = tensors.at(INV_VARIANCE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_INV_VARIANCE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &inv_var_desc)); + + // Set DSCALE output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DSCALE, Rmsnorm_backward_attributes::output_names::DSCALE); + auto dscale_desc = tensors.at(DSCALE->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DSCALE_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dscale_desc)); + + // Set optional DBIAS output tensor + if (attributes.use_dbias.value()) { + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DBIAS, Rmsnorm_backward_attributes::output_names::DBIAS); + auto dbias_desc = tensors.at(DBIAS->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DBIAS_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dbias_desc)); + } + + // Set DX output tensor + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(DX, Rmsnorm_backward_attributes::output_names::DX); + auto dx_desc = tensors.at(DX->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(drmsnorm_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_NORM_BWD_DXDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &dx_desc)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(drmsnorm_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(drmsnorm_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RMS_NORM_BPROP"})"_json); + } +#endif +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/rng.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/rng.h new file mode 100644 index 00000000..09b762d3 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/rng.h @@ -0,0 +1,187 @@ +#pragma once + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class RngNode : public NodeCRTP { + public: + Rng_attributes attributes; + + RngNode(Rng_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::RNG; + } + + error_t + infer_properties_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for rng node " << attributes.name); + + auto y_tensor = attributes.outputs[Rng_attributes::output_names::Y]; + + attributes.fill_from_context(context); + + // If user does not set shape and layout of the generated tensor, + // Get it from node attributes + // If layout is not set, generate the strides from layout + + if (y_tensor->get_dim().empty() && attributes.get_dim().size()) { + y_tensor->set_dim(attributes.dim); + } + + if (y_tensor->get_stride().empty()) { + if (attributes.get_stride().size()) { + y_tensor->set_stride(attributes.get_stride()); + } else { + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); + } + } + + if (y_tensor->get_dim().empty() || y_tensor->get_stride().empty()) { + return {error_code_t::SHAPE_DEDUCTION_FAILED, "RNG node output shape deduction failed"}; + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building RngNode operations " << attributes.name << " "); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.get_distribution() != RngDistribution_t::BERNOULLI, + error_code_t::ATTRIBUTE_NOT_SET, + "no other distribution except bernoulli supported."); + + // Create RNG descriptor by directly calling cuDNN backend API + RngDesc_v8 rng_descriptor; + + _CUDNN_CHECK_CUDNN_ERROR(rng_descriptor.initialize_managed_backend_pointer(CUDNN_BACKEND_RNG_DESCRIPTOR)); + + // Set distribution type + cudnnRngDistribution_t cudnn_rng_distribution; + + _CUDNN_CHECK_CUDNN_ERROR(detail::convert_to_cudnn_type(attributes.get_distribution(), cudnn_rng_distribution)); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_descriptor.get_raw_desc(), + CUDNN_ATTR_RNG_DISTRIBUTION, + CUDNN_TYPE_RNG_DISTRIBUTION, + 1, + &cudnn_rng_distribution)); + + // Set Bernoulli distribution probability + double bernoulli_prob = attributes.get_bernoulli_probability().value(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_descriptor.get_raw_desc(), + CUDNN_ATTR_RNG_BERNOULLI_DIST_PROBABILITY, + CUDNN_TYPE_DOUBLE, + 1, + &bernoulli_prob)); + + // Finalize the descriptor + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rng_descriptor.get_raw_desc())); + CUDNN_FE_LOG_LABEL_ENDL(rng_descriptor); + + // Create operation by directly calling cuDNN backend API + Operation_v8 rng_operation; + + _CUDNN_CHECK_CUDNN_ERROR( + rng_operation.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)); + + // Set output tensor Y + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Rng_attributes::output_names::Y); + auto y_desc = tensors.at(Y->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + rng_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RNG_YDESC, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &y_desc)); + + // Set RNG descriptor + auto rng_raw_desc = rng_descriptor.get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &rng_raw_desc)); + + if (attributes.seed.has_value()) { + // Set seed as int64_t value + int64_t seed_value = attributes.get_seed().value(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute( + rng_operation.get_raw_desc(), CUDNN_ATTR_OPERATION_RNG_SEED, CUDNN_TYPE_INT64, 1, &seed_value)); + } else { + // Set seed tensor descriptor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Seed, Rng_attributes::input_names::Seed); + auto seed_desc = tensors.at(Seed->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_SEED, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &seed_desc)); + + // Set offset tensor descriptor + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Offset, Rng_attributes::input_names::Offset); + auto offset_desc = tensors.at(Offset->second->get_uid())->get_raw_desc(); + + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(rng_operation.get_raw_desc(), + CUDNN_ATTR_OPERATION_RNG_OFFSET_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &offset_desc)); + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(rng_operation.get_raw_desc())); + + operations.push_back(std::make_shared(std::move(rng_operation))); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RNG"})"_json); + } +#endif +}; + +inline void +INode::rng(std::shared_ptr seed, + std::shared_ptr offset, + Rng_attributes attributes, + std::shared_ptr y) { + attributes.inputs[Rng_attributes::input_names::Seed] = seed; + attributes.inputs[Rng_attributes::input_names::Offset] = offset; + attributes.outputs[Rng_attributes::output_names::Y] = y; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +inline std::shared_ptr +INode::rng(std::shared_ptr seed, + std::shared_ptr offset, + Rng_attributes attributes) { + attributes.inputs[Rng_attributes::input_names::Seed] = seed; + attributes.inputs[Rng_attributes::input_names::Offset] = offset; + auto Y = attributes.outputs[Rng_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return Y; +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h new file mode 100644 index 00000000..55d635a4 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -0,0 +1,1964 @@ +#pragma once + +#include + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "matmul.h" +#include "pointwise.h" +#include "rng.h" +#include "softmax.h" +#include "paged_cache_load.h" +#include "sdpa_support_surface.h" + +namespace cudnn_frontend::graph { + +namespace attn::score_modifiers { + +// clang-format off +inline float get_negative_inf_value(); + +inline std::shared_ptr causal_mask( + std::shared_ptr graph, + std::shared_ptr attention_score +); + +inline std::shared_ptr bias( + std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr bias_tensor +); + +inline std::shared_ptr causal_mask_bottom_right( + std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr seq_len_q, + std::shared_ptr seq_len_kv +); + +inline std::shared_ptr padding_mask( + std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr seq_len_kv, + std::shared_ptr seq_len_q +); + +inline std::shared_ptr sliding_window_mask( + std::shared_ptr graph, + std::shared_ptr attention_score, + DiagonalAlignment_t diagonal_alignment, + std::optional left_window, + std::optional right_window, + int64_t s_q, + int64_t s_kv, + std::shared_ptr s_q_ptr, + std::shared_ptr s_kv_ptr +); + +inline std::shared_ptr alibi_mask( + std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr& alibi_slopes, + int64_t h_q, + int64_t& alibi_slopes_size +); +// clang-format on + +} // namespace attn::score_modifiers + +template +class SDPANodeBase : public NodeCRTP { + protected: + using input_names = SDPA_attributes::input_names; + using output_names = SDPA_attributes::output_names; + + std::shared_ptr rng_output; + std::shared_ptr alibi_slopes; + int64_t alibi_slopes_size = 0; + + public: + SDPA_attributes attributes; + + SDPANodeBase(SDPA_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + bool + is_paged_v() const { + auto page_table_v_it = attributes.inputs.find(input_names::Page_table_V); + return ((page_table_v_it) != attributes.inputs.end() && page_table_v_it->second != nullptr); + } + + bool + is_paged_k() const { + auto page_table_k_it = attributes.inputs.find(input_names::Page_table_K); + return ((page_table_k_it) != attributes.inputs.end() && page_table_k_it->second != nullptr); + } + + bool + has_seq_len_q() const { + auto seq_len_Q_it = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_Q); + return ((seq_len_Q_it) != attributes.inputs.end() && seq_len_Q_it->second != nullptr); + } + + bool + has_seq_len_kv() const { + auto seq_len_KV_it = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_KV); + return ((seq_len_KV_it) != attributes.inputs.end() && seq_len_KV_it->second != nullptr); + } + + // Helper function to infer KV sequence length + // Note that it cannot be run as part of infer_properties_node as + // this is being used in pre_validate_node + int64_t + infer_s_kv() const { + int64_t s_kv = -1; + + auto get_input_dim = [this](const SDPA_attributes::input_names& input_name) { + auto const input_it = attributes.inputs.find(input_name); + if (input_it != attributes.inputs.end()) { + return input_it->second->get_dim(); + } else { + return std::vector({-1, -1, -1, -1}); + } + }; + + auto const& k_dim = get_input_dim(input_names::K); + auto const& v_dim = get_input_dim(input_names::V); + + // If s_kv was set explicitly, use that + if (attributes.max_seq_len_kv.has_value()) { + s_kv = attributes.max_seq_len_kv.value(); + } + // When one of K or V cache are paged, s_kv can be extracted directly + else if (!is_paged_k()) { + s_kv = k_dim[2]; + + } else if (!is_paged_v()) { + s_kv = v_dim[2]; + } else { + CUDNN_FE_LOG_LABEL_ENDL( + "WARNING: maximum kv sequence length is being inferred. To set it explicitly, please use " + "\"set_paged_attention_max_seq_len_kv\""); + + auto bias_it = attributes.inputs.find(input_names::Bias); + auto rng_it = attributes.outputs.find(output_names::RNG_DUMP); + + // If there is a bias, extract it from there + if (bias_it != attributes.inputs.end() && bias_it->second != nullptr) { + s_kv = get_input_dim(input_names::Bias)[3]; + // If there is an rng_dump output, extract it from there + } else if (rng_it != attributes.outputs.end() && rng_it->second != nullptr) { + s_kv = rng_it->second->get_dim()[3]; + // When both caches are paged, and the above failed, we need to infer s_kv from the page table and + // container + } else { + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_k = get_input_dim(input_names::Page_table_K); + // [b, h_k, block_size, d_k] + auto const container_dim_k = get_input_dim(input_names::K); + int64_t s_k = page_table_dim_k[2] * container_dim_k[2]; + + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_v = get_input_dim(input_names::Page_table_V); + // [b, h_v, block_size, d_v] + auto const container_dim_v = get_input_dim(input_names::V); + int64_t s_v = page_table_dim_v[2] * container_dim_v[2]; + + s_kv = std::min(s_k, s_v); + } + } + + return s_kv; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating SDPANode " << attributes.name); + + // check that Q, K, V, O tensors has been assigned + // check that dim and strides has been assigned and last stride is 1 +#define CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(port, port_map) \ + { \ + std::shared_ptr tensor_ptr = port_map.at(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_dim().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The dim for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_stride().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The stride for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + tensor_ptr->get_stride()[3] != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Q, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::K, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::V, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::O, attributes.outputs); + + if (attributes.generate_stats.value_or(false) == true) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::Stats); + } + + // If max is requested, validate that the output tensor is present + if (attributes.outputs.find(output_names::Max) != attributes.outputs.end() && + attributes.outputs.at(output_names::Max) != nullptr) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::Max); + } + + // If sum_exp is requested, validate that the output tensor is present + if (attributes.outputs.find(output_names::Sum_exp) != attributes.outputs.end() && + attributes.outputs.at(output_names::Sum_exp) != nullptr) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::Sum_exp); + } + +#undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + + // validate backend limitations for the operation + auto validation_result = + attributes.validate_sdpa_support_surface(this->context, infer_s_kv(), is_paged_k(), is_paged_v()); + if (validation_result.is_good() == false) { + return validation_result; + } + + // return NOT_SET if sink_token present with 9.12 and below + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91300 && + attributes.inputs.find(input_names::SINK_TOKEN) != attributes.inputs.end(), + error_code_t::ATTRIBUTE_NOT_SET, + "SDPA with sink_token is not supported before 9.13."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + if (attributes.generate_stats.value_or(false)) { + auto stats = attributes.outputs.at(output_names::Stats); + auto stats_dim = stats->get_dim(); + + if (stats_dim.empty()) { + // Fill properties of virtual tensors + auto const& p_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = p_dim[0]; + auto h = p_dim[1]; + auto s_q = p_dim[2]; + stats->set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + } + } + + if (attributes.outputs[output_names::Max] != nullptr) { + auto max = attributes.outputs.at(output_names::Max); + + if (max->get_dim().empty()) { + // Fill properties of virtual tensors + auto const& p_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = p_dim[0]; + auto h = p_dim[1]; + auto s_q = p_dim[2]; + max->set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + } + } + + if (attributes.outputs[output_names::Sum_exp] != nullptr) { + auto sum_exp = attributes.outputs.at(output_names::Sum_exp); + + if (sum_exp->get_dim().empty()) { + // Fill properties of virtual tensors + auto const& p_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = p_dim[0]; + auto h = p_dim[1]; + auto s_q = p_dim[2]; + sum_exp->set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + } + } + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { +#define CUDNN_FE_VALIDATE_STRIDE(port, port_map) \ + { \ + auto const& t = port_map.find(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + t->second->get_stride().back() != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_VALIDATE_STRIDE(output_names::O, attributes.outputs); + +#undef CUDNN_FE_VALIDATE_STRIDE + + return {error_code_t::OK, ""}; + } + + virtual int64_t + get_fe_workspace_size_node() const override final { + int64_t size = 0; + + // align alibi slopes memory to 16 bytes + size += ((alibi_slopes_size + 15) / 16 * 16); + + return size; + } + + virtual error_t + collect_tensors_in_workspace_node( + std::unordered_map>>& + workspace_modifications, + int64_t& offset) const override final { + if (attributes.alibi_mask) { + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Q, input_names::Q); + int64_t const h_q = Q->second->get_dim()[1]; + auto alibi_slopes_vec = detail::get_alibi_slope(h_q); + workspace_modifications.emplace(alibi_slopes->get_uid(), std::make_tuple(0, offset, alibi_slopes_vec)); + int64_t alibi_slopes_size_padded = ((alibi_slopes_size + 15) / 16 * 16); + offset = offset + alibi_slopes_size_padded; + } + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "SDPA_FWD"})"_json); + } +#endif +}; + +class CompositeSDPANode : public SDPANodeBase { + public: + CompositeSDPANode(SDPA_attributes&& attributes_, detail::Context const& context) + : SDPANodeBase(std::move(attributes_), context) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + expand_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for CompositeSDPANode node " << attributes.name); + + // DO NOT REMOVE + // input data type is needed for: + // - aType of bmm2 + // - dropout scale in pre 8.9.3 + attributes.fill_from_context(this->context); + + // Gather dim to fill properties of virtual tensors + auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = q_dim[0]; + auto h_q = q_dim[1]; + auto s_q = q_dim[2]; + auto d_qk = q_dim[3]; + auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + auto h_k = k_dim[1]; + auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); + auto h_v = v_dim[1]; + auto d_v = v_dim[3]; + // Infer s_kv + int64_t s_kv = infer_s_kv(); + + std::shared_ptr k_cache; + if (!is_paged_k()) { + // 1. map K->KT + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, V + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // V = {b, h_v, s_kv, d_v} + // So the code below maps the K->KT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + // 2. Set k_cache + k_cache = attributes.inputs[input_names::K]; + } else { + // Create a paged cache load operation + auto paged_cache_load_attributes_k = PagedCacheLoad_attributes().set_name("paged_k_cache_operation"); + // Need to create virtual tensor descriptor for yOut here as it cannot be inferred + // K-cache has BHDS layout + k_cache = std::make_shared(); + k_cache->set_is_virtual(true); + k_cache->set_dim({b, h_k, d_qk, s_kv}); + k_cache->set_stride({d_qk * s_kv * h_k, d_qk * s_kv, 1, d_qk}); + k_cache->set_data_type(attributes.inputs[input_names::K]->get_data_type()); + paged_cache_load(attributes.inputs[input_names::K], + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::Page_table_K], + paged_cache_load_attributes_k, + k_cache); + } + + // This tensor tracks the main chain of data flow + std::shared_ptr last_output; + + //// Q * K + auto bmm1_attributes = Matmul_attributes() + .set_name("bmm1") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); + + if (attributes.padding_mask) { + bmm1_attributes.set_padding(0.0); + } + + auto const& bmm1_output = matmul(attributes.inputs[input_names::Q], k_cache, bmm1_attributes); + // Setting dim and strides as pointwise op wont have knowledge of how to do it for mha. + bmm1_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + last_output = bmm1_output; + + //// Optional Attn scale + // In case user provided a scalar value, do a fused scalar. + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + // If attn scale present, add a pointwise mul node + if (attributes.inputs[input_names::Attn_scale]) { + Pointwise_attributes scale_attributes; + scale_attributes.set_name("attn_scale").set_mode(PointwiseMode_t::MUL); + auto const& attn_scale_output = + pointwise(last_output, attributes.inputs[input_names::Attn_scale], scale_attributes); + last_output = attn_scale_output; + } + + // Descale Q + if (attributes.inputs.find(input_names::Descale_Q) != attributes.inputs.end() && + attributes.inputs.at(input_names::Descale_Q) != nullptr) { + auto descale_q_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL).set_name("descale_q"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_Q), descale_q_attributes); + } + + // Descale K + if (attributes.inputs.find(input_names::Descale_K) != attributes.inputs.end() && + attributes.inputs.at(input_names::Descale_K) != nullptr) { + auto descale_k_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL).set_name("descale_k"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_K), descale_k_attributes); + } + + if (attributes.attention_score_modifier != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = this->context; + last_output = attributes.attention_score_modifier(graph_, last_output); + sub_nodes.emplace_back(node_); + } + + // Optional bias + if (attributes.inputs.find(input_names::Bias) != attributes.inputs.end() && + attributes.inputs[input_names::Bias]) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = this->context; + last_output = attn::score_modifiers::bias(graph_, last_output, attributes.inputs[input_names::Bias]); + sub_nodes.emplace_back(node_); + } + + if (attributes.alibi_mask) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = this->context; + last_output = attn::score_modifiers::alibi_mask(graph_, last_output, alibi_slopes, h_q, alibi_slopes_size); + sub_nodes.emplace_back(node_); + } + + // There are two cases of applying padding mask + // 1. when actual seq_len is less than or equal to max_seq_len + if (attributes.padding_mask) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = this->context; + last_output = attn::score_modifiers::padding_mask(graph_, + last_output, + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::SEQ_LEN_Q]); + sub_nodes.emplace_back(node_); + } + + // 2. (bug in cudnn backend) no padding with max_seq_len%64!=0 + if ((s_kv % 64 != 0) && (!(attributes.padding_mask)) && (detail::get_backend_version() < 90000)) { + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + auto col_index_output = pointwise(last_output, col_index_attributes); + // scalar seq_kv only needs to be passed in case there in no padding mask and seq_kv is not multiple of 64. + // Also future versions of cudnn will not need it, hence tensor is pre-fixed with WAR. + auto WAR_scalar_max_seq_kv = std::make_shared(static_cast(s_kv)); + + auto col_less_seq_kv_attributes = + Pointwise_attributes().set_name("col_less_seq_kv").set_mode(PointwiseMode_t::CMP_LT); + auto col_less_seq_kv_output = + pointwise(col_index_output, WAR_scalar_max_seq_kv, col_less_seq_kv_attributes); + + // Lower attributes to binary select attributes + auto negative_inf_padding = + std::make_shared(attn::score_modifiers::get_negative_inf_value()); + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + auto padding_mask_output = + pointwise(last_output, negative_inf_padding, col_less_seq_kv_output, binary_select_attributes); + last_output = padding_mask_output; + } + + // Apply (bottom-right) causal masking (with right bound) and/or set the left bound + if (attributes.left_bound.has_value() || attributes.right_bound.has_value()) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = this->context; + + auto s_kv_ptr = attributes.inputs.find(input_names::SEQ_LEN_KV) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_KV] + : nullptr; + auto s_q_ptr = attributes.inputs.find(input_names::SEQ_LEN_Q) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_Q] + : nullptr; + + last_output = attn::score_modifiers::sliding_window_mask(graph_, + last_output, + attributes.diagonal_alignment, + attributes.left_bound, + attributes.right_bound, + s_q, + s_kv, + s_q_ptr, + s_kv_ptr); + sub_nodes.emplace_back(node_); + } + + // Lower attributes to softmax attributes + auto softmax_output = std::make_shared(); + softmax_output->set_is_virtual(true); + + auto softmax_attributes = Softmax_attributes().set_name("softmax"); + // Set sink for softmax if user has provided a sink tensor + if (attributes.inputs.find(input_names::SINK_TOKEN) != attributes.inputs.end()) { + softmax_attributes.set_sink(attributes.inputs[input_names::SINK_TOKEN]); + } + // Special non-functional-style call. Needed because output already created and provided to user. + softmax(last_output, + softmax_attributes, + softmax_output, + attributes.outputs[output_names::Stats], + attributes.outputs[output_names::Max], + attributes.outputs[output_names::Sum_exp]); + last_output = softmax_output; + + // Two cases for training: dropout present or not + bool dropout_present = false; + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + if (attributes.dropout_probability.has_value()) { + dropout_present = true; + // Special case: Skip dropout when 0.0 probability. Only do for 8.9.3 and up as rng isn't optional earlier. + if (detail::get_backend_version() > 8902 && attributes.dropout_probability.value() == 0.0) { + dropout_present = false; + } + } else if (is_dropout_custom) { + dropout_present = true; + } + + if (dropout_present) { + if (is_dropout_custom) { + auto dropout_scale_attributes = + Pointwise_attributes().set_name("dropout_scale_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_scale_output = + pointwise(last_output, attributes.inputs[input_names::Dropout_scale], dropout_scale_attributes); + + auto mask_attributes = + Pointwise_attributes().set_name("dropout_mask_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_mask_output = + pointwise(dropout_scale_output, dropout_mask->second, mask_attributes); + last_output = dropout_mask_output; + } else { + if (attributes.outputs[output_names::RNG_DUMP] != nullptr) { + rng_output = attributes.outputs[output_names::RNG_DUMP]; + rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0 - attributes.dropout_probability.value()), + rng_output); + } else { + rng_output = rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0 - attributes.dropout_probability.value())); + rng_output + // Hard coding dim and strides as rng output can no inputs to infer it from. + ->set_dim({b, h_q, s_q, s_kv}) + .set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + } + + auto mask_attributes = + Pointwise_attributes().set_name("dropout_mask_mul").set_mode(PointwiseMode_t::MUL); + auto const& dropout_mask_output = pointwise(last_output, rng_output, mask_attributes); + last_output = dropout_mask_output; + + std::shared_ptr dropout_scale = nullptr; + + if (detail::get_backend_version() < 8903) { + half dropout_scale_value = __float2half(1.0f / (1.0f - attributes.dropout_probability.value())); + dropout_scale = std::make_shared(dropout_scale_value); + } else { + float dropout_scale_value = (1.0f / (1.0f - attributes.dropout_probability.value())); + dropout_scale = std::make_shared(dropout_scale_value); + } + + auto dropout_scale_attributes = + Pointwise_attributes().set_name("dropout_scale").set_mode(PointwiseMode_t::MUL); + auto const& dropout_scale_output = pointwise(last_output, dropout_scale, dropout_scale_attributes); + last_output = dropout_scale_output; + } + } + + // Amax S + if (attributes.outputs.find(output_names::Amax_S) != attributes.outputs.end() && + attributes.outputs.at(output_names::Amax_S) != nullptr) { + auto amax_attributes = Reduction_attributes().set_name("amax_s").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(last_output, amax_attributes, attributes.outputs.at(output_names::Amax_S)); + } + + // Scale S + if (attributes.inputs.find(input_names::Scale_S) != attributes.inputs.end() && + attributes.inputs.at(input_names::Scale_S) != nullptr) { + auto scale_s_attributes = Pointwise_attributes().set_name("scale_s").set_mode(PointwiseMode_t::MUL); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Scale_S), scale_s_attributes); + } + + // Lower attributes to bmm2 attributes + // Requirement by cudnn backend to take in bmm2 aType as i/o type. + last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + auto const& seq_len_q = attributes.inputs[input_names::SEQ_LEN_Q]; + auto const& seq_len_kv = attributes.inputs[input_names::SEQ_LEN_KV]; + // auto const& V = attributes.inputs[input_names::V]; + auto const& O = attributes.outputs[output_names::O]; + + std::shared_ptr v_cache; + + if (!is_paged_v()) { + v_cache = attributes.inputs[input_names::V]; + } else { + auto paged_cache_load_attributes_v = PagedCacheLoad_attributes().set_name("paged_v_cache_operation"); + v_cache = std::make_shared(); + v_cache->set_dim({b, h_v, s_kv, d_v}) + .set_stride({d_v * s_kv * h_v, d_v * s_kv, d_v, 1}) + .set_data_type(attributes.inputs[input_names::V]->get_data_type()); + v_cache->set_is_virtual(true); + paged_cache_load(attributes.inputs[input_names::V], + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::Page_table_V], + paged_cache_load_attributes_v, + v_cache); + } + + //// S * V + if (attributes.mma_core_mode == DataType_t::HALF) { + auto bmm2_attributes = + Matmul_attributes().set_name("bmm2").set_m_override(seq_len_q).set_k_override(seq_len_kv); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul(last_output, v_cache, bmm2_attributes, O); + } else if (attributes.mma_core_mode == DataType_t::FP8_E4M3 || + attributes.mma_core_mode == DataType_t::FP8_E5M2) { + auto const& descale_s = attributes.inputs.at(input_names::Descale_S); + auto const& descale_v = attributes.inputs.at(input_names::Descale_V); + auto const& scale_o = attributes.inputs.at(input_names::Scale_O); + auto const& amax_o = attributes.outputs.at(output_names::Amax_O); + + auto bmm2_attributes = + Matmul_fp8_attributes().set_name("bmm2").set_m_override(seq_len_q).set_k_override(seq_len_kv); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(last_output, v_cache, descale_s, descale_v, scale_o, bmm2_attributes, O, amax_o); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF(true, error_code_t::GRAPH_NOT_SUPPORTED, "Unsupported MMA core mode"); + } + + return {error_code_t::OK, ""}; + } +}; + +class CompositeSDPABackwardNode : public NodeCRTP { + using input_names = SDPA_backward_attributes::input_names; + using output_names = SDPA_backward_attributes::output_names; + + private: + // non-virtual node gpu tensors + std::shared_ptr dQ_accum; + int64_t dQ_accum_size = 0; + std::shared_ptr dK_fullhead; + int64_t dK_fullhead_size = 0; + std::shared_ptr dV_fullhead; + int64_t dV_fullhead_size = 0; + std::shared_ptr softmax_sum; + int64_t softmax_sum_size = 0; + std::shared_ptr alibi_slopes; + int64_t alibi_slopes_size = 0; + + mutable bool has_workaround_padding_mask = false; // Will be edited in pre_validate_node() + mutable int32_t s_q_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node() + mutable int32_t s_kv_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node() + mutable std::shared_ptr + workaround_padding_mask_seq_len_q; // Will be edited in pre_validate_node() + mutable std::shared_ptr + workaround_padding_mask_seq_len_kv; // Will be edited in pre_validate_node() + mutable int64_t batch_size_for_workaround_padding_mask = 0; // Will be edited in pre_validate_node() + mutable bool is_deterministic_algorithm_supported_on_blackwell = false; // Will be edited in pre_validate_node() + + public: + mutable SDPA_backward_attributes attributes; // Will be edited in pre_validate_node() for workaround padding mask + + CompositeSDPABackwardNode(SDPA_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating CompositeSDPABackwardNode" << attributes.name); + + // check that Q, K, V, O, stats, dO, dQ, dK, dV tensors has been assigned + // check that dim and strides has been assigned and last stride is 1 +#define CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(port, port_map) \ + { \ + std::shared_ptr tensor_ptr = port_map.at(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_dim().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The dim for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_stride().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The stride for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + tensor_ptr->get_stride()[3] != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Q, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::K, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::V, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::O, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Stats, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::dO, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dQ, attributes.outputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dK, attributes.outputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dV, attributes.outputs); + +#undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + + // validate backend limitations for the operation + // clang-format off + int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + int64_t s_kv = attributes.inputs.at(input_names::V)->get_dim()[2]; + int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; + int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; + int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; + int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; + int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + + bool const is_ragged = attributes.inputs.at(input_names::Q)->get_ragged_offset() || + attributes.inputs.at(input_names::K)->get_ragged_offset() || + attributes.inputs.at(input_names::V)->get_ragged_offset() || + attributes.inputs.at(input_names::O)->get_ragged_offset(); + + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + auto const& dbias_mask = attributes.outputs.find(output_names::dBias); + bool const is_dbias = (dbias_mask != attributes.outputs.end() && dbias_mask->second != nullptr); + + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; + + auto const& rng_tensor = attributes.outputs.find(output_names::RNG_DUMP); + bool const is_rng = (rng_tensor != attributes.outputs.end() && rng_tensor->second != nullptr); + + // validation TODO: + // - validate stats has valid dims + // - validate Q and dQ have the same dims + + // Stop s_q = S_kv = 1 from running + RETURN_CUDNN_FRONTEND_ERROR_IF(s_q == 1 && s_kv == 1, + error_code_t::GRAPH_NOT_SUPPORTED, + "s_q = s_kv = 1 is not supported."); + + cudaDeviceProp prop; + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + + if (prop.major == 9) { + // validate basic dimension requirements + + if ((detail::get_backend_version() >= 91100) && (detail::get_backend_version() < 91300)) { + + if ((128 < d_qk) && (d_qk <= 192) && (64 < d_v) && (d_v <= 128)) { + + // DeepSeek case, 9.11 only supports 192 hidden dim + RETURN_CUDNN_FRONTEND_ERROR_IF( (d_v != 128) && (d_qk != 192), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim d_v should be equal to 128 if d_qk is 192"); + } + } + + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim should be less than or equal to 256 and hidden_dim should be multiple of 8"); + + } else if (prop.major == 10 && detail::get_backend_version() >= 91100) { + // validate basic dimension requirements + if (d_qk == 192) { // special case for 192 hidden dim + RETURN_CUDNN_FRONTEND_ERROR_IF( (d_v != 128), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim d_v should be equal to 128 if d_qk is 192"); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim should be less than or equal to 128 and hidden_dim should be multiple of 8 when d_qk != d_v"); + } + } else { + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim should be less than or equal to 128 and hidden_dim should be multiple of 8"); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.attention_score_modifier != nullptr) && + (attributes.alibi_mask || attributes.padding_mask || attributes.has_causal_like_masking() || + attributes.left_bound.has_value()), error_code_t::GRAPH_NOT_SUPPORTED,"Attention score mod enabled and hence other subgraphs are disabled."); + + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); + + // validate options for attn_scale + auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); + bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attributes.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + // validate alibi requirements + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.alibi_mask && !(attributes.right_bound.has_value() && attributes.right_bound.value() == 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "When alibi mask is used, diagonal_band_right_bound needs to be set to 0."); + + // validate options for bias mask + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); + + if (s_kv % 128 != 0 && attributes.padding_mask == false && is_ragged == false && detail::get_backend_version() <= 91500) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Workaround padding mask is enabled for s_q % 128 != 0 and use_padding_mask == false and is_ragged == false"); + has_workaround_padding_mask = true; + batch_size_for_workaround_padding_mask = attributes.inputs.at(input_names::Q)->get_dim()[0]; + s_q_for_workaround_padding_mask = s_q; + s_kv_for_workaround_padding_mask = s_kv; + workaround_padding_mask_seq_len_q = std::make_shared(); + workaround_padding_mask_seq_len_q->set_name("workaround_padding_mask_seq_len_q").set_dim({batch_size_for_workaround_padding_mask,1,1,1}).set_stride({1,1,1,1}).set_data_type(DataType_t::INT32); + workaround_padding_mask_seq_len_kv = std::make_shared(); + workaround_padding_mask_seq_len_kv->set_name("workaround_padding_mask_seq_len_kv").set_dim({batch_size_for_workaround_padding_mask,1,1,1}).set_stride({1,1,1,1}).set_data_type(DataType_t::INT32); + attributes.set_padding_mask(true); + attributes.set_seq_len_q(workaround_padding_mask_seq_len_q).set_seq_len_kv(workaround_padding_mask_seq_len_kv); + } + + // validate options for padding mask + auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); + bool const has_seq_len_q = (seq_len_q != attributes.inputs.end()) && (seq_len_q->second != nullptr); + auto const& seq_len_kv = attributes.inputs.find(input_names::SEQ_LEN_KV); + bool const has_seq_len_kv = (seq_len_kv != attributes.inputs.end()) && (seq_len_kv->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.padding_mask && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + RETURN_CUDNN_FRONTEND_ERROR_IF((!attributes.padding_mask && !attributes.attention_score_modifier) && (has_seq_len_q || has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + // validate options for max_total_seq_len + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value()) && !is_ragged, + error_code_t::GRAPH_NOT_SUPPORTED, + "max_total_seq_len_q is only supported with packed layout"); + + // validate options for bottom right causal mask + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.has_causal_mask_bottom_right() && (!attributes.padding_mask) && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask does not support max_s_q > max_s_kv. Please virtually slice the Q tensor and pass it as max_s_q == max_s_kv"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.has_causal_mask_bottom_right() && (is_bias || attributes.alibi_mask || (is_ragged && !attributes.padding_mask) || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_dropout=False. Further is_ragged==True is only allowed when padding_mask=True."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.has_causal_mask_bottom_right() && (detail::get_backend_version() < 90600) && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with s_q multiple of 64, and s_kv multiple of 64, for cudnn version below 9.6.0"); + + // validate options for sliding window length + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.left_bound.has_value() && attributes.left_bound.value() <= 0, + error_code_t::INVALID_VALUE, + "Left bound (Sliding window length) should be greater than or equals to zero when set."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.left_bound.has_value() && (s_q * attributes.left_bound.value() == s_kv * attributes.left_bound.value()) && (detail::get_backend_version() <= 90900) && (prop.major == 9) && attributes.has_causal_mask_bottom_right(), + error_code_t::GRAPH_NOT_SUPPORTED, + "On Hopper architecture, this specific combination of s_q, s_kv, and left_bound + right_bound + bottom right diagonal alignment is not supported for backend version 9.9 or below"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.left_bound.has_value() && (!attributes.padding_mask) && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with max_s_q <= max_s_kv."); + + if ((detail::get_backend_version() >= 91002)) { + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.left_bound.has_value() || attributes.right_bound.has_value()) && ((is_ragged && !attributes.padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Left and right bounds with is_ragged==True is only allowed when padding_mask=True. And the diagonal alignment must be set."); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.left_bound.has_value() && (! attributes.has_causal_like_masking() || is_dropout || is_bias || (is_ragged && !attributes.padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Left and right bounds are only supported with is_dropout=False, is_bias=False. Further is_ragged==True is only allowed when padding_mask=True. Lastly the diagonal alignment must be set."); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.right_bound.has_value() && attributes.right_bound.value() < 0, + error_code_t::INVALID_VALUE, + "Right bound needs to be larger than or equal to zero"); + + // validate options for dropout mask + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && is_dropout_custom, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // validate options for deterministic algorithm + if(attributes.is_deterministic_algorithm && (prop.major == 10)) { + RETURN_CUDNN_FRONTEND_ERROR_IF( (detail::get_backend_version() < 91800), + error_code_t::GRAPH_NOT_SUPPORTED, + "Deterministic algorithm is not supported on blackwell architecture with cudnn version below 9.18.0"); + + // dbias bias rng/dropout alibi + RETURN_CUDNN_FRONTEND_ERROR_IF(is_dbias || is_rng || is_dropout || attributes.alibi_mask, + error_code_t::GRAPH_NOT_SUPPORTED, + "Deterministic algorithm is not supported on blackwell architecture when dbias, rng/dropout, alibi is enabled"); + + is_deterministic_algorithm_supported_on_blackwell = true; + } + + if(detail::get_backend_version() >= 91801) { + RETURN_CUDNN_FRONTEND_ERROR_IF(is_ragged && (8 == prop.major || 12 == prop.major) && attributes.is_deterministic_algorithm, + error_code_t::GRAPH_NOT_SUPPORTED, + "Deterministic algorithm is not supported for bprop thd on SM8X and SM12X GPUs"); + } + + // version specific validation + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.6, s_kv not a multiple of 64 or d not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8907 && (s_kv % 64 != 0) && (!(attributes.padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.7, s_kv not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && ((s_q % 64 != 0) || (s_kv % 64 != 0)) && (attributes.padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.0.0, s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && (s_q < 64), + error_code_t::GRAPH_NOT_SUPPORTED, + " Sequence length must be greater than or equal to 64 for cudnn version prior to v9.0.0"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90200 && attributes.left_bound.has_value(), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.2.0, sliding window attention is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_dbias && attributes.padding_mask, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, dBias with variable sequence lengths is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_dbias && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, dBias not support s_q/s_kv which aren't multiple of 64"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90600 && is_ragged && ((h_q != h_k) || (h_q != h_v)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.6.0, group-query attention with raggged offset is not supported"); + + // TODO add version check once fixed + RETURN_CUDNN_FRONTEND_ERROR_IF(prop.major == 10 && is_rng, + error_code_t::GRAPH_NOT_SUPPORTED, + "Dropout RNG dump is not supported for SM Major version 10"); + + // TODO add version check once fixed + RETURN_CUDNN_FRONTEND_ERROR_IF(prop.major == 10 && is_ragged && is_dbias, + error_code_t::GRAPH_NOT_SUPPORTED, + "dbias with ragged is not supported for SM Major version 10"); + + // validate that datatype is set for the graph + RETURN_CUDNN_FRONTEND_ERROR_IF(this->context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); + // If dsink is set, sink also needs to be set + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.outputs.find(output_names::DSINK_TOKEN) != attributes.outputs.end() && attributes.inputs.find(input_names::SINK_TOKEN) == attributes.inputs.end(), + error_code_t::ATTRIBUTE_NOT_SET, + "If dsink is set, sink also needs to be set."); + // clang-format on + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + // clang-format off + if (detail::get_backend_version() < 90600 && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { + CUDNN_FE_LOG_LABEL_ENDL("WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but cuDNN version is below 9.6.0 does not support max_total_seq_len_q. The workspace memory size required to execute this graph may be unexpectedly large"); + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + + // TODO add version check once fixed + int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; + int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + if ((attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value()) && (d_qk % 16 != 0 || d_v % 16 != 0)) { + CUDNN_FE_LOG_LABEL_ENDL("WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but d is not a multiple of 16 has a known functional issue. The workspace memory size required to execute this graph may be unexpectedly large"); + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + + + if(detail::get_backend_version() >= 91801) { + cudaDeviceProp prop; + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + if((8 == prop.major || 12 == prop.major) && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + } + // clang-format on + + return {error_code_t::OK, ""}; + } + + error_t + expand_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for CompositeSDPABackwardNode " << attributes.name); + + attributes.fill_from_context(context); + + // Gather dim to fill properties of virtual tensors + auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = q_dim[0]; + auto h_q = q_dim[1]; + auto s_q = q_dim[2]; + auto d_qk = q_dim[3]; + auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + auto h_k = k_dim[1]; + auto s_kv = k_dim[2]; + auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); + auto h_v = v_dim[1]; + auto d_v = v_dim[3]; + + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, VT + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // VT = {b, h_v, d_v, s_kv} + // So the code below maps the K->KT and V->VT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_stride(temp_vec); + + std::shared_ptr last_output, exp_s_output, dS_output, rng_output; + + // --------------Initialize and create tensors before creating nodes-------------------- + // one_tensor is needed for non-dropout graphs + // one_tensor is passed by the node + auto one_tensor = std::make_shared(1.0f); + + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + // if dropout_mask is used, then the user passes scale and scale_inverse + bool is_dropout_prob = (attributes.dropout_probability.has_value()); + bool is_dropout_mask = (attributes.inputs[input_names::Dropout_mask] != nullptr); + if (is_dropout_prob) { + float dropout_scale_value = 1.0f / (1.0f - attributes.dropout_probability.value()); + float dropout_scale_inv_value = (1.0f - attributes.dropout_probability.value()); + + attributes.inputs[input_names::Dropout_scale] = std::make_shared(dropout_scale_value); + attributes.inputs[input_names::Dropout_scale_inv] = + std::make_shared(dropout_scale_inv_value); + } + + // ---------------------input tensor workarounds--------------------------- + + bool use_dp_workspace = false; + + cudaDeviceProp prop; + if (context.get_sm_version() > 0) { + prop.major = context.get_sm_version() / 10; + } else { + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + } + + if (detail::get_backend_version() >= 8905 && detail::get_backend_version() < 90000) { + // workspace optimization is enabled by default when: + // 8.9.5 <= cudnn version < 9.0.0 + // device >= hopper + // batch * num_heads * seq_len_q * seq_len_kv * 2 <= dP workspace limit + // + // This following environment variable allows you to control the dP workspace limit. + // From cuDNN version 9.0.0, this option is obsolete will be ignored. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=unset - enable workspace opt. until the default 256MB limit. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=-1 - always enable workspace opt. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=0 - always disable workspace opt. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=n - enable workspace opt. until the n byte limit + + // hopper or above + if (prop.major >= 9) { + // default upper limit for workspace 256MB + int64_t max_dp_workspace_bytes = 256 * 1024 * 1024; + + // allow setting the upper limit with envvars + char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); + if (env_dp_workspace_limit_char) { + char* end_ptr = nullptr; + max_dp_workspace_bytes = std::strtoll(env_dp_workspace_limit_char, &end_ptr, 10); + + if (*end_ptr != '\0') { + RETURN_CUDNN_FRONTEND_ERROR_IF(true, + error_code_t::ATTRIBUTE_NOT_SET, + "Invalid argument for CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT " + "(int64_t; in bytes)"); + } + } + + int64_t workspace_s_q = ((s_q + 64 - 1) / 64) * 64; + int64_t workspace_s_kv = ((s_kv + 64 - 1) / 64) * 64; + int64_t required_dp_workspace_bytes = b * h_q * workspace_s_q * workspace_s_kv * 2; + + if (max_dp_workspace_bytes == -1) { + use_dp_workspace = true; + } else if (max_dp_workspace_bytes == 0) { + use_dp_workspace = false; + } else { + use_dp_workspace = (required_dp_workspace_bytes <= max_dp_workspace_bytes); + } + } + } + + // Force dP workspace implementation if: + // - dBias is enabled (dBias is only supported on workspace implementation) + // - the user force requests deterministic algorithm on hopper + if (attributes.outputs[output_names::dBias] || attributes.is_deterministic_algorithm) { + use_dp_workspace = true; + } + + // --------------RNG node-------------------- + + if (is_dropout_prob) { + if (attributes.outputs[output_names::RNG_DUMP] != nullptr) { + rng_output = attributes.outputs[output_names::RNG_DUMP]; + rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0f - attributes.dropout_probability.value()), + rng_output); + } else { + rng_output = rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0f - attributes.dropout_probability.value())); + rng_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + } + } else if (is_dropout_mask) { + rng_output = attributes.inputs[input_names::Dropout_mask]; + } + + // --------------"dO * o => softmax_sum" chain-------------------- + + // last_output = dO * O + last_output = pointwise(attributes.inputs[input_names::dO], + attributes.inputs[input_names::O], + Pointwise_attributes().set_name("mul_dO_O").set_mode(PointwiseMode_t::MUL)); + last_output->set_dim({b, h_q, s_q, d_v}).set_stride({h_q * s_q * d_v, s_q * d_v, h_q * d_v, 1}); + + // last_output = reduce(last_output, "b hq sq dv -> b hq sq 1") + last_output = + reduction(last_output, Reduction_attributes().set_name("reduce_dO_o").set_mode(ReductionMode_t::ADD)); + last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + + if (attributes.outputs.find(output_names::DSINK_TOKEN) != attributes.outputs.end()) { + // sub_sink = sink - stats + auto sub_sink = pointwise(attributes.inputs[input_names::SINK_TOKEN], + attributes.inputs[input_names::Stats], + Pointwise_attributes().set_name("sub_sink").set_mode(PointwiseMode_t::SUB)); + + // exp_sink = exp(sub_sink) + auto exp_sink = + pointwise(sub_sink, Pointwise_attributes().set_name("exp_sink").set_mode(PointwiseMode_t::EXP)); + + // per_token_grad = exp_sink * last_output + auto per_token_grad = + pointwise(exp_sink, + last_output, + Pointwise_attributes().set_name("mul_exp_sink_last_output").set_mode(PointwiseMode_t::MUL)); + + // dSink = redduce(per_token_grad) + reduction(per_token_grad, + Reduction_attributes().set_name("reduce_per_token_grad").set_mode(ReductionMode_t::ADD), + attributes.outputs[output_names::DSINK_TOKEN]); + } + + // softmax_sum = last_output * dropout_scale + last_output = pointwise(last_output, + attributes.inputs[input_names::Dropout_scale_inv] + ? attributes.inputs[input_names::Dropout_scale_inv] + : one_tensor, + Pointwise_attributes().set_name("scale_dropout_inv").set_mode(PointwiseMode_t::MUL)); + last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + + softmax_sum = last_output; + softmax_sum->set_is_virtual(false); + softmax_sum->set_dim({b, h_q, s_q, 1}); + softmax_sum->set_data_type(DataType_t::FLOAT); + + if (attributes.inputs[input_names::Stats]->get_ragged_offset() && attributes.max_total_seq_len_q.has_value()) { + // sized TH1 softmax_sum + softmax_sum->set_stride(attributes.inputs[input_names::Stats]->get_stride()); + softmax_sum->set_ragged_offset(attributes.inputs[input_names::Stats]->get_ragged_offset()); + softmax_sum_size = attributes.max_total_seq_len_q.value() * + (attributes.inputs[input_names::Stats]->get_stride())[2] * sizeof(float); + } else { + // sized BHS1 softmax_sum + softmax_sum->set_stride({h_q * s_q, s_q, 1, 1}); + softmax_sum_size = b * h_q * s_q * 1 * sizeof(float); + } + + // --------------"Q @ KT => exp_softmax => dV" chain-------------------- + + // s = einsum(q, k, "b hq sq dqk, b (hk g) skv dqk -> b hq sq skv", g=hq//hk) + last_output = matmul(attributes.inputs[input_names::Q], + attributes.inputs[input_names::K], + Matmul_attributes() + .set_name("matmul_Q_KT") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV])); + last_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + + // last_output = last_output * attention_scale + if (attributes.inputs[input_names::Attn_scale]) { + last_output = pointwise(last_output, + attributes.inputs[input_names::Attn_scale], + Pointwise_attributes().set_name("mul_s_attn_scale").set_mode(PointwiseMode_t::MUL)); + } + + if (attributes.attention_score_modifier != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attributes.attention_score_modifier(graph_, last_output); + sub_nodes.emplace_back(node_); + } + + // (optional) last_output = last_output + bias + if (attributes.inputs.find(input_names::Bias) != attributes.inputs.end() && + attributes.inputs[input_names::Bias]) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::bias(graph_, last_output, attributes.inputs[input_names::Bias]); + sub_nodes.emplace_back(node_); + } + + // (optional) last_output = last_output + alibi_mask + if (attributes.alibi_mask) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::alibi_mask(graph_, last_output, alibi_slopes, h_q, alibi_slopes_size); + sub_nodes.emplace_back(node_); + } + + // (optional) Apply padding mask + if (attributes.padding_mask) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::padding_mask(graph_, + last_output, + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::SEQ_LEN_Q]); + sub_nodes.emplace_back(node_); + } + + // last_output = last_output - stats + last_output = pointwise(last_output, + attributes.inputs[input_names::Stats], + Pointwise_attributes().set_name("sub_s_m").set_mode(PointwiseMode_t::SUB)); + + // WAR for bug 4475073 by explicitly putting the padding value again after the stats have been loaded + if (attributes.padding_mask && detail::get_backend_version() >= 90000 && + detail::get_backend_version() < 91000) { + auto row_idx_output = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_row_idx_2nd_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_idx_output->set_data_type(DataType_t::INT32); + + auto col_idx_output = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_col_idx_2nd_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_idx_output->set_data_type(DataType_t::INT32); + + auto row_mask_output = pointwise(row_idx_output, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("lt_row_sq_2nd_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + row_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto col_mask_output = pointwise(col_idx_output, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("lt_col_skv_2nd_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + col_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto padding_mask_output = pointwise(row_mask_output, + col_mask_output, + Pointwise_attributes() + .set_name("and_row_col_2nd_padding") + .set_mode(PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(DataType_t::BOOLEAN)); + padding_mask_output->set_data_type(DataType_t::BOOLEAN); + auto negative_inf_padding = + std::make_shared(attn::score_modifiers::get_negative_inf_value()); + + last_output = pointwise( + last_output, + negative_inf_padding, + padding_mask_output, + Pointwise_attributes().set_name("select_2nd_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); + } + + // Apply (bottom-right) causal masking (with right bound) and/or set the left bound + if (attributes.left_bound.has_value() || attributes.right_bound.has_value()) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + + auto s_kv_ptr = attributes.inputs.find(input_names::SEQ_LEN_KV) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_KV] + : nullptr; + auto s_q_ptr = attributes.inputs.find(input_names::SEQ_LEN_Q) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_Q] + : nullptr; + + last_output = attn::score_modifiers::sliding_window_mask(graph_, + last_output, + attributes.diagonal_alignment, + attributes.left_bound, + attributes.right_bound, + s_q, + s_kv, + s_q_ptr, + s_kv_ptr); + sub_nodes.emplace_back(std::move(node_)); + } + + // last_output = exp(last_output) + last_output = pointwise(last_output, Pointwise_attributes().set_name("exp_s").set_mode(PointwiseMode_t::EXP)); + + exp_s_output = last_output; + + // (optional) last_output = last_output * dropout rng_output + if (is_dropout_prob || is_dropout_mask) { + last_output = + pointwise(last_output, + rng_output, + Pointwise_attributes().set_name("mul_p_dropout_mask").set_mode(PointwiseMode_t::MUL)); + } + + // (optional) last_output = last_output * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + last_output = + pointwise(last_output, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_p_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + + // dV = einsum(p, dO, "b hq sq skv", "b hq sq dv -> b hq skv dv") + // if GQA, then dV = reduce(dV, "b (hv g) skv dv -> b hv skv dv", g=hq//hv) + // as reshape + matmul + last_output = reshape(last_output, Reshape_attributes().set_name("reshape_p")); + last_output->set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (h_q == h_v) { + // for MHA + matmul(last_output, + attributes.inputs[input_names::dO], + Matmul_attributes() + .set_name("matmul_pT_dO") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]), + attributes.outputs[output_names::dV]); + } else { + // for GQA and MQA + dV_fullhead = matmul(last_output, + attributes.inputs[input_names::dO], + Matmul_attributes() + .set_name("matmul_pT_dO") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q])); + + dV_fullhead->set_dim({b, h_q, s_kv, d_v}); + dV_fullhead->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (attributes.outputs[output_names::dV]->get_ragged_offset() && + attributes.max_total_seq_len_kv.has_value()) { + // hack 1 - map dV strides to dV_fullhead strides + std::vector dV_fullhead_stride = attributes.outputs[output_names::dV]->get_stride(); + dV_fullhead_stride[2] = dV_fullhead_stride[2] * (h_q / h_v); // sequence stride + dV_fullhead_stride[0] = dV_fullhead_stride[0] * (h_q / h_v); // batch stride + dV_fullhead->set_stride(dV_fullhead_stride); + // hack 2 - map dV ragged offset to dV_fullhead ragged offset with implicit multiplier + // implicit multiplier = h_q / h_v + dV_fullhead->set_ragged_offset(attributes.outputs[output_names::dV]->get_ragged_offset()); + // hack 3 - non virtual dV full head + dV_fullhead->set_is_virtual(false); + dV_fullhead_size = attributes.max_total_seq_len_kv.value() * dV_fullhead_stride[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dV_fullhead->set_stride({h_q * s_kv * d_v, s_kv * d_v, d_v, 1}); + } + + reduction(dV_fullhead, + Reduction_attributes().set_name("red_dV_head").set_mode(ReductionMode_t::ADD), + attributes.outputs[output_names::dV]); + } + + // --------------"dO @ VT => dS_output => dK" chain-------------------- + + // dP = einsum(dO, v, "b hq sq dv, b (hv g) skv dv -> b hq sq skv", g=hq//hv) + last_output = matmul(attributes.inputs[input_names::dO], + attributes.inputs[input_names::V], + Matmul_attributes() + .set_name("matmul_dO_VT") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV])); + last_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + + // last_output = last_output(dP) * mask + if (is_dropout_prob || is_dropout_mask) { + last_output = pointwise(last_output, + rng_output, + Pointwise_attributes().set_name("dP_dropout_mask").set_mode(PointwiseMode_t::MUL)); + } + + // last_output = last_output - softmax_sum + last_output = pointwise(last_output, + softmax_sum, + Pointwise_attributes().set_name("sub_dP_softmax_sum").set_mode(PointwiseMode_t::SUB)); + + // last_output = last_output * exp_s_output + last_output = pointwise( + last_output, exp_s_output, Pointwise_attributes().set_name("mul_dP_exp_s").set_mode(PointwiseMode_t::MUL)); + + // (optional) last_output = last_output * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + last_output = + pointwise(last_output, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_dS_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + + if (attributes.outputs[output_names::dBias]) { + reduction(last_output, + Reduction_attributes().set_name("red_dP_dBias").set_mode(ReductionMode_t::ADD), + attributes.outputs[output_names::dBias]); + } + + // apply the bprop of attention score modifier + if (attributes.attention_score_modifier_bprop != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attributes.attention_score_modifier_bprop(graph_, last_output); + sub_nodes.emplace_back(node_); + } + + // (optional) last_output = last_output * bmm_scale + if (attributes.inputs[input_names::Attn_scale]) { + last_output = + pointwise(last_output, + attributes.inputs[input_names::Attn_scale], + Pointwise_attributes().set_name("mul_dS_attn_scale").set_mode(PointwiseMode_t::MUL)); + } + + dS_output = last_output; + + // dK = einsum(dS, Q, "b hq sq skv", "b hq sq dqk -> b hq skv dqk") + // if GQA, then dK = reduce(dK, "b (hk g) skv dqk -> b hk skv dqk", hq//hk) + // as reshape + matmul + last_output = reshape(last_output, Reshape_attributes().set_name("reshape_dS")); + last_output->set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (h_q == h_k) { + // for MHA + matmul(last_output, + attributes.inputs[input_names::Q], + Matmul_attributes() + .set_name("matmul_dST_Q") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]), + attributes.outputs[output_names::dK]); + } else { + // for GQA and MQA + dK_fullhead = matmul(last_output, + attributes.inputs[input_names::Q], + Matmul_attributes() + .set_name("matmul_dST_Q") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q])); + + dK_fullhead->set_dim({b, h_q, s_kv, d_qk}); + dK_fullhead->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (attributes.outputs[output_names::dK]->get_ragged_offset() && + attributes.max_total_seq_len_kv.has_value()) { + // sized THD dK_full_heads + // hack 1 - map dK strides to dK_fullhead strides + std::vector dK_fullhead_stride = attributes.outputs[output_names::dK]->get_stride(); + dK_fullhead_stride[0] = dK_fullhead_stride[0] * (h_q / h_k); // batch stride + dK_fullhead_stride[2] = dK_fullhead_stride[2] * (h_q / h_k); // sequence stride + dK_fullhead->set_stride(dK_fullhead_stride); + // hack 2 - map dK ragged offset to dK_fullhead ragged offset with implicit multiplier + // implicit multiplier = h_q / h_k + dK_fullhead->set_ragged_offset(attributes.outputs[output_names::dK]->get_ragged_offset()); + // hack 3 - non virtual dK full head + dK_fullhead->set_is_virtual(false); + dK_fullhead_size = attributes.max_total_seq_len_kv.value() * dK_fullhead_stride[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dK_fullhead->set_stride({h_q * s_kv * d_qk, s_kv * d_qk, d_qk, 1}); + } + + reduction(dK_fullhead, + Reduction_attributes().set_name("red_dK_head").set_mode(ReductionMode_t::ADD), + attributes.outputs[output_names::dK]); + } + + // --------------"dp_scaled @ K => dQ" chain-------------------- + + auto const& kt_dim = attributes.inputs[input_names::K]->get_dim(); + auto const& kt_stride = attributes.inputs[input_names::K]->get_stride(); + + // dQ = einsum(dS, K, "b hq sq skv, b (hk g) skv dqk -> b hq sq dqk", g=hq//hk) + // as reshape + matmul + last_output = reshape(attributes.inputs[input_names::K], Reshape_attributes().set_name("reshape_k")); + last_output->set_dim({kt_dim[0], kt_dim[1], kt_dim[3], kt_dim[2]}) + .set_stride({kt_stride[0], kt_stride[1], kt_stride[3], kt_stride[2]}); + + if (attributes.inputs[input_names::K]->get_ragged_offset() != nullptr) { + last_output->set_ragged_offset(attributes.inputs[input_names::K]->get_ragged_offset()); + } + + if (!use_dp_workspace) { + dQ_accum = std::make_shared(); + dQ_accum->set_is_virtual(false); + dQ_accum->set_dim({b, h_q, s_q, d_qk}); + dQ_accum->set_data_type(DataType_t::FLOAT); + + if (attributes.outputs[output_names::dQ]->get_ragged_offset() && + attributes.max_total_seq_len_q.has_value()) { + // sized THD dQ_accum + dQ_accum->set_stride(attributes.outputs[output_names::dQ]->get_stride()); + dQ_accum->set_ragged_offset(attributes.outputs[output_names::dQ]->get_ragged_offset()); + dQ_accum_size = attributes.max_total_seq_len_q.value() * + (attributes.outputs[output_names::dQ]->get_stride())[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dQ_accum->set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}); + dQ_accum_size = b * h_q * s_q * d_qk * sizeof(float); + } + + matmul(dS_output, + last_output, + Matmul_attributes() + .set_name("matmul_dS_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), + dQ_accum); + + pointwise(dQ_accum, + Pointwise_attributes().set_name("identity_dQ").set_mode(PointwiseMode_t::IDENTITY), + attributes.outputs[output_names::dQ]); + } else { + matmul(dS_output, + last_output, + Matmul_attributes() + .set_name("matmul_dS_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), + attributes.outputs[output_names::dQ]); + } + + return {error_code_t::OK, ""}; + } + + std::pair> + override_heuristics_query() const { + if (is_deterministic_algorithm_supported_on_blackwell) { + return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } else { + return {-1, {}}; + } + } + + virtual int64_t + get_fe_workspace_size_node() const override final { + int64_t size = 0; + + size += ((alibi_slopes_size + 15) / 16 * 16); // align alibi slopes memory to 16 bytes + size += dQ_accum_size; + size += dK_fullhead_size; + size += dV_fullhead_size; + size += softmax_sum_size; + + if (has_workaround_padding_mask) { + size += batch_size_for_workaround_padding_mask * sizeof(int32_t) * 2; + } + + return size; + } + + virtual error_t + collect_tensors_in_workspace_node( + std::unordered_map>>& + workspace_modifications, + int64_t& offset) const override final { + if (attributes.alibi_mask) { + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Q, input_names::Q); + int64_t const h_q = Q->second->get_dim()[1]; + auto alibi_slopes_vec = detail::get_alibi_slope(h_q); + workspace_modifications.emplace(alibi_slopes->get_uid(), std::make_tuple(0, offset, alibi_slopes_vec)); + int64_t alibi_slopes_size_padded = ((alibi_slopes_size + 15) / 16 * 16); + offset = offset + alibi_slopes_size_padded; + } + + if (dQ_accum && !dQ_accum->get_is_virtual()) { + if (detail::get_backend_version() < 90600) { + // prior to cuDNN 9.6.0, dQ_accum needed to be memset by frontend + workspace_modifications.emplace(dQ_accum->get_uid(), + std::make_tuple(1, offset, std::vector{(float)dQ_accum_size})); + } else { + workspace_modifications.emplace(dQ_accum->get_uid(), std::make_tuple(2, offset, std::vector())); + } + offset = offset + dQ_accum_size; + } + + if (dK_fullhead && !dK_fullhead->get_is_virtual()) { + workspace_modifications.emplace(dK_fullhead->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + dK_fullhead_size; + } + + if (dV_fullhead && !dV_fullhead->get_is_virtual()) { + workspace_modifications.emplace(dV_fullhead->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + dV_fullhead_size; + } + + if (softmax_sum && !softmax_sum->get_is_virtual()) { + workspace_modifications.emplace(softmax_sum->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + softmax_sum_size; + } + + if (has_workaround_padding_mask) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Collecting workaround padding mask tensors with batch size " + << batch_size_for_workaround_padding_mask << " with UIDs " + << workaround_padding_mask_seq_len_q->get_uid() << " and " + << workaround_padding_mask_seq_len_kv->get_uid()); + std::vector workaround_padding_mask_seq_len_q_vec(batch_size_for_workaround_padding_mask, + s_q_for_workaround_padding_mask); + std::vector workaround_padding_mask_seq_len_kv_vec(batch_size_for_workaround_padding_mask, + s_kv_for_workaround_padding_mask); + + // reinterpret_cast the int32_t vector data to float vector for workspace_modifications + std::vector workaround_padding_mask_seq_len_q_vec_float( + reinterpret_cast(workaround_padding_mask_seq_len_q_vec.data()), + reinterpret_cast(workaround_padding_mask_seq_len_q_vec.data()) + + batch_size_for_workaround_padding_mask); + std::vector workaround_padding_mask_seq_len_kv_vec_float( + reinterpret_cast(workaround_padding_mask_seq_len_kv_vec.data()), + reinterpret_cast(workaround_padding_mask_seq_len_kv_vec.data()) + + batch_size_for_workaround_padding_mask); + + workspace_modifications.emplace(workaround_padding_mask_seq_len_q->get_uid(), + std::make_tuple(0, offset, workaround_padding_mask_seq_len_q_vec_float)); + offset = offset + batch_size_for_workaround_padding_mask * sizeof(float); + workspace_modifications.emplace(workaround_padding_mask_seq_len_kv->get_uid(), + std::make_tuple(0, offset, workaround_padding_mask_seq_len_kv_vec_float)); + offset = offset + batch_size_for_workaround_padding_mask * sizeof(float); + } + + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "SDPA_BWD"})"_json); + } +#endif +}; + +class UnifiedSDPANode : public SDPANodeBase { + public: + UnifiedSDPANode(SDPA_attributes&& attributes_, detail::Context const& context) + : SDPANodeBase(std::move(attributes_), context) {} + + Type + getType() override final { + return Type::UNIFIED_SDPA; + } + + error_t + expand_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for UnifiedSDPANode node " << attributes.name); + + // DO NOT REMOVE + // input data type is needed for: + // - aType of bmm2 + // - dropout scale in pre 8.9.3 + attributes.fill_from_context(this->context); + + //// Optional Attn scale + // In case user provided a scalar value, do a fused scalar. + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(operations); + CUDNN_FE_LOG_LABEL("INFO: " << "Building UnifiedSDPANode operations " << attributes.name << " "); + auto cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Unified SDPA node requires cuDNN 9.13.1"}; + +#if (CUDNN_VERSION >= 91301) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91301, cudnn_ver_error); + auto unified_sdpa_operation = + make_shared_backend_pointer((cudnnBackendDescriptorType_t)CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR); + + auto Q = attributes.inputs.find(SDPA_attributes::input_names::Q)->second; + auto backend_q = tensors[Q->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_QDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_q)); + + auto K = attributes.inputs.find(SDPA_attributes::input_names::K)->second; + auto backend_k = tensors[K->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_KDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_k)); + + auto V = attributes.inputs.find(SDPA_attributes::input_names::V)->second; + auto backend_v = tensors[V->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_VDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_v)); + + auto O = attributes.outputs.find(SDPA_attributes::output_names::O)->second; + auto backend_o = tensors[O->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_ODESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_o)); + + auto stats_it = attributes.outputs.find(SDPA_attributes::output_names::Stats); + if (stats_it != attributes.outputs.end()) { + auto backend_stats = tensors[stats_it->second->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_STATSDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_stats)); + } + + auto attn_scale_it = attributes.inputs.find(SDPA_attributes::input_names::Attn_scale); + if (attn_scale_it != attributes.inputs.end()) { + auto backend_scale = tensors[attn_scale_it->second->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_SCALEDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_scale)); + } + + auto block_mask_it = attributes.inputs.find(SDPA_attributes::input_names::Block_mask); + if (block_mask_it != attributes.inputs.end() && block_mask_it->second != nullptr) { + auto block_mask_cudnn_ver_error = + error_t{error_code_t::GRAPH_NOT_SUPPORTED, "Block mask in unified SDPA node requires cuDNN 9.14.0"}; +#if CUDNN_VERSION >= 91400 + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91400, block_mask_cudnn_ver_error); + auto backend_block_mask = tensors[block_mask_it->second->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_BLOCK_MASK_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_block_mask)); +#else + return block_mask_cudnn_ver_error; +#endif + } + + // Paged attention attributes + if (is_paged_k() || is_paged_v() || has_seq_len_q() || has_seq_len_kv()) { + auto paged_cudnn_ver_error = error_t{error_code_t::GRAPH_NOT_SUPPORTED, + "Paged attention in unified SDPA node requires cuDNN 9.15.0"}; +#if (CUDNN_VERSION >= 91500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(91500, paged_cudnn_ver_error); + + if (is_paged_k()) { + auto page_table_K = attributes.inputs.find(SDPA_attributes::input_names::Page_table_K)->second; + auto backend_page_table_K = tensors[page_table_K->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_PAGE_TABLE_KDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_page_table_K)); + } + + if (is_paged_v()) { + auto page_table_V = attributes.inputs.find(SDPA_attributes::input_names::Page_table_V)->second; + auto backend_page_table_V = tensors[page_table_V->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_PAGE_TABLE_VDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_page_table_V)); + } + + if (has_seq_len_q()) { + auto seq_len_Q = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_Q)->second; + auto backend_seq_len_Q = tensors[seq_len_Q->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_SEQ_LEN_QDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_seq_len_Q)); + } + + if (has_seq_len_kv()) { + auto seq_len_KV = attributes.inputs.find(SDPA_attributes::input_names::SEQ_LEN_KV)->second; + auto backend_seq_len_KV = tensors[seq_len_KV->get_uid()]->get_desc()->get_backend_descriptor(); + _CUDNN_CHECK_CUDNN_ERROR(detail::set_attribute(unified_sdpa_operation->get_backend_descriptor(), + CUDNN_ATTR_OPERATION_SDPA_FWD_SEQ_LEN_KVDESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &backend_seq_len_KV)); + } + + // Ignore attributes.max_seq_len_kv, because unified engine doesn't need it (it's harmless if set). + + // Ignore attributes.padding_mask, because unified engine already applies an implicit padding mask + // if seq_len_Q and seq_len_KV are both provided. We already checked in + // `SDPA_attributes::validate_sdpa_support_surface()` that padding_mask must be true if and + // only if seq_len_Q and seq_len_KV are both set, so we don't need to check it here. +#else + return paged_cudnn_ver_error; +#endif + } + + _CUDNN_CHECK_CUDNN_ERROR(detail::finalize(unified_sdpa_operation->get_backend_descriptor())); + + raw_operations.push_back(unified_sdpa_operation); + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(uids_involved_in_operations); + CUDNN_FRONTEND_UNUSED(raw_operations); + CUDNN_FRONTEND_UNUSED(tensors); + return cudnn_ver_error; +#endif // CUDNN_VERSION >= 91301 + } +}; + +} // namespace cudnn_frontend::graph diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_fp8_bwd.h new file mode 100644 index 00000000..225ecf1f --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -0,0 +1,649 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "matmul_fp8.h" +#include "pointwise.h" +#include "reduction.h" +#include "softmax.h" + +namespace cudnn_frontend::graph { + +class SDPAFP8BackwardNode : public NodeCRTP { + using input_names = SDPA_fp8_backward_attributes::input_names; + using output_names = SDPA_fp8_backward_attributes::output_names; + + private: + mutable bool is_deterministic_algorithm_supported_on_blackwell = false; // Will be edited in pre_validate_node() + + public: + SDPA_fp8_backward_attributes attributes; + + SDPAFP8BackwardNode(SDPA_fp8_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating SDPAFP8BackwardNode " << attributes.name); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90100, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 backward operation is only supported starting cudnn 9.1.0. Please " + "consider upgrading your current version."); + + cudaDeviceProp prop; + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + prop.major < 9, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Hopper architecture and newer. Please " + "consider using a newer architecture."); + + // check that Q, K, V, O, stats, dO, dQ, dK, dV tensors has been assigned + // check that dim and strides has been assigned and last stride is 1 +#define CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(port, port_map) \ + { \ + std::shared_ptr tensor_ptr = port_map.at(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_dim().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The dim for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_stride().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The stride for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + tensor_ptr->get_stride()[3] != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Q, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::K, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::V, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::O, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Stats, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::dO, attributes.inputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dQ, attributes.outputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dK, attributes.outputs); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dV, attributes.outputs); + +#undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + + // validate backend limitations for the operation + // clang-format off + int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + int64_t s_kv = attributes.inputs.at(input_names::K)->get_dim()[2]; + int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; + int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; + int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; + int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; + int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + + auto const& dq_tensor = attributes.outputs.at(output_names::dQ); + auto const& dq_data_type = dq_tensor->get_data_type(); + auto const& dk_tensor = attributes.outputs.at(output_names::dK); + auto const& dk_data_type = dk_tensor->get_data_type(); + auto const& dv_tensor = attributes.outputs.at(output_names::dV); + auto const& dv_data_type = dv_tensor->get_data_type(); + + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + bool const is_dropout = attributes.dropout_probability.has_value(); + + // validation TODO: + // - validate stats has valid dims + + // validate basic dimension requirements + if(prop.major >= 10) { + RETURN_CUDNN_FRONTEND_ERROR_IF(((d_qk > 128) || (d_qk % 16 != 0)) && !(d_qk == 192 && d_v == 128), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_qk shoud be less than or equal to 128 and hidden_dim d_qk should be multiple of 16 unless d_qk == 192 and d_v == 128"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(((d_v > 128) || (d_v % 16 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_v shoud be less than or equal to 128 and hidden_dim d_v should be multiple of 16"); + } + else { + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk != 128) || (d_qk % 16 != 0) || (d_v != 128) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be equal to 128 and hidden_dim should be multiple of 16"); + } + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); + + // validate options for attn_scale + auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); + bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attributes.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + // validate options for bias mask + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); + + // validate options for padding mask + auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); + bool const has_seq_len_q = (seq_len_q != attributes.inputs.end()) && (seq_len_q->second != nullptr); + auto const& seq_len_kv = attributes.inputs.find(input_names::SEQ_LEN_KV); + bool const has_seq_len_kv = (seq_len_kv != attributes.inputs.end()) && (seq_len_kv->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.padding_mask && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + RETURN_CUDNN_FRONTEND_ERROR_IF((!attributes.padding_mask) && (has_seq_len_q || has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + // validate options for dropout mask + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && is_dropout_custom, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + + // Validate options for causal_mask_bottom_right + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && detail::get_backend_version() < 90700, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.7.0, bottom right causal masking is not supported."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && prop.major < 10, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Blackwell architecture and newer. Please " + "consider using a newer architecture."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask && attributes.causal_mask_bottom_right, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask and causal mask cannot be both enabled"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask does not support s_q > s_kv. Please virtually slice the Q tensor and pass it as s_q == s_kv"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with is_bias=False, is_dropout=False."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with s_q multiple of 64, and s_kv multiple of 64"); + + // validate that datatype is set for the graph + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); + + // validate options for deterministic algorithm + if (attributes.is_deterministic_algorithm && (prop.major == 10)) { + RETURN_CUDNN_FRONTEND_ERROR_IF((detail::get_backend_version() < 91900), + error_code_t::GRAPH_NOT_SUPPORTED, + "FP8 deterministic algorithm is not supported on blackwell architecture with cudnn version below 9.19.0"); + + // dbias bias rng/dropout alibi + RETURN_CUDNN_FRONTEND_ERROR_IF(is_dropout, + error_code_t::GRAPH_NOT_SUPPORTED, + "FP8 deterministic algorithm is not supported on blackwell architecture when dropout is enabled"); + + is_deterministic_algorithm_supported_on_blackwell = true; + } + + // if output data type is half or bfloat16 for any of dq, dk, dv, and version is below 9.13 or is not blackwell, return NOT_SUPPORTED + RETURN_CUDNN_FRONTEND_ERROR_IF( + (dq_data_type == DataType_t::HALF || dq_data_type == DataType_t::BFLOAT16 || + dk_data_type == DataType_t::HALF || dk_data_type == DataType_t::BFLOAT16 || + dv_data_type == DataType_t::HALF || dv_data_type == DataType_t::BFLOAT16) && + (detail::get_backend_version() < 91300 || prop.major < 10), + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on cuDNN version 9.13.0 and newer. Please " + "consider upgrading your current version."); + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + return {error_code_t::OK, ""}; + } + + error_t + expand_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for Scaled_dot_product_flash_attention node " + << attributes.name); + + attributes.fill_from_context(context); + + // Gather dim to fill properties of virtual tensors + auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = q_dim[0]; + auto h_q = q_dim[1]; + auto s_q = q_dim[2]; + // auto d_qk = q_dim[3]; + auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + // auto h_k = k_dim[1]; + auto s_kv = k_dim[2]; + // auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); + // auto h_v = v_dim[1]; + // auto d_v = v_dim[3]; + + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, VT + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // VT = {b, h_v, d_v, s_kv} + // So the code below maps the K->KT and V->VT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_stride(temp_vec); + + std::shared_ptr rng_output; + + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + + // if dropout_prob is used, then the node passes scale and scale inverse + // if dropout_mask is used, then the user passes scale and scale_inverse + bool is_dropout_prob = (attributes.dropout_probability.has_value()); + bool is_dropout_mask = (attributes.inputs[input_names::Dropout_mask] != nullptr); + if (is_dropout_prob) { + float dropout_scale_value = 1.0f / (1.0f - attributes.dropout_probability.value()); + float dropout_scale_inv_value = (1.0f - attributes.dropout_probability.value()); + + attributes.inputs[input_names::Dropout_scale] = std::make_shared(dropout_scale_value); + attributes.inputs[input_names::Dropout_scale_inv] = + std::make_shared(dropout_scale_inv_value); + } + + // --------------RNG node-------------------- + + if (is_dropout_prob) { + rng_output = rng(attributes.inputs[input_names::Seed], + attributes.inputs[input_names::Offset], + Rng_attributes() + .set_name("rng") + .set_distribution(RngDistribution_t::BERNOULLI) + .set_bernoulli_probability(1.0f - attributes.dropout_probability.value())); + rng_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); + } else if (is_dropout_mask) { + rng_output = attributes.inputs[input_names::Dropout_mask]; + } + + //// dO * O + mul_attributes.set_name("mul_dO_O"); + auto last_output = + pointwise(attributes.inputs[input_names::dO], attributes.inputs[input_names::O], mul_attributes); + + // reduce(dO) + last_output = + reduction(last_output, Reduction_attributes().set_name("reduce_dO").set_mode(ReductionMode_t::ADD)); + last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + + // Descale dO + mul_attributes.set_name("descale_dO"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_dO), mul_attributes); + last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + + // Descale O + mul_attributes.set_name("descale_O"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_O), mul_attributes); + + // softmax_sum = last_output * dropout_scale + if(attributes.inputs[input_names::Dropout_scale_inv]) { + last_output = pointwise(last_output, + attributes.inputs[input_names::Dropout_scale_inv], + Pointwise_attributes().set_name("scale_dropout_inv").set_mode(PointwiseMode_t::MUL)); + } + auto softmax_sum = last_output; + + //// Q * K + auto bmm_Q_K_attributes = Matmul_attributes().set_name("bmm_Q_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); + auto last_dV = matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm_Q_K_attributes); + + //// Optional Attn scale + // In case user provided a scalar value, do a fused scalar. + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + // If attn scale present, add a pointwise mul node + if (auto attn_scale_it = attributes.inputs.find(input_names::Attn_scale); attn_scale_it != attributes.inputs.end()) { + mul_attributes.set_name("attn_scale"); + last_dV = pointwise(last_dV, attn_scale_it->second, mul_attributes); + } + + //// Descales + // Descale Q + mul_attributes.set_name("descale_q"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Descale_Q), mul_attributes); + + // Descale K + mul_attributes.set_name("descale_k"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Descale_K), mul_attributes); + + // (optional) last_dV = last_dV + bias + if (auto bias_it = attributes.inputs.find(input_names::Bias); bias_it != attributes.inputs.end()) { + last_dV = pointwise(last_dV, + bias_it->second, + Pointwise_attributes().set_name("add_bias").set_mode(PointwiseMode_t::ADD)); + } + + // (optional) Apply padding mask + if (attributes.padding_mask) { + auto row_idx_output = pointwise(last_dV, + Pointwise_attributes() + .set_name("gen_row_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_idx_output->set_data_type(DataType_t::INT32); + + auto col_idx_output = pointwise(last_dV, + Pointwise_attributes() + .set_name("gen_col_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_idx_output->set_data_type(DataType_t::INT32); + + auto row_mask_output = pointwise(row_idx_output, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("lt_row_sq_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + row_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto col_mask_output = pointwise(col_idx_output, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("lt_col_skv_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + col_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto padding_mask_output = pointwise(row_mask_output, + col_mask_output, + Pointwise_attributes() + .set_name("and_row_col_padding") + .set_mode(PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(DataType_t::BOOLEAN)); + padding_mask_output->set_data_type(DataType_t::BOOLEAN); + + // Use a smaller value of neg infinity so that the softmax stats for rows that are fully padded dont + // go towards NaNs/Infs when multipled by the numerous scale/descale + auto negative_inf_padding = std::make_shared(attn::score_modifiers::get_negative_inf_value()); + + last_dV = + pointwise(last_dV, + negative_inf_padding, + padding_mask_output, + Pointwise_attributes().set_name("select_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); + } + + //// Optional causal masking + if (attributes.causal_mask) { + auto row_index_attributes = + Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + std::shared_ptr row_index_output = pointwise(last_dV, row_index_attributes); + row_index_output->set_data_type(DataType_t::INT32); + + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + auto const& col_index_output = pointwise(last_dV, col_index_attributes); + col_index_output->set_data_type(DataType_t::INT32); + + if (attributes.causal_mask_bottom_right) { + if (attributes.inputs[input_names::SEQ_LEN_KV]) { + row_index_output = pointwise(row_index_output, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index_output = pointwise(row_index_output, + std::make_shared(static_cast(s_kv)), + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } + row_index_output->set_data_type(DataType_t::INT32); + + if (attributes.inputs[input_names::SEQ_LEN_Q]) { + row_index_output = pointwise(row_index_output, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index_output = pointwise(row_index_output, + std::make_shared(static_cast(s_q)), + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } + row_index_output->set_data_type(DataType_t::INT32); + } + + auto greater_than_attributes = Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN); + auto const& row_greater_than_col_output = + pointwise(row_index_output, col_index_output, greater_than_attributes); + row_greater_than_col_output->set_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_causal = std::make_shared(attn::score_modifiers::get_negative_inf_value()); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + last_dV = pointwise(last_dV, negative_inf_causal, row_greater_than_col_output, binary_select_attributes); + } + + //// Apply Softmax + // last_dV = last_dV - stats + last_dV = pointwise(last_dV, + attributes.inputs[input_names::Stats], + Pointwise_attributes().set_name("sub_dV_Stats").set_mode(PointwiseMode_t::SUB)); + + // last_dV = exp(last_dV) + last_dV = pointwise(last_dV, Pointwise_attributes().set_name("exp_dV").set_mode(PointwiseMode_t::EXP)); + auto exp_S = last_dV; + + // (optional) last_dV = last_dV * dropout rng_output + if (is_dropout_prob || is_dropout_mask) { + last_dV = + pointwise(last_dV, + rng_output, + Pointwise_attributes().set_name("mul_p_dropout_mask").set_mode(PointwiseMode_t::MUL)); + } + + // (optional) last_dV = last_dV * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + last_dV = + pointwise(last_dV, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_dS_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + + // Scale S + mul_attributes.set_name("scale_S"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Scale_S), mul_attributes); + last_dV->set_data_type(attributes.inputs.at(input_names::Q)->get_data_type()); + + // Reshape S + last_dV = reshape(last_dV, Reshape_attributes().set_name("S_transpose")); + last_dV->set_name("S_T").set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + last_dV->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + //// S_T * dO + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(last_dV, + attributes.inputs[input_names::dO], + attributes.inputs[input_names::Descale_S], + attributes.inputs[input_names::Descale_dO], + attributes.inputs[input_names::Scale_dV], + Matmul_fp8_attributes().set_name("bmm_S_T_dO") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]), + attributes.outputs[output_names::dV], + attributes.outputs[output_names::Amax_dV]); + + //// dO * V_T + auto bmm_dO_V_T_attributes = Matmul_attributes().set_name("bmm_dO_V_T") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_n_override(attributes.inputs[input_names::SEQ_LEN_KV]); + last_output = + matmul(attributes.inputs[input_names::dO], attributes.inputs[input_names::V], bmm_dO_V_T_attributes); + + //// Descales + // Descale dO + mul_attributes.set_name("descale_dO"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_dO), mul_attributes); + + // Descale V + mul_attributes.set_name("descale_V"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_V), mul_attributes); + + // dP = last_output - softmax_sum + auto dP = pointwise(last_output, + softmax_sum, + Pointwise_attributes().set_name("sub_dP_softmax_sum").set_mode(PointwiseMode_t::SUB)); + + // dP = dP * exp_S + mul_attributes.set_name("mul_dP_exp_S"); + dP = pointwise(dP, exp_S, mul_attributes); + + // (optional) dP = dP * dropout_scale + if (attributes.inputs[input_names::Dropout_scale]) { + dP = + pointwise(dP, + attributes.inputs[input_names::Dropout_scale], + Pointwise_attributes().set_name("mul_dS_dropout_scale").set_mode(PointwiseMode_t::MUL)); + } + + // if (attributes.outputs[output_names::dBias]) { + // reduction(dP, + // Reduction_attributes().set_name("red_dP_dBias").set_mode(ReductionMode_t::ADD), + // attributes.outputs[output_names::dBias]); + // } + + // (optional) dP = dP * attn_scale + if (auto attn_scale_it = attributes.inputs.find(input_names::Attn_scale); attn_scale_it != attributes.inputs.end()) { + mul_attributes.set_name("mul_dS_attn_scale"); + dP = pointwise(dP, attn_scale_it->second, mul_attributes); + } + + // Amax dP + auto amax_attributes = Reduction_attributes().set_name("amax_dP").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(dP, amax_attributes, attributes.outputs.at(output_names::Amax_dP)); + + // Scale dP + mul_attributes.set_name("scale_dP"); + dP = pointwise(dP, attributes.inputs.at(input_names::Scale_dP), mul_attributes); + dP->set_data_type(attributes.inputs.at(input_names::dO)->get_data_type()); + + //// dP * K + auto const& kt_dim = attributes.inputs[input_names::K]->get_dim(); + auto const& kt_stride = attributes.inputs[input_names::K]->get_stride(); + + auto K = reshape(attributes.inputs[input_names::K], Reshape_attributes().set_name("reshape_K")); + K->set_dim({kt_dim[0], kt_dim[1], kt_dim[3], kt_dim[2]}) + .set_stride({kt_stride[0], kt_stride[1], kt_stride[3], kt_stride[2]}); + + auto bmm_dP_K_attributes = Matmul_fp8_attributes().set_name("bmm_dP_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(dP, + K, + attributes.inputs[input_names::Descale_dP], + attributes.inputs[input_names::Descale_K], + attributes.inputs[input_names::Scale_dQ], + bmm_dP_K_attributes, + attributes.outputs[output_names::dQ], + attributes.outputs[output_names::Amax_dQ]); + + //// dP.T * Q + auto dP_T_attributes = Reshape_attributes().set_name("dP_T"); + auto dP_T = reshape(dP, dP_T_attributes); + dP_T->set_data_type(attributes.inputs.at(input_names::dO)->get_data_type()); + dP_T->set_name("dP_T").set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + + auto bmm_dP_T_Q_attributes = Matmul_fp8_attributes().set_name("bmm_dP_T_Q") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q]); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(dP_T, + attributes.inputs[input_names::Q], + attributes.inputs[input_names::Descale_dP], + attributes.inputs[input_names::Descale_Q], + attributes.inputs[input_names::Scale_dK], + bmm_dP_T_Q_attributes, + attributes.outputs[output_names::dK], + attributes.outputs[output_names::Amax_dK]); + + return {error_code_t::OK, ""}; + } + + std::pair> + override_heuristics_query() const { + if (is_deterministic_algorithm_supported_on_blackwell) { + return {5, {{KnobType_t::KERNEL_CFG, 31}, {KnobType_t::STAGES, 2}}}; + } else { + return {-1, {}}; + } + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "SDPA_FP8_BWD"})"_json); + } +#endif +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h new file mode 100644 index 00000000..6486943a --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/sdpa_support_surface.h @@ -0,0 +1,504 @@ +#pragma once + +#include +#include + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +inline error_t +SDPA_attributes::validate_sdpa_support_surface(const detail::Context& context, + int64_t s_kv, + bool is_paged_k, + bool is_paged_v) const { + // Extract dimensions from tensors + int64_t s_q = inputs.at(SDPA_attributes::input_names::Q)->get_dim()[2]; + // s_kv is passed in from the caller + int64_t h_q = inputs.at(SDPA_attributes::input_names::Q)->get_dim()[1]; + int64_t h_k = inputs.at(SDPA_attributes::input_names::K)->get_dim()[1]; + int64_t h_v = inputs.at(SDPA_attributes::input_names::V)->get_dim()[1]; + int64_t d_qk = inputs.at(SDPA_attributes::input_names::Q)->get_dim()[3]; + int64_t d_v = inputs.at(SDPA_attributes::input_names::V)->get_dim()[3]; + + bool const is_ragged = inputs.at(SDPA_attributes::input_names::Q)->get_ragged_offset() || + inputs.at(SDPA_attributes::input_names::K)->get_ragged_offset() || + inputs.at(SDPA_attributes::input_names::V)->get_ragged_offset() || + outputs.at(SDPA_attributes::output_names::O)->get_ragged_offset(); + + auto const& output_tensor = outputs.at(SDPA_attributes::output_names::O); + auto const& output_data_type = output_tensor->get_data_type(); + + auto const& bias_mask = inputs.find(SDPA_attributes::input_names::Bias); + bool const is_bias = (bias_mask != inputs.end() && bias_mask->second != nullptr); + + auto const& dropout_mask = inputs.find(SDPA_attributes::input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != inputs.end()) && (dropout_mask->second != nullptr); + bool const is_dropout = dropout_probability.has_value() || is_dropout_custom; + + bool const is_paged = is_paged_k || is_paged_v; + + auto const& rng_tensor = outputs.find(SDPA_attributes::output_names::RNG_DUMP); + bool const is_rng = (rng_tensor != outputs.end() && rng_tensor->second != nullptr); + + bool const max_seq_kv_explicit = max_seq_len_kv.has_value(); + + auto const& attn_scale = inputs.find(SDPA_attributes::input_names::Attn_scale); + bool const has_attn_scale = (attn_scale != inputs.end()) && (attn_scale->second != nullptr); + + auto const& seq_len_q = inputs.find(SDPA_attributes::input_names::SEQ_LEN_Q); + bool const has_seq_len_q = (seq_len_q != inputs.end()) && (seq_len_q->second != nullptr); + auto const& seq_len_kv = inputs.find(SDPA_attributes::input_names::SEQ_LEN_KV); + bool const has_seq_len_kv = (seq_len_kv != inputs.end()) && (seq_len_kv->second != nullptr); + + // validation TODO: + // - validate stats has valid dims + + // Get device properties + cudaDeviceProp prop; + int device; + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + _CUDNN_CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + + // Common FP16 and FP8 validation + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF( + (h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); + + // validate options for attn_scale + RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + // validate options for bias mask + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && detail::get_backend_version() < 8906, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask is not supported below cudnn version 8.9.6"); + + RETURN_CUDNN_FRONTEND_ERROR_IF((detail::get_backend_version() >= 8906 && detail::get_backend_version() < 90000) && + (context.get_sm_version() > 0 && context.get_sm_version() < 90), + error_code_t::GRAPH_NOT_SUPPORTED, + "Post scale Bias mask is not supported below Hopper for cudnn version" + + std::to_string(detail::get_backend_version())); + + // validate options for padding mask + RETURN_CUDNN_FRONTEND_ERROR_IF(padding_mask && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + RETURN_CUDNN_FRONTEND_ERROR_IF((!padding_mask && !attention_score_modifier) && (has_seq_len_q || has_seq_len_kv), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_ragged && ((padding_mask == false) && (attention_score_modifier == nullptr)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Ragged offsets are only supported with padding mask."); + + // validate options for dropout mask + RETURN_CUDNN_FRONTEND_ERROR_IF( + dropout_probability.has_value() && is_dropout_custom, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(dropout_probability.has_value() && dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // validate options for causal mask and bottom right causal mask + RETURN_CUDNN_FRONTEND_ERROR_IF( + (padding_mask || alibi_mask || has_causal_mask_bottom_right()) && (detail::get_backend_version() < 8906), + error_code_t::GRAPH_NOT_SUPPORTED, + "Only causal mask is supported in cudnn versions below 8.9.6"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + has_causal_mask_bottom_right() && (!padding_mask) && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask does not support max_s_q > max_s_kv. Please virtually slice the Q tensor and pass it " + "as max_s_q == max_s_kv"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + has_causal_mask_bottom_right() && (is_bias || alibi_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_dropout=False."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(has_causal_mask_bottom_right() && (detail::get_backend_version() < 90600) && + ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with s_q multiple of 64, and s_kv " + "multiple of 64, for cudnn version below 9.6.0"); + + // validate that datatype is set for the graph + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); + + if (mma_core_mode == DataType_t::FP8_E4M3 || mma_core_mode == DataType_t::FP8_E5M2) { + // FP8 specific validation + + // version specific validation + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90100, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported starting cudnn 9.1.0. Please " + "consider upgrading your current version."); + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() == 91000, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is not supported on cudnn 9.10.0. Please " + "consider upgrading your current version."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + prop.major < 9, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Hopper architecture and newer. Please " + "consider using a newer architecture."); + + // validate basic dimension requirements + // d_qk=192 with d_v=128 is only supported starting from cuDNN 9.19 + bool const d192_v128_supported = (detail::get_backend_version() >= 91900); + if (prop.major >= 10) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + ((d_qk > 128) || (d_qk % 16 != 0)) && !(d192_v128_supported && d_qk == 192 && d_v == 128), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_qk should be less than or equal to 128 and hidden_dim d_qk " + "should be multiple of 16 unless d_qk == 192 and d_v == 128 (requires cuDNN 9.19+)"); + RETURN_CUDNN_FRONTEND_ERROR_IF( + ((d_v > 128) || (d_v % 16 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim d_v should be less than or equal to 128 and hidden_dim d_v should be multiple of 16"); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF( + (d_qk > 256) || (d_qk % 16 != 0) || (d_v > 256) || (d_v % 16 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be less than or equal to 256 and hidden_dim should be multiple of 16"); + } + + // Validate options for causal_mask_bottom_right + RETURN_CUDNN_FRONTEND_ERROR_IF(has_causal_mask_bottom_right() && detail::get_backend_version() < 90700, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.7.0, bottom right causal masking is not supported."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + has_causal_mask_bottom_right() && prop.major < 10, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Blackwell architecture and newer. Please " + "consider using a newer architecture."); + + // if output data type is half or bfloat16, and version is below 9.13 or is not blackwell, return NOT_SUPPORTED + RETURN_CUDNN_FRONTEND_ERROR_IF( + (output_data_type == DataType_t::HALF || output_data_type == DataType_t::BFLOAT16) && + (detail::get_backend_version() < 91300 || prop.major < 10), + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on cuDNN version 9.13.0 and newer. Please " + "consider upgrading your current version."); + } else if (mma_core_mode == DataType_t::HALF) { + // FP16 specific validation + + RETURN_CUDNN_FRONTEND_ERROR_IF( + (attention_score_modifier != nullptr) && + (alibi_mask || has_causal_like_masking() || padding_mask || left_bound.has_value()), + error_code_t::GRAPH_NOT_SUPPORTED, + "Attention score mod enabled and hence other subgraphs are disabled."); + + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF( + (d_qk % 8 != 0) || (d_v % 8 != 0), error_code_t::GRAPH_NOT_SUPPORTED, "hidden_dim should be multiple of 8"); + + // validate alibi requirements + RETURN_CUDNN_FRONTEND_ERROR_IF(alibi_mask && !(right_bound.has_value() && right_bound.value() == 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "When alibi mask is used, diagonal_band_right_bound needs to be set to 0."); + + // validate options for bottom right causal mask + RETURN_CUDNN_FRONTEND_ERROR_IF(has_causal_mask_bottom_right() && (detail::get_backend_version() < 90300), + error_code_t::GRAPH_NOT_SUPPORTED, + "Causal bottom right masking requires cudnn 9.3.0 and above"); + + // Combination of mask and bias + RETURN_CUDNN_FRONTEND_ERROR_IF( + (is_bias && (has_causal_like_masking() || padding_mask) && (detail::get_backend_version() < 8906)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias + padding or causal mask is only supported in 8.9.6 and above"); + + // validate options for sliding window length + RETURN_CUDNN_FRONTEND_ERROR_IF((left_bound.has_value() && detail::get_backend_version() < 90200), + error_code_t::GRAPH_NOT_SUPPORTED, + "sliding window is only supported 9.2.0 and above"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + left_bound.has_value() && left_bound.value() <= 0 && detail::get_backend_version() < 91000, + error_code_t::INVALID_VALUE, + "Left bound (Sliding window length) should be greater than zero when set."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(left_bound.has_value() && (!padding_mask) && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with max_s_q <= max_s_kv."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + left_bound.has_value() && (s_q * left_bound.value() == s_kv * left_bound.value()) && + (detail::get_backend_version() <= 90900) && (prop.major == 9) && has_causal_mask_bottom_right(), + error_code_t::GRAPH_NOT_SUPPORTED, + "On Hopper architecture, this specific combination of s_q, s_kv, and left_bound + right_bound + bottom " + "right diagonal alignment is not supported for backend version 9.9 or below"); + + if ((detail::get_backend_version() < 91002)) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + left_bound.has_value() && (!has_causal_like_masking() || is_dropout || is_bias), + error_code_t::GRAPH_NOT_SUPPORTED, + "Left and right bounds are only supported with is_dropout=False, is_bias=False. And the diagonal " + "alignment must be set."); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(right_bound.has_value() && right_bound.value() < 0, + error_code_t::INVALID_VALUE, + "Right bound needs to be larger than or equal to zero"); + + // Validate options for s_q == 1 + const bool is_decode_only = (s_q == 1); + RETURN_CUDNN_FRONTEND_ERROR_IF(is_decode_only && (prop.major == 10) && (d_qk > 128 || d_v > 128) && + (detail::get_backend_version() <= 90900), + error_code_t::GRAPH_NOT_SUPPORTED, + "decode only mode, i.e. s_q == 1 not supported for blackwell architecture with " + "d_qk or d_v > 128 for backend version 9.9 or below"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + is_decode_only && (detail::get_backend_version() <= 90900) && (right_bound.has_value()), + error_code_t::GRAPH_NOT_SUPPORTED, + "decode only mode, i.e. s_q == 1, not supported with masking (right_bound is set) for backend version 9.9 " + "or below"); + + // validate options for paged attention + RETURN_CUDNN_FRONTEND_ERROR_IF( + is_paged && (d_qk > 128 || d_v > 128) && detail::get_backend_version() <= 90900, + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged attention only supported with d_qk and d_v <= 128 for backend version 9.9 or below"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_paged && is_ragged && detail::get_backend_version() < 90700, + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged caches are not supported in combination with ragged offsets."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_paged && (!has_seq_len_q || !has_seq_len_kv), + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged caches can only be used in combination with padding mask and variable " + "sequence lengths for both Q and KV."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + !is_paged && max_seq_kv_explicit, + error_code_t::GRAPH_NOT_SUPPORTED, + "When not using paged attention, there is no need to explicitly set max kv sequence length."); + + if (max_seq_kv_explicit) { + auto max_seq_kv = max_seq_len_kv.value(); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_dim()[3] != max_seq_kv), + error_code_t::GRAPH_NOT_SUPPORTED, + "Value set through set_paged_attention_max_seq_len_kv is incompatible with " + "the sequence length of the bias"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_rng && rng_tensor->second->get_dim()[3] != max_seq_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Value set through set_paged_attention_max_seq_len_kv is incompatible with " + "the sequence length of the RNG_DUMP"); + } + + // Additional validation for paged attention with packed page tables + RETURN_CUDNN_FRONTEND_ERROR_IF( + ((is_paged_k && inputs.at(SDPA_attributes::input_names::Page_table_K)->get_ragged_offset()) || + (is_paged_v && inputs.at(SDPA_attributes::input_names::Page_table_V)->get_ragged_offset())) && + detail::get_backend_version() < 91002, + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged attention with packed page tables only supported with cudnn version 9.10.2 and above"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8903, + error_code_t::GRAPH_NOT_SUPPORTED, + "SDPA OP requires cudnn version 8.9.3 and above"); + + // If user has set sm_version allow SM specific checks + if (context.get_sm_version() > 0) { + RETURN_CUDNN_FRONTEND_ERROR_IF(80 > context.get_sm_version(), + error_code_t::GRAPH_NOT_SUPPORTED, + "cudnn SDPA operation requires Ampere and above"); + } + + // (cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups FIXME + + // version specific validation + if (prop.major == 8) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::get_backend_version() <= 90900 && ((d_qk > 128) || (d_v > 128)), + error_code_t::GRAPH_NOT_SUPPORTED, + "head_dim should be less than or equal to 128 for backend version 9.9 or below on ampere architecture"); + } + if (prop.major == 9) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::get_backend_version() <= 90900 && ((d_qk > 256) || (d_v > 256)), + error_code_t::GRAPH_NOT_SUPPORTED, + "head_dim should be less than or equal to 256 for backend version 9.9 or below on hopper architecture"); + } + if (prop.major == 10) { + RETURN_CUDNN_FRONTEND_ERROR_IF((detail::get_backend_version() < 90900) && ((d_qk > 128) || (d_v > 128)), + error_code_t::GRAPH_NOT_SUPPORTED, + "head_dim should be less than or equal to 128 for backend version 9.8 or " + "below on blackwell architecture"); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.6, s_kv not a multiple of 64 or d not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8907 && (s_kv % 64 != 0) && (!(padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.7, s_kv not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::get_backend_version() < 90000 && ((s_q % 64 != 0) || (s_kv % 64 != 0)) && + (padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.0.0, s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90200 && left_bound.has_value(), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.2.0, sliding window attention is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_paged, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, paged caches are not supported"); + + if (is_ragged) { + RETURN_CUDNN_FRONTEND_ERROR_IF((context.get_sm_version() > 0 && context.get_sm_version() < 90), + error_code_t::GRAPH_NOT_SUPPORTED, + "THD (ragged offset) is only supported in Hopper and above"); + } + // TODO add version check once fixed + RETURN_CUDNN_FRONTEND_ERROR_IF(prop.major == 10 && is_rng, + error_code_t::GRAPH_NOT_SUPPORTED, + "dropout RNG dump is not supported for Blackwell architecture"); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF(true, error_code_t::GRAPH_NOT_SUPPORTED, "Unsupported mma core mode"); + } + + // Check whether the selected implementation supports the requested features. + CHECK_CUDNN_FRONTEND_ERROR(verify_sdpa_support_surface_for_implementation(context, implementation)); + + return {error_code_t::OK, ""}; +} + +// Verify that the underlying implementation supports all the features in these attributes. +// Unlike `validate_sdpa_support_surface()`, this may be called before validation, so: +// * don't assume any particular keys already exist in `inputs` or `outputs` +// * don't assume any tensor dims or strides are already set +// We return error codes directly instead of using `RETURN_CUDNN_FRONTEND_ERROR_IF` +// to avoid unneeded logging when this function is being called in a non-error-generating +// situation (e.g. during auto-select of SDPA implementation). +inline error_t +SDPA_attributes::verify_sdpa_support_surface_for_implementation(const detail::Context& context, + AttentionImplementation_t impl) const { + switch (impl) { + case AttentionImplementation_t::AUTO: + // This function should not be called with AUTO. + return {error_code_t::INVALID_VALUE, + "Can't call verify_sdpa_support_surface_for_implementation with impl=AUTO"}; + case AttentionImplementation_t::COMPOSITE: + for (const auto& [key, value] : inputs) { + RETURN_CUDNN_FRONTEND_ERROR_IF(key == input_names::Block_mask && value != nullptr, + error_code_t::GRAPH_NOT_SUPPORTED, + "Composite SDPA node doesn't support Block_mask input"); + } + break; + case AttentionImplementation_t::UNIFIED: { + auto effective_cudnn_ver = std::min(detail::get_backend_version(), detail::get_compiled_version()); + RETURN_CUDNN_FRONTEND_ERROR_IF(effective_cudnn_ver < 91301, + error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node requires cuDNN 9.13.1"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_dynamic_shape_enabled(), + error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support dynamic shape"); + + // TODO: Provide smarter error messages that provide the required cuDNN version for each input. + std::unordered_set allowed_input_names{ + input_names::Q, input_names::K, input_names::V, input_names::Attn_scale}; + std::string allowed_input_msg = + "Unified SDPA node doesn't yet support inputs other than Q, K, V, Attn_scale"; + + if (effective_cudnn_ver >= 91400) { + allowed_input_names.insert({input_names::Block_mask}); + allowed_input_msg += ", Block_mask"; + } + + if (effective_cudnn_ver >= 91500) { + allowed_input_names.insert({input_names::Page_table_K, + input_names::Page_table_V, + input_names::SEQ_LEN_Q, + input_names::SEQ_LEN_KV}); + allowed_input_msg += ", Page_table_K, Page_table_V, SEQ_LEN_Q, SEQ_LEN_KV"; + } + + for (const auto& [key, value] : inputs) { + if (allowed_input_names.find(key) == allowed_input_names.end() && value != nullptr) { + return {error_code_t::GRAPH_NOT_SUPPORTED, allowed_input_msg}; + } + } + + for (const auto& [key, value] : outputs) { + if (key != output_names::O && key != output_names::Stats && value != nullptr) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support outputs other than O and Stats"}; + } + } + + if (alibi_mask) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "Unified SDPA node doesn't yet support alibi mask"}; + } + + if (padding_mask && effective_cudnn_ver < 91500) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "Padding mask for unified SDPA node requires cuDNN 9.15.0"}; + } + + if (left_bound.has_value() || right_bound.has_value()) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support left bound or right bound"}; + } + + if (diagonal_alignment != DiagonalAlignment_t::TOP_LEFT) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "Unified SDPA node doesn't yet support diagonal alignment"}; + } + + if (dropout_probability.has_value()) { + return {error_code_t::GRAPH_NOT_SUPPORTED, "Unified SDPA node doesn't yet support dropout"}; + } + + // Unified engine in cuDNN < 9.15 can't meaningfully support max sequence length, + // while versions >= 9.15 "support" it by ignoring it (unified engine doesn't need it). + if (max_seq_len_kv.has_value() && effective_cudnn_ver < 91500) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Max sequence length for unified SDPA node cannot be set in cuDNN < 9.15.0"}; + } + + if (attention_score_modifier != nullptr) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support attention score modifier"}; + } + + if (mma_core_mode != DataType_t::HALF) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support a data type other than fp16"}; + } + + if ((compute_data_type != DataType_t::NOT_SET && compute_data_type != DataType_t::FLOAT) || + context.get_compute_data_type() != DataType_t::FLOAT) { + return {error_code_t::GRAPH_NOT_SUPPORTED, + "Unified SDPA node doesn't yet support compute data type other than float"}; + } + } break; + } + + return {error_code_t::OK, ""}; +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/slice.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/slice.h new file mode 100644 index 00000000..e40f5c53 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/slice.h @@ -0,0 +1,115 @@ +#pragma once + +namespace cudnn_frontend::graph { + +class SliceNode : public NodeCRTP { + public: + Slice_attributes attributes; + + SliceNode(Slice_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::SLICE; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for slice node " << attributes.name + << std::endl; + + attributes.fill_from_context(context); + + auto output = attributes.outputs.at(Slice_attributes::output_names::Y); + auto output_dim = output->get_dim(); + + if (output_dim.empty()) { + for (size_t i = 0; i < attributes.slices.size(); ++i) { + output_dim.push_back(attributes.slices[i].second - attributes.slices[i].first); + } + output->set_dim(output_dim); + } + + auto const input = attributes.inputs.at(Slice_attributes::input_names::X); + auto const input_data_type = input->get_data_type(); + auto const output_data_type = output->get_data_type(); + if (output_data_type == DataType_t::NOT_SET) { + output->set_data_type(input_data_type); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF(output_data_type != input_data_type, + error_code_t::INVALID_VALUE, + "output and input tensor data types should match for slice operation."); + } + + auto const input_stride = input->get_stride(); + if (output->get_stride().empty()) { + // For simple slicing without changing the step, the stride remains the same + // std::vector stride_order = + // detail::generate_stride_order_preserving_format(input_stride, output_dim.size()); + output->set_stride(input_stride); + } + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_tensors_node(std::unordered_map>& tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) const override final { + // Do not make input tensor for backend. + // But assign it a uid + auto const input = attributes.inputs.at(Slice_attributes::input_names::X); + if (input->has_uid() == false) { + detail::assign_uid(input.get(), potential_uid, used_uids); + } + + auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); + output->set_is_virtual(false); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>&, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>&) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + // No corresponding backend operation + + auto const virutal_output = attributes.outputs.at(Slice_attributes::output_names::Y); + if (virutal_output && virutal_output->get_is_virtual() == false) { + uids_involved_in_operations.insert(virutal_output->get_uid()); + if (auto ragged_offset = virutal_output->get_ragged_offset()) { + uids_involved_in_operations.insert(ragged_offset->get_uid()); + } + } + + return {error_code_t::OK, ""}; + } + + error_t + collect_variant_pack_replacements_node( + std::unordered_map>& + variant_pack_replacements) const override final { + auto const input = attributes.inputs.at(Slice_attributes::input_names::X); + auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); + + variant_pack_replacements[input->get_uid()] = {output->get_uid(), attributes.get_offset()}; + + return {error_code_t::OK, ""}; + }; + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "SLICE"})"_json); + } +#endif +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node/softmax.h b/third_party/cudnn-frontend/include/cudnn_frontend/node/softmax.h new file mode 100644 index 00000000..2263b490 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node/softmax.h @@ -0,0 +1,171 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "pointwise.h" +#include "reduction.h" + +namespace cudnn_frontend::graph { + +class SoftmaxNode : public NodeCRTP { + public: + Softmax_attributes attributes; + + SoftmaxNode(Softmax_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating SoftmaxNode " << attributes.name); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + return {error_code_t::OK, ""}; + } + + error_t + expand_node() override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Inferrencing properties for Softmax node " << attributes.name); + + attributes.fill_from_context(context); + + // Fill properties of virtual tensors + auto const p_dim = attributes.inputs[Softmax_attributes::input_names::P]->get_dim(); + auto b = p_dim[0]; + auto h = p_dim[1]; + auto s_q = p_dim[2]; + + auto max_output = attributes.outputs[Softmax_attributes::output_names::Max]; + if (max_output == nullptr) { + max_output = std::make_shared(); + max_output->set_is_virtual(true); + } + //////////////// TODO ////////////////////////// + // Check Stride (Before setting dimension?) + if (max_output->get_dim().empty()) { + max_output->set_dim({b, h, s_q, 1}); + } + if (max_output->get_stride().empty()) { + max_output->set_stride({h * s_q, s_q, 1, 1}); + } + + auto max_attributes = Reduction_attributes().set_name("Max").set_mode(ReductionMode_t::MAX); + // If sink tensor is present, we also need to take a pointwise max with sink + if (attributes.inputs.find(Softmax_attributes::input_names::SINK) != attributes.inputs.end()) { + auto s_max = reduction(attributes.inputs[Softmax_attributes::input_names::P], max_attributes); + s_max->set_name("s_max"); + + auto sink_tensor = attributes.inputs[Softmax_attributes::input_names::SINK]; + auto sink_attributes = Pointwise_attributes().set_name("max_sink").set_mode(PointwiseMode_t::MAX); + pointwise(s_max, sink_tensor, sink_attributes, max_output); + } else { + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(attributes.inputs[Softmax_attributes::input_names::P], max_attributes, max_output); + } + + auto sub_attributes = Pointwise_attributes().set_name("sub").set_mode(PointwiseMode_t::SUB); + auto const& sub_output = + pointwise(attributes.inputs[Softmax_attributes::input_names::P], max_output, sub_attributes); + sub_output->set_name("sub_M"); + + auto exp_attributes = Pointwise_attributes().set_name("exp").set_mode(PointwiseMode_t::EXP); + auto const& exp_output = pointwise(sub_output, exp_attributes); + exp_output->set_name("exp_sub_M"); + + auto sum_output = attributes.outputs[Softmax_attributes::output_names::Sum_exp]; + if (sum_output == nullptr) { + sum_output = std::make_shared(); + sum_output->set_is_virtual(true); + } + sum_output->set_name("SumExp"); + if (sum_output->get_dim().empty()) { + sum_output->set_dim({b, h, s_q, 1}); + } + if (sum_output->get_stride().empty()) { + sum_output->set_stride({h * s_q, s_q, 1, 1}); + } + auto sum_attributes = Reduction_attributes().set_name("sum").set_mode(ReductionMode_t::ADD); + // If sink tensor is present, also subtract it and take its exp + if (attributes.inputs.find(Softmax_attributes::input_names::SINK) != attributes.inputs.end()) { + auto sink_tensor = attributes.inputs[Softmax_attributes::input_names::SINK]; + auto sub_sink = pointwise(sink_tensor, max_output, sub_attributes); + sub_sink->set_name("sub_sink"); + + auto exp_sink = pointwise(sub_sink, exp_attributes); + exp_sink->set_name("exp_sink"); + + auto temp_sum = reduction(exp_output, sum_attributes); + temp_sum->set_name("SumExp_elements").set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + + auto add_attributes = Pointwise_attributes().set_name("add_sink").set_mode(PointwiseMode_t::ADD); + pointwise(temp_sum, exp_sink, add_attributes, sum_output); + } else { + reduction(exp_output, sum_attributes, sum_output); + } + + // WAR when: + // - softmax stats in not requested + // - max and sum_exp are not requested + if (attributes.outputs[Softmax_attributes::output_names::Stats] == nullptr && + attributes.outputs[Softmax_attributes::output_names::Max] == nullptr && + attributes.outputs[Softmax_attributes::output_names::Sum_exp] == nullptr) { + auto softmax_stats = std::make_shared(); + softmax_stats->set_is_virtual(true); + attributes.outputs[Softmax_attributes::output_names::Stats] = softmax_stats; + } + + if (attributes.outputs.find(Softmax_attributes::output_names::Stats) != attributes.outputs.end() && + attributes.outputs[Softmax_attributes::output_names::Stats] != nullptr) { + auto log_attributes = Pointwise_attributes().set_name("log").set_mode(PointwiseMode_t::LOG); + auto const& log_output = pointwise(sum_output, log_attributes); + log_output->set_dim({b, h, s_q, 1}).set_stride({h * s_q, s_q, 1, 1}); + + auto add_attributes = Pointwise_attributes().set_name("add").set_mode(PointwiseMode_t::ADD); + // Special non-functional-style call. Needed because output already created and provided to user. + pointwise( + max_output, log_output, add_attributes, attributes.outputs[Softmax_attributes::output_names::Stats]); + } + + auto div_attributes = Pointwise_attributes().set_name("div").set_mode(PointwiseMode_t::DIV); + // Special non-functional-style call. Needed because output already created and provided to user. + pointwise(exp_output, sum_output, div_attributes, attributes.outputs[Softmax_attributes::output_names::S]); + + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + } +#endif +}; + +inline void +INode::softmax(std::shared_ptr p, + Softmax_attributes attributes, + std::shared_ptr s, + std::shared_ptr stats, + std::shared_ptr max, + std::shared_ptr sum_exp) { + attributes.inputs[Softmax_attributes::input_names::P] = p; + attributes.outputs[Softmax_attributes::output_names::S] = s; + attributes.outputs[Softmax_attributes::output_names::Stats] = stats; + attributes.outputs[Softmax_attributes::output_names::Max] = max; + attributes.outputs[Softmax_attributes::output_names::Sum_exp] = sum_exp; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/node_interface.h b/third_party/cudnn-frontend/include/cudnn_frontend/node_interface.h new file mode 100644 index 00000000..7019d12c --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/node_interface.h @@ -0,0 +1,487 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include "../cudnn_frontend_Tensor.h" +#include "../cudnn_frontend_Operation.h" +#include "../cudnn_frontend_OperationGraph.h" +#include "../cudnn_frontend_ExecutionPlan.h" +#include "../cudnn_frontend_VariantPack.h" +#include "../cudnn_frontend_shim.h" + +#include "cudnn_interface.h" + +#include "graph_properties.h" + +namespace cudnn_frontend { + +namespace graph { + +class BatchNormNode; +class DBNNode; +class ConcatenateNode; +class MatmulNode; +class MatmulFP8Node; +class PointwiseNode; +class ReductionNode; +class ResampleNode; +class ReshapeNode; +class RngNode; +class SoftmaxNode; +class MoeGroupedMatmulNode; + +// Interface for all nodes to follow. +class INode { + public: + // A closed set of types that are allowed to be passed by value today + using pass_by_values_t = Tensor_attributes::pass_by_values_t; + + detail::Context context; + + protected: + // Will eventually be moved to Graph class + std::unordered_set> full_graph_outputs; + std::shared_ptr + output_tensor(std::string const& name) { + auto tensor = std::make_shared(); + tensor->set_name(name).set_is_virtual(true); + full_graph_outputs.insert(tensor); + return tensor; + } + + private: + virtual error_t + pre_validate_node() const { + return {error_code_t::OK, ""}; + }; + + virtual error_t + infer_properties_node() = 0; + + virtual error_t + expand_node() { + return {error_code_t::OK, ""}; + }; + + virtual error_t + post_validate_node() const { + return {error_code_t::OK, ""}; + }; + + virtual int64_t + get_fe_workspace_size_node() const { + return 0; + } + + virtual error_t + collect_pass_by_value_tensors_node(std::unordered_map&) const { + return {error_code_t::OK, ""}; + }; + + virtual error_t + collect_variant_pack_replacements_node( + std::unordered_map>&) const { + return {error_code_t::OK, ""}; + }; + + virtual error_t + create_cudnn_tensors_node( + std::unordered_map>& uid_to_backend_tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) const = 0; + + virtual error_t + collect_tensors_in_workspace_node( + std::unordered_map>>&, + int64_t&) const { + return {error_code_t::OK, ""}; + } + + protected: + // Type of each node. Nodes can either be a composite (value COMPOSITE) or + // one of the other primitive types. Primitives types are nothing but + // cudnn operations. + enum class Type { + COMPOSITE, + BATCHNORM, + BATCHNORM_INFERENCE, + BN_FINALIZE, + CONVOLUTION, + DBN, + DBN_WEIGHT, + DLN, + DIN, + DGRAD, + DRMSNorm, + GENSTATS, + LAYERNORM, + INSTANCENORM, + MATMUL, + POINTWISE, + REDUCTION, + RESAMPLE, + RESHAPE, + RMSNORM, + RNG, + SLICE, + WGRAD, + PAGED_CACHE_LOAD, + BLOCK_SCALE_QUANTIZE, + BLOCK_SCALE_DEQUANTIZE, + CONCATENATE, + ADALAYERNORM, + DADALAYERNORM, + UNIFIED_SDPA, + MOE_GROUPED_MATMUL, + }; + Type tag; + + inline void + matmul(std::shared_ptr a, + std::shared_ptr b, + Matmul_attributes attributes, + std::shared_ptr c); + + void + matmul_fp8(std::shared_ptr a, + std::shared_ptr b, + std::shared_ptr descale_a, + std::shared_ptr descale_b, + std::shared_ptr scale_c, + Matmul_fp8_attributes attributes, + std::shared_ptr c, + std::shared_ptr amax_c); + + void + softmax(std::shared_ptr p, + Softmax_attributes attributes, + std::shared_ptr s, + std::shared_ptr stats, + std::shared_ptr max, + std::shared_ptr sum_exp); + + void + softmax(std::shared_ptr p, + Softmax_attributes attributes, + std::shared_ptr s, + std::shared_ptr m, + std::shared_ptr zinv); + + void + pointwise(std::shared_ptr a, + Pointwise_attributes attributes, + std::shared_ptr c); + + void + pointwise(std::shared_ptr a, + std::shared_ptr b, + Pointwise_attributes attributes, + std::shared_ptr c); + + void + reduction(std::shared_ptr a, + Reduction_attributes attributes, + std::shared_ptr c); + + void + rng(std::shared_ptr seed, + std::shared_ptr offset, + Rng_attributes attributes, + std::shared_ptr y); + + void + paged_cache_load(std::shared_ptr container, + std::shared_ptr seqLen, + std::shared_ptr pageTable, + PagedCacheLoad_attributes attributes, + std::shared_ptr yOut); + + void + block_scale_quantize(std::shared_ptr x, + Block_scale_quantize_attributes attributes, + std::shared_ptr y, + std::shared_ptr scale); + + void + block_scale_dequantize(std::shared_ptr x, + std::shared_ptr scale, + Block_scale_dequantize_attributes attributes, + std::shared_ptr y); + + void + concatenate(std::vector> x, + Concatenate_attributes attributes, + std::shared_ptr y); + + void + moe_grouped_matmul(std::shared_ptr token, + std::shared_ptr weight, + std::shared_ptr first_token_offset, + std::shared_ptr token_index, + std::shared_ptr token_ks, + Moe_grouped_matmul_attributes attributes, + std::shared_ptr output); + + error_t + validate_subtree() { + // pre validate to catch errors early + // Otherwise code reability decreases in expand_and_infer + CHECK_CUDNN_FRONTEND_ERROR(pre_validate_node()); + CHECK_CUDNN_FRONTEND_ERROR(infer_properties_node()); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->validate_subtree()); + } + CHECK_CUDNN_FRONTEND_ERROR(post_validate_node()); + return {error_code_t::OK, ""}; + } + + error_t + expand_subtree() { + // Validate self + CHECK_CUDNN_FRONTEND_ERROR(pre_validate_node()); + CHECK_CUDNN_FRONTEND_ERROR(infer_properties_node()); + CHECK_CUDNN_FRONTEND_ERROR(expand_node()); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->expand_subtree()); + } + CHECK_CUDNN_FRONTEND_ERROR(post_validate_node()); + return {error_code_t::OK, ""}; + } + + // Creates cudnn tensors for each node (and its sub nodes) + error_t + create_cudnn_tensors_subtree( + std::unordered_map>& uid_to_backend_tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) const { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensors_node(uid_to_backend_tensors, potential_uid, used_uids)); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR( + sub_node->create_cudnn_tensors_subtree(uid_to_backend_tensors, potential_uid, used_uids)); + } + return {error_code_t::OK, ""}; + } + + error_t + collect_pass_by_value_tensors_subtree( + std::unordered_map& tensor_to_pass_by_value) const { + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_node(tensor_to_pass_by_value)); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); + } + return {error_code_t::OK, ""}; + } + + error_t + collect_tensors_in_workspace_subtree( + std::unordered_map>>& + worskspace_modifications, + int64_t& offset) const { + CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_node(worskspace_modifications, offset)); + offset = get_fe_workspace_size_node(); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR( + sub_node->collect_tensors_in_workspace_subtree(worskspace_modifications, offset)); + offset += sub_node->get_fe_workspace_size_node(); + } + return {error_code_t::OK, ""}; + } + + error_t + collect_variant_pack_replacements_subtree( + std::unordered_map>& replacements) + const { + CHECK_CUDNN_FRONTEND_ERROR(collect_variant_pack_replacements_node(replacements)); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->collect_variant_pack_replacements_subtree(replacements)); + } + return {error_code_t::OK, ""}; + } + + int64_t + get_fe_workspace_size_subtree() const { + int64_t fe_workspace_size = get_fe_workspace_size_node(); + for (auto const& sub_node : sub_nodes) { + fe_workspace_size += sub_node->get_fe_workspace_size_subtree(); + } + return fe_workspace_size; + } + + // Creates cudnn operation for each node (and its sub nodes) + // Only INode that map to a primitive cudnn operation need to specialize. + virtual error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operation, + std::vector>& backend_operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& uid_to_backend_tensors) const { + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->create_cudnn_operations( + uids_involved_in_operation, backend_operations, raw_operations, uid_to_backend_tensors)); + } + return {error_code_t::OK, ""}; + } + + // An implicitly topological-sorted vector of sub nodes. + // The sorted order is a side effect of functional API. + std::vector> sub_nodes; + + public: + virtual Type + getType() = 0; + + virtual std::pair> + override_heuristics_query() const { + return {-1, {}}; + } + + std::shared_ptr matmul(std::shared_ptr, + std::shared_ptr, + Matmul_attributes); + + std::shared_ptr pointwise(std::shared_ptr, Pointwise_attributes); + std::shared_ptr pointwise(std::shared_ptr, + std::shared_ptr, + Pointwise_attributes); + std::shared_ptr pointwise(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Pointwise_attributes); + + std::shared_ptr reduction(std::shared_ptr, Reduction_attributes); + std::array, 2> resample(std::shared_ptr, Resample_attributes); + std::shared_ptr reshape(std::shared_ptr, Reshape_attributes); + + std::shared_ptr rng(std::shared_ptr, + std::shared_ptr, + Rng_attributes); + + INode(detail::Context const& context) : context(context) {} + + // Make sure each node implements a public serialize function +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const = 0; +#endif + + virtual size_t + key() { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + json j; + serialize(j); + return std::hash{}(j); +#else + return 1; +#endif + } + + virtual ~INode() = default; +}; + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB +[[maybe_unused]] static void +to_json(json& j, const INode& p) { + p.serialize(j); +} +#endif + +template +class NodeCRTP : public INode { + DerivedT& + self() { + return *static_cast(this); + } + DerivedT const& + self() const { + return *static_cast(this); + } + + error_t + collect_pass_by_value_tensors_node( + std::unordered_map& tensor_to_pass_by_value) const override final { + CHECK_CUDNN_FRONTEND_ERROR(self().attributes.fill_pass_by_value(tensor_to_pass_by_value)); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_tensors_node(std::unordered_map>& tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) const override { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating cudnn tensors for node named '" << self().attributes.name << "':"); + + if constexpr (std::is_same_v) { + for (auto const& tensor : self().attributes.inputs) { + if (tensor) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + } + } + } else { + for (auto const& [name, tensor] : self().attributes.inputs) { + (void)name; + if (tensor) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + } + } + } + + for (auto const& [name, tensor] : self().attributes.outputs) { + (void)name; + if (tensor) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + } + } + + // Handle special case of BN where peer_stats is also an input + if constexpr (std::is_same_v || std::is_same_v) { + // Special case in BN where peer stats is also an input but is not present in inputs map + for (auto const& tensor : self().attributes.peer_stats) { + if (tensor) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + } + } + } + + return {error_code_t::OK, ""}; + } + + protected: + using INode::INode; +}; + +#define CUDNN_FE_VALIDATE_TENSOR_(port, map_) \ + { \ + auto t = map_.find(port); \ + bool const has_t = (t != map_.end()) && (t->second != nullptr); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + !has_t, error_code_t::ATTRIBUTE_NOT_SET, std::string("Tensor ") + #port + " not set"); \ + } + +#define CUDNN_FE_VALIDATE_AND_ASSIGN_TENSOR_(tensor, port, map_) \ + auto tensor = map_.find(port); \ + { \ + bool const has_t = (tensor != map_.end()) && (tensor->second != nullptr); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + !has_t, error_code_t::ATTRIBUTE_NOT_SET, std::string("Tensor ") + #port + " not set"); \ + } + +#define CUDNN_FE_VALIDATE_INPUT_TENSOR(port) CUDNN_FE_VALIDATE_TENSOR_(port, attributes.inputs) + +#define CUDNN_FE_VALIDATE_OUTPUT_TENSOR(port) CUDNN_FE_VALIDATE_TENSOR_(port, attributes.outputs) + +#define CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(tensor, port) \ + CUDNN_FE_VALIDATE_AND_ASSIGN_TENSOR_(tensor, port, attributes.inputs) + +#define CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(tensor, port) \ + CUDNN_FE_VALIDATE_AND_ASSIGN_TENSOR_(tensor, port, attributes.outputs) + +} // namespace graph + +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/plans.h b/third_party/cudnn-frontend/include/cudnn_frontend/plans.h new file mode 100644 index 00000000..c30812f2 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/plans.h @@ -0,0 +1,694 @@ +#pragma once + +#include +#include +#include + +#include "../cudnn_frontend_EngineConfig.h" +#include "../cudnn_frontend_Logging.h" +#include "graph_helpers.h" + +#include "backend/execution_helpers.h" +#include "backend/plan_helpers.h" + +namespace cudnn_frontend { + +namespace detail { + +inline error_t +execute(cudnnHandle_t handle, + ExecutionPlan* plan, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr, + std::vector const& override_uids, + std::vector> const& override_shapes, + std::vector> const& override_strides) { + // TODO: below line fails with MSVC. warning C4127: conditional expression is constant + // RETURN_CUDNN_FRONTEND_ERROR_IF(!plan, error_code_t::GRAPH_EXECUTION_FAILED, "No plan found to execute!!"); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing " << plan->getTag() << "..."); + + backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack( + variant_pack_descriptor, device_ptrs, uids, workspace_ptr, override_uids, override_shapes, override_strides)); + _CUDNN_CHECK_CUDNN_ERROR(execute(handle, plan->get_raw_desc(), variant_pack_descriptor.get_ptr())); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executed " << plan->getTag() << "."); + + return {error_code_t::OK, ""}; +} + +inline error_t +execute(cudnnHandle_t handle, + ExecutionPlan* plan, + std::vector& device_ptrs, + std::vector const& uids, + void* workspace_ptr) { + // TODO: below line fails with MSVC. warning C4127: conditional expression is constant + // RETURN_CUDNN_FRONTEND_ERROR_IF(!plan, error_code_t::GRAPH_EXECUTION_FAILED, "No plan found to execute!!"); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executing " << plan->getTag() << "..."); + + backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, workspace_ptr)); + _CUDNN_CHECK_CUDNN_ERROR(execute(handle, plan->get_raw_desc(), variant_pack_descriptor.get_ptr())); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: Executed " << plan->getTag() << "."); + + return {error_code_t::OK, ""}; +} + +inline error_t +query_cudnn_heuristics_impl(std::shared_ptr const& operation_graph, + cudnn_frontend::EngineConfigList& configs, + std::vector const& modes, + int32_t sm_count, + std::shared_ptr device_properties = nullptr) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + operation_graph == nullptr, + error_code_t::HEURISTIC_QUERY_FAILED, + "Empty operation graph provided. Did you forget to call graph.build_operation_graph()?"); + + auto const& operation_graph_tag = operation_graph->getTag(); + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << " Getting plan from heuristics for " << operation_graph_tag << " ..."); + + std::vector statuses; +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + statuses = cudnn_frontend::get_heuristics_list( + modes, *operation_graph, allowAllConfig, configs, true, sm_count, device_properties); +#else + // build() can throw + // wrap in try catch + try { + statuses = cudnn_frontend::get_heuristics_list( + modes, *operation_graph, allowAllConfig, configs, true, sm_count, device_properties); + } catch (cudnn_frontend::cudnnException& e) { + // Silly MSVC error that thinks below condition is constexpr + // RETURN_CUDNN_FRONTEND_ERROR_IF( + // e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::HEURISTIC_QUERY_FAILED, e.what()); + CUDNN_FE_LOG_LABEL("ERROR: " << e.what() << ". "); + CUDNN_FE_LOG(error_code_t::HEURISTIC_QUERY_FAILED << " because querying heuristics failed at " << __FILE__ + << ":" << __LINE__ << "\n"); + return {error_code_t::HEURISTIC_QUERY_FAILED, e.what()}; + } +#endif + + CUDNN_FE_LOG_LABEL("INFO: get_heuristics_list statuses: "); + for (size_t i = 0; i < statuses.size(); i++) { + CUDNN_FE_LOG(cudnn_frontend::to_string(statuses[i]) << " "); + } + CUDNN_FE_LOG(std::endl); + + CUDNN_FE_LOG_LABEL_ENDL("INFO: config list has " << configs.size() << " configurations."); + + if (configs.empty()) { + std::string err_msg = detail::get_last_error_string_(); + CUDNN_FE_LOG_LABEL_ENDL("ERROR: No valid engine configs returned from heuristics.\n" << err_msg); + return {error_code_t::HEURISTIC_QUERY_FAILED, + "No valid engine configs for " + operation_graph_tag + "\n" + err_msg}; + } + return {error_code_t::OK, ""}; +} + +inline error_t +create_cudnn_execution_plan(std::shared_ptr& plan, + std::string const& serialized_data, + cudnnHandle_t handle) { + auto&& plan_builder = cudnn_frontend::ExecutionPlanBuilder(); + + plan_builder.setHandle(handle); + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto built_plan = plan_builder.loadFromJson(serialized_data); + RETURN_CUDNN_FRONTEND_ERROR_IF(built_plan.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + built_plan.get_error()); + plan = std::make_shared(std::move(built_plan)); +#else + // build() can throw + // wrap in try catch + try { + auto built_plan = plan_builder.loadFromJson(serialized_data); + plan = std::make_shared(std::move(built_plan)); + } catch (cudnn_frontend::cudnnException& e) { + // Silly MSVC error that thinks below condition is constexpr + // RETURN_CUDNN_FRONTEND_ERROR_IF( + // e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + // e.what()); + CUDNN_FE_LOG_LABEL(" ERROR: " << e.what() << ". "); + CUDNN_FE_LOG(error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED << " because plan building failed at " + << __FILE__ << ":" << __LINE__ << "\n"); + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, e.what()}; + } +#endif + + return {error_code_t::OK, ""}; +} + +inline error_t +create_cudnn_execution_plan(std::shared_ptr& plan, + ManagedOpaqueDescriptor const& config, + std::string const& operation_graph_tag, + std::shared_ptr kernel_cache) { + auto&& plan_builder = cudnn_frontend::ExecutionPlanBuilder(); + + plan_builder.setEngineConfig(config, operation_graph_tag).setKernelCache(kernel_cache); + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto built_plan = plan_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF(built_plan.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + built_plan.get_error()); + plan = std::make_shared(std::move(built_plan)); +#else + // build() can throw + // wrap in try catch + try { + auto built_plan = plan_builder.build(); + plan = std::make_shared(std::move(built_plan)); + } catch (cudnn_frontend::cudnnException& e) { + // Silly MSVC error that thinks below condition is constexpr + // RETURN_CUDNN_FRONTEND_ERROR_IF( + // e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + // e.what()); + CUDNN_FE_LOG_LABEL("ERROR: " << e.what() << ". "); + CUDNN_FE_LOG(error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED << " because plan building failed at " + << __FILE__ << ":" << __LINE__ << "\n"); + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, e.what()}; + } +#endif + + return {error_code_t::OK, ""}; +} + +} // namespace detail + +namespace graph { +class Execution_plan_list { + std::string operation_tag; + + std::vector barred_indices; + std::shared_ptr kernel_cache = nullptr; + + int64_t max_workspace_allowed = std::numeric_limits::max(); + int64_t max_shared_mem_allowed = 1024 * 1024 * 1024; // Crazy high number (2GB) which will never be hit + + std::vector barred_engine_names = {}; + EngineConfigList engine_configs; + + error_t + _build_plan_at_index_impl(int64_t index) { + if (execution_plans[index] == nullptr) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_execution_plan( + execution_plans[index], engine_configs[index], operation_tag, kernel_cache)); + } + + auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; + } + } + return false; + }; + auto const& plan_tag = execution_plans[index]->getTag(); + if (is_blocked(plan_tag, barred_engine_names)) { + barred_indices[index] = true; + + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Deselecting execution plan with name " + plan_tag + " at position " + + std::to_string(index)}; + } + + // workspace check for 9.2+ is already done at engine config level + if (detail::get_backend_version() < 90200 || detail::get_compiled_version() < 90200) { + if (execution_plans[index]->getWorkspaceSize() > max_workspace_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Workspace size is too large."}; + } + } + + // Sets candidate in case user does not call execute with plan_index later. + candidate = index; + + return {error_code_t::OK, ""}; + } + + public: + std::vector> numeric_notes; + std::vector> behavior_notes; + + std::vector> + execution_plans; // a built plan corresponding to each engine config, irrespective of whether config is + // selected or deselected. + + // Stores position of best plan in above vector of execution plan + int64_t candidate = -1; + + void + set_tag(std::string const& tag) { + operation_tag = tag; + } + void + enqueue_engine_configs(EngineConfigList list) { + std::move(list.begin(), list.end(), back_inserter(engine_configs)); + } + void + set_kernel_cache(std::shared_ptr kernel_cache_) { + kernel_cache = kernel_cache_; + } + + std::vector>& + get_execution_plans() { + return execution_plans; + } + + error_t + query_properties() { + numeric_notes.reserve(engine_configs.size()); + behavior_notes.reserve(engine_configs.size()); + + barred_indices.resize(engine_configs.size(), 0); + execution_plans.resize(engine_configs.size()); + + for (auto& engine_config : engine_configs) { + int64_t elem_count = 0; + std::vector numeric; + std::vector behavior; + + ManagedOpaqueDescriptor extractedEngine = make_shared_backend_pointer(CUDNN_BACKEND_ENGINE_DESCRIPTOR); + cudnnBackendDescriptor_t extractedEngine_ = extractedEngine->get_backend_descriptor(); + auto status = detail::get_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_ENGINE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &elem_count, + &extractedEngine_); + RETURN_CUDNN_FRONTEND_ERROR_IF((status != CUDNN_STATUS_SUCCESS), + error_code_t::HEURISTIC_QUERY_FAILED, + "Heuristic query Engine failed."); + + status = detail::get_attribute(extractedEngine_, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, + &elem_count, + nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF((status != CUDNN_STATUS_SUCCESS), + error_code_t::HEURISTIC_QUERY_FAILED, + "Heuristic query Numerical Note failed"); + + numeric.resize(static_cast(elem_count)); + status = detail::get_attribute(extractedEngine_, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, + &elem_count, + numeric.data()); + RETURN_CUDNN_FRONTEND_ERROR_IF((status != CUDNN_STATUS_SUCCESS), + error_code_t::HEURISTIC_QUERY_FAILED, + "Heuristic query Numerical Note failed"); + status = detail::get_attribute(extractedEngine_, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, + CUDNN_TYPE_BEHAVIOR_NOTE, + CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, + &elem_count, + nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF((status != CUDNN_STATUS_SUCCESS), + error_code_t::HEURISTIC_QUERY_FAILED, + "Heuristic query Behavior Note failed"); + + behavior.resize(static_cast(elem_count)); + status = detail::get_attribute(extractedEngine_, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, + CUDNN_TYPE_BEHAVIOR_NOTE, + CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, + &elem_count, + behavior.data()); + RETURN_CUDNN_FRONTEND_ERROR_IF((status != CUDNN_STATUS_SUCCESS), + error_code_t::HEURISTIC_QUERY_FAILED, + "Heuristic query Behavior Note failed"); + + std::vector numerics; + numerics.resize(numeric.size()); + for (auto& note : numeric) { + numerics.push_back(detail::convert_from_cudnn_type(note)); + } + numeric_notes.emplace_back(std::move(numerics)); + + std::vector behaviors; + behaviors.reserve(behaviors.size()); + for (auto& note : behavior) { + behaviors.push_back(detail::convert_from_cudnn_type(note)); + } + behavior_notes.emplace_back(std::move(behaviors)); + } + return {error_code_t::OK, ""}; + } + + error_t + filter_numeric_notes(std::vector const& notes, bool const keep) { + for (auto& note : notes) { + for (auto i = 0u; i < engine_configs.size(); i++) { + bool has_barred_note = + std::find(numeric_notes[i].begin(), numeric_notes[i].end(), note) != numeric_notes[i].end(); + + barred_indices[i] = barred_indices[i] || (has_barred_note ? !keep : keep); + } + } + return {error_code_t::OK, ""}; + } + + error_t + filter_behavior_notes(std::vector const& notes, bool const keep) { + for (auto& note : notes) { + for (auto i = 0u; i < engine_configs.size(); i++) { + bool has_barred_note = + std::find(behavior_notes[i].begin(), behavior_notes[i].end(), note) != behavior_notes[i].end(); + + barred_indices[i] = barred_indices[i] || (has_barred_note ? !keep : keep); + } + } + return {error_code_t::OK, ""}; + } + + void + set_max_workspace_allowed(int64_t const workspace_allowed) { + max_workspace_allowed = workspace_allowed; + } + + void + set_max_shared_mem_allowed(int64_t const smem_allowed) { + max_shared_mem_allowed = smem_allowed; + } + + void + set_barred_names(std::vector const& engine_names) { + barred_engine_names = engine_names; + } + + EngineConfigList + get_barred_engine_configs() { + EngineConfigList barred_engine_configs; + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << " Filtering engine_configs ..." << engine_configs.size()); + for (auto i = 0u; i < engine_configs.size(); i++) { + if (barred_indices[i] == false) { + barred_engine_configs.push_back(engine_configs[i]); + } + } + CUDNN_FE_LOG_LABEL_ENDL("INFO: " << " barred engine_configs ..." << barred_engine_configs.size()); + return barred_engine_configs; + } + + error_t + get_name_at_index(int64_t index, std::string& name) const { + name = detail::get_engine_tag(engine_configs[index]); + return {error_code_t::OK, ""}; + } + + error_t + check_support_at_index(int64_t index) { + // Ignore if the engine config was deselected. + // This usually happens when user deselects by numerical and behavioural notes. + + RETURN_CUDNN_FRONTEND_ERROR_IF((index < 0) || (static_cast(barred_indices.size()) <= index), + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " is invalid."); + + if (barred_indices[index] == true) { + CUDNN_FE_LOG_LABEL_ENDL("Deselecting execution plan at position " << index); + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(barred_indices[index] == true, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "Deselecting execution plan"); + + // Ignore if engine name was specified to be ignored by the user. + auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; + } + } + return false; + }; + auto cfg_tag = detail::get_engine_tag(engine_configs[index]); + if (is_blocked(cfg_tag, barred_engine_names)) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Deselecting execution plan with name " + cfg_tag + " at position " + + std::to_string(index)}; + } + + if (detail::get_backend_version() >= 90200 && detail::get_compiled_version() >= 90200) { + // Ignore kernels that require larger than tolerable shared memory. + int32_t shared_memory_size = INT32_MAX; + auto status = detail::get_shared_memory_size(engine_configs[index], shared_memory_size); + if (status.is_bad()) { + CUDNN_FE_LOG_LABEL_ENDL("WARN: Unknown Shared memory size, so not deselecting plan at position " + << index); + } else if (shared_memory_size > max_shared_mem_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Skipping plan since shared memory violation. Requires " + + std::to_string(shared_memory_size)}; + } + + // Filter by workspace can happen at this engine config stage itself. + int64_t workspace_size = INT64_MAX; + CHECK_CUDNN_FRONTEND_ERROR(detail::get_workspace_size(engine_configs[index], workspace_size)); + if (workspace_size > max_workspace_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Skipping plan since workspace violation. Requires " + + std::to_string(workspace_size)}; + } + } + // Else we need to build the config. A successful execution plan build means that check_support succeeded. + else { + CHECK_CUDNN_FRONTEND_ERROR(_build_plan_at_index_impl(index)); + } + + CUDNN_FE_LOG_LABEL_ENDL("Check support for index " << index << " passed with cfg " << cfg_tag); + // All checks passed for this config, so return success. + return {error_code_t::OK, ""}; + } + + error_t + check_support() { + // Go over each engine config and return true when you find the first one that is supported. + for (auto i = 0u; i < engine_configs.size(); i++) { + auto status = check_support_at_index(i); + if (status.is_good()) { + return {error_code_t::OK, ""}; + } + } + + std::string err_msg = detail::get_last_error_string_(); + CUDNN_FE_LOG_LABEL_ENDL("ERROR: No valid engine configs returned from heuristics.\n" << err_msg); + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: No execution plans support the graph." + err_msg}; + } + + error_t + get_behavior_notes_at_index(int64_t const index, std::vector& notes) const { + RETURN_CUDNN_FRONTEND_ERROR_IF((index < 0) || (static_cast(behavior_notes.size()) <= index), + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " is invalid."); + + notes = behavior_notes[index]; + + return {error_code_t::OK, ""}; + } + + error_t + build_plans(cudnnHandle_t handle, std::string const& json) { + execution_plans.resize(1); + auto const& fe_status = detail::create_cudnn_execution_plan(execution_plans[0], json, handle); + + if (fe_status.is_good()) { + candidate = 0; + } + + return fe_status; + } + + error_t + build_plan_at_index(int64_t index) { + CHECK_CUDNN_FRONTEND_ERROR(check_support_at_index(index)); + CHECK_CUDNN_FRONTEND_ERROR(_build_plan_at_index_impl(index)); + + return {error_code_t::OK, ""}; + } + + error_t + build_plans(BuildPlanPolicy_t const policy, bool const do_multithreaded_builds) { + RETURN_CUDNN_FRONTEND_ERROR_IF(do_multithreaded_builds, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "Doing multithreaded builds is not yet supported."); + + // short circuit in case a plan was already created. + // This happens as check_support for v8 builds a plan. + if (policy == BuildPlanPolicy_t::HEURISTICS_CHOICE && candidate != -1) { + return {error_code_t::OK, ""}; + } + + for (auto i = 0u; i < engine_configs.size(); i++) { + auto status = build_plan_at_index(i); + if (status.is_bad()) { + CUDNN_FE_LOG_LABEL_ENDL("WARN: Failed to build plan at " << i); + continue; + } + + // Only set the candidate the first time, as the order of iteration is from highest to lowest priority + if (candidate == -1) { + candidate = static_cast(i); + CUDNN_FE_LOG_LABEL_ENDL("INFO: Candidate set as " << i); + } + + // Return from this function as first successfully built plan is found. + if (policy == BuildPlanPolicy_t::HEURISTICS_CHOICE) { + return {error_code_t::OK, ""}; + } + } + + // Return an error if no execution plans could be built + RETURN_CUDNN_FRONTEND_ERROR_IF(candidate == -1, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: No valid execution plans built."); + + return {error_code_t::OK, ""}; + } + + int64_t + get_autotune_workspace() const { + int64_t max_size = 0; + for (auto& plan : execution_plans) { + max_size = std::max(max_size, plan->getWorkspaceSize()); + } + return max_size; + } + + static error_t + autotune_default_impl(std::vector>& execution_plans, + cudnnHandle_t handle, + std::unordered_map const& tensor_to_pointer_map, + void* workspace_ptr, + void*) { + // Create the variant pack for all the plans to use. + std::vector uids; + std::vector ptrs; + for (auto it : tensor_to_pointer_map) { + uids.push_back(it.first); + ptrs.push_back(it.second); + } + + std::vector> time_sorted_plans; + + auto plan_cmp = [](std::shared_ptr a, std::shared_ptr b) { + return a->getExecutionTime() < b->getExecutionTime(); + }; + + std::multiset, decltype(plan_cmp)> timed_execution_plans(plan_cmp); + + const int maxIterCount = 100; + const float threshhold = 0.95f; + uint64_t successful_plan_count = 0; + cudaEvent_t start, stop; + detail::cuda_event_create(&start); + detail::cuda_event_create(&stop); + detail::cuda_device_synchronize(); + + cudaStream_t stream = nullptr; + detail::get_stream(handle, &stream); + + for (auto plan : execution_plans) { + float time_ms = 0.0f; + float final_time_ms = 0.0f; + float min_time_ms = std::numeric_limits::max(); + + // Warm-up run + CHECK_CUDNN_FRONTEND_ERROR(detail::execute(handle, plan.get(), ptrs, uids, workspace_ptr)); + successful_plan_count++; + detail::cuda_device_synchronize(); + + for (int i = 0; i < maxIterCount; i++) { + detail::cuda_event_record(start, stream); + + auto status = detail::execute(handle, plan.get(), ptrs, uids, workspace_ptr); + + detail::cuda_event_record(stop, stream); + detail::cuda_event_synchronize(stop); + detail::cuda_event_elapsed_time(&time_ms, start, stop); + + final_time_ms = std::min(min_time_ms, time_ms); + if (time_ms / min_time_ms < threshhold) { + min_time_ms = final_time_ms; + } else { + break; + } + } + + CUDNN_FE_LOG_LABEL_ENDL("Plan " << plan->getTag() << " took " << std::setw(10) << final_time_ms); + plan->setExecutionTime(final_time_ms); + timed_execution_plans.insert(plan); + } + + execution_plans.clear(); + for (auto sorted_plan : timed_execution_plans) { + execution_plans.push_back(sorted_plan); + } + + detail::cuda_event_destroy(start); + detail::cuda_event_destroy(stop); + + CUDNN_FE_LOG_LABEL_ENDL("Autotuned " << successful_plan_count << " plans."); + return {error_code_t::OK, ""}; + } + + std::function>&, + cudnnHandle_t, + std::unordered_map const&, + void*, + void*)> + autotune_impl = &Execution_plan_list::autotune_default_impl; + + error_t + autotune(cudnnHandle_t handle, + std::unordered_map const& tensor_to_pointer_map, + void* workspace, + void* user_impl = nullptr) { + auto error = autotune_impl(execution_plans, handle, tensor_to_pointer_map, workspace, user_impl); + return error; + } + + error_t + is_plan_index_executable(int64_t const index) const { + RETURN_CUDNN_FRONTEND_ERROR_IF((index < 0) || (static_cast(execution_plans.size()) <= index), + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " is invalid."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(execution_plans[index] == nullptr, + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " did not build."); + + return {error_code_t::OK, ""}; + } +}; + +} // namespace graph +} // namespace cudnn_frontend diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/LICENSE.MIT b/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/LICENSE.MIT new file mode 100644 index 00000000..1c1f7a69 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/LICENSE.MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2013-2022 Niels Lohmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/json.hpp b/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/json.hpp new file mode 100644 index 00000000..85b4cdd3 --- /dev/null +++ b/third_party/cudnn-frontend/include/cudnn_frontend/thirdparty/nlohmann/json.hpp @@ -0,0 +1,26710 @@ +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +/****************************************************************************\ + * Note on documentation: The source files contain links to the online * + * documentation of the public API at https://json.nlohmann.me. This URL * + * contains the most recent documentation and should also be applicable to * + * previous versions; documentation for deprecated functions is not * + * removed, but marked deprecated. See "Generate documentation" section in * + * file docs/README.md. * +\****************************************************************************/ + +#ifndef INCLUDE_NLOHMANN_JSON_HPP_ +#define INCLUDE_NLOHMANN_JSON_HPP_ + +#include // all_of, find, for_each +#include // nullptr_t, ptrdiff_t, size_t +#include // hash, less +#include // initializer_list +#ifndef JSON_NO_IO +#include // istream, ostream +#endif // JSON_NO_IO +#include // random_access_iterator_tag +#include // unique_ptr +#include // string, stoi, to_string +#include // declval, forward, move, pair, swap +#include // vector + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +// This file contains all macro definitions affecting or depending on the ABI + +#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK +#if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) +#if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3 +#warning "Already included a different version of the library!" +#endif +#endif +#endif + +#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) +#define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) +#define NLOHMANN_JSON_VERSION_PATCH 3 // NOLINT(modernize-macro-to-enum) + +#ifndef JSON_DIAGNOSTICS +#define JSON_DIAGNOSTICS 0 +#endif + +#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON +#define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 +#endif + +#if JSON_DIAGNOSTICS +#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag +#else +#define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS +#endif + +#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON +#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp +#else +#define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION +#define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 +#endif + +// Construct the namespace ABI tags component +#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi##a##b +#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) + +#define NLOHMANN_JSON_ABI_TAGS \ + NLOHMANN_JSON_ABI_TAGS_CONCAT(NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ + NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON) + +// Construct the namespace version component +#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) _v##major##_##minor##_##patch +#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ + NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) + +#if NLOHMANN_JSON_NAMESPACE_NO_VERSION +#define NLOHMANN_JSON_NAMESPACE_VERSION +#else +#define NLOHMANN_JSON_NAMESPACE_VERSION \ + NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT( \ + NLOHMANN_JSON_VERSION_MAJOR, NLOHMANN_JSON_VERSION_MINOR, NLOHMANN_JSON_VERSION_PATCH) +#endif + +// Combine namespace components +#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a##b +#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) + +#ifndef NLOHMANN_JSON_NAMESPACE +#define NLOHMANN_JSON_NAMESPACE \ + nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT(NLOHMANN_JSON_ABI_TAGS, NLOHMANN_JSON_NAMESPACE_VERSION) +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN \ + namespace nlohmann { \ + inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT(NLOHMANN_JSON_ABI_TAGS, NLOHMANN_JSON_NAMESPACE_VERSION) { +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END \ + } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ + } // namespace nlohmann +#endif + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include // transform +#include // array +#include // forward_list +#include // inserter, front_inserter, end +#include // map +#include // string +#include // tuple, make_tuple +#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible +#include // unordered_map +#include // pair, declval +#include // valarray + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include // nullptr_t +#include // exception +#if JSON_DIAGNOSTICS +#include // accumulate +#endif +#include // runtime_error +#include // to_string +#include // vector + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include // array +#include // size_t +#include // uint8_t +#include // string + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include // declval, pair +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +#include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + +// #include + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail { + +template +struct make_void { + using type = void; +}; +template +using void_t = typename make_void::type; + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail { + +// https://en.cppreference.com/w/cpp/experimental/is_detected +struct nonesuch { + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + nonesuch(nonesuch const&&) = delete; + void + operator=(nonesuch const&) = delete; + void + operator=(nonesuch&&) = delete; +}; + +template class Op, class... Args> +struct detector { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +template