Skip to content

Commit 4c064e3

Browse files
Arm Backend: Add support for split_copy.default (#14717)
Add support for partly supported split_copy operator that was already decomposed into multiple slice operators, but was not fully supported in the Arm backend. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Agrima Khare <[email protected]> Co-authored-by: Sebastian Larsson <[email protected]>
1 parent 546c680 commit 4c064e3

File tree

4 files changed

+105
-17
lines changed

4 files changed

+105
-17
lines changed

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,24 @@ def call(self, graph_module: torch.fx.GraphModule):
4646
dim = (dim + rank) % rank
4747

4848
# Validate that split lengths cover the entire dimension
49-
length_sum = sum(split_lengths)
49+
5050
dim_size = shape[dim]
51-
if length_sum != dim_size:
52-
raise ValueError(
53-
f"Split sizes {split_lengths} sum to {length_sum}, "
54-
f"but dimension {dim} has size {dim_size}"
55-
)
51+
if isinstance(split_lengths, int):
52+
if split_lengths <= 0:
53+
raise ValueError(
54+
f"Split size must be positive, got {split_lengths}"
55+
)
56+
full_chunks, remainder = divmod(dim_size, split_lengths)
57+
split_lengths = [split_lengths] * full_chunks
58+
if remainder:
59+
split_lengths.append(remainder)
60+
else:
61+
length_sum = sum(split_lengths)
62+
if length_sum != dim_size:
63+
raise ValueError(
64+
f"Split sizes {split_lengths} sum to {length_sum}, "
65+
f"but dimension {dim} has size {dim_size}"
66+
)
5667

5768
# Convert split argument 'split_lengths' to slice arguments start and end.
5869
starts = [0] * len(split_lengths)

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
exir_ops.edge.aten.log.default,
5656
exir_ops.edge.aten.linear.default,
5757
exir_ops.edge.aten.split_with_sizes_copy.default,
58+
exir_ops.edge.aten.split_copy.Tensor,
5859
exir_ops.edge.aten.floor.default,
5960
exir_ops.edge.aten.full.default,
6061
exir_ops.edge.aten.full_like.default,
@@ -152,6 +153,7 @@
152153
exir_ops.edge.aten.log.default,
153154
exir_ops.edge.aten.linear.default,
154155
exir_ops.edge.aten.split_with_sizes_copy.default,
156+
exir_ops.edge.aten.split_copy.Tensor,
155157
exir_ops.edge.aten.floor.default,
156158
exir_ops.edge.aten.full.default,
157159
exir_ops.edge.aten.full_like.default,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def _match_pattern(
330330
torch.ops.aten.slice_copy.Tensor,
331331
torch.ops.aten.split.Tensor,
332332
torch.ops.aten.split_with_sizes.default,
333+
torch.ops.aten.split_copy.Tensor,
333334
torch.ops.aten.transpose.Dimname,
334335
torch.ops.aten.transpose.int,
335336
torch.ops.aten.transpose_copy.int,

backends/arm/test/ops/test_split.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323

2424
class Split(torch.nn.Module):
25-
2625
test_data = {
2726
"split_1d_2_size_0_dim": lambda: (torch.rand(10), 2, 0),
2827
"split_2d_3_size_1_dim": lambda: (torch.rand(10, 10), 3, 1),
@@ -60,12 +59,24 @@ def forward(
6059
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]
6160

6261

62+
class SplitCopy(torch.nn.Module):
63+
aten_op = "torch.ops.aten.split_copy.Tensor"
64+
exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor"
65+
66+
def forward(
67+
self,
68+
x: torch.Tensor,
69+
split_size: int,
70+
dim: int,
71+
):
72+
return torch.split_copy(x, split_size=split_size, dim=dim)
73+
74+
6375
@common.parametrize(
6476
"test_data",
6577
(Split.test_data | Split.test_data_list),
6678
)
6779
def test_split_with_sizes_tosa_FP(test_data: input_t1):
68-
6980
pipeline = TosaPipelineFP[input_t1](
7081
Split(),
7182
test_data(),
@@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1):
7788

7889
@common.parametrize("test_data", Split.test_data_list)
7990
def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
80-
8191
pipeline = TosaPipelineFP[input_t1](
8292
SplitWithSizes(),
8393
test_data(),
@@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
92102
(Split.test_data | Split.test_data_list),
93103
)
94104
def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
95-
96105
pipeline = TosaPipelineFP[input_t1](
97106
SplitSingleOut(),
98107
test_data(),
@@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
107116
(Split.test_data | Split.test_data_list),
108117
)
109118
def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
110-
111119
pipeline = TosaPipelineFP[input_t1](
112120
SplitTwoOut(),
113121
test_data(),
@@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
122130
(Split.test_data | Split.test_data_list),
123131
)
124132
def test_split_with_sizes_tosa_INT(test_data: input_t1):
125-
126133
pipeline = TosaPipelineINT[input_t1](
127134
Split(),
128135
test_data(),
@@ -161,7 +168,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1):
161168
)
162169
@common.XfailIfNoCorstone320
163170
def test_split_with_sizes_u85_INT(test_data: input_t1):
164-
165171
pipeline = EthosU85PipelineINT[input_t1](
166172
Split(),
167173
test_data(),
@@ -190,7 +196,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1):
190196
@common.parametrize("test_data", Split.test_data_list)
191197
@common.SkipIfNoModelConverter
192198
def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
193-
194199
pipeline = VgfPipeline[input_t1](
195200
SplitWithSizes(),
196201
test_data(),
@@ -207,7 +212,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
207212
)
208213
@common.SkipIfNoModelConverter
209214
def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
210-
211215
pipeline = VgfPipeline[input_t1](
212216
SplitSingleOut(),
213217
test_data(),
@@ -224,7 +228,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
224228
)
225229
@common.SkipIfNoModelConverter
226230
def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
227-
228231
pipeline = VgfPipeline[input_t1](
229232
SplitTwoOut(),
230233
test_data(),
@@ -241,7 +244,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
241244
)
242245
@common.SkipIfNoModelConverter
243246
def test_split_with_sizes_vgf_INT(test_data: input_t1):
244-
245247
pipeline = VgfPipeline[input_t1](
246248
Split(),
247249
test_data(),
@@ -250,3 +252,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1):
250252
tosa_version="TOSA-1.0+INT",
251253
)
252254
pipeline.run()
255+
256+
257+
@common.parametrize("test_data", Split.test_data)
258+
def test_split_tensor_tosa_FP(test_data: Tuple):
259+
pipeline = TosaPipelineFP[input_t1](
260+
SplitCopy(),
261+
test_data(),
262+
aten_op=SplitCopy.aten_op,
263+
exir_op=SplitCopy.exir_op,
264+
)
265+
pipeline.run()
266+
267+
268+
@common.parametrize("test_data", Split.test_data)
269+
def test_split_tensor_tosa_INT(test_data: Tuple):
270+
pipeline = TosaPipelineINT[input_t1](
271+
SplitCopy(),
272+
test_data(),
273+
aten_op=SplitCopy.aten_op,
274+
exir_op=SplitCopy.exir_op,
275+
)
276+
pipeline.run()
277+
278+
279+
@common.XfailIfNoCorstone300
280+
@common.parametrize("test_data", Split.test_data)
281+
def test_split_tensor_u55_INT(test_data: Tuple):
282+
pipeline = EthosU55PipelineINT[input_t1](
283+
SplitCopy(),
284+
test_data(),
285+
aten_ops=SplitCopy.aten_op,
286+
exir_ops=SplitCopy.exir_op,
287+
)
288+
pipeline.run()
289+
290+
291+
@common.XfailIfNoCorstone320
292+
@common.parametrize("test_data", Split.test_data)
293+
def test_split_tensor_u85_INT(test_data: Tuple):
294+
pipeline = EthosU85PipelineINT[input_t1](
295+
SplitCopy(),
296+
test_data(),
297+
aten_ops=SplitCopy.aten_op,
298+
exir_ops=SplitCopy.exir_op,
299+
)
300+
pipeline.run()
301+
302+
303+
@common.parametrize("test_data", Split.test_data)
304+
@common.SkipIfNoModelConverter
305+
def test_split_tensor_vgf_FP(test_data: Tuple):
306+
pipeline = VgfPipeline[input_t1](
307+
SplitCopy(),
308+
test_data(),
309+
aten_op=SplitCopy.aten_op,
310+
exir_op=SplitCopy.exir_op,
311+
tosa_version="TOSA-1.0+FP",
312+
)
313+
pipeline.run()
314+
315+
316+
@common.parametrize("test_data", Split.test_data)
317+
@common.SkipIfNoModelConverter
318+
def test_split_tensor_vgf_INT(test_data: Tuple):
319+
pipeline = VgfPipeline[input_t1](
320+
SplitCopy(),
321+
test_data(),
322+
aten_op=SplitCopy.aten_op,
323+
exir_op=SplitCopy.exir_op,
324+
tosa_version="TOSA-1.0+INT",
325+
)
326+
pipeline.run()

0 commit comments

Comments
 (0)