From 6da9e02d1bbdcfae7bb4a0d1abde457139598c7a Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Thu, 4 Sep 2025 14:40:14 +0000 Subject: [PATCH 1/5] add denormal flags Signed-off-by: Fabian Mora --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 20 ++++- compiler/plugins/target/ROCM/test/BUILD.bazel | 1 + .../plugins/target/ROCM/test/CMakeLists.txt | 1 + .../plugins/target/ROCM/test/func_attrs.mlir | 44 +++++++++++ .../Dialect/Codegen/IR/IREECodegenAttrs.td | 27 +++++++ .../ROCDLAnnotateKernelForTranslation.cpp | 6 ++ .../annotate_kernel_for_translation.mlir | 76 +++++++++++++++++++ .../iree/compiler/Codegen/Utils/GPUUtils.cpp | 26 +++++++ .../iree/compiler/Codegen/Utils/GPUUtils.h | 9 +++ 9 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 compiler/plugins/target/ROCM/test/func_attrs.mlir diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 0827d4e28eda..758d39ac0f61 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -84,9 +84,10 @@ struct ROCMOptions { std::string encodingLayoutResolver = GPU::kNoEncodingLayoutResolverName; bool slpVectorization = true; bool globalISel = false; - bool specializeDispatches = false; bool enableTensorUKernels = false; + IREE::Codegen::DenormalFpMath denormalFpMathF32 = + IREE::Codegen::DenormalFpMath::None; void bindOptions(OptionsBinder &binder) { using namespace llvm; @@ -161,6 +162,19 @@ struct ROCMOptions { binder.opt("iree-hip-enable-tensor-ukernels", enableTensorUKernels, cl::cat(category), cl::desc("Enable MLIR-based ukernels.")); + binder.opt( + "iree-hip-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category), + cl::desc("Denormal floating point math mode for f32"), + cl::values( + clEnumValN(IREE::Codegen::DenormalFpMath::IEEE, "ieee", + "Use IEEE 754-2008 denormal behavior"), + clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign, + "preserve-sign", + "Convert denormals to zero while preserving sign"), + clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero, + "positive-zero", "Convert denormals to positive zero"), + clEnumValN(IREE::Codegen::DenormalFpMath::Dynamic, "dynamic", + "Let runtime decide denormal behavior"))); } LogicalResult verify(mlir::Builder &builder) const { @@ -337,6 +351,10 @@ class ROCMTargetBackend final : public TargetBackend { if (options.wavesPerEu > 0) { addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems); } + if (options.denormalFpMathF32 != IREE::Codegen::DenormalFpMath::None) { + addConfigDenormalFpMathF32(b.getContext(), options.denormalFpMathF32, + configItems); + } if (options.enableTensorUKernels) { addConfig(kUKernelProviderName, diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index e00154a12f74..c56a9654e948 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -19,6 +19,7 @@ iree_lit_test_suite( "config_ukernel_data_tiled_mma_gfx942.mlir", "default_tuning_specs_amdgpu.mlir", "enable_tensor_ukernels.mlir", + "func_attrs.mlir", "gpu_encoding_attrs.mlir", "lower_rocm_ukernel_descriptor.mlir", "lowering_strategy_from_tuning_spec.mlir", diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index cfe872b9f57d..874aabd303c4 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -19,6 +19,7 @@ iree_lit_test_suite( "config_ukernel_data_tiled_mma_gfx942.mlir" "default_tuning_specs_amdgpu.mlir" "enable_tensor_ukernels.mlir" + "func_attrs.mlir" "gpu_encoding_attrs.mlir" "lower_rocm_ukernel_descriptor.mlir" "lowering_strategy_from_tuning_spec.mlir" diff --git a/compiler/plugins/target/ROCM/test/func_attrs.mlir b/compiler/plugins/target/ROCM/test/func_attrs.mlir new file mode 100644 index 000000000000..23a69947acc4 --- /dev/null +++ b/compiler/plugins/target/ROCM/test/func_attrs.mlir @@ -0,0 +1,44 @@ +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ +// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a %s | FileCheck --check-prefix=CHECK-NONE %s + +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ +// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ +// RUN: --iree-hip-denormal-fp-math-f32=ieee --iree-hip-waves-per-eu=1 %s | FileCheck --check-prefix=CHECK-IEEE %s + +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ +// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ +// RUN: --iree-hip-denormal-fp-math-f32=positive-zero --iree-hip-waves-per-eu=2 %s | FileCheck --check-prefix=CHECK-POSITIVE-ZERO %s + +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ +// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ +// RUN: --iree-hip-denormal-fp-math-f32=preserve-sign --iree-hip-waves-per-eu=3 %s | FileCheck --check-prefix=CHECK-PRESERVE-SIGN %s + +// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ +// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ +// RUN: --iree-hip-denormal-fp-math-f32=dynamic --iree-hip-waves-per-eu=4 %s | FileCheck --check-prefix=CHECK-DYNAMIC %s + +module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} { + // CHECK-NONE-LABEL: #hal.executable.target< + // CHECK-NONE-NOT: denormal_fp_math_f32 + // CHECK-NONE-NOT: waves_per_eu + + // CHECK-IEEE-LABEL: #hal.executable.target< + // CHECK-IEEE-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math + // CHECK-IEEE-SAME: waves_per_eu = 1 + + // CHECK-POSITIVE-ZERO-LABEL: #hal.executable.target< + // CHECK-POSITIVE-ZERO-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero"> + // CHECK-POSITIVE-ZERO-SAME: waves_per_eu = 2 + + // CHECK-PRESERVE-SIGN-LABEL: #hal.executable.target< + // CHECK-PRESERVE-SIGN-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign"> + // CHECK-PRESERVE-SIGN-SAME: waves_per_eu = 3 + + // CHECK-DYNAMIC-LABEL: #hal.executable.target< + // CHECK-DYNAMIC-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math + // CHECK-DYNAMIC-SAME: waves_per_eu = 4 + util.global private @__device_0 = #hal.device.alias<"hip"> : !hal.device + util.func public @softmax_static_10x256x256xf32() { + util.return + } +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 17e44e5afe3e..53407cd26c77 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -610,4 +610,31 @@ def IREECodegen_XORShuffleAttr : let genVerifyDecl = 1; } +//===---------------------------------------------------------------------===// +// iree_codegen.denormal_mode +//===---------------------------------------------------------------------===// + +def DenormalMode_None : I32EnumCase<"None", 0, "none">; +def DenormalMode_IEEE : I32EnumCase<"IEEE", 1, "ieee">; +def DenormalMode_PreserveSign : I32EnumCase<"PreserveSign", 2, "preserve-sign">; +def DenormalMode_PositiveZero : I32EnumCase<"PositiveZero", 3, "positive-zero">; +def DenormalMode_Dynamic : I32EnumCase<"Dynamic", 4, "dynamic">; + +def DenormalFpMathEnum : I32Enum< + "DenormalFpMath", + "Denormal mode for fp math", [ + DenormalMode_None, + DenormalMode_IEEE, + DenormalMode_PreserveSign, + DenormalMode_PositiveZero, + DenormalMode_Dynamic + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; +} + +def DenormalFpMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp index 35ea3db6fc10..c135f740bce9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp @@ -74,6 +74,12 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, getConfigWavesPerEuAttr(targetAttr.getConfiguration())) { rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, attr); } + if (IREE::Codegen::DenormalFpMathAttr attr = + getConfigDenormalFpMathF32Attr(targetAttr.getConfiguration()); + attr && attr.getValue() != IREE::Codegen::DenormalFpMath::None) { + funcOp.setDenormalFpMathF32( + IREE::Codegen::stringifyDenormalFpMath(attr.getValue())); + } // Kernel argument preloading is only supported on gfx942 and newer targets // from the CDNA family. This is enabled using the `inreg` function argument diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir index 87909ba2c86c..d279d1bb3087 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -114,3 +114,79 @@ builtin.module { // CHECK-LABEL: llvm.func @test_no_kern_arg // CHECK-SAME: (%{{.+}}: i32) + +// ----- + +// Check that denormal and waves_per_eu are set + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, + iree_codegen.target_info = #iree_gpu.target>, + ukernels = "none", + waves_per_eu = 2 : i64 + }> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_rocdl_attrs { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} + builtin.module { + // CHECK-LABEL: llvm.func @test_rocdl_attrs + // CHECK: denormal_fp_math_f32 = "preserve-sign" + // CHECK: rocdl.waves_per_eu = 2 : i64 + llvm.func @test_rocdl_attrs(%arg0: i32) { + llvm.return + } + } + } + } +} + +// ----- + +// Check that denormal and waves_per_eu are set + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + denormal_fp_math_f32 = #iree_codegen.denormal_fp_math, + iree_codegen.target_info = #iree_gpu.target>, + ukernels = "none", + waves_per_eu = 1 : i64 + }> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_rocdl_attrs { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} + builtin.module { + // CHECK-LABEL: llvm.func @test_rocdl_attrs + // CHECK: denormal_fp_math_f32 = "ieee" + // CHECK: rocdl.waves_per_eu = 1 : i64 + llvm.func @test_rocdl_attrs(%arg0: i32) { + llvm.return + } + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index d542d065db34..9171251ad543 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -35,6 +35,7 @@ #define DBGSNL() (llvm::dbgs() << "\n") constexpr unsigned kShuffleBitWidth = 32; +constexpr char kDenormalFpMathF32AttrName[] = "denormal_fp_math_f32"; // TODO: These are AMD GPU specific. These need to find a better home. constexpr char kWavesPerEuAttrName[] = "waves_per_eu"; @@ -1043,6 +1044,31 @@ void addConfigWavesPerEu(MLIRContext *context, int64_t wavesPerEu, IntegerAttr::get(IntegerType::get(context, 64), wavesPerEu)); } +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig) { + if (!targetConfig) { + return {}; + } + + return targetConfig.getAs( + kDenormalFpMathF32AttrName); +} +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig) { + IREE::Codegen::DenormalFpMathAttr attr = + getConfigDenormalFpMathF32Attr(targetConfig); + if (!attr) { + return std::nullopt; + } + return attr.getValue(); +} +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config) { + config.emplace_back(StringAttr::get(context, kDenormalFpMathF32AttrName), + IREE::Codegen::DenormalFpMathAttr::get(context, mode)); +} + std::optional getGPUSubgroupSize(mlir::FunctionOpInterface func) { // First try to see if there is a subgroup size chosen in the CodeGen pipeline // configuration. diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index 5d7df379fc5f..e41b432fe7f9 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" @@ -207,11 +208,19 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op); std::optional getConfigWavesPerEu(DictionaryAttr targetAttr); IntegerAttr getConfigWavesPerEuAttr(DictionaryAttr targetAttr); +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig); +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig); + /// Methods to add attributes to the `config` list. void addConfigGPUTarget(MLIRContext *context, IREE::GPU::TargetAttr, SmallVectorImpl &config); void addConfigWavesPerEu(MLIRContext *context, int64_t wavesPerEu, SmallVectorImpl &config); +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config); /// Returns the GPU subgroup size chosen for the current CodeGen pipeline if /// exists; otherwise returns the subgroup size from the GPU target description. From d62a5a0b9009675ada1afe9597cee01d0a96e89b Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 8 Sep 2025 12:11:27 +0000 Subject: [PATCH 2/5] address comments Signed-off-by: Fabian Mora --- compiler/plugins/target/ROCM/test/BUILD.bazel | 1 - .../plugins/target/ROCM/test/CMakeLists.txt | 1 - .../plugins/target/ROCM/test/func_attrs.mlir | 44 ------------------- 3 files changed, 46 deletions(-) delete mode 100644 compiler/plugins/target/ROCM/test/func_attrs.mlir diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index c56a9654e948..e00154a12f74 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -19,7 +19,6 @@ iree_lit_test_suite( "config_ukernel_data_tiled_mma_gfx942.mlir", "default_tuning_specs_amdgpu.mlir", "enable_tensor_ukernels.mlir", - "func_attrs.mlir", "gpu_encoding_attrs.mlir", "lower_rocm_ukernel_descriptor.mlir", "lowering_strategy_from_tuning_spec.mlir", diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index 874aabd303c4..cfe872b9f57d 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -19,7 +19,6 @@ iree_lit_test_suite( "config_ukernel_data_tiled_mma_gfx942.mlir" "default_tuning_specs_amdgpu.mlir" "enable_tensor_ukernels.mlir" - "func_attrs.mlir" "gpu_encoding_attrs.mlir" "lower_rocm_ukernel_descriptor.mlir" "lowering_strategy_from_tuning_spec.mlir" diff --git a/compiler/plugins/target/ROCM/test/func_attrs.mlir b/compiler/plugins/target/ROCM/test/func_attrs.mlir deleted file mode 100644 index 23a69947acc4..000000000000 --- a/compiler/plugins/target/ROCM/test/func_attrs.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ -// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a %s | FileCheck --check-prefix=CHECK-NONE %s - -// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ -// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ -// RUN: --iree-hip-denormal-fp-math-f32=ieee --iree-hip-waves-per-eu=1 %s | FileCheck --check-prefix=CHECK-IEEE %s - -// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ -// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ -// RUN: --iree-hip-denormal-fp-math-f32=positive-zero --iree-hip-waves-per-eu=2 %s | FileCheck --check-prefix=CHECK-POSITIVE-ZERO %s - -// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ -// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ -// RUN: --iree-hip-denormal-fp-math-f32=preserve-sign --iree-hip-waves-per-eu=3 %s | FileCheck --check-prefix=CHECK-PRESERVE-SIGN %s - -// RUN: iree-opt --split-input-file --iree-hal-resolve-device-aliases \ -// RUN: --iree-hal-target-device=hip --iree-hip-target=gfx90a \ -// RUN: --iree-hip-denormal-fp-math-f32=dynamic --iree-hip-waves-per-eu=4 %s | FileCheck --check-prefix=CHECK-DYNAMIC %s - -module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} { - // CHECK-NONE-LABEL: #hal.executable.target< - // CHECK-NONE-NOT: denormal_fp_math_f32 - // CHECK-NONE-NOT: waves_per_eu - - // CHECK-IEEE-LABEL: #hal.executable.target< - // CHECK-IEEE-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math - // CHECK-IEEE-SAME: waves_per_eu = 1 - - // CHECK-POSITIVE-ZERO-LABEL: #hal.executable.target< - // CHECK-POSITIVE-ZERO-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero"> - // CHECK-POSITIVE-ZERO-SAME: waves_per_eu = 2 - - // CHECK-PRESERVE-SIGN-LABEL: #hal.executable.target< - // CHECK-PRESERVE-SIGN-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign"> - // CHECK-PRESERVE-SIGN-SAME: waves_per_eu = 3 - - // CHECK-DYNAMIC-LABEL: #hal.executable.target< - // CHECK-DYNAMIC-SAME: denormal_fp_math_f32 = #iree_codegen.denormal_fp_math - // CHECK-DYNAMIC-SAME: waves_per_eu = 4 - util.global private @__device_0 = #hal.device.alias<"hip"> : !hal.device - util.func public @softmax_static_10x256x256xf32() { - util.return - } -} From 453945f1f9f72036eff1d61de334cad1c019c429 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 8 Sep 2025 14:28:07 +0000 Subject: [PATCH 3/5] address comments Signed-off-by: Fabian Mora --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 6 +----- .../Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td | 14 +++++++------- .../ROCDL/annotate_kernel_for_translation.mlir | 8 ++++---- .../src/iree/compiler/Codegen/Utils/GPUUtils.cpp | 2 +- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 758d39ac0f61..eb5c470efa62 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -166,15 +166,11 @@ struct ROCMOptions { "iree-hip-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category), cl::desc("Denormal floating point math mode for f32"), cl::values( - clEnumValN(IREE::Codegen::DenormalFpMath::IEEE, "ieee", - "Use IEEE 754-2008 denormal behavior"), clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign, "preserve-sign", "Convert denormals to zero while preserving sign"), clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero, - "positive-zero", "Convert denormals to positive zero"), - clEnumValN(IREE::Codegen::DenormalFpMath::Dynamic, "dynamic", - "Let runtime decide denormal behavior"))); + "positive-zero", "Convert denormals to positive zero"))); } LogicalResult verify(mlir::Builder &builder) const { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 53407cd26c77..3d6450b90e7e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -614,20 +614,20 @@ def IREECodegen_XORShuffleAttr : // iree_codegen.denormal_mode //===---------------------------------------------------------------------===// +// Do not add any denormal mode. def DenormalMode_None : I32EnumCase<"None", 0, "none">; -def DenormalMode_IEEE : I32EnumCase<"IEEE", 1, "ieee">; -def DenormalMode_PreserveSign : I32EnumCase<"PreserveSign", 2, "preserve-sign">; -def DenormalMode_PositiveZero : I32EnumCase<"PositiveZero", 3, "positive-zero">; -def DenormalMode_Dynamic : I32EnumCase<"Dynamic", 4, "dynamic">; +// Convert denormals to zero while preserving sign. +def DenormalMode_PreserveSign : I32EnumCase<"PreserveSign", 1, "preserve-sign">; +// Convert denormals to positive zero. +def DenormalMode_PositiveZero : I32EnumCase<"PositiveZero", 2, "positive-zero">; +// Denormal fp mode. def DenormalFpMathEnum : I32Enum< "DenormalFpMath", "Denormal mode for fp math", [ DenormalMode_None, - DenormalMode_IEEE, DenormalMode_PreserveSign, - DenormalMode_PositiveZero, - DenormalMode_Dynamic + DenormalMode_PositiveZero ]> { let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir index d279d1bb3087..09c69ff98b30 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -117,7 +117,7 @@ builtin.module { // ----- -// Check that denormal and waves_per_eu are set +// Check that denormal and waves_per_eu are set. #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, @@ -155,10 +155,10 @@ builtin.module { // ----- -// Check that denormal and waves_per_eu are set +// Check that denormal and waves_per_eu are set. #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { - denormal_fp_math_f32 = #iree_codegen.denormal_fp_math, + denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero">, iree_codegen.target_info = #iree_gpu.target &config) { - config.emplace_back(StringAttr::get(context, kDenormalFpMathF32AttrName), + config.emplace_back(kDenormalFpMathF32AttrName, IREE::Codegen::DenormalFpMathAttr::get(context, mode)); } From 6f8a96136b9b98655b8e4cc1133c5bc2a4c1ef8c Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 8 Sep 2025 16:52:53 +0000 Subject: [PATCH 4/5] address comments Signed-off-by: Fabian Mora --- .../test/ROCDL/annotate_kernel_for_translation.mlir | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir index 09c69ff98b30..4d9c7a97cdc2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -117,7 +117,7 @@ builtin.module { // ----- -// Check that denormal and waves_per_eu are set. +// Check that denormal is set. #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, @@ -129,8 +129,7 @@ builtin.module { max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, - ukernels = "none", - waves_per_eu = 2 : i64 + ukernels = "none" }> #pipeline_layout = #hal.pipeline.layout], flags = Indirect> @@ -144,7 +143,6 @@ builtin.module { builtin.module { // CHECK-LABEL: llvm.func @test_rocdl_attrs // CHECK: denormal_fp_math_f32 = "preserve-sign" - // CHECK: rocdl.waves_per_eu = 2 : i64 llvm.func @test_rocdl_attrs(%arg0: i32) { llvm.return } @@ -155,7 +153,7 @@ builtin.module { // ----- -// Check that denormal and waves_per_eu are set. +// Check that denormal is set. #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero">, @@ -167,8 +165,7 @@ builtin.module { max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, - ukernels = "none", - waves_per_eu = 1 : i64 + ukernels = "none" }> #pipeline_layout = #hal.pipeline.layout], flags = Indirect> @@ -182,7 +179,6 @@ builtin.module { builtin.module { // CHECK-LABEL: llvm.func @test_rocdl_attrs // CHECK: denormal_fp_math_f32 = "positive-zero" - // CHECK: rocdl.waves_per_eu = 1 : i64 llvm.func @test_rocdl_attrs(%arg0: i32) { llvm.return } From 6b1629d7ec1034eee0c6646c5f616d50ed664371 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Wed, 10 Sep 2025 19:57:50 +0000 Subject: [PATCH 5/5] address reviewer comments Signed-off-by: Fabian Mora --- .../Dialect/Codegen/IR/IREECodegenAttrs.td | 7 +++++ .../annotate_kernel_for_translation.mlir | 4 +-- .../iree/compiler/Codegen/Utils/GPUUtils.cpp | 26 ------------------- .../iree/compiler/Codegen/Utils/GPUUtils.h | 8 ------ .../src/iree/compiler/Codegen/Utils/Utils.cpp | 25 ++++++++++++++++++ .../src/iree/compiler/Codegen/Utils/Utils.h | 16 ++++++++++++ 6 files changed, 50 insertions(+), 36 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 3d6450b90e7e..d70bd2d805ef 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -635,6 +635,13 @@ def DenormalFpMathEnum : I32Enum< def DenormalFpMathAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; + + let extraClassDeclaration = [{ + /// Returns the key name for fp32 DenormalFpMathAttr in a DictionaryAttr. + static StringRef getFP32DictKeyName() { + return "iree_codegen.denormal_fp_math_f32"; + } + }]; } #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir index 4d9c7a97cdc2..cfa6e4907cda 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -120,7 +120,7 @@ builtin.module { // Check that denormal is set. #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { - denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, + iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, iree_codegen.target_info = #iree_gpu.target, + iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero">, iree_codegen.target_info = #iree_gpu.target( - kDenormalFpMathF32AttrName); -} -std::optional -getConfigDenormalFpMathF32(DictionaryAttr targetConfig) { - IREE::Codegen::DenormalFpMathAttr attr = - getConfigDenormalFpMathF32Attr(targetConfig); - if (!attr) { - return std::nullopt; - } - return attr.getValue(); -} -void addConfigDenormalFpMathF32(MLIRContext *context, - IREE::Codegen::DenormalFpMath mode, - SmallVectorImpl &config) { - config.emplace_back(kDenormalFpMathF32AttrName, - IREE::Codegen::DenormalFpMathAttr::get(context, mode)); -} - std::optional getGPUSubgroupSize(mlir::FunctionOpInterface func) { // First try to see if there is a subgroup size chosen in the CodeGen pipeline // configuration. diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index e41b432fe7f9..fd698a1dd406 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -208,19 +208,11 @@ IREE::GPU::TargetAttr getGPUTargetAttr(Operation *op); std::optional getConfigWavesPerEu(DictionaryAttr targetAttr); IntegerAttr getConfigWavesPerEuAttr(DictionaryAttr targetAttr); -IREE::Codegen::DenormalFpMathAttr -getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig); -std::optional -getConfigDenormalFpMathF32(DictionaryAttr targetConfig); - /// Methods to add attributes to the `config` list. void addConfigGPUTarget(MLIRContext *context, IREE::GPU::TargetAttr, SmallVectorImpl &config); void addConfigWavesPerEu(MLIRContext *context, int64_t wavesPerEu, SmallVectorImpl &config); -void addConfigDenormalFpMathF32(MLIRContext *context, - IREE::Codegen::DenormalFpMath mode, - SmallVectorImpl &config); /// Returns the GPU subgroup size chosen for the current CodeGen pipeline if /// exists; otherwise returns the subgroup size from the GPU target description. diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 7c39950bc364..5f8aee9e6655 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1355,6 +1355,31 @@ Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, operands); } +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig) { + if (!targetConfig) { + return {}; + } + + return targetConfig.getAs( + IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName()); +} +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig) { + IREE::Codegen::DenormalFpMathAttr attr = + getConfigDenormalFpMathF32Attr(targetConfig); + if (!attr) { + return std::nullopt; + } + return attr.getValue(); +} +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config) { + config.emplace_back(IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName(), + IREE::Codegen::DenormalFpMathAttr::get(context, mode)); +} + //===---------------------------------------------------------------------===// // Replace Memref users (transitively) //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h index d97f0ba1d7aa..e425a58b1f1b 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" @@ -299,6 +300,21 @@ bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, IREE::TensorExt::DispatchTensorType tensorType, ValueRange dynamicDims); +/// Retrieves the DenormalFpMathAttr for F32 values from the given target +/// configuration. This attribute specifies how denormal floating-point values +/// are handled in F32 operations. +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig); +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig); + +/// Adds a denormal floating-point math configuration for F32 values to the +/// configuration list. This configures how denormal floating-point values are +/// handled in F32 operations. +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config); + //===----------------------------------------------------------------------===// // Utility functions for vector size inference for dynamic shapes //===----------------------------------------------------------------------===//