Skip to content

Commit

Permalink
[Torch Dialect] improve argmax/argmin's decomposition to support keep… (
Browse files Browse the repository at this point in the history
#3514)

…dim=True when dim=None
  • Loading branch information
qingyunqu committed Jul 2, 2024
1 parent 2f231f3 commit e2fbded
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 36 deletions.
46 changes: 31 additions & 15 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7313,11 +7313,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.patched_argmax_shape_func(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %arg2 : !torch.bool\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %5 = torch.aten.append.t %3, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
Expand Down Expand Up @@ -7372,19 +7399,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
Expand Down
60 changes: 46 additions & 14 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,15 +1920,19 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
Location loc = op.getLoc();
Value input = op.getSelf();
Value dim = op.getDim();
Value keepDim = op.getKeepdim();
Value result = op.getResult();

bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(
op, "expected keepdim to be a constant bool");
}
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
BaseTensorType indicesTensorType = cast<BaseTensorType>(result.getType());
std::optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
if (!maybeInputRank || *maybeInputRank == 0) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a rank");
op, "expected input tensor to have a rank > 0");
}
unsigned inputRank = *maybeInputRank;
if (!indicesTensorType.hasSizes())
Expand All @@ -1945,21 +1949,49 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
BaseTensorType flattenType =
cast<BaseTensorType>(inputType.getWithSizesAndDtype(
{kUnknownSize}, inputType.getOptionalDtype()));
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
zero, end);
Value resultIndices =
rewriter
.create<DecompOpTy>(
loc,
valueTensorType.getWithSizesAndDtype(
ArrayRef<int64_t>{}, valueTensorType.getOptionalDtype()),
indicesTensorType.getWithSizesAndDtype(
ArrayRef<int64_t>{},
indicesTensorType.getOptionalDtype()),
input, /*dim=*/zero, /*keepdim=*/falseValue)
.getIndices();
if (keepDim) {
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value dimList = rewriter.create<PrimListConstructOp>(
loc,
Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
SmallVector<Value>(inputRank, one));
resultIndices = rewriter.create<AtenReshapeOp>(
loc,
indicesTensorType.getWithSizesAndDtype(
SmallVector<int64_t>(inputRank, 1),
indicesTensorType.getOptionalDtype()),
resultIndices, dimList);
}
rewriter.replaceOp(op, resultIndices);
return success();
} else {
Value resultIndices =
rewriter
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
input, dim, op.getKeepdim())
.getIndices();
rewriter.replaceOp(op, resultIndices);
return success();
}

Value resultArg =
rewriter
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType, input,
dim, keepDim)
.getIndices();

rewriter.replaceOp(op, resultArg);
return success();
}
};
} // namespace
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,7 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ArgmaxKeepdimModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,19 @@ def aten〇trace〡shape(self: List[int]) -> List[int]:
assert len(self) == 2, "input must have rank 2"
return []

# TODO: replace this patched function with `upstream_shape_functions.argmax` when upstream fix it
# see https://github.com/pytorch/pytorch/pull/129838
def patched_argmax_shape_func(self: List[int], dim: Optional[int] = None, keepdim: bool = False):
if dim is None and keepdim:
out: List[int] = []
for i in self:
out.append(1)
return out
return upstream_shape_functions.argmax(self, dim, keepdim)

@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`.
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
Expand All @@ -690,11 +701,11 @@ def aten〇trace〡shape(self: List[int]) -> List[int]:
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
])
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)
return patched_argmax_shape_func(self, dim, keepdim)

def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
# There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here.
return upstream_shape_functions.argmax(self, dim, keepdim)
return patched_argmax_shape_func(self, dim, keepdim)

# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
# making it impossible to add support for it using the current design of the shape library.
Expand Down Expand Up @@ -722,12 +733,19 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa
def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)

@check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`.
Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`.
ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds.
ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
])
def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]:
if dim is None:
return [], []
else:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
reduced_shape = patched_argmax_shape_func(self, dim, keepdim)
return reduced_shape, reduced_shape

def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,29 @@ def ArgmaxModule_basic(module, tu: TestUtils):
# ==============================================================================


class ArgmaxKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.argmax(a, keepdim=True)


@register_test_case(module_factory=lambda: ArgmaxKeepdimModule())
def ArgmaxKeepdimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


# ==============================================================================


class ArgmaxIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit e2fbded

Please sign in to comment.