Skip to content

Commit e649e06

Browse files
zezhangZe Zhang
andauthored
Add aten.unflatten.int support and its torch-to-tosa lowering (#2509)
Add aten.unflatten.int op Add its torch-to-tosa lowering Update the TorchToTosa/basic.mlir tests To test e2e tosa lowering: `python -m e2e_testing.main -v -c=tosa` --------- Co-authored-by: Ze Zhang <[email protected]>
1 parent 9b5a4af commit e649e06

File tree

9 files changed

+152
-4
lines changed

9 files changed

+152
-4
lines changed

e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
1818
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
1919
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
20+
"UnflattenStaticModule_basic",
2021
}
2122

2223
TORCHDYNAMO_XFAIL_SET = {
@@ -1056,6 +1057,7 @@
10561057
"BatchNorm3DModule_basic",
10571058
"BatchNorm1DStaticShapeModule_basic",
10581059
"FlattenStaticModule_basic",
1060+
"UnflattenStaticModule_basic",
10591061
"FlattenRank0Module_basic",
10601062
"ElementwiseFlattenBroadcastModule_basic",
10611063
"SquareModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
75377537
}];
75387538
}
75397539

7540+
def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
7541+
AllowsTypeRefinement,
7542+
ReadOnly
7543+
]> {
7544+
let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`";
7545+
let arguments = (ins
7546+
AnyTorchTensorType:$self,
7547+
Torch_IntType:$dim,
7548+
AnyTorchListOfTorchIntType:$sizes
7549+
);
7550+
let results = (outs
7551+
AnyTorchTensorType:$result
7552+
);
7553+
let hasCustomAssemblyFormat = 1;
7554+
let extraClassDefinition = [{
7555+
ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) {
7556+
return parseDefaultTorchOp(parser, result, 3, 1);
7557+
}
7558+
void AtenUnflattenIntOp::print(OpAsmPrinter &printer) {
7559+
printDefaultTorchOp(printer, *this, 3, 1);
7560+
}
7561+
}];
7562+
}
7563+
75407564
def Torch_AtenDimOp : Torch_Op<"aten.dim", [
75417565
AllowsTypeRefinement,
75427566
HasValueSemantics,

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
25252525
return success();
25262526
}
25272527

2528+
template <>
2529+
LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
2530+
AtenUnflattenIntOp op, OpAdaptor adaptor,
2531+
ConversionPatternRewriter &rewriter) const {
2532+
2533+
// Not a ranked tensor type
2534+
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
2535+
if (!selfType || !selfType.hasStaticShape())
2536+
return rewriter.notifyMatchFailure(
2537+
op,
2538+
"Only ranked tensor types with static shapes are currently supported");
2539+
2540+
int64_t selfRank = selfType.getRank();
2541+
int64_t dim;
2542+
2543+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2544+
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
2545+
2546+
SmallVector<int64_t> sizes;
2547+
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes)))
2548+
return rewriter.notifyMatchFailure(
2549+
op, "Only constant sizes are currently supported");
2550+
2551+
if (selfRank > 0 && !isValidDim(dim, selfRank))
2552+
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
2553+
2554+
SmallVector<int64_t> newShape;
2555+
for (auto s :
2556+
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
2557+
int64_t idx = s.index();
2558+
if (idx < dim || idx > dim) {
2559+
newShape.push_back(s.value());
2560+
} else {
2561+
auto sum = 1;
2562+
for (auto newDims : sizes) {
2563+
newShape.push_back(newDims);
2564+
sum *= newDims;
2565+
}
2566+
if (sum != s.value())
2567+
return rewriter.notifyMatchFailure(op,
2568+
"sizes mismatch with original dim");
2569+
}
2570+
}
2571+
2572+
auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
2573+
selfType.getElementType());
2574+
2575+
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2576+
op, getTypeConverter()->convertType(newType), adaptor.getSelf(),
2577+
rewriter.getDenseI64ArrayAttr(newShape));
2578+
2579+
return success();
2580+
}
2581+
25282582
template <>
25292583
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
25302584
AtenPermuteOp op, OpAdaptor adaptor,
@@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
50505104
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
50515105
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
50525106
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
5107+
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
50535108
INSERT_ATENOP_PATTERN(AtenPermuteOp);
50545109
INSERT_ATENOP_PATTERN(AtenLog2Op);
50555110
INSERT_ATENOP_PATTERN(AtenThresholdOp);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
72057205
" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
72067206
" return %0 : !torch.list<int>\n"
72077207
" }\n"
7208+
" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
7209+
" %none = torch.constant.none\n"
7210+
" %int1 = torch.constant.int 1\n"
7211+
" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
7212+
" %1 = torch.aten.add.t %0, %arg2 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
7213+
" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n"
7214+
" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
7215+
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
7216+
" return %4 : !torch.list<int>\n"
7217+
" }\n"
72087218
" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>) -> !torch.list<int> {\n"
72097219
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
72107220
" return %0 : !torch.list<int>\n"
@@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
85808590
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
85818591
" return %0#1 : !torch.int\n"
85828592
" }\n"
8593+
" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
8594+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8595+
" return %0#1 : !torch.int\n"
8596+
" }\n"
85838597
" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
85848598
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
85858599
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) {
199199
// that it does not return a view and treat those as having value
200200
// semantics.
201201
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
202-
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
203-
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
204-
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
205-
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
202+
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
203+
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
204+
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
205+
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
206206
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
207207
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
208208
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,9 @@ def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int])
623623
def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
624624
return upstream_shape_functions.flatten(self, start_dim, end_dim)
625625

626+
def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]:
627+
return self[:dim] + sizes + self[dim + 1:]
628+
626629
def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
627630
return upstream_shape_functions.linear(input, weight, bias)
628631

@@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_
16561659
self_rank, self_dtype = self_rank_dtype
16571660
return self_dtype
16581661

1662+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1]))
1663+
def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int:
1664+
self_rank, self_dtype = self_rank_dtype
1665+
return self_dtype
1666+
16591667
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
16601668
def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
16611669
self_rank, self_dtype = self_rank_dtype

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
516516
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
517517
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
518518
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
519+
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
519520
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
520521
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
521522
emit("aten::Bool.Tensor : (Tensor) -> (bool)")

python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
304304
# ==============================================================================
305305

306306

307+
class UnflattenStaticModule(torch.nn.Module):
308+
309+
def __init__(self):
310+
super().__init__()
311+
312+
@export
313+
@annotate_args([
314+
None,
315+
([1, 6, 4], torch.float32, True),
316+
])
317+
def forward(self, x):
318+
return torch.ops.aten.unflatten(x, 1, (2, 3))
319+
320+
321+
@register_test_case(module_factory=lambda: UnflattenStaticModule())
322+
def UnflattenStaticModule_basic(module, tu: TestUtils):
323+
module.forward(tu.rand(1, 6, 4))
324+
325+
326+
# ==============================================================================
327+
328+
307329
class FlattenStaticModule(torch.nn.Module):
308330

309331
def __init__(self):

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor
556556

557557
// -----
558558

559+
// CHECK-LABEL: func.func @forward(
560+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
561+
// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32>
562+
// CHECK: %[[VAL_1:.*]] = torch.constant.int 1
563+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
564+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
565+
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
566+
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array<i64: 1, 2, 3, 4>} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32>
567+
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
568+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32>
569+
// CHECK: }
570+
func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> {
571+
%int1 = torch.constant.int 1
572+
%int2 = torch.constant.int 2
573+
%int3 = torch.constant.int 3
574+
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
575+
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
576+
return %1 : !torch.vtensor<[1,2,3,4],f32>
577+
}
578+
579+
// -----
580+
559581
// CHECK-LABEL: func.func @forward(
560582
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
561583
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,

0 commit comments

Comments
 (0)