diff --git a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp index 37679077..93a34387 100644 --- a/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp +++ b/lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp @@ -28,6 +28,20 @@ using namespace mlir::neura; namespace{ +struct ArithFMulToNeuraFMul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::MulFOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Type resultType = op.getType(); + + rewriter.replaceOpWithNewOp(op, resultType, lhs, rhs, Value()); + return success(); + } +}; + struct LowerArithToNeuraPass : public PassWrapper> { @@ -44,6 +58,7 @@ struct LowerArithToNeuraPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); mlir::neura::arith2neura::populateWithGenerated(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td b/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td index 6004cf96..dae54825 100644 --- a/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td +++ b/lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td @@ -14,6 +14,11 @@ def : Pat< (Neura_FSubOp $lhs, $rhs) >; +def : Pat< + (LLVM_FMulOp $lhs, $rhs, $_fastmath), + (Neura_FMulOp $lhs, $rhs) +>; + def : Pat< (LLVM_ConstantOp $value), (Neura_ConstantOp $value) diff --git a/test/neura/interpreter/interpreter.mlir b/test/neura/interpreter/interpreter.mlir index c5e163e7..8390b766 100644 --- a/test/neura/interpreter/interpreter.mlir +++ b/test/neura/interpreter/interpreter.mlir @@ -20,4 +20,14 @@ module { return %2 : f32 // CHECK: 7.0 } + + func.func @test_mul() -> f32 { + %arg0 = arith.constant 9.0 : f32 + %cst = arith.constant 2.0 : f32 + %0 = "neura.mov"(%arg0) : (f32) -> f32 + %1 = "neura.mov"(%cst) : (f32) -> f32 + %2 = "neura.fmul"(%0, %1) : (f32, f32) -> f32 + return %2 : f32 + // CHECK: 18.0 + } } diff --git a/test/neura/interpreter/multi_ops.mlir b/test/neura/interpreter/multi_ops.mlir new file mode 100644 index 00000000..8c9ddaed --- /dev/null +++ b/test/neura/interpreter/multi_ops.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s \ +// RUN: --convert-scf-to-cf \ +// RUN: --convert-math-to-llvm \ +// RUN: --convert-arith-to-llvm \ +// RUN: --convert-func-to-llvm \ +// RUN: --convert-cf-to-llvm \ +// RUN: --reconcile-unrealized-casts \ +// RUN: -o %t-lowered-to-llvm.mlir + +// RUN: mlir-translate -mlir-to-llvmir \ +// RUN: %t-lowered-to-llvm.mlir \ +// RUN: -o %t-lower_and_interpreter.ll + +// RUN: llc %t-lower_and_interpreter.ll \ +// RUN: -filetype=obj -o %t-out.o + +// RUN: clang++ main.cpp %t-out.o \ +// RUN: -o %t-out.bin + +// RUN: %t-out.bin > %t-dumped_output.txt + +// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s \ +// RUN: -o %t-neura.mlir + +// RUN: neura-interpreter %t-neura.mlir >> %t-dumped_output.txt +// RUN: FileCheck %s < %t-dumped_output.txt + +// RUN: FileCheck %s -check-prefix=GOLDEN < %t-dumped_output.txt +// GOLDEN: 7.0 + +module { + func.func @test() -> f32 attributes { llvm.emit_c_interface }{ + %arg0 = arith.constant 9.0 : f32 + %cst = arith.constant 2.0 : f32 + %0 = arith.subf %arg0, %cst : f32 + %1 = arith.mulf %arg0, %0 : f32 + // CHECK: Golden output: [[OUTPUT:[0-9]+\.[0-9]+]] + // CHECK: [neura-interpreter] Output: [[OUTPUT]] + return %1 : f32 + } +} + diff --git a/tools/neura-interpreter/neura-interpreter.cpp b/tools/neura-interpreter/neura-interpreter.cpp index c12b58d0..14c10062 100644 --- a/tools/neura-interpreter/neura-interpreter.cpp +++ b/tools/neura-interpreter/neura-interpreter.cpp @@ -77,7 +77,12 @@ int main(int argc, char **argv) { float lhs = valueMap[fsubOp.getLhs()]; float rhs = valueMap[fsubOp.getRhs()]; valueMap[fsubOp.getResult()] = lhs - rhs; - } else if (auto retOp = dyn_cast(op)) { + } else if (auto fmulOp = dyn_cast(op)) { + float lhs = valueMap[fmulOp.getLhs()]; + float rhs = valueMap[fmulOp.getRhs()]; + valueMap[fmulOp.getResult()] = lhs * rhs; + } + else if (auto retOp = dyn_cast(op)) { float result = valueMap[retOp.getOperand(0)]; llvm::outs() << "[neura-interpreter] Output: " << llvm::format("%.6f", result) << "\n"; } else {