Skip to content

Commit

Permalink
Change lowering of onnx.IF to Krnl (#2932)
Browse files Browse the repository at this point in the history
* implementation

Signed-off-by: chentong319 <[email protected]>

* test case change

Signed-off-by: chentong319 <[email protected]>

* format

Signed-off-by: chentong319 <[email protected]>

* add test for If back

Signed-off-by: chentong319 <[email protected]>

* format

Signed-off-by: chentong319 <[email protected]>

---------

Signed-off-by: chentong319 <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
chentong319 and tungld authored Sep 6, 2024
1 parent ce4e041 commit c5d3e72
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_onnx_mlir_library(OMONNXToKrnl
ControlFlow/If.cpp
ControlFlow/Loop.cpp
ControlFlow/Scan.cpp
ControlFlow/Yield.cpp
ConvertONNXToKrnl.cpp
ML/CategoryMapper.cpp
Math/CumSum.cpp
Expand Down
8 changes: 0 additions & 8 deletions src/Conversion/ONNXToKrnl/ControlFlow/If.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ struct ONNXIfOpLowering : public OpConversionPattern<ONNXIfOp> {

rewriter.eraseBlock(&scfBranch.back());
scfBranch.takeBody(graph);
rewriter.setInsertionPointToEnd(&scfBranch.back());

Operation *yieldOp = scfBranch.back().getTerminator();
llvm::SmallVector<Value> outputs;
if (failed(rewriter.getRemappedValues(yieldOp->getOperands(), outputs))) {
llvm_unreachable("failed to convert branch return values");
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, outputs);
}
};

Expand Down
57 changes: 57 additions & 0 deletions src/Conversion/ONNXToKrnl/ControlFlow/Yield.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------------- Yield.cpp - Lowering Yield Op ------------------===//
//
// Copyright 2019-2023 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Yield Operator to Krnl dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

struct ONNXYieldOpLowering : public OpConversionPattern<ONNXYieldOp> {
ONNXYieldOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: OpConversionPattern(typeConverter, ctx) {}

LogicalResult matchAndRewrite(ONNXYieldOp yieldOp, ONNXYieldOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Gather info.
Operation *op = yieldOp.getOperation();
Location loc = ONNXLoc<ONNXYieldOp>(op);

MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder> create(
rewriter, loc);

ValueRange inputs = yieldOp.getOperands();
llvm::SmallVector<Value> outputs;
for (Value input : inputs) {
Type inputType = input.getType();
Type outputType = typeConverter->convertType(inputType);
outputs.emplace_back(typeConverter->materializeTargetConversion(
rewriter, loc, outputType, input));
}

rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, outputs);

onnxToKrnlSimdReport(op);
return success();
}
};

void populateLoweringONNXYieldOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXYieldOpLowering>(typeConverter, ctx);
}

} // namespace onnx_mlir
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
populateLoweringONNXIfOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXLoopOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXScanOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXYieldOpPattern(patterns, typeConverter, ctx);
// Math
populateLoweringONNXCumSumOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXDFTOpPattern(patterns, typeConverter, ctx);
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ void populateLoweringONNXLoopOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXScanOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
void populateLoweringONNXYieldOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);

// `Math` directory methods:
void populateLoweringONNXClipOpPattern(
Expand Down
17 changes: 11 additions & 6 deletions test/mlir/conversion/onnx_to_krnl/ControlFlow/If.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ func.func @test_if_simple(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i
onnx.Yield %arg2 : tensor<i64>
}) : (tensor<i1>) -> tensor<i64>
return %0 : tensor<i64>
// CHECK-LABEL: @test_if_simple
// CHECK-LABEL: func.func @test_if_simple
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<i1>, [[PARAM_1_:%.+]]: memref<i64>, [[PARAM_2_:%.+]]: memref<i64>) -> memref<i64> {
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][] : memref<i1>
// CHECK: [[VAR_1_:%.+]] = scf.if [[LOAD_PARAM_0_MEM_]] -> (memref<i64>) {
// CHECK: scf.yield [[PARAM_1_]] : memref<i64>
// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_2_]] : memref<i64> to tensor<i64>
// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref<i64> to tensor<i64>
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][] : memref<i1>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = scf.if [[LOAD_PARAM_0_MEM_]] -> (memref<i64>) {
// CHECK-DAG: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[VAR_1_]] : tensor<i64> to memref<i64>
// CHECK: scf.yield [[VAR_4_]] : memref<i64>
// CHECK: } else {
// CHECK: scf.yield [[PARAM_2_]] : memref<i64>
// CHECK: [[VAR_4_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_0_]] : tensor<i64> to memref<i64>
// CHECK: scf.yield [[VAR_4_1_]] : memref<i64>
// CHECK: }
// CHECK: return [[VAR_1_]] : memref<i64>
// CHECK: return [[VAR_3_]] : memref<i64>
// CHECK: }
}

2 changes: 0 additions & 2 deletions test/mlir/conversion/onnx_to_krnl/ControlFlow/lit.local.cfg

This file was deleted.

0 comments on commit c5d3e72

Please sign in to comment.