diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 846470202c15..090aa4cecb91 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -76,6 +76,7 @@ void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline( if (options.decompose) { pm.addNestedPass( Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(Torch::createRecomposeComplexOpsPass()); pm.addNestedPass(createCanonicalizerPass()); } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..01c0a867b89f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -922,17 +922,6 @@ "UpSampleNearest2dVecNoneShape_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", - # Error: `aten.as_strided` op is not supported - "ChunkListUnpackDynamic_Module_basic", - "ChunkListUnpackUnevenDynamic_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "SplitWithSizes_Module_basic", "Unfold_Module_basic", "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_basic", @@ -4018,17 +4007,7 @@ "AtenAsStridedModule_basic", "AtenAsStridedNoStorageOffsetModule_basic", "AtenAsStridedUnknownSizeModule_basic", - "ChunkListUnpackDynamic_Module_basic", - "ChunkListUnpackUnevenDynamic_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", "NativeGroupNormModule_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "SplitWithSizes_Module_basic", # error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>' "ElementwiseClampMaxModule_bfloat16", "ElementwiseClampMinModule_bfloat16", diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index 1c202ed3a382..ac522c517ff5 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -52,8 +52,6 @@ def _get_decomposition_table(): # support for aten.native_batch_norm_backward. aten._native_batch_norm_legit_functional, aten.native_group_norm, - aten.split.Tensor, - aten.split_with_sizes, aten.norm.ScalarOpt_dim, aten.embedding_dense_backward, aten.native_layer_norm_backward, diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 0b3da8ad2155..918bf7128b65 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -10,8 +10,6 @@ torch.ops.aten.norm.ScalarOpt_dim, torch.ops.aten.native_group_norm, torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, torch.ops.aten.native_layer_norm, torch.ops.aten.masked_fill.Tensor, torch.ops.aten.masked_fill.Scalar,