Skip to content

Commit

Permalink
[mlir][GPUToNVVM] Add benefit to populate functions (#128484)
Browse files Browse the repository at this point in the history
Certain GPU->NVVM patterns compete with Arith->LLVM patterns. (The ones
that lower to libdevice.) Add an optional `benefit` parameter to all
`populate` functions so that users can give preference to GPU->NVVM
patterns.
  • Loading branch information
matthias-springer authored Feb 24, 2025
1 parent 5bddadf commit 4defac9
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 122 deletions.
16 changes: 13 additions & 3 deletions mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

namespace mlir {
Expand All @@ -35,18 +36,27 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
/// GPU dialect to NVVM.
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);

/// Populate patterns that lower certain arith and math dialect ops to
/// libdevice calls.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm
/// op that is not available on every GPU.
void populateGpuSubgroupReduceOpLoweringPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
PatternBenefit benefit = 1);
} // namespace mlir

#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def ApplyGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
Collects patterns that convert GPU dialect ops to NVVM dialect ops. These
patterns require an "LLVMTypeConverter".
}];
let arguments = (ins DefaultValuedAttr<I16Attr, "1">:$benefit);
let assemblyFormat = "attr-dict";
}

Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ struct GPUDynamicSharedMemoryOpLowering
using ConvertOpToLLVMPattern<
gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
unsigned alignmentBit = 0)
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
unsigned alignmentBit = 0,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter, benefit),
alignmentBit(alignmentBit) {}

LogicalResult
Expand Down Expand Up @@ -81,8 +82,9 @@ struct GPUFuncOpLoweringOptions {

struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(const LLVMTypeConverter &converter,
const GPUFuncOpLoweringOptions &options)
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
const GPUFuncOpLoweringOptions &options,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter, benefit),
allocaAddrSpace(options.allocaAddrSpace),
workgroupAddrSpace(options.workgroupAddrSpace),
kernelAttributeName(options.kernelAttributeName),
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ struct OpLowering : public ConvertOpToLLVMPattern<Op> {
IntrType intrType;

public:
explicit OpLowering(const LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<Op>(typeConverter),
explicit OpLowering(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
indexKind(IndexKind::Other), intrType(IntrType::None) {}

explicit OpLowering(const LLVMTypeConverter &typeConverter,
IndexKind indexKind, IntrType intrType)
: ConvertOpToLLVMPattern<Op>(typeConverter),
IndexKind indexKind, IntrType intrType,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<Op>(typeConverter, benefit),
indexBitwidth(typeConverter.getIndexTypeBitwidth()),
indexKind(indexKind), intrType(intrType) {}

Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
StringRef f32Func, StringRef f64Func,
StringRef f32ApproxFunc, StringRef f16Func,
StringRef i32Func = "")
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
StringRef i32Func = "",
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
i32Func(i32Func) {}

Expand Down
Loading

0 comments on commit 4defac9

Please sign in to comment.