diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8976fe4c2b35..3cacd78a2309 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6422,51 +6422,52 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ }]; } -def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ +def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dims + Torch_IntType:$upscale_factor ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPermuteOp::print(OpAsmPrinter &printer) { + void AtenPixelShuffleOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ +def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$upscale_factor + AnyTorchListOfTorchIntType:$dims ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPixelShuffleOp::print(OpAsmPrinter &printer) { + void AtenPermuteOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasVerifier = 1; } def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 86342f279548..c7dc571b6d9d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2859,6 +2859,96 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +LogicalResult AtenPermuteOp::verify() { + + // Verification of the permute op for input & output dimensions with + // statically known sizes. + + SmallVector permutation; + auto permutationObtained = getListConstructElements(getDims(), permutation); + if (!permutationObtained) { + return success(); + } + + auto outType = getResult().getType().cast(); + auto inType = getSelf().getType().cast(); + + if (!outType.hasSizes() || !inType.hasSizes()) { + return success(); + } + + auto outShape = outType.getSizes(); + auto inShape = inType.getSizes(); + + auto outRank = outShape.size(); + + if (outRank != inShape.size()) { + return emitOpError( + "expected input and output tensors to have same rank, but ") + << inShape.size() << " != " << outRank << '.'; + } + + if (outRank != permutation.size()) { + return emitOpError() << "expected permutation to have size equal result " + "tensor rank. The permutation has " + << permutation.size() + << " elements, the output has rank " << outRank << '.'; + } + + + // Initialization of the reverse permutation. -1 denotes an unknown + // permutation index. + SmallVector reversePermutation(outRank, -1); + + // In this loop: + // (1) check that the permutation indices are in bounds, and not duplicated. + // (2) populate reversePermutation (to check for duplicates). + // (3) check that the input and output shapes agree with the permutation. For + // example, if the permutation is (1,2,0) and the input shape is (2,3,5), + // then the output shape must be (3,5,2). + + for (uint64_t to = 0; to < outRank; ++to) { + int64_t from; + + auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from)); + + if (!fromIsSet) { + continue; + } + + // if 'from' is the unkwown index, continue. + if (from == -1) { + continue; + } + + if (!isValidDim(from, outRank)) { + return emitError("observed invalid index in permutation (") + << from << ") for input tensor of rank " << outRank << '.'; + } + + if (reversePermutation[from] != -1) { + return emitOpError("has a duplicate dimension (") + << from << ") in its permutation " << getDims() << '.'; + } + reversePermutation[from] = to; + + auto dimSizesDefined = + inShape[from] != kUnknownSize && outShape[to] != kUnknownSize; + auto dimSizesDifferent = inShape[from] != outShape[to]; + + if (dimSizesDefined && dimSizesDifferent) { + return emitOpError("has a permutation which is not compatible with the " + "input and output shapes. ") + << "The input shape in dimension " << from << " is " + << inShape[from] << ", and the output shape in dimension " << to + << " is " << outShape[to] + << " : they should be the same with this permutation. "; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4c114eda8e08..ff78d463a6e0 100644 --- a/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -114,7 +114,7 @@ def _get_main_module_name() -> str: def raw_emit_op(operator: JitOperator, emitter_td: TextEmitter, *, traits: List[str], - has_folder: bool, has_canonicalizer: bool): + has_folder: bool, has_canonicalizer: bool, has_verifier: bool): """Emit the ODS for a JitOperator to a textual file. This is the lowest level of emission and is responsible for low-level @@ -199,6 +199,8 @@ def generic_result_name(i): p_td("let hasFolder = 1;") if has_canonicalizer: p_td("let hasCanonicalizer = 1;") + if has_verifier: + p_td("let hasVerifier = 1;") p_td("}") p_td("\n") @@ -208,7 +210,8 @@ def emit_op(operator: JitOperator, *, traits: Optional[List[str]] = None, has_folder: bool = False, - has_canonicalizer: bool = False): + has_canonicalizer: bool = False, + has_verifier: bool = False): """Main entry point for op emission. Besides emitting the op, it deduces / adds traits based on the operator @@ -228,7 +231,8 @@ def emit_op(operator: JitOperator, emitter_td, traits=traits, has_folder=has_folder, - has_canonicalizer=has_canonicalizer) + has_canonicalizer=has_canonicalizer, + has_verifier=has_verifier) def emit_ops(emitter_td: TextEmitter, registry: Registry): @@ -481,8 +485,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") - emit("aten::permute : (Tensor, int[]) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") + emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index f22d5b785746..067c1a9b67f4 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -281,3 +281,84 @@ func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*, %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64> return %0 : !torch.vtensor<*,f64> } + + +// ----- + +func.func @torch.permute$test_changing_rank (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3,4],f32> { + + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + + %perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + // expected-error@+1 {{expected input and output tensors to have same rank, but 3 != 4}} + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + + return %3 : !torch.vtensor<[1,2,3,4],f32> +} + +// ----- + +func.func @torch.permute$test_permutation_too_short (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> { + + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + + %perm = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + + // expected-error@+1 {{The permutation has 2 elements, the output has rank 3}} + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,3],f32> + + return %3 : !torch.vtensor<[1,2,3],f32> +} + +// ----- + +func.func @torch.permute$duplicate_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[2,3,1],f32> { + + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %perm = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + // expected-error@+1 {{'torch.aten.permute' op has a duplicate dimension (1) in its permutation}} + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[2,3,1],f32> + + return %3 : !torch.vtensor<[2,3,1],f32> +} + +// ----- + +func.func @torch.permute$incorrect_output_shape (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[3,1,2],f32> { + + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %none = torch.constant.none + + %perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + // expected-error@+1 {{'torch.aten.permute' op has a permutation which is not compatible with the input and output shapes. The input shape in dimension 1 is 2, and the output shape in dimension 0 is 3 : they should be the same with this permutation.}} + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[3,1,2],f32> + + return %3 : !torch.vtensor<[3,1,2],f32> +} + + +// ----- + +func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> { + + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int7 = torch.constant.int 7 + %perm = torch.prim.ListConstruct %int0, %int1, %int7 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + + // expected-error@+1 {{observed invalid index in permutation (7) for input tensor of rank 3.}} + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,3],f32> + + return %3 : !torch.vtensor<[1,2,3],f32> +} + diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 178db4fa1da6..d8d0fc33098b 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -170,3 +170,14 @@ func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,5 %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> return %arg2 : !torch.list> } + +// Check that verification passes with '-1' as a permutation index. +func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> { + %intm1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,3],f32> + return %3 : !torch.vtensor<[1,2,3],f32> +} +