Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg committed Jan 21, 2025
1 parent b01bb25 commit 202c92f
Show file tree
Hide file tree
Showing 41 changed files with 3,692 additions and 1,011 deletions.
934 changes: 934 additions & 0 deletions BUILD

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ inline Type u1Ty(MLIRContext *ctx) {
}

// Float types
inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
inline Type f16Ty(MLIRContext *ctx) { return Float16Type::get(ctx); }
inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
lhsDivisibility = 1;
}
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
}

int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining arguments that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
Expand All @@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining uses of values that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
Expand Down
121 changes: 51 additions & 70 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace mlir::triton::gpu {

namespace {

template <typename T> bool hasEncoding(Value value) {
template <typename T>
bool hasEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<TensorOrMemDesc>(type)) {
auto encoding = tensorType.getEncoding();
Expand All @@ -25,7 +26,7 @@ bool hasDotOperandEncoding(Value value) {
return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
}

} // namespace
} // namespace

//===----------------------------------------------------------------------===//
// Canonicalizer
Expand All @@ -36,16 +37,13 @@ struct CanonicalizeConvertFromReshape
: public mlir::OpRewritePattern<triton::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::ReshapeOp op,
PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
triton::ReshapeOp op, PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (!convert) return failure();
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout()) return failure();

rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder());
Expand All @@ -58,12 +56,10 @@ struct CanonicalizeConvertFromHistogram
: public mlir::OpRewritePattern<triton::HistogramOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::HistogramOp op,
PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
triton::HistogramOp op, PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (!convert) return failure();
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand All @@ -79,15 +75,13 @@ struct CanonicalizeConvertFromHistogram
struct CanonicalizeConvertFromGatherSource : public OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
GatherOp op, PatternRewriter &rewriter) const override {
// Don't do this if the compiler picked an optimized layout.
if (op.getEfficientLayout())
return failure();
if (op.getEfficientLayout()) return failure();

auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (!convert) return failure();

rewriter.replaceOpWithNewOp<GatherOp>(op, convert.getSrc(), op.getIndices(),
op.getAxis());
Expand All @@ -100,13 +94,15 @@ struct CanonicalizeConvertFromAlloc
: public mlir::OpRewritePattern<triton::gpu::LocalAllocOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalAllocOp op,
PatternRewriter &rewriter) const override {
if (!op.getSrc())
return failure();
mlir::LogicalResult matchAndRewrite(
triton::gpu::LocalAllocOp op, PatternRewriter &rewriter) const override {
if (!op.getSrc()) return failure();
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
if (!convert) return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
Expand All @@ -119,12 +115,10 @@ struct CanonicalizeConvertFromLocalStore
: public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op,
PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
triton::gpu::LocalStoreOp op, PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (!convert) return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(op, convert.getSrc(),
op.getDst());
return mlir::success();
Expand All @@ -135,19 +129,16 @@ struct CanonicalizeConvertFromSplit
: public mlir::OpRewritePattern<triton::SplitOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::SplitOp op,
PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
triton::SplitOp op, PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (!convert) return failure();
auto srcEncoding = convert.getSrc().getType().getEncoding();
// Multiple source layout can give the same output layout, if the source
// layout of the convert gives the same destination layout we can skip the
// convert.
auto dstEncoding = inferDstEncoding(op, srcEncoding);
if (dstEncoding != op.getOutLHS().getType().getEncoding())
return failure();
if (dstEncoding != op.getOutLHS().getType().getEncoding()) return failure();
rewriter.replaceOpWithNewOp<triton::SplitOp>(op, convert.getSrc());
return mlir::success();
}
Expand All @@ -157,9 +148,8 @@ struct CanonicalizeConvertFromConvert
: public OpRewritePattern<ConvertLayoutOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
mlir::LogicalResult matchAndRewrite(
ConvertLayoutOp op, PatternRewriter &rewriter) const override {
// Convert to the same layout is redundant.
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
Expand All @@ -170,22 +160,21 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();

// for hopper MMAv3
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
return dot->hasTrait<OpTrait::DotLike>();
})) {
return failure();
}

Operation *arg = op.getSrc().getDefiningOp();
if (!arg)
return failure();
if (!arg) return failure();

// cvt(reshape) -> reshape
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
Expand Down Expand Up @@ -233,8 +222,7 @@ struct CanonicalizeConvertFromConvert

// cvt(cat) -> cat
if (auto cat = dyn_cast<CatOp>(arg)) {
if (isExpensiveCat(cat, op.getType().getEncoding()))
return failure();
if (isExpensiveCat(cat, op.getType().getEncoding())) return failure();

rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
Expand Down Expand Up @@ -291,15 +279,14 @@ LogicalResult UpcastMXFPOp::verify() {

auto xTy = getSrc().getType();
auto scaleTy = getScale().getType();
Builder b(getContext());
if (xTy.getElementType() != b.getBF16Type() &&
xTy.getElementType() != b.getF16Type() &&
xTy.getElementType() != b.getI8Type()) {
return emitOpError(
"element type of the first operand must be bf16/fp16 or i8");

if (xTy.getElementType() != BFloat16Type::get(getContext()) &&
xTy.getElementType() != Float16Type::get(getContext()) &&
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the first operand must be bf16 or i8");
}

if (scaleTy.getElementType() != b.getI8Type()) {
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the second operand must be uint8");
}

Expand Down Expand Up @@ -373,14 +360,12 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

RankedTensorType
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
ScaleDotElemType inputElemType,
Type outputElemType) {
RankedTensorType UpcastMXFPOp::deduceOutputType(
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType,
Type outputElemType) {
MLIRContext *ctx = inputTensor.getContext();
auto xTy = inputTensor.getType();
if (inputElemType != ScaleDotElemType::E2M1)
return xTy;
if (inputElemType != ScaleDotElemType::E2M1) return xTy;

auto xShape = xTy.getShape();
auto newShape = llvm::to_vector(xShape);
Expand Down Expand Up @@ -466,17 +451,13 @@ void LocalAllocOp::getEffects(
}

OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) {
if (getType().getMutableMemory())
return {};
if (getType().getMutableMemory()) return {};
auto src = getSrc();
if (!src)
return {};
if (!src) return {};
auto localLoadOp = src.getDefiningOp<LocalLoadOp>();
if (!localLoadOp)
return {};
if (!localLoadOp) return {};
auto loadSrc = localLoadOp.getSrc();
if (loadSrc.getType() != getType())
return {};
if (loadSrc.getType() != getType()) return {};
return loadSrc;
}

Expand Down
Loading

0 comments on commit 202c92f

Please sign in to comment.