Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,35 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
let traits = [SameOperandsAndResultElementType];
}

// Defines a multiplication operation.
def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
let summary = "Floating multiplication operation";
let opName = "fmul";
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs);
let results = (outs AnyFloat:$result);
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def VectorOfAnyFloat :
TypeConstraint<
CPred<
"mlir::isa<::mlir::VectorType>($_self) && "
"mlir::isa<::mlir::FloatType>(mlir::cast<::mlir::VectorType>($_self).getElementType())"
>,
"vector of floats"
>;

// Defines a vector multiplication operation.
def Neura_VFMulOp : Op<NeuraDialect, "vfmul"> {
let summary = "Vector floating multiplication operation";
let opName = "vfmul";
let arguments = (ins VectorOfAnyFloat:$lhs, VectorOfAnyFloat:$rhs);
let results = (outs VectorOfAnyFloat:$result);
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def Neura_FAddFAddOp : Op<NeuraDialect, "fadd_fadd"> {
let summary = "Fused fadd(fadd(a, b), c)";
let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c);
Expand All @@ -30,6 +59,13 @@ def Neura_FAddFAddOp : Op<NeuraDialect, "fadd_fadd"> {
let traits = [SameOperandsAndResultElementType];
}

def Neura_FMulFAddOp : Op<NeuraDialect, "fmul_fadd"> {
let summary = "Fused fadd(fmul(a, b), c)";
let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c);
let results = (outs AnyFloat:$result);
// let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

// Defines a move operation for data communication.
def Neura_MovOp : Op<NeuraDialect, "mov"> {
Expand Down
75 changes: 71 additions & 4 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,59 @@ using namespace mlir;

namespace {

// Lowers integer add from mlir.llvm.add to nuera.add. We provide the lowering
// here instead of tablegen due to that mlir.llvm.add uses an EnumProperty
// (IntegerOverflowFlags) defined via MLIR interfaces — which DRR cannot match
// on or extract from.
struct LlvmAddToNeuraAdd : public OpRewritePattern<mlir::LLVM::AddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::AddOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs());
return success();
}
};

struct LlvmFMulToNeuraFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::FMulOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type result_type = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(result_type))
return failure();

rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs);
return success();
}
};

struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::FMulOp op,
PatternRewriter &rewriter) const override {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Type result_type = op->getResult(0).getType();

// Only matches vector<xf32>.
auto vecTy = mlir::dyn_cast<VectorType>(result_type);
if (!vecTy || !mlir::isa<FloatType>(vecTy.getElementType()))
return failure();

rewriter.replaceOpWithNewOp<neura::VFMulOp>(op, result_type, lhs, rhs);
return success();
}
};

struct LowerLlvmToNeuraPass
: public PassWrapper<LowerLlvmToNeuraPass, OperationPass<func::FuncOp>> {
: public PassWrapper<LowerLlvmToNeuraPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerLlvmToNeuraPass)

Expand All @@ -40,10 +91,26 @@ struct LowerLlvmToNeuraPass

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
// Adds DRR patterns.
mlir::neura::llvm2neura::populateWithGenerated(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
patterns.add<LlvmAddToNeuraAdd>(&getContext());
patterns.add<LlvmFMulToNeuraFMul>(&getContext());
patterns.add<LlvmVFMulToNeuraVFMul>(&getContext());
FrozenRewritePatternSet frozen(std::move(patterns));

ModuleOp module_op = getOperation();

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region &region : op->getRegions()) {
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
signalPassFailure();
}
}
}
});
}
};
} // namespace
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ def : Pat<
(LLVM_FAddOp $lhs, $rhs, $_fastmath),
(Neura_FAddOp $lhs, $rhs)
>;

1 change: 1 addition & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_library(NeuraTransforms
LINK_LIBS PUBLIC
MLIRIR
MLIRFuncDialect
MLIRLLVMDialect
MLIRPass
MLIRSupport
MLIRTransformUtils
Expand Down
110 changes: 93 additions & 17 deletions lib/Transforms/FusePatternsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,116 @@ using namespace mlir;

namespace {

struct FuseFAddFAddPattern : public RewritePattern {
FuseFAddFAddPattern(MLIRContext *ctx)
: RewritePattern("neura.fadd", /*benefit=*/1, ctx) {}
struct FuseFAddFAddPattern : public OpRewritePattern<neura::FAddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
auto first = dyn_cast<neura::FAddOp>(op);
if (!first || !first->hasOneUse()) return failure();
LogicalResult matchAndRewrite(neura::FAddOp second,
PatternRewriter &rewriter) const override {
Value lhs = second.getLhs();
Value rhs = second.getRhs();

auto user = dyn_cast<neura::FAddOp>(*first->getUsers().begin());
if (!user) return failure();
auto lhs_op = lhs.getDefiningOp<neura::FAddOp>();
auto rhs_op = rhs.getDefiningOp<neura::FAddOp>();

Location loc = user.getLoc();
Type type = user.getType();
neura::FAddOp first = nullptr;
Value tail;

auto fused = rewriter.create<neura::FAddFAddOp>(loc, type,
first.getLhs(), first.getRhs(), user.getRhs());
// Case 1: LHS is another fadd.
if (lhs_op && second.getRhs()) {
first = lhs_op;
tail = second.getRhs();
}
// Case 2: RHS is another fadd.
else if (rhs_op && second.getLhs()) {
first = rhs_op;
tail = second.getLhs();
}

rewriter.replaceOp(user, fused.getResult());
if (!first)
return failure();

// Only fuses if the first fadd is not reused elsewhere.
if (!first->hasOneUse())
return failure();

Location loc = second.getLoc();
Type type = second.getType();

auto fused = rewriter.create<neura::FAddFAddOp>(
loc, type, first.getLhs(), first.getRhs(), tail);

rewriter.replaceOp(second, fused.getResult());
rewriter.eraseOp(first);
return success();
}
};

struct FusePatternsPass : public PassWrapper<FusePatternsPass, OperationPass<func::FuncOp>> {
struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(neura::FAddOp add,
PatternRewriter &rewriter) const override {
auto lhs_op = add.getLhs().getDefiningOp<neura::FMulOp>();
auto rhs_op = add.getRhs().getDefiningOp<neura::FMulOp>();

neura::FMulOp fmul = nullptr;
Value other;

// Case 1: fmul is on the LHS.
if (lhs_op && add.getRhs()) {
fmul = lhs_op;
other = add.getRhs();
}
// Case 2: fmul is on the RHS.
else if (rhs_op && add.getLhs()) {
fmul = rhs_op;
other = add.getLhs();
}

if (!fmul)
return failure();

// Optional: only fuses if fmul has a single use.
if (!fmul->hasOneUse())
return failure();

Location loc = add.getLoc();
Type type = add.getType();

auto fused = rewriter.create<neura::FMulFAddOp>(
loc, type, fmul.getLhs(), fmul.getRhs(), other);

rewriter.replaceOp(add, fused.getResult());
rewriter.eraseOp(fmul);
return success();
}
};

struct FusePatternsPass : public PassWrapper<FusePatternsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePatternsPass)

StringRef getArgument() const override { return "fuse-patterns"; }
StringRef getDescription() const override { return "Apply Neura fusion patterns."; }

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
FrozenRewritePatternSet frozen(std::move(patterns));

ModuleOp module_op = getOperation();

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region &region : op->getRegions()) {
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
signalPassFailure();
}
}
}
});
}
};

Expand Down
22 changes: 18 additions & 4 deletions lib/Transforms/InsertMovPass.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

using namespace mlir;

Expand Down Expand Up @@ -54,7 +55,7 @@ struct InsertMovForNeuraOps : public RewritePattern {
};

struct InsertMovPass
: public PassWrapper<InsertMovPass, OperationPass<func::FuncOp>> {
: public PassWrapper<InsertMovPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass)

StringRef getArgument() const override { return "insert-mov"; }
Expand All @@ -69,8 +70,21 @@ struct InsertMovPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<InsertMovForNeuraOps>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
FrozenRewritePatternSet frozen(std::move(patterns));

ModuleOp module_op = getOperation();

// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region &region : op->getRegions()) {
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
signalPassFailure();
}
}
}
});
}
};
} // namespace
Expand Down
1 change: 1 addition & 0 deletions test/neura/fadd_fadd.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Applies pattern fusion before mov insertion.
// RUN: mlir-neura-opt --lower-arith-to-neura --fuse-patterns --insert-mov %s | FileCheck %s

func.func @test(%a: f32, %b: f32) -> f32 {
Expand Down
46 changes: 46 additions & 0 deletions test/neura/for_loop/kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: mlir-neura-opt %s | FileCheck %s

#include <stdio.h>

#define NTAPS 32

float input[NTAPS] = {
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
};
float output[NTAPS];
float coefficients[NTAPS] = {0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25};

void kernel(float input[], float output[], float coefficient[]);

int main()
{

// input_dsp (input, NTAPS, 0);

kernel(input, output, coefficients);

// output_dsp (input, NTAPS, 0);
// output_dsp (coefficients, NTAPS, 0);
// output_dsp (output, NTAPS, 0);
printf("output: %f\n", output[0]);
return 0;
}

/* input : input sample array */
/* output: output sample array */
/* coefficient: coefficient array */
void kernel(float input[], float output[], float coefficient[]) {
int i;
int j = 0;

for (i = 0; i < NTAPS; ++i) {
float tmp = input[i] * coefficient[i];
output[j] += tmp;
}
}
Loading
Loading