Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ def Neura_FMulFAddOp : Op<NeuraDialect, "fmul_fadd"> {
let traits = [SameOperandsAndResultElementType];
}

def Neura_MulAddOp : Op<NeuraDialect, "mul_add"> {
let summary = "Fused add(mul(a, b), c)";
let arguments = (ins AnyType:$a, AnyType:$b, AnyType:$c, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$a `,` $b `,` $c `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

// ----------------------------------------------------
// Defines move operations.
def Neura_DataMovOp : Op<NeuraDialect, "data_mov"> {
Expand Down
128 changes: 127 additions & 1 deletion lib/NeuraDialect/Transforms/FusePatternPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ struct FuseFAddFAddPattern : public OpRewritePattern<neura::FAddOp> {

LogicalResult matchAndRewrite(neura::FAddOp second,
PatternRewriter &rewriter) const override {
// Checks if rhs exists before trying to get its defining op.
if (!second.getRhs()) {
return failure();
}

Value lhs = second.getLhs();
Value rhs = second.getRhs();

Expand Down Expand Up @@ -61,6 +66,11 @@ struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {

LogicalResult matchAndRewrite(neura::FAddOp add,
PatternRewriter &rewriter) const override {
// Checks if rhs exists before trying to get its defining op.
if (!add.getRhs()) {
return failure();
}

auto lhs_op = add.getLhs().getDefiningOp<neura::FMulOp>();
auto rhs_op = add.getRhs().getDefiningOp<neura::FMulOp>();

Expand All @@ -82,7 +92,7 @@ struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {
return failure();
}

// Optional: only fuses if fmul has a single use.
// Optionally fuses if fmul has a single use.
if (!fmul->hasOneUse()) {
return failure();
}
Expand All @@ -99,6 +109,119 @@ struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {
}
};

struct FuseGepLoadPattern : public OpRewritePattern<neura::LoadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(neura::LoadOp load,
PatternRewriter &rewriter) const override {
Value addr = load.getAddr();
auto gep_op = addr.getDefiningOp<neura::GEP>();

if (!gep_op)
return failure();

// Only fuses if the gep has a single use.
if (!gep_op->hasOneUse())
return failure();

Location loc = load.getLoc();
Type type = load.getType();

// Creates the fused operation with base and indices from gep.
SmallVector<Value> indexValues;
for (auto gepIndex : gep_op.getIndicesAndPredicate()) {
indexValues.push_back(gepIndex);
}

auto fused = rewriter.create<neura::LoadIndexedOp>(
loc, type, gep_op.getBase(), indexValues, load.getPredicate());

rewriter.replaceOp(load, fused.getResult());
rewriter.eraseOp(gep_op);
return success();
}
};

struct FuseGEPStorePattern : public OpRewritePattern<neura::StoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(neura::StoreOp store,
PatternRewriter &rewriter) const override {
Value addr = store.getAddr();
auto gep_op = addr.getDefiningOp<neura::GEP>();

if (!gep_op)
return failure();

// Only fuses if the gep has a single use.
if (!gep_op->hasOneUse())
return failure();

Location loc = store.getLoc();

// Creates the fused operation with base and indices from gep.
SmallVector<Value> indexValues;
for (auto gepIndex : gep_op.getIndicesAndPredicate()) {
indexValues.push_back(gepIndex);
}

rewriter.create<neura::StoreIndexedOp>(
loc, store.getValue(), gep_op.getBase(), indexValues, store.getPredicate());

rewriter.eraseOp(store);
rewriter.eraseOp(gep_op);
return success();
}
};

struct FuseMulAddPattern : public OpRewritePattern<neura::AddOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(neura::AddOp add,
PatternRewriter &rewriter) const override {
// Checks if rhs exists before trying to get its defining op.
if (!add.getRhs()) {
return failure();
}

auto lhs_op = add.getLhs().getDefiningOp<neura::MulOp>();
auto rhs_op = add.getRhs().getDefiningOp<neura::MulOp>();

neura::MulOp mul = nullptr;
Value other;

// Case 1: mul is on the LHS.
if (lhs_op && add.getRhs()) {
mul = lhs_op;
other = add.getRhs();
}
// Case 2: mul is on the RHS.
else if (rhs_op && add.getLhs()) {
mul = rhs_op;
other = add.getLhs();
}

if (!mul) {
return failure();
}

// Only fuses if mul has a single use.
if (!mul->hasOneUse()) {
return failure();
}

Location loc = add.getLoc();
Type type = add.getType();

auto fused = rewriter.create<neura::MulAddOp>(
loc, type, mul.getLhs(), mul.getRhs(), other, Value());

rewriter.replaceOp(add, fused.getResult());
rewriter.eraseOp(mul);
return success();
}
};

struct FusePatternPass
: public PassWrapper<FusePatternPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePatternPass)
Expand All @@ -113,6 +236,9 @@ struct FusePatternPass
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
patterns.add<FuseGepLoadPattern>(&getContext(), 4);
patterns.add<FuseGEPStorePattern>(&getContext(), 5);
patterns.add<FuseMulAddPattern>(&getContext(), 6);
FrozenRewritePatternSet frozen(std::move(patterns));

// Applies to every region inside the module (regardless of func type,
Expand Down
26 changes: 26 additions & 0 deletions test/neura/fusion/kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <stdio.h>
#include <stdlib.h>

#define NTAPS 1024

int A[NTAPS][NTAPS];
int s[NTAPS];
int q[NTAPS];
int p[NTAPS];
int r[NTAPS];

void kernel(int A[][NTAPS], int s[], int q[], int p[], int r[]) {
int i, j;

for (i = 0; i < NTAPS; i++) {
for (j = 0; j < NTAPS; j++) {
s[j] = s[j] + r[i] * A[i][j];
q[i] = q[i] + A[i][j] * p[j];
}
}
}

int main() {
kernel(A, s, q, p, r);
}

24 changes: 24 additions & 0 deletions test/neura/fusion/test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# RUN: clang++ -S -emit-llvm -O3 -fno-unroll-loops -fno-vectorize -o %t-kernel.ll kernel.cpp
# RUN: mlir-translate --import-llvm %t-kernel.ll -o %t-kernel.mlir
# RUN: mlir-neura-opt --assign-accelerator \
# RUN: --lower-llvm-to-neura \
# RUN: --canonicalize-live-in \
# RUN: --leverage-predicated-value \
# RUN: --fold-constant \
# RUN: --transform-ctrl-to-data-flow \
# RUN: --fold-constant \
# RUN: --fuse-pattern \
# RUN: --view-op-graph \
# RUN: --insert-data-mov %t-kernel.mlir -o %t-kernel_dataflow.mlir | FileCheck %s --check-prefix=CHECK-FUSED --input-file=%t-kernel_dataflow.mlir

# RUN: mlir-neura-opt --map-to-accelerator="mapping-strategy=heuristic backtrack-config=customized" %t-kernel_dataflow.mlir | FileCheck %s --check-prefix=CHECK-MAPPING

# CHECK-FUSED: func.func
# CHECK-FUSED: accelerator = "neura"
# CHECK-FUSED: %102 = neura.load_indexed %100[%101 : !neura.data<i64, i1>] !neura.data<!llvm.ptr, i1> : !neura.data<i32, i1>
# CHECK-FUSED: %33 = "neura.mul_add"(%30, %31, %32) : (i32, i32, i32) -> i32
# CHECK-FUSED: %42 = "neura.mul_add"(%39, %40, %41) : (i32, i32, i32) -> i32

# CHECK-MAPPING: mapping_info
# CHECK-MAPPING: mapping_mode = "spatial-temporal", mapping_strategy = "heuristic", rec_mii = 9 : i32, res_mii = 5 : i32, x_tiles = 4 : i32, y_tiles = 4 : i32
# CHECK-MAPPING: mapping_locs
Loading