diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 4ae43dcfa505..647aab32fc37 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1875,23 +1875,89 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) + return failure(); + if (!outputShape.empty() && outputShape.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "output_shape list size does not match the number of axes"); + } + + auto inferOutputPaddingFromOutputShape = + [&]() -> FailureOr> { + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, + "padding list size does not match the number of axes"); + } + bool isPerAxisPadded = padding.size() == rank - 2; + SmallVector inferredOutputPadding; + inferredOutputPadding.reserve(rank - 2); + for (unsigned i = 0; i < rank - 2; i++) { + // ONNX pads are laid out as [x1_begin, ..., xN_begin, x1_end, + // ..., xN_end] when fully specified, or as a per-axis symmetric + // value when half-sized. + int64_t totalPadding = isPerAxisPadded + ? 2 * padding[i] + : padding[i] + padding[i + rank - 2]; + int64_t inferredDim = strides[i] * (inputShape[2 + i] - 1) - + totalPadding + + ((kernelShape[i] - 1) * dilations[i] + 1); + int64_t inferredOutputPaddingValue = outputShape[i] - inferredDim; + if (inferredOutputPaddingValue < 0) { + return rewriter.notifyMatchFailure( + binder.op, + "output_shape would require a negative output_padding"); + } + if (inferredOutputPaddingValue >= strides[i]) { + return rewriter.notifyMatchFailure( + binder.op, + "output_shape would require output_padding >= stride, " + "which violates the ONNX ConvTranspose specification"); + } + inferredOutputPadding.push_back(inferredOutputPaddingValue); + } + return inferredOutputPadding; + }; + + auto applyOutputPaddingFromOutputShape = [&]() -> LogicalResult { + FailureOr> inferredOutputPadding = + inferOutputPaddingFromOutputShape(); + if (failed(inferredOutputPadding)) + return failure(); + if (outputPadding != defaultOutputPadding && + outputPadding != *inferredOutputPadding) { + return rewriter.notifyMatchFailure( + binder.op, "output_shape and output_padding imply different " + "output_padding values"); + } + outputPadding = *inferredOutputPadding; + return success(); + }; if (autoPad == "VALID") { // Zero padding. padding = defaultPadding; + if (!outputShape.empty()) { + if (failed(applyOutputPaddingFromOutputShape())) + return failure(); + } } else if (autoPad == "NOTSET") { // Explicit padding; read pads with defaults. if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) return failure(); + if (!outputShape.empty()) { + if (failed(applyOutputPaddingFromOutputShape())) + return failure(); + } } else { // autopad == SAME_UPPER or SAME_LOWER - // Auto-padding; output_shape defaults to input_shape * strides. - SmallVector defaultOutputShape; - for (unsigned i = 0; i < rank - 2; i++) { - defaultOutputShape.push_back(inputShape[2 + i] * strides[i]); + // Auto-padding. When output_shape is not specified, default it to + // input_shape * strides. + if (outputShape.empty()) { + for (unsigned i = 0; i < rank - 2; i++) { + outputShape.push_back(inputShape[2 + i] * strides[i]); + } } - if (binder.s64IntegerArrayAttr(outputShape, "output_shape", - defaultOutputShape)) - return failure(); SmallVector paddingEnd; for (unsigned i = 0; i < rank - 2; i++) { int64_t totalPadding = diff --git a/test/Conversion/TorchOnnxToTorch/convtranspose_diagnostics.mlir b/test/Conversion/TorchOnnxToTorch/convtranspose_diagnostics.mlir new file mode 100644 index 000000000000..c376e74e9e15 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/convtranspose_diagnostics.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch + +func.func @test_convtranspose_output_shape_with_conflicting_output_padding( + %arg0: !torch.vtensor<[1,1,3,3],f32>, + %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + attributes {torch.onnx_meta.ir_version = 10 : si64, + torch.onnx_meta.opset_version = 22 : si64, + torch.onnx_meta.producer_name = "backend-test", + torch.onnx_meta.producer_version = ""} { + // expected-error @below {{failed to legalize operation 'torch.operator' that was explicitly marked illegal}} + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) { + torch.onnx.output_padding = [0 : si64, 1 : si64], + torch.onnx.output_shape = [10 : si64, 8 : si64], + torch.onnx.strides = [3 : si64, 2 : si64] + } : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index ce5bb23ae466..f748e44937ef 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1296,6 +1296,93 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc // ----- +// CHECK-LABEL: @test_convtranspose_output_shape_autopad_valid + func.func @test_convtranspose_output_shape_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,9,9],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.convolution %arg0, %arg1, {{.*}}, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], {{.*}}, %[[OUTPUT_PADDING]], {{.*}} -> !torch.vtensor<[1,2,9,9],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.output_shape = [9 : si64, 9 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,9,9],f32> + return %0 : !torch.vtensor<[1,2,9,9],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_output_shape + func.func @test_convtranspose_output_shape(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,10,8],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_shape = [10 : si64, 8 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_output_shape_with_output_padding + func.func @test_convtranspose_output_shape_with_output_padding(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[PADV0:.*]] = torch.constant.int 0 + // CHECK: %[[PADV1:.*]] = torch.constant.int 0 + // CHECK: %[[DILV0:.*]] = torch.constant.int 1 + // CHECK: %[[DILV1:.*]] = torch.constant.int 1 + // CHECK: %[[STRV0:.*]] = torch.constant.int 3 + // CHECK: %[[STRV1:.*]] = torch.constant.int 2 + // CHECK: %[[OPADV0:.*]] = torch.constant.int 1 + // CHECK: %[[OPADV1:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[PADV0]], %[[PADV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[DILV0]], %[[DILV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[STRV0]], %[[STRV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[OPADV0]], %[[OPADV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.convolution %arg0, %arg1, {{.*}}, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], {{.*}}, %[[OUTPUT_PADDING]], {{.*}} -> !torch.vtensor<[1,2,10,8],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_padding = [1 : si64, 1 : si64], torch.onnx.output_shape = [10 : si64, 8 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_output_shape_with_pads + func.func @test_convtranspose_output_shape_with_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,8,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[PADV0:.*]] = torch.constant.int 1 + // CHECK: %[[PADV1:.*]] = torch.constant.int 1 + // CHECK: %[[DILV0:.*]] = torch.constant.int 1 + // CHECK: %[[DILV1:.*]] = torch.constant.int 1 + // CHECK: %[[STRV0:.*]] = torch.constant.int 3 + // CHECK: %[[STRV1:.*]] = torch.constant.int 2 + // CHECK: %[[OPADV0:.*]] = torch.constant.int 1 + // CHECK: %[[OPADV1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[PADV0]], %[[PADV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[DILV0]], %[[DILV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[STRV0]], %[[STRV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[OPADV0]], %[[OPADV1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.convolution %arg0, %arg1, {{.*}}, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], {{.*}}, %[[OUTPUT_PADDING]], {{.*}} -> !torch.vtensor<[1,2,8,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_shape = [8 : si64, 5 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,8,5],f32> + return %0 : !torch.vtensor<[1,2,8,5],f32> + } + +// ----- + // CHECK-LABEL: @test_convtranspose_pad func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0