diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 5f0eae8146bb..966817de923e 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -603,9 +603,13 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // fillK: range of each index, total number of fillInput(could be scatter) // after flattened k = 1*1*3 = 3 - for (int i = 0; i < ND; i++) { - fillK *= fillValuesType.getShape()[i]; + int64_t fillNumElements = 1; + for (int64_t dim : fillValuesType.getShape()) { + fillNumElements *= dim; } + if (fillNumElements % C != 0) + return std::nullopt; + fillK = fillNumElements / C; SmallVector tosaFillValuesShape({N, fillK, C}); // {1,3,1} // Reshape/Flatten fillValues to 3d tensor diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c59e83ebbd6b..ad831793b6de 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3825,7 +3825,6 @@ "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic", - "IndexPutImpl2DImplicitModule_basic", "IndexPutImpl2DIndexModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index bd1aae7913fb..99a446d3065d 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2746,6 +2746,26 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t // ----- +// CHECK-LABEL: func.func @torch.aten.index_put_hacked_twin_flattened_updates( +// CHECK: %[[SCATTER:.*]] = tosa.scatter +// CHECK-SAME: (tensor<1x6x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x6x1xf32> +// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[SCATTER]] +// CHECK-SAME: (tensor<1x6x1xf32>, !tosa.shape<3>) -> tensor<1x2x3xf32> +// CHECK: torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<1x2x3xf32> -> !torch.vtensor<[1,2,3],f32> +func.func @torch.aten.index_put_hacked_twin_flattened_updates( + %arg0: !torch.vtensor<[1,2,3],f32>, + %arg1: !torch.vtensor<[6],si64>, + %arg2: !torch.vtensor<[6],si64>, + %arg3: !torch.vtensor<[6],si64>, + %arg4: !torch.vtensor<[6],f32>) -> !torch.vtensor<[1,2,3],f32> { + %indices = torch.prim.ListConstruct %arg1, %arg2, %arg3 : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.list + %false = torch.constant.bool false + %0 = torch.aten.index_put.hacked_twin %arg0, %indices, %arg4, %false : !torch.vtensor<[1,2,3],f32>, !torch.list, !torch.vtensor<[6],f32>, !torch.bool -> !torch.vtensor<[1,2,3],f32> + return %0 : !torch.vtensor<[1,2,3],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4,2],si64>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {