Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/NeuraDialect/NeuraPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ std::unique_ptr<mlir::Pass> createGenerateCodePass();
std::unique_ptr<mlir::Pass> createFuseControlFlowPass();
std::unique_ptr<mlir::Pass> createCanonicalizeLiveInPass();
std::unique_ptr<mlir::Pass> createCanonicalizeCastPass();
std::unique_ptr<mlir::Pass> createFoldConstantsPass();

#define GEN_PASS_REGISTRATION
#include "NeuraDialect/NeuraPasses.h.inc"
Expand Down
11 changes: 11 additions & 0 deletions include/NeuraDialect/NeuraPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,15 @@ def CanonicalizeCast : Pass<"canonicalize-cast", "ModuleOp"> {
let constructor = "neura::createCanonicalizeCastPass()";
}

def FoldConstants : Pass<"fold-constants", "ModuleOp"> {
let summary = "Folds constant operations in the Neura dialect";
let description = [{
This pass applies constant folding transformations to Neura dialect operations.
The folding includes:
1. Folding constant operations.
2. Removing dead operations.
}];
let constructor = "neura::createFoldConstantsPass()";
}

#endif // NEURA_PASSES_TD
1 change: 1 addition & 0 deletions lib/NeuraDialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_library(
FuseControlFlowPass.cpp
CanonicalizeLiveInPass.cpp
CanonicalizeCastPass.cpp
FoldConstantsPass.cpp

DEPENDS
MLIRNeuraTransformsIncGen
Expand Down
91 changes: 91 additions & 0 deletions lib/NeuraDialect/Transforms/FoldConstantsPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

#define GEN_PASS_DEF_FOLDCONSTANT
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {
struct FuseConstantAndGrantPattern
: public OpRewritePattern<neura::ConstantOp> {
using OpRewritePattern<neura::ConstantOp>::OpRewritePattern;

LogicalResult matchAndRewrite(neura::ConstantOp constant_op,
PatternRewriter &rewriter) const override {
bool made_change = false;

// Checks if the constant operation is used by a grant_once or grant_always
// operation.
for (auto user : constant_op->getUsers()) {
llvm::errs() << "Checking use: " << *user << "\n";
if (isa<neura::GrantOnceOp>(user) || isa<neura::GrantAlwaysOp>(user)) {
if (neura::GrantOnceOp grant_once_op =
dyn_cast<neura::GrantOnceOp>(user)) {
auto new_grant_once_op = rewriter.create<neura::GrantOnceOp>(
grant_once_op.getLoc(), grant_once_op.getResult().getType(),
/*value=*/nullptr, constant_op->getAttr("value"));
// Replaces the original constant operation with the new one.
rewriter.replaceOp(grant_once_op, new_grant_once_op);
made_change = true;
} else if (neura::GrantAlwaysOp grant_always_op =
dyn_cast<neura::GrantAlwaysOp>(user)) {
auto new_grant_always_op = rewriter.create<neura::GrantAlwaysOp>(
grant_always_op.getLoc(), grant_always_op.getResult().getType(),
/*value=*/nullptr, constant_op->getAttr("value"));
// Replaces the original constant operation with the new one.
rewriter.replaceOp(grant_always_op, new_grant_always_op);
made_change = true;
}
}
}

if (constant_op->use_empty()) {
// If the constant operation has no users, it can be removed.
rewriter.eraseOp(constant_op);
made_change = true;
}

return success(made_change);
}
};

struct FoldConstantsPass
: public PassWrapper<FoldConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FoldConstantsPass)

StringRef getArgument() const override { return "fold-constants"; }
StringRef getDescription() const override {
return "Fold constant operations.";
}

void runOnOperation() override {
ModuleOp module_op = getOperation();
RewritePatternSet patterns(&getContext());
patterns.add<FuseConstantAndGrantPattern>(&getContext());
FrozenRewritePatternSet frozen(std::move(patterns));

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region &region : op->getRegions()) {
if (failed(applyPatternsGreedily(region, frozen))) {
signalPassFailure();
}
}
}
});
}
};

} // namespace

namespace mlir::neura {
std::unique_ptr<Pass> createFoldConstantsPass() {
return std::make_unique<FoldConstantsPass>();
}
} // namespace mlir::neura
3 changes: 1 addition & 2 deletions lib/NeuraDialect/Transforms/FusePatternsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ struct FusePatternsPass
}

void runOnOperation() override {
ModuleOp module_op = getOperation();
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
FrozenRewritePatternSet frozen(std::move(patterns));

ModuleOp module_op = getOperation();

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
Expand Down
Loading