diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index 7481e57a35..6d2f1dc5d6 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -1630,6 +1630,41 @@ struct RecomposeConcatPattern : public OpRewritePattern { } }; +struct RemoveDimZeroInputInConcatPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXConcatOp concatOp, PatternRewriter &rewriter) const final { + ValueRange inputs = concatOp.getOperands(); + int64_t axis = concatOp.getAxis(); + + // Collect indices of inputs whose dim size at axis is zero. + SmallVector indices; + for (unsigned int i = 0; i < inputs.size(); ++i) { + Value inp = inputs[i]; + if (!hasShapeAndRank(inp)) + continue; + ArrayRef shape = getShape(inp.getType()); + // Scalar with rank 0. Dim size is one (not zero). + if (shape.size() == 0) + continue; + if (shape[axis] == 0) + indices.emplace_back(i); + } + if (indices.empty()) + return rewriter.notifyMatchFailure( + concatOp, "No operand whose dim at axis is zero"); + + // Rewrite: remove operands whose dim at axis is zero. + rewriter.modifyOpInPlace(concatOp, [&]() { + for (int64_t idx : indices) + concatOp.getOperation()->eraseOperand(idx); + }); + return success(); + } +}; + // ============================================================================= // Rewrite pattern LayerNormalization // ============================================================================= @@ -2193,6 +2228,7 @@ void ONNXCastOp::getCanonicalizationPatterns( void ONNXConcatOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.insert(context); + results.insert(context); } /// on the ONNXClipOp. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 1714056fd7..ff31ad8950 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -2130,6 +2130,8 @@ func.func @test_remove_where_equal_4(%arg0: tensor) -> tensor<2xi64> { // ----- +// COM: Canonicalize ConcatOp. + func.func @test_recompose_concat(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32> ) -> tensor<1x12x4xf32> { %0 = "onnx.Concat"(%arg0, %arg1) {axis = 1 : si64, onnx_node_name = "onnx.Concat_0"} : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x6x4xf32> %1 = "onnx.Concat"(%0, %arg0) {axis = 1 : si64, onnx_node_name = "onnx.Concat_1"} : (tensor<1x6x4xf32>, tensor<1x3x4xf32>) -> tensor<1x9x4xf32> @@ -2145,6 +2147,45 @@ return %2 : tensor<1x12x4xf32> } // ----- + +func.func @test_concat_remove_dim_0_operand_2_args(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "onnx.Concat"(%arg0, %arg1) {axis = 1 : si64 }: (tensor, tensor) -> tensor + onnx.Return %0 : tensor + +// CHECK-LABEL: func.func @test_concat_remove_dim_0_operand_2_args +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: onnx.Return [[PARAM_0_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_concat_remove_dim_0_operand_3_args(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : si64 }: (tensor, tensor, tensor) -> tensor + onnx.Return %0 : tensor + +// CHECK-LABEL: func.func @test_concat_remove_dim_0_operand_3_args +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor, tensor) -> tensor +// CHECK: onnx.Return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_concat_donot_remove_operand(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : si64 }: (tensor, tensor, tensor) -> tensor + onnx.Return %0 : tensor + +// CHECK-LABEL: func.func @test_concat_donot_remove_operand +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: onnx.Return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + func.func @test_split_relu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) { %cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64> %0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) @@ -2162,6 +2203,7 @@ func.func @test_split_relu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf // CHECK: } // ----- + func.func @test_split_relu_movement_not_all_equal(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) { %cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64> %0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) @@ -2181,6 +2223,9 @@ func.func @test_split_relu_movement_not_all_equal(%arg0: tensor<1x8x2xf32>) -> ( // CHECK: } // ----- + +// ----- + func.func @test_split_leakyrelu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) { %cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64> %0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) @@ -2198,6 +2243,7 @@ func.func @test_split_leakyrelu_movement(%arg0: tensor<1x8x2xf32>) -> (tensor<1x // CHECK: } // ----- + func.func @test_split_leakyrelu_movement_different_alpha(%arg0: tensor<1x8x2xf32>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>) { %cst = onnx.Constant dense<[2, 3, 3]> : tensor<3xi64> %0:3 = "onnx.Split"(%arg0, %cst) {axis = 1 : si64} : (tensor<1x8x2xf32>, tensor<3xi64>) -> (tensor<1x2x2xf32>, tensor<1x3x2xf32>, tensor<1x3x2xf32>)