From e2fbded49cdfa37185e8dbfbef0164e23d005c08 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 2 Jul 2024 09:08:57 +0800 Subject: [PATCH] =?UTF-8?q?[Torch=20Dialect]=20improve=20argmax/argmin's?= =?UTF-8?q?=20decomposition=20to=20support=20keep=E2=80=A6=20(#3514)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …dim=True when dim=None --- .../Transforms/AbstractInterpLibrary.cpp | 46 +++++++++----- .../Torch/Transforms/DecomposeComplexOps.cpp | 60 ++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 32 +++++++--- .../test_suite/reduction.py | 23 +++++++ 5 files changed, 126 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 8bf50fd21cc..0e244e51a96 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7313,11 +7313,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @__torch__.patched_argmax_shape_func(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %arg2 : !torch.bool\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" @@ -7372,19 +7399,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple, list> {\n" -" %none = torch.constant.none\n" -" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.tuple, list>) {\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" %3 = torch.prim.ListConstruct : () -> !torch.list\n" -" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %4 : !torch.tuple, list>\n" -" } else {\n" -" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" -" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %4 : !torch.tuple, list>\n" -" }\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %1 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 36e79736381..f966b320c13 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1920,15 +1920,19 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value dim = op.getDim(); - Value keepDim = op.getKeepdim(); Value result = op.getResult(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "expected keepdim to be a constant bool"); + } BaseTensorType inputType = cast(input.getType()); BaseTensorType indicesTensorType = cast(result.getType()); std::optional maybeInputRank = getTensorRank(input); - if (!maybeInputRank) { + if (!maybeInputRank || *maybeInputRank == 0) { return rewriter.notifyMatchFailure( - op, "expected input tensor to have a rank"); + op, "expected input tensor to have a rank > 0"); } unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) @@ -1945,21 +1949,49 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { BaseTensorType flattenType = cast(inputType.getWithSizesAndDtype( {kUnknownSize}, inputType.getOptionalDtype())); - dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value falseValue = rewriter.create(loc, false); input = rewriter.create(loc, flattenType, input, - dim, end); + zero, end); + Value resultIndices = + rewriter + .create( + loc, + valueTensorType.getWithSizesAndDtype( + ArrayRef{}, valueTensorType.getOptionalDtype()), + indicesTensorType.getWithSizesAndDtype( + ArrayRef{}, + indicesTensorType.getOptionalDtype()), + input, /*dim=*/zero, /*keepdim=*/falseValue) + .getIndices(); + if (keepDim) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dimList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + SmallVector(inputRank, one)); + resultIndices = rewriter.create( + loc, + indicesTensorType.getWithSizesAndDtype( + SmallVector(inputRank, 1), + indicesTensorType.getOptionalDtype()), + resultIndices, dimList); + } + rewriter.replaceOp(op, resultIndices); + return success(); + } else { + Value resultIndices = + rewriter + .create(loc, valueTensorType, indicesTensorType, + input, dim, op.getKeepdim()) + .getIndices(); + rewriter.replaceOp(op, resultIndices); + return success(); } - - Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, input, - dim, keepDim) - .getIndices(); - - rewriter.replaceOp(op, resultArg); - return success(); } }; } // namespace diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index adfb68b94be..7bbd82a0d7c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1505,6 +1505,7 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ArgmaxKeepdimModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6e4957e5889..1dbadd6897b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -680,8 +680,19 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: assert len(self) == 2, "input must have rank 2" return [] +# TODO: replace this patched function with `upstream_shape_functions.argmax` when upstream fix it +# see https://github.com/pytorch/pytorch/pull/129838 +def patched_argmax_shape_func(self: List[int], dim: Optional[int] = None, keepdim: bool = False): + if dim is None and keepdim: + out: List[int] = [] + for i in self: + out.append(1) + return out + return upstream_shape_functions.argmax(self, dim, keepdim) + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. @@ -690,11 +701,11 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. ]) def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: # There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here. - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) # TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, # making it impossible to add support for it using the current design of the shape library. @@ -722,12 +733,19 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. +]) def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: - if dim is None: - return [], [] - else: - reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) - return reduced_shape, reduced_shape + reduced_shape = patched_argmax_shape_func(self, dim, keepdim) + return reduced_shape, reduced_shape def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 7cf6dd69445..9a683e3c621 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1533,6 +1533,29 @@ def ArgmaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ArgmaxKeepdimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.argmax(a, keepdim=True) + + +@register_test_case(module_factory=lambda: ArgmaxKeepdimModule()) +def ArgmaxKeepdimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ArgmaxIntModule(torch.nn.Module): def __init__(self): super().__init__()