Skip to content

Commit

Permalink
Add verification for torch permute op (#2551)
Browse files Browse the repository at this point in the history
- adds support for an optional verifier to the generated torch op
tablegen (GeneratedTorchOps.td)
- uses the above to add a verifier for the torch permute op. 

Motivation: I hit an unclear error from linalg while developing a
decomposition pass for pixel_shuffle. The error would have been clearer
if the problem had been detected earlier in the invalid aten.permute op.

Testing: new tests added. To run added tests, from the base directory
run

```
 ./build/bin/llvm-lit  test/Dialect/Torch/invalid.mlir
 ```
  • Loading branch information
newling committed Nov 15, 2023
1 parent e81282a commit dad1f01
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 15 deletions.
23 changes: 12 additions & 11 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6422,51 +6422,52 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
}];
}

def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dims
Torch_IntType:$upscale_factor
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPermuteOp::print(OpAsmPrinter &printer) {
void AtenPixelShuffleOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`";
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$upscale_factor
AnyTorchListOfTorchIntType:$dims
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult AtenPermuteOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPixelShuffleOp::print(OpAsmPrinter &printer) {
void AtenPermuteOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [
Expand Down
90 changes: 90 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,96 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}

LogicalResult AtenPermuteOp::verify() {

// Verification of the permute op for input & output dimensions with
// statically known sizes.

SmallVector<Value> permutation;
auto permutationObtained = getListConstructElements(getDims(), permutation);
if (!permutationObtained) {
return success();
}

auto outType = getResult().getType().cast<BaseTensorType>();
auto inType = getSelf().getType().cast<BaseTensorType>();

if (!outType.hasSizes() || !inType.hasSizes()) {
return success();
}

auto outShape = outType.getSizes();
auto inShape = inType.getSizes();

auto outRank = outShape.size();

if (outRank != inShape.size()) {
return emitOpError(
"expected input and output tensors to have same rank, but ")
<< inShape.size() << " != " << outRank << '.';
}

if (outRank != permutation.size()) {
return emitOpError() << "expected permutation to have size equal result "
"tensor rank. The permutation has "
<< permutation.size()
<< " elements, the output has rank " << outRank << '.';
}


// Initialization of the reverse permutation. -1 denotes an unknown
// permutation index.
SmallVector<int64_t> reversePermutation(outRank, -1);

// In this loop:
// (1) check that the permutation indices are in bounds, and not duplicated.
// (2) populate reversePermutation (to check for duplicates).
// (3) check that the input and output shapes agree with the permutation. For
// example, if the permutation is (1,2,0) and the input shape is (2,3,5),
// then the output shape must be (3,5,2).

for (uint64_t to = 0; to < outRank; ++to) {
int64_t from;

auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from));

if (!fromIsSet) {
continue;
}

// if 'from' is the unkwown index, continue.
if (from == -1) {
continue;
}

if (!isValidDim(from, outRank)) {
return emitError("observed invalid index in permutation (")
<< from << ") for input tensor of rank " << outRank << '.';
}

if (reversePermutation[from] != -1) {
return emitOpError("has a duplicate dimension (")
<< from << ") in its permutation " << getDims() << '.';
}
reversePermutation[from] = to;

auto dimSizesDefined =
inShape[from] != kUnknownSize && outShape[to] != kUnknownSize;
auto dimSizesDifferent = inShape[from] != outShape[to];

if (dimSizesDefined && dimSizesDifferent) {
return emitOpError("has a permutation which is not compatible with the "
"input and output shapes. ")
<< "The input shape in dimension " << from << " is "
<< inShape[from] << ", and the output shape in dimension " << to
<< " is " << outShape[to]
<< " : they should be the same with this permutation. ";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// DtypeCalculateYieldDtypesOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _get_main_module_name() -> str:
def raw_emit_op(operator: JitOperator,
emitter_td: TextEmitter,
*, traits: List[str],
has_folder: bool, has_canonicalizer: bool):
has_folder: bool, has_canonicalizer: bool, has_verifier: bool):
"""Emit the ODS for a JitOperator to a textual file.
This is the lowest level of emission and is responsible for low-level
Expand Down Expand Up @@ -199,6 +199,8 @@ def generic_result_name(i):
p_td("let hasFolder = 1;")
if has_canonicalizer:
p_td("let hasCanonicalizer = 1;")
if has_verifier:
p_td("let hasVerifier = 1;")
p_td("}")
p_td("\n")

Expand All @@ -208,7 +210,8 @@ def emit_op(operator: JitOperator,
*,
traits: Optional[List[str]] = None,
has_folder: bool = False,
has_canonicalizer: bool = False):
has_canonicalizer: bool = False,
has_verifier: bool = False):
"""Main entry point for op emission.
Besides emitting the op, it deduces / adds traits based on the operator
Expand All @@ -228,7 +231,8 @@ def emit_op(operator: JitOperator,
emitter_td,
traits=traits,
has_folder=has_folder,
has_canonicalizer=has_canonicalizer)
has_canonicalizer=has_canonicalizer,
has_verifier=has_verifier)


def emit_ops(emitter_td: TextEmitter, registry: Registry):
Expand Down Expand Up @@ -481,8 +485,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True)
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
Expand Down
81 changes: 81 additions & 0 deletions test/Dialect/Torch/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,84 @@ func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64>
return %0 : !torch.vtensor<*,f64>
}


// -----

func.func @torch.permute$test_changing_rank (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3,4],f32> {

%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2

%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

// expected-error@+1 {{expected input and output tensors to have same rank, but 3 != 4}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>

return %3 : !torch.vtensor<[1,2,3,4],f32>
}

// -----

func.func @torch.permute$test_permutation_too_short (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {

%int0 = torch.constant.int 0
%int1 = torch.constant.int 1

%perm = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list<int>

// expected-error@+1 {{The permutation has 2 elements, the output has rank 3}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>

return %3 : !torch.vtensor<[1,2,3],f32>
}

// -----

func.func @torch.permute$duplicate_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[2,3,1],f32> {

%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%perm = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

// expected-error@+1 {{'torch.aten.permute' op has a duplicate dimension (1) in its permutation}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[2,3,1],f32>

return %3 : !torch.vtensor<[2,3,1],f32>
}

// -----

func.func @torch.permute$incorrect_output_shape (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[3,1,2],f32> {

%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%none = torch.constant.none

%perm = torch.prim.ListConstruct %int1, %int2, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>

// expected-error@+1 {{'torch.aten.permute' op has a permutation which is not compatible with the input and output shapes. The input shape in dimension 1 is 2, and the output shape in dimension 0 is 3 : they should be the same with this permutation.}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[3,1,2],f32>

return %3 : !torch.vtensor<[3,1,2],f32>
}


// -----

func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {

%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int7 = torch.constant.int 7
%perm = torch.prim.ListConstruct %int0, %int1, %int7 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>


// expected-error@+1 {{observed invalid index in permutation (7) for input tensor of rank 3.}}
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>

return %3 : !torch.vtensor<[1,2,3],f32>
}

11 changes: 11 additions & 0 deletions test/Dialect/Torch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,14 @@ func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,5
%arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list<vtensor<[1,?,56,96],f16>>
return %arg2 : !torch.list<vtensor<[1,?,56,96],f16>>
}

// Check that verification passes with '-1' as a permutation index.
func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32>) -> !torch.vtensor<[1,2,3],f32> {
%intm1 = torch.constant.int -1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,2,3],f32>
return %3 : !torch.vtensor<[1,2,3],f32>
}

0 comments on commit dad1f01

Please sign in to comment.