diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 0827d4e28eda..eb5c470efa62 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -84,9 +84,10 @@ struct ROCMOptions { std::string encodingLayoutResolver = GPU::kNoEncodingLayoutResolverName; bool slpVectorization = true; bool globalISel = false; - bool specializeDispatches = false; bool enableTensorUKernels = false; + IREE::Codegen::DenormalFpMath denormalFpMathF32 = + IREE::Codegen::DenormalFpMath::None; void bindOptions(OptionsBinder &binder) { using namespace llvm; @@ -161,6 +162,15 @@ struct ROCMOptions { binder.opt("iree-hip-enable-tensor-ukernels", enableTensorUKernels, cl::cat(category), cl::desc("Enable MLIR-based ukernels.")); + binder.opt( + "iree-hip-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category), + cl::desc("Denormal floating point math mode for f32"), + cl::values( + clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign, + "preserve-sign", + "Convert denormals to zero while preserving sign"), + clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero, + "positive-zero", "Convert denormals to positive zero"))); } LogicalResult verify(mlir::Builder &builder) const { @@ -337,6 +347,10 @@ class ROCMTargetBackend final : public TargetBackend { if (options.wavesPerEu > 0) { addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems); } + if (options.denormalFpMathF32 != IREE::Codegen::DenormalFpMath::None) { + addConfigDenormalFpMathF32(b.getContext(), options.denormalFpMathF32, + configItems); + } if (options.enableTensorUKernels) { addConfig(kUKernelProviderName, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td index 17e44e5afe3e..d70bd2d805ef 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td @@ -610,4 +610,38 @@ def IREECodegen_XORShuffleAttr : let genVerifyDecl = 1; } +//===---------------------------------------------------------------------===// +// iree_codegen.denormal_mode +//===---------------------------------------------------------------------===// + +// Do not add any denormal mode. +def DenormalMode_None : I32EnumCase<"None", 0, "none">; +// Convert denormals to zero while preserving sign. +def DenormalMode_PreserveSign : I32EnumCase<"PreserveSign", 1, "preserve-sign">; +// Convert denormals to positive zero. +def DenormalMode_PositiveZero : I32EnumCase<"PositiveZero", 2, "positive-zero">; + +// Denormal fp mode. +def DenormalFpMathEnum : I32Enum< + "DenormalFpMath", + "Denormal mode for fp math", [ + DenormalMode_None, + DenormalMode_PreserveSign, + DenormalMode_PositiveZero + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::Codegen"; +} + +def DenormalFpMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; + + let extraClassDeclaration = [{ + /// Returns the key name for fp32 DenormalFpMathAttr in a DictionaryAttr. + static StringRef getFP32DictKeyName() { + return "iree_codegen.denormal_fp_math_f32"; + } + }]; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp index 35ea3db6fc10..c135f740bce9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp @@ -74,6 +74,12 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp, getConfigWavesPerEuAttr(targetAttr.getConfiguration())) { rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, attr); } + if (IREE::Codegen::DenormalFpMathAttr attr = + getConfigDenormalFpMathF32Attr(targetAttr.getConfiguration()); + attr && attr.getValue() != IREE::Codegen::DenormalFpMath::None) { + funcOp.setDenormalFpMathF32( + IREE::Codegen::stringifyDenormalFpMath(attr.getValue())); + } // Kernel argument preloading is only supported on gfx942 and newer targets // from the CDNA family. This is enabled using the `inreg` function argument diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir index 87909ba2c86c..cfa6e4907cda 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir @@ -114,3 +114,75 @@ builtin.module { // CHECK-LABEL: llvm.func @test_no_kern_arg // CHECK-SAME: (%{{.+}}: i32) + +// ----- + +// Check that denormal is set. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">, + iree_codegen.target_info = #iree_gpu.target>, + ukernels = "none" + }> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_rocdl_attrs { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} + builtin.module { + // CHECK-LABEL: llvm.func @test_rocdl_attrs + // CHECK: denormal_fp_math_f32 = "preserve-sign" + llvm.func @test_rocdl_attrs(%arg0: i32) { + llvm.return + } + } + } + } +} + +// ----- + +// Check that denormal is set. + +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero">, + iree_codegen.target_info = #iree_gpu.target>, + ukernels = "none" + }> +#pipeline_layout = #hal.pipeline.layout], + flags = Indirect> +builtin.module { + hal.executable public @test_rocdl_attrs { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) { + hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} + builtin.module { + // CHECK-LABEL: llvm.func @test_rocdl_attrs + // CHECK: denormal_fp_math_f32 = "positive-zero" + llvm.func @test_rocdl_attrs(%arg0: i32) { + llvm.return + } + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index 5d7df379fc5f..fd698a1dd406 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 7c39950bc364..5f8aee9e6655 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1355,6 +1355,31 @@ Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op, operands); } +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig) { + if (!targetConfig) { + return {}; + } + + return targetConfig.getAs( + IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName()); +} +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig) { + IREE::Codegen::DenormalFpMathAttr attr = + getConfigDenormalFpMathF32Attr(targetConfig); + if (!attr) { + return std::nullopt; + } + return attr.getValue(); +} +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config) { + config.emplace_back(IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName(), + IREE::Codegen::DenormalFpMathAttr::get(context, mode)); +} + //===---------------------------------------------------------------------===// // Replace Memref users (transitively) //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h index d97f0ba1d7aa..e425a58b1f1b 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" @@ -299,6 +300,21 @@ bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, IREE::TensorExt::DispatchTensorType tensorType, ValueRange dynamicDims); +/// Retrieves the DenormalFpMathAttr for F32 values from the given target +/// configuration. This attribute specifies how denormal floating-point values +/// are handled in F32 operations. +IREE::Codegen::DenormalFpMathAttr +getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig); +std::optional +getConfigDenormalFpMathF32(DictionaryAttr targetConfig); + +/// Adds a denormal floating-point math configuration for F32 values to the +/// configuration list. This configures how denormal floating-point values are +/// handled in F32 operations. +void addConfigDenormalFpMathF32(MLIRContext *context, + IREE::Codegen::DenormalFpMath mode, + SmallVectorImpl &config); + //===----------------------------------------------------------------------===// // Utility functions for vector size inference for dynamic shapes //===----------------------------------------------------------------------===//