Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 83 additions & 15 deletions lib/NeuraDialect/Transforms/FusePatternsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,49 @@ using namespace mlir;
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {
struct FuseConstantAndGrantPattern
: public OpRewritePattern<neura::ConstantOp> {
using OpRewritePattern<neura::ConstantOp>::OpRewritePattern;

LogicalResult matchAndRewrite(neura::ConstantOp constant_op,
PatternRewriter &rewriter) const override {
bool made_change = false;

// Checks if the constant operation is used by a grant_once or grant_always
// operation.
for (auto user : constant_op->getUsers()) {
llvm::errs() << "Checking use: " << *user << "\n";
if (isa<neura::GrantOnceOp>(user) || isa<neura::GrantAlwaysOp>(user)) {
if (neura::GrantOnceOp grant_once_op =
dyn_cast<neura::GrantOnceOp>(user)) {
auto new_grant_once_op = rewriter.create<neura::GrantOnceOp>(
grant_once_op.getLoc(), grant_once_op.getResult().getType(),
/*value=*/nullptr, constant_op->getAttr("value"));
// Replaces the original constant operation with the new one.
rewriter.replaceOp(grant_once_op, new_grant_once_op);
made_change = true;
} else if (neura::GrantAlwaysOp grant_always_op =
dyn_cast<neura::GrantAlwaysOp>(user)) {
auto new_grant_always_op = rewriter.create<neura::GrantAlwaysOp>(
grant_always_op.getLoc(), grant_always_op.getResult().getType(),
/*value=*/nullptr, constant_op->getAttr("value"));
// Replaces the original constant operation with the new one.
rewriter.replaceOp(grant_always_op, new_grant_always_op);
made_change = true;
}
}
}

if (constant_op->use_empty()) {
// If the constant operation has no users, it can be removed.
rewriter.eraseOp(constant_op);
made_change = true;
}

return success(made_change);
}
};

struct FuseFAddFAddPattern : public OpRewritePattern<neura::FAddOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -105,24 +148,49 @@ struct FusePatternsPass
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
FrozenRewritePatternSet frozen(std::move(patterns));

ModuleOp module_op = getOperation();

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region &region : op->getRegions()) {
if (failed(applyPatternsGreedily(region, frozen))) {
signalPassFailure();
}
}
// Phase 1: Apply FuseConstantAndGrantPattern.
{
RewritePatternSet patterns(&getContext());
patterns.add<FuseConstantAndGrantPattern>(&getContext());
FrozenRewritePatternSet frozen(std::move(patterns));

if (failed(applyPatternsGreedily(module_op, frozen))) {
signalPassFailure();
}
});
}

// Phase 2: Apply other patterns.
{
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext());
patterns.add<FuseFMulFAddPattern>(&getContext());
FrozenRewritePatternSet frozen(std::move(patterns));

if (failed(applyPatternsGreedily(module_op, frozen))) {
signalPassFailure();
}
}
// {
// RewritePatternSet patterns(&getContext());
// patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
// patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
// patterns.add<FuseConstantAndGrantPattern>(&getContext(), 1);
// FrozenRewritePatternSet frozen(std::move(patterns));

// // Applies to every region inside the module (regardless of func type,
// // e.g., mlir func or llvm func).
// module_op.walk([&](Operation *op) {
// if (!op->getRegions().empty()) {
// for (Region &region : op->getRegions()) {
// if (failed(applyPatternsGreedily(region, frozen))) {
// signalPassFailure();
// }
// }
// }
// });
// }
}
};

Expand Down
Loading