Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1875,23 +1875,86 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, "Expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
return failure();
if (!outputShape.empty() && outputShape.size() != rank - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"output_shape list size does not match the number of axes");
}

auto inferOutputPaddingFromOutputShape =
[&]() -> FailureOr<SmallVector<int64_t>> {
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op,
"padding list size does not match the number of axes");
}
SmallVector<int64_t> inferredOutputPadding;
inferredOutputPadding.reserve(rank - 2);
for (unsigned i = 0; i < rank - 2; i++) {
// ONNX pads are laid out as [x1_begin, ..., xN_begin, x1_end,
// ..., xN_end] when fully specified, or as a per-axis symmetric
// value when half-sized.
int64_t totalPadding = padding.size() == 2 * (rank - 2)
? padding[i] + padding[i + rank - 2]
: 2 * padding[i];
int64_t inferredDim = strides[i] * (inputShape[2 + i] - 1) -
totalPadding +
((kernelShape[i] - 1) * dilations[i] + 1);
int64_t inferredOutputPaddingValue = outputShape[i] - inferredDim;
if (inferredOutputPaddingValue < 0) {
return rewriter.notifyMatchFailure(
binder.op,
"output_shape would require a negative output_padding, "
"which is not supported");
}
if (inferredOutputPaddingValue >= strides[i]) {
return rewriter.notifyMatchFailure(
binder.op,
"output_shape would require output_padding >= stride, "
"which violates the ONNX ConvTranspose specification");
}
inferredOutputPadding.push_back(inferredOutputPaddingValue);
}
return inferredOutputPadding;
};

if (!outputShape.empty() && outputPadding != defaultOutputPadding) {
return rewriter.notifyMatchFailure(
binder.op,
"output_shape with explicit output_padding is not supported");
}

if (autoPad == "VALID") {
// Zero padding.
padding = defaultPadding;
if (!outputShape.empty()) {
FailureOr<SmallVector<int64_t>> inferredOutputPadding =
inferOutputPaddingFromOutputShape();
if (failed(inferredOutputPadding))
return failure();
outputPadding = *inferredOutputPadding;
}
} else if (autoPad == "NOTSET") {
// Explicit padding; read pads with defaults.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding))
return failure();
if (!outputShape.empty()) {
FailureOr<SmallVector<int64_t>> inferredOutputPadding =
inferOutputPaddingFromOutputShape();
if (failed(inferredOutputPadding))
return failure();
outputPadding = *inferredOutputPadding;
}
} else { // autopad == SAME_UPPER or SAME_LOWER
// Auto-padding; output_shape defaults to input_shape * strides.
SmallVector<int64_t> defaultOutputShape;
for (unsigned i = 0; i < rank - 2; i++) {
defaultOutputShape.push_back(inputShape[2 + i] * strides[i]);
// Auto-padding. When output_shape is not specified, default it to
// input_shape * strides.
if (outputShape.empty()) {
for (unsigned i = 0; i < rank - 2; i++) {
outputShape.push_back(inputShape[2 + i] * strides[i]);
}
}
if (binder.s64IntegerArrayAttr(outputShape, "output_shape",
defaultOutputShape))
return failure();
SmallVector<int64_t> paddingEnd;
for (unsigned i = 0; i < rank - 2; i++) {
int64_t totalPadding =
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/TorchOnnxToTorch/convtranspose_diagnostics.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch

func.func @test_convtranspose_output_shape_with_explicit_output_padding(
%arg0: !torch.vtensor<[1,1,3,3],f32>,
%arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32>
attributes {torch.onnx_meta.ir_version = 10 : si64,
torch.onnx_meta.opset_version = 22 : si64,
torch.onnx_meta.producer_name = "backend-test",
torch.onnx_meta.producer_version = ""} {
// expected-error @below {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}}
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {
torch.onnx.output_padding = [1 : si64, 1 : si64],
torch.onnx.output_shape = [10 : si64, 8 : si64],
torch.onnx.strides = [3 : si64, 2 : si64]
} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32>
return %0 : !torch.vtensor<[1,2,10,8],f32>
}
66 changes: 66 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,72 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc

// -----

// CHECK-LABEL: @test_convtranspose_output_shape_autopad_valid
func.func @test_convtranspose_output_shape_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,9,9],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.convolution %arg0, %arg1, {{.*}}, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], {{.*}}, %[[OUTPUT_PADDING]], {{.*}} -> !torch.vtensor<[1,2,9,9],f32>
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.output_shape = [9 : si64, 9 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,9,9],f32>
return %0 : !torch.vtensor<[1,2,9,9],f32>
}

// -----

// CHECK-LABEL: @test_convtranspose_output_shape
func.func @test_convtranspose_output_shape(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,10,8],f32>
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_shape = [10 : si64, 8 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32>
return %0 : !torch.vtensor<[1,2,10,8],f32>
}

// -----

// CHECK-LABEL: @test_convtranspose_output_shape_with_pads
func.func @test_convtranspose_output_shape_with_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,8,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[PADV0:.*]] = torch.constant.int 1
// CHECK: %[[PADV1:.*]] = torch.constant.int 1
// CHECK: %[[DILV0:.*]] = torch.constant.int 1
// CHECK: %[[DILV1:.*]] = torch.constant.int 1
// CHECK: %[[STRV0:.*]] = torch.constant.int 3
// CHECK: %[[STRV1:.*]] = torch.constant.int 2
// CHECK: %[[OPADV0:.*]] = torch.constant.int 1
// CHECK: %[[OPADV1:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[PADV0]], %[[PADV1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[DILV0]], %[[DILV1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[STRV0]], %[[STRV1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[OPADV0]], %[[OPADV1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.convolution %arg0, %arg1, {{.*}}, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], {{.*}}, %[[OUTPUT_PADDING]], {{.*}} -> !torch.vtensor<[1,2,8,5],f32>
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_shape = [8 : si64, 5 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,8,5],f32>
return %0 : !torch.vtensor<[1,2,8,5],f32>
}

// -----

// CHECK-LABEL: @test_convtranspose_pad
func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
Expand Down
Loading