Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -33,7 +34,6 @@ struct FuseConstantAndGrantPattern
// Checks if the constant operation is used by a grant_once or grant_always
// operation.
for (auto user : constant_op->getUsers()) {
llvm::errs() << "Checking use: " << *user << "\n";
if (isa<neura::GrantOnceOp>(user) || isa<neura::GrantAlwaysOp>(user)) {
if (neura::GrantOnceOp grant_once_op =
dyn_cast<neura::GrantOnceOp>(user)) {
Expand Down Expand Up @@ -101,46 +101,72 @@ template <typename OpType>
struct FuseRhsConstantPattern : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;

// By default, we assume the operation is not commutative.
// If the operation is commutative, we can extend this pattern to support
// constant folding on the left-hand side operand as well.
virtual bool isCommutative() const { return false; }

virtual Operation *
createOpWithFusedRhsConstant(OpType op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(OpType op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const = 0;

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
if (op->hasAttr("rhs_const_value")) {
// Already fused with a constant on the right-hand side.
return failure();
}

Value lhs = op.getLhs();
Value rhs = op.getRhs();

if (isOriginConstantOp(lhs)) {
llvm::errs() << "LHS constant folding not supported yet.\n";
return failure();
}
bool lhs_is_const = isOriginConstantOp(lhs);
bool rhs_is_const = rhs && isOriginConstantOp(rhs);

if (!rhs || !isOriginConstantOp(rhs)) {
return failure();
if (rhs_is_const) {
auto constant_op = dyn_cast<neura::ConstantOp>(rhs.getDefiningOp());

Attribute rhs_const_value = getOriginConstantValue(rhs);
Operation *fused_op =
createOpWithFusedRhsConstant(op, lhs, rhs_const_value, rewriter);

rewriter.replaceOp(op, fused_op->getResults());
if (constant_op->use_empty()) {
rewriter.eraseOp(constant_op);
}
return success();
}

auto constant_op = dyn_cast<neura::ConstantOp>(rhs.getDefiningOp());
if (lhs_is_const && !rhs_is_const && isCommutative()) {
auto constant_op = dyn_cast<neura::ConstantOp>(lhs.getDefiningOp());

Attribute rhs_const_value = getOriginConstantValue(rhs);
Operation *fused_op =
createOpWithFusedRhsConstant(op, rhs_const_value, rewriter);
Attribute lhs_const_value = getOriginConstantValue(lhs);
Operation *fused_op =
createOpWithFusedRhsConstant(op, rhs, lhs_const_value, rewriter);

rewriter.replaceOp(op, fused_op->getResults());
if (constant_op->use_empty()) {
rewriter.eraseOp(constant_op);
rewriter.replaceOp(op, fused_op->getResults());
if (constant_op->use_empty()) {
rewriter.eraseOp(constant_op);
}
return success();
}
return success();

return failure();
}
};

struct FuseAddRhsConstantPattern : public FuseRhsConstantPattern<neura::AddOp> {
using FuseRhsConstantPattern<neura::AddOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::AddOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::AddOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::AddOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -151,10 +177,11 @@ struct FuseSubRhsConstantPattern : public FuseRhsConstantPattern<neura::SubOp> {
using FuseRhsConstantPattern<neura::SubOp>::FuseRhsConstantPattern;

Operation *
createOpWithFusedRhsConstant(neura::SubOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::SubOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::SubOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -164,11 +191,14 @@ struct FuseSubRhsConstantPattern : public FuseRhsConstantPattern<neura::SubOp> {
struct FuseMulRhsConstantPattern : public FuseRhsConstantPattern<neura::MulOp> {
using FuseRhsConstantPattern<neura::MulOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::MulOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::MulOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::MulOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -180,10 +210,11 @@ struct FuseICmpRhsConstantPattern
using FuseRhsConstantPattern<neura::ICmpOp>::FuseRhsConstantPattern;

Operation *
createOpWithFusedRhsConstant(neura::ICmpOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::ICmpOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::ICmpOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr, op.getCmpType());
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -194,11 +225,14 @@ struct FuseFAddRhsConstantPattern
: public FuseRhsConstantPattern<neura::FAddOp> {
using FuseRhsConstantPattern<neura::FAddOp>::FuseRhsConstantPattern;

bool isCommutative() const override { return true; }

Operation *
createOpWithFusedRhsConstant(neura::FAddOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::FAddOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::FAddOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -209,10 +243,11 @@ struct FuseDivRhsConstantPattern : public FuseRhsConstantPattern<neura::DivOp> {
using FuseRhsConstantPattern<neura::DivOp>::FuseRhsConstantPattern;

Operation *
createOpWithFusedRhsConstant(neura::DivOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::DivOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::DivOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand All @@ -223,10 +258,11 @@ struct FuseRemRhsConstantPattern : public FuseRhsConstantPattern<neura::RemOp> {
using FuseRhsConstantPattern<neura::RemOp>::FuseRhsConstantPattern;

Operation *
createOpWithFusedRhsConstant(neura::RemOp op, Attribute rhs_const_value,
createOpWithFusedRhsConstant(neura::RemOp op, Value non_const_operand,
Attribute rhs_const_value,
PatternRewriter &rewriter) const override {
auto fused_op = rewriter.create<neura::RemOp>(
op.getLoc(), op.getResult().getType(), op.getLhs(),
op.getLoc(), op.getResult().getType(), non_const_operand,
/*rhs=*/nullptr);
addConstantAttribute(fused_op, "rhs_const_value", rhs_const_value);
return fused_op;
Expand Down
46 changes: 46 additions & 0 deletions test/optimization/constant_folding/simple_loop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: mlir-neura-opt %s \
// RUN: --fold-constant \
// RUN: | FileCheck %s -check-prefix=FOLD

module {
func.func @_Z11simple_loopPiS_(%arg0: memref<?xi32>, %arg1: memref<?xi32>) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
%0 = "neura.constant"() <{value = "%arg0"}> : () -> memref<?xi32>
%1 = "neura.constant"() <{value = "%arg1"}> : () -> memref<?xi32>
%2 = "neura.constant"() <{value = 1 : i64}> : () -> i64
%3 = "neura.constant"() <{value = 128 : i64}> : () -> i64
%4 = "neura.constant"() <{value = 1 : i32}> : () -> i32
%5 = "neura.constant"() <{value = 2 : i32}> : () -> i32
%6 = "neura.constant"() <{value = 0 : i64}> : () -> i64
neura.br %6 : i64 to ^bb1
^bb1(%7: i64): // 2 preds: ^bb0, ^bb2
%8 = "neura.icmp"(%7, %3) <{cmpType = "slt"}> : (i64, i64) -> i1
neura.cond_br %8 : i1 then to ^bb2 else to ^bb3
^bb2: // pred: ^bb1
%9 = neura.load_indexed %0[%7 : i64] memref<?xi32> : i32
%10 = "neura.mul"(%5, %9) : (i32, i32) -> i32
%11 = "neura.add"(%4, %9) : (i32, i32) -> i32
neura.store_indexed %11 to %1[%7 : i64] memref<?xi32> : i32
%12 = "neura.add"(%7, %2) : (i64, i64) -> i64
neura.br %12 : i64 to ^bb1
^bb3: // pred: ^bb1
"neura.return"() : () -> ()
}
}

// FOLD: func.func @_Z11simple_loopPiS_(%arg0: memref<?xi32>, %arg1: memref<?xi32>) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
// FOLD-NEXT: %0 = "neura.constant"() <{value = "%arg0"}> : () -> memref<?xi32>
// FOLD-NEXT: %1 = "neura.constant"() <{value = "%arg1"}> : () -> memref<?xi32>
// FOLD-NEXT: %2 = "neura.constant"() <{value = 0 : i64}> : () -> i64
// FOLD-NEXT: neura.br %2 : i64 to ^bb1
// FOLD-NEXT: ^bb1(%3: i64): // 2 preds: ^bb0, ^bb2
// FOLD-NEXT: %4 = "neura.icmp"(%3) <{cmpType = "slt"}> {rhs_const_value = 128 : i64} : (i64) -> i1
// FOLD-NEXT: neura.cond_br %4 : i1 then to ^bb2 else to ^bb3
// FOLD-NEXT: ^bb2: // pred: ^bb1
// FOLD-NEXT: %5 = neura.load_indexed %0[%3 : i64] memref<?xi32> : i32
// FOLD-NEXT: %6 = "neura.mul"(%5) {rhs_const_value = 2 : i32} : (i32) -> i32
// FOLD-NEXT: %7 = "neura.add"(%5) {rhs_const_value = 1 : i32} : (i32) -> i32
// FOLD-NEXT: neura.store_indexed %7 to %1[%3 : i64] memref<?xi32> : i32
// FOLD-NEXT: %8 = "neura.add"(%3) {rhs_const_value = 1 : i64} : (i64) -> i64
// FOLD-NEXT: neura.br %8 : i64 to ^bb1
// FOLD-NEXT: ^bb3: // pred: ^bb1
// FOLD-NEXT: "neura.return"() : () -> ()
8 changes: 4 additions & 4 deletions test/visualize/test2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ func.func @test_print_op_graph(%a: f32, %b: f32) -> f32 {
// CHECK-GRAPH: label = "neura.return : ()\n"
// CHECK-GRAPH: digraph G
// CHECK-GRAPH: label = "func.func : ()\n\naccelerator: \"neura\"\nfunction_type: (f32, f32) -> f32\nsym_name: \"test_print_op_graph...";
// CHECK-GRAPH: label = "neura.constant : (!neura.data<f32, i1>)
// CHECK-GRAPH: label = "neura.fadd : (!neura.data<f32, i1>)\n"
// CHECK-GRAPH: label = "neura.constant : (!neura.data<f32, i1>)\n\nvalue: \"%arg0\"", shape = ellipse, style = filled];
// CHECK-GRAPH: label = "neura.fadd : (!neura.data<f32, i1>)\n\nrhs_const_value: \"%arg1\"", shape = ellipse, style = filled];
// CHECK-GRAPH: digraph G
// CHECK-GRAPH: label = "func.func : ()\n\naccelerator: \"neura\"\ndataflow_mode: \"predicate\"\nfunction_type: (f32, f32) -> f32\nsym_name: \"test_print_op_graph...";
// CHECK-GRAPH: label = "neura.constant : (!neura.data<f32, i1>)
// CHECK-GRAPH: label = "neura.constant : (!neura.data<f32, i1>)\n\nvalue: \"%arg0\"", shape = ellipse, style = filled];
// CHECK-GRAPH: label = "neura.data_mov : (!neura.data<f32, i1>)
// CHECK-GRAPH: label = "neura.fadd : (!neura.data<f32, i1>)\n"
// CHECK-GRAPH: label = "neura.fadd : (!neura.data<f32, i1>)\n\nrhs_const_value: \"%arg1\"", shape = ellipse, style = filled];