diff --git a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp index 1ef6ab9d..03dc818e 100644 --- a/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp +++ b/lib/NeuraDialect/Transforms/Optimizations/HwAgnosticOpt/FoldConstantPass.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -33,7 +34,6 @@ struct FuseConstantAndGrantPattern // Checks if the constant operation is used by a grant_once or grant_always // operation. for (auto user : constant_op->getUsers()) { - llvm::errs() << "Checking use: " << *user << "\n"; if (isa(user) || isa(user)) { if (neura::GrantOnceOp grant_once_op = dyn_cast(user)) { @@ -101,46 +101,72 @@ template struct FuseRhsConstantPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + // By default, we assume the operation is not commutative. + // If the operation is commutative, we can extend this pattern to support + // constant folding on the left-hand side operand as well. + virtual bool isCommutative() const { return false; } + virtual Operation * - createOpWithFusedRhsConstant(OpType op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(OpType op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const = 0; LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { + if (op->hasAttr("rhs_const_value")) { + // Already fused with a constant on the right-hand side. + return failure(); + } + Value lhs = op.getLhs(); Value rhs = op.getRhs(); - if (isOriginConstantOp(lhs)) { - llvm::errs() << "LHS constant folding not supported yet.\n"; - return failure(); - } + bool lhs_is_const = isOriginConstantOp(lhs); + bool rhs_is_const = rhs && isOriginConstantOp(rhs); - if (!rhs || !isOriginConstantOp(rhs)) { - return failure(); + if (rhs_is_const) { + auto constant_op = dyn_cast(rhs.getDefiningOp()); + + Attribute rhs_const_value = getOriginConstantValue(rhs); + Operation *fused_op = + createOpWithFusedRhsConstant(op, lhs, rhs_const_value, rewriter); + + rewriter.replaceOp(op, fused_op->getResults()); + if (constant_op->use_empty()) { + rewriter.eraseOp(constant_op); + } + return success(); } - auto constant_op = dyn_cast(rhs.getDefiningOp()); + if (lhs_is_const && !rhs_is_const && isCommutative()) { + auto constant_op = dyn_cast(lhs.getDefiningOp()); - Attribute rhs_const_value = getOriginConstantValue(rhs); - Operation *fused_op = - createOpWithFusedRhsConstant(op, rhs_const_value, rewriter); + Attribute lhs_const_value = getOriginConstantValue(lhs); + Operation *fused_op = + createOpWithFusedRhsConstant(op, rhs, lhs_const_value, rewriter); - rewriter.replaceOp(op, fused_op->getResults()); - if (constant_op->use_empty()) { - rewriter.eraseOp(constant_op); + rewriter.replaceOp(op, fused_op->getResults()); + if (constant_op->use_empty()) { + rewriter.eraseOp(constant_op); + } + return success(); } - return success(); + + return failure(); } }; struct FuseAddRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; + bool isCommutative() const override { return true; } + Operation * - createOpWithFusedRhsConstant(neura::AddOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::AddOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -151,10 +177,11 @@ struct FuseSubRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; Operation * - createOpWithFusedRhsConstant(neura::SubOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::SubOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -164,11 +191,14 @@ struct FuseSubRhsConstantPattern : public FuseRhsConstantPattern { struct FuseMulRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; + bool isCommutative() const override { return true; } + Operation * - createOpWithFusedRhsConstant(neura::MulOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::MulOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -180,10 +210,11 @@ struct FuseICmpRhsConstantPattern using FuseRhsConstantPattern::FuseRhsConstantPattern; Operation * - createOpWithFusedRhsConstant(neura::ICmpOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::ICmpOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr, op.getCmpType()); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -194,11 +225,14 @@ struct FuseFAddRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; + bool isCommutative() const override { return true; } + Operation * - createOpWithFusedRhsConstant(neura::FAddOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::FAddOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -209,10 +243,11 @@ struct FuseDivRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; Operation * - createOpWithFusedRhsConstant(neura::DivOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::DivOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; @@ -223,10 +258,11 @@ struct FuseRemRhsConstantPattern : public FuseRhsConstantPattern { using FuseRhsConstantPattern::FuseRhsConstantPattern; Operation * - createOpWithFusedRhsConstant(neura::RemOp op, Attribute rhs_const_value, + createOpWithFusedRhsConstant(neura::RemOp op, Value non_const_operand, + Attribute rhs_const_value, PatternRewriter &rewriter) const override { auto fused_op = rewriter.create( - op.getLoc(), op.getResult().getType(), op.getLhs(), + op.getLoc(), op.getResult().getType(), non_const_operand, /*rhs=*/nullptr); addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value); return fused_op; diff --git a/test/optimization/constant_folding/simple_loop.mlir b/test/optimization/constant_folding/simple_loop.mlir new file mode 100644 index 00000000..d8027b28 --- /dev/null +++ b/test/optimization/constant_folding/simple_loop.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-neura-opt %s \ +// RUN: --fold-constant \ +// RUN: | FileCheck %s -check-prefix=FOLD + +module { + func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage} { + %0 = "neura.constant"() <{value = "%arg0"}> : () -> memref + %1 = "neura.constant"() <{value = "%arg1"}> : () -> memref + %2 = "neura.constant"() <{value = 1 : i64}> : () -> i64 + %3 = "neura.constant"() <{value = 128 : i64}> : () -> i64 + %4 = "neura.constant"() <{value = 1 : i32}> : () -> i32 + %5 = "neura.constant"() <{value = 2 : i32}> : () -> i32 + %6 = "neura.constant"() <{value = 0 : i64}> : () -> i64 + neura.br %6 : i64 to ^bb1 + ^bb1(%7: i64): // 2 preds: ^bb0, ^bb2 + %8 = "neura.icmp"(%7, %3) <{cmpType = "slt"}> : (i64, i64) -> i1 + neura.cond_br %8 : i1 then to ^bb2 else to ^bb3 + ^bb2: // pred: ^bb1 + %9 = neura.load_indexed %0[%7 : i64] memref : i32 + %10 = "neura.mul"(%5, %9) : (i32, i32) -> i32 + %11 = "neura.add"(%4, %9) : (i32, i32) -> i32 + neura.store_indexed %11 to %1[%7 : i64] memref : i32 + %12 = "neura.add"(%7, %2) : (i64, i64) -> i64 + neura.br %12 : i64 to ^bb1 + ^bb3: // pred: ^bb1 + "neura.return"() : () -> () + } +} + +// FOLD: func.func @_Z11simple_loopPiS_(%arg0: memref, %arg1: memref) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage} { +// FOLD-NEXT: %0 = "neura.constant"() <{value = "%arg0"}> : () -> memref +// FOLD-NEXT: %1 = "neura.constant"() <{value = "%arg1"}> : () -> memref +// FOLD-NEXT: %2 = "neura.constant"() <{value = 0 : i64}> : () -> i64 +// FOLD-NEXT: neura.br %2 : i64 to ^bb1 +// FOLD-NEXT: ^bb1(%3: i64): // 2 preds: ^bb0, ^bb2 +// FOLD-NEXT: %4 = "neura.icmp"(%3) <{cmpType = "slt"}> {rhs_const_value = 128 : i64} : (i64) -> i1 +// FOLD-NEXT: neura.cond_br %4 : i1 then to ^bb2 else to ^bb3 +// FOLD-NEXT: ^bb2: // pred: ^bb1 +// FOLD-NEXT: %5 = neura.load_indexed %0[%3 : i64] memref : i32 +// FOLD-NEXT: %6 = "neura.mul"(%5) {rhs_const_value = 2 : i32} : (i32) -> i32 +// FOLD-NEXT: %7 = "neura.add"(%5) {rhs_const_value = 1 : i32} : (i32) -> i32 +// FOLD-NEXT: neura.store_indexed %7 to %1[%3 : i64] memref : i32 +// FOLD-NEXT: %8 = "neura.add"(%3) {rhs_const_value = 1 : i64} : (i64) -> i64 +// FOLD-NEXT: neura.br %8 : i64 to ^bb1 +// FOLD-NEXT: ^bb3: // pred: ^bb1 +// FOLD-NEXT: "neura.return"() : () -> () diff --git a/test/visualize/test2.mlir b/test/visualize/test2.mlir index 1c6b3db6..7a152997 100644 --- a/test/visualize/test2.mlir +++ b/test/visualize/test2.mlir @@ -16,10 +16,10 @@ func.func @test_print_op_graph(%a: f32, %b: f32) -> f32 { // CHECK-GRAPH: label = "neura.return : ()\n" // CHECK-GRAPH: digraph G // CHECK-GRAPH: label = "func.func : ()\n\naccelerator: \"neura\"\nfunction_type: (f32, f32) -> f32\nsym_name: \"test_print_op_graph..."; -// CHECK-GRAPH: label = "neura.constant : (!neura.data) -// CHECK-GRAPH: label = "neura.fadd : (!neura.data)\n" +// CHECK-GRAPH: label = "neura.constant : (!neura.data)\n\nvalue: \"%arg0\"", shape = ellipse, style = filled]; +// CHECK-GRAPH: label = "neura.fadd : (!neura.data)\n\nrhs_const_value: \"%arg1\"", shape = ellipse, style = filled]; // CHECK-GRAPH: digraph G // CHECK-GRAPH: label = "func.func : ()\n\naccelerator: \"neura\"\ndataflow_mode: \"predicate\"\nfunction_type: (f32, f32) -> f32\nsym_name: \"test_print_op_graph..."; -// CHECK-GRAPH: label = "neura.constant : (!neura.data) +// CHECK-GRAPH: label = "neura.constant : (!neura.data)\n\nvalue: \"%arg0\"", shape = ellipse, style = filled]; // CHECK-GRAPH: label = "neura.data_mov : (!neura.data) -// CHECK-GRAPH: label = "neura.fadd : (!neura.data)\n" +// CHECK-GRAPH: label = "neura.fadd : (!neura.data)\n\nrhs_const_value: \"%arg1\"", shape = ellipse, style = filled];