Skip to content

Commit

Permalink
[onnx] Added flatten (#2760)
Browse files Browse the repository at this point in the history
  • Loading branch information
daveliddell and Dave Liddell authored Jan 20, 2024
1 parent b3a3ad4 commit 2f49240
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
73 changes: 73 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,79 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, data, dimValueList);
return success();
});
patterns.onOp(
"Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// Flatten means to partition the input tensor's dimensions
// into a "left range" spanning 0 to axis - 1 and a "right range"
// spanning axis to rank - 1. Each range is then collapsed
// into a single dimension, resulting in a 2-D tensor.
// If either range is empty, it is replaced with a single
// dimension of size 1.
//
// For example, for a 4-D input tensor of shape (a, b, c, d)
// and axis==2, flatten produces a 2-D tensor of shape
// (a*b, c*d).
//
// If instead axis==0, the left range is empty, and the result
// is (1, a*b*c*d).

Torch::ValueTensorType resultType;
Value operand;
int64_t axis;
if (binder.tensorOperand(operand) ||
binder.s64IntegerAttr(axis, "axis", 1) ||
binder.tensorResultType(resultType))
return failure();

// If axis is negative, count from the right instead of left
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
if (axis < 0)
axis = rank + axis;

Value collapsedRight;
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
binder.op->getContext());

if (axis >= rank) {
// If the right range is empty, add a dim of size 1 to the
// right side of the shape:
// cr = torch.unsqueeze(x, x.ndim)
Value rankConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(rank));
collapsedRight = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), baseType, operand, rankConst);
} else {
// Otherwise, collapse the right range into a single dimension:
// cr = torch._prims.collapse(x, axis, x.ndim - 1)
Value axisConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
Value rankLess1Const = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1));
collapsedRight = rewriter.create<Torch::PrimsCollapseOp>(
binder.getLoc(), baseType, operand, axisConst, rankLess1Const);
}

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));

if (axis <= 0) {
// If the left range is empty, add a dim of size 1 to the
// left side of the shape:
// torch.unsqueeze(cr, 0)
rewriter.replaceOpWithNewOp<Torch::AtenUnsqueezeOp>(
binder.op, resultType, collapsedRight, zero);
return success();
}

// Otherwise, collapse the left range into a single dimension:
// torch._prims.collapse(cr, 0, axis - 1)
Value axisLess1Const = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1));
rewriter.replaceOpWithNewOp<Torch::PrimsCollapseOp>(
binder.op, resultType, collapsedRight, zero, axisLess1Const);
return success();
});
patterns.onOp("Floor", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
113 changes: 113 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m
return %0 : !torch.vtensor<[2],si64>
}

// CHECK-LABEL: @test_flatten_4d_axis_2
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
return %0 : !torch.vtensor<[6,20],f32>
}

// CHECK-LABEL: @test_flatten_4d_axis_0
func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
return %0 : !torch.vtensor<[1,120],f32>
}

// CHECK-LABEL: @test_flatten_4d_axis_4
func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32>
return %0 : !torch.vtensor<[120,1],f32>
}

// CHECK-LABEL: @test_flatten_4d_axis_negative_2
func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32>
return %0 : !torch.vtensor<[6,20],f32>
}

// CHECK-LABEL: @test_flatten_4d_axis_negative_1
func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32>
return %0 : !torch.vtensor<[24,5],f32>
}

// CHECK-LABEL: @test_flatten_4d_axis_negative_4
func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32>
return %0 : !torch.vtensor<[1,120],f32>
}

// CHECK-LABEL: @test_flatten_2d_axis_1
func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
return %0 : !torch.vtensor<[2,3],f32>
}

// CHECK-LABEL: @test_flatten_1d_axis_0
func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
return %0 : !torch.vtensor<[1,2],f32>
}

// CHECK-LABEL: @test_flatten_1d_axis_negative_1
func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
// CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32>
return %0 : !torch.vtensor<[1,2],f32>
}

// COM: CHECK-LABEL: @test_flatten_1d_axis_1
func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
// CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
// CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
// CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
%0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32>
return %0 : !torch.vtensor<[2,1],f32>
}

0 comments on commit 2f49240

Please sign in to comment.