Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ struct ROCMOptions {
std::string encodingLayoutResolver = GPU::kNoEncodingLayoutResolverName;
bool slpVectorization = true;
bool globalISel = false;

bool specializeDispatches = false;
bool enableTensorUKernels = false;
IREE::Codegen::DenormalFpMath denormalFpMathF32 =
IREE::Codegen::DenormalFpMath::None;

void bindOptions(OptionsBinder &binder) {
using namespace llvm;
Expand Down Expand Up @@ -161,6 +162,15 @@ struct ROCMOptions {
binder.opt<bool>("iree-hip-enable-tensor-ukernels", enableTensorUKernels,
cl::cat(category),
cl::desc("Enable MLIR-based ukernels."));
binder.opt<IREE::Codegen::DenormalFpMath>(
"iree-hip-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category),
cl::desc("Denormal floating point math mode for f32"),
cl::values(
clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign,
"preserve-sign",
"Convert denormals to zero while preserving sign"),
clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero,
"positive-zero", "Convert denormals to positive zero")));
}

LogicalResult verify(mlir::Builder &builder) const {
Expand Down Expand Up @@ -337,6 +347,10 @@ class ROCMTargetBackend final : public TargetBackend {
if (options.wavesPerEu > 0) {
addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems);
}
if (options.denormalFpMathF32 != IREE::Codegen::DenormalFpMath::None) {
addConfigDenormalFpMathF32(b.getContext(), options.denormalFpMathF32,
configItems);
}

if (options.enableTensorUKernels) {
addConfig(kUKernelProviderName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,38 @@ def IREECodegen_XORShuffleAttr :
let genVerifyDecl = 1;
}

//===---------------------------------------------------------------------===//
// iree_codegen.denormal_mode
//===---------------------------------------------------------------------===//

// Do not add any denormal mode.
def DenormalMode_None : I32EnumCase<"None", 0, "none">;
// Convert denormals to zero while preserving sign.
def DenormalMode_PreserveSign : I32EnumCase<"PreserveSign", 1, "preserve-sign">;
// Convert denormals to positive zero.
def DenormalMode_PositiveZero : I32EnumCase<"PositiveZero", 2, "positive-zero">;

// Denormal fp mode.
def DenormalFpMathEnum : I32Enum<
"DenormalFpMath",
"Denormal mode for fp math", [
DenormalMode_None,
DenormalMode_PreserveSign,
DenormalMode_PositiveZero
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
}

def DenormalFpMathAttr :
EnumAttr<IREECodegen_Dialect, DenormalFpMathEnum, "denormal_fp_math"> {
let assemblyFormat = "`<` $value `>`";

let extraClassDeclaration = [{
/// Returns the key name for fp32 DenormalFpMathAttr in a DictionaryAttr.
static StringRef getFP32DictKeyName() {
return "iree_codegen.denormal_fp_math_f32";
}
}];
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp,
getConfigWavesPerEuAttr(targetAttr.getConfiguration())) {
rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, attr);
}
if (IREE::Codegen::DenormalFpMathAttr attr =
getConfigDenormalFpMathF32Attr(targetAttr.getConfiguration());
attr && attr.getValue() != IREE::Codegen::DenormalFpMath::None) {
funcOp.setDenormalFpMathF32(
IREE::Codegen::stringifyDenormalFpMath(attr.getValue()));
}

// Kernel argument preloading is only supported on gfx942 and newer targets
// from the CDNA family. This is enabled using the `inreg` function argument
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,75 @@ builtin.module {

// CHECK-LABEL: llvm.func @test_no_kern_arg
// CHECK-SAME: (%{{.+}}: i32)

// -----

// Check that denormal is set.

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">,
iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
wgp = <compute = int32, storage = b32,
subgroup = none,
subgroup_size_choices = [64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
ukernels = "none"
}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>],
flags = Indirect>
builtin.module {
hal.executable public @test_rocdl_attrs {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
} attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]}
builtin.module {
// CHECK-LABEL: llvm.func @test_rocdl_attrs
// CHECK: denormal_fp_math_f32 = "preserve-sign"
llvm.func @test_rocdl_attrs(%arg0: i32) {
llvm.return
}
}
}
}
}

// -----

// Check that denormal is set.

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"positive-zero">,
iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
wgp = <compute = int32, storage = b32,
subgroup = none,
subgroup_size_choices = [64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
ukernels = "none"
}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>],
flags = Indirect>
builtin.module {
hal.executable public @test_rocdl_attrs {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @test_rocdl_attrs ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
} attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]}
builtin.module {
// CHECK-LABEL: llvm.func @test_rocdl_attrs
// CHECK: denormal_fp_math_f32 = "positive-zero"
llvm.func @test_rocdl_attrs(%arg0: i32) {
llvm.return
}
}
}
}
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
#define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,31 @@ Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
operands);
}

IREE::Codegen::DenormalFpMathAttr
getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig) {
if (!targetConfig) {
return {};
}

return targetConfig.getAs<IREE::Codegen::DenormalFpMathAttr>(
IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName());
}
std::optional<IREE::Codegen::DenormalFpMath>
getConfigDenormalFpMathF32(DictionaryAttr targetConfig) {
IREE::Codegen::DenormalFpMathAttr attr =
getConfigDenormalFpMathF32Attr(targetConfig);
if (!attr) {
return std::nullopt;
}
return attr.getValue();
}
void addConfigDenormalFpMathF32(MLIRContext *context,
IREE::Codegen::DenormalFpMath mode,
SmallVectorImpl<NamedAttribute> &config) {
config.emplace_back(IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName(),
IREE::Codegen::DenormalFpMathAttr::get(context, mode));
}

//===---------------------------------------------------------------------===//
// Replace Memref users (transitively)
//===---------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_UTILS_UTILS_H_
#define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
Expand Down Expand Up @@ -299,6 +300,21 @@ bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp,
IREE::TensorExt::DispatchTensorType tensorType,
ValueRange dynamicDims);

/// Retrieves the DenormalFpMathAttr for F32 values from the given target
/// configuration. This attribute specifies how denormal floating-point values
/// are handled in F32 operations.
IREE::Codegen::DenormalFpMathAttr
getConfigDenormalFpMathF32Attr(DictionaryAttr targetConfig);
std::optional<IREE::Codegen::DenormalFpMath>
getConfigDenormalFpMathF32(DictionaryAttr targetConfig);

/// Adds a denormal floating-point math configuration for F32 values to the
/// configuration list. This configures how denormal floating-point values are
/// handled in F32 operations.
void addConfigDenormalFpMathF32(MLIRContext *context,
IREE::Codegen::DenormalFpMath mode,
SmallVectorImpl<NamedAttribute> &config);

//===----------------------------------------------------------------------===//
// Utility functions for vector size inference for dynamic shapes
//===----------------------------------------------------------------------===//
Expand Down
Loading