2222
2323
2424class Split (torch .nn .Module ):
25-
2625 test_data = {
2726 "split_1d_2_size_0_dim" : lambda : (torch .rand (10 ), 2 , 0 ),
2827 "split_2d_3_size_1_dim" : lambda : (torch .rand (10 , 10 ), 3 , 1 ),
@@ -60,12 +59,24 @@ def forward(
6059 return x .split (split_size = split_size_or_sections , dim = dim )[1 :3 ]
6160
6261
62+ class SplitCopy (torch .nn .Module ):
63+ aten_op = "torch.ops.aten.split_copy.Tensor"
64+ exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor"
65+
66+ def forward (
67+ self ,
68+ x : torch .Tensor ,
69+ split_size : int ,
70+ dim : int ,
71+ ):
72+ return torch .split_copy (x , split_size = split_size , dim = dim )
73+
74+
6375@common .parametrize (
6476 "test_data" ,
6577 (Split .test_data | Split .test_data_list ),
6678)
6779def test_split_with_sizes_tosa_FP (test_data : input_t1 ):
68-
6980 pipeline = TosaPipelineFP [input_t1 ](
7081 Split (),
7182 test_data (),
@@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1):
7788
7889@common .parametrize ("test_data" , Split .test_data_list )
7990def test_split_with_sizes_tosa_FP_2 (test_data : input_t1 ):
80-
8191 pipeline = TosaPipelineFP [input_t1 ](
8292 SplitWithSizes (),
8393 test_data (),
@@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
92102 (Split .test_data | Split .test_data_list ),
93103)
94104def test_split_with_sizes_tosa_FP_one_out (test_data : input_t1 ):
95-
96105 pipeline = TosaPipelineFP [input_t1 ](
97106 SplitSingleOut (),
98107 test_data (),
@@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
107116 (Split .test_data | Split .test_data_list ),
108117)
109118def test_split_with_sizes_tosa_FP_two_out (test_data : input_t1 ):
110-
111119 pipeline = TosaPipelineFP [input_t1 ](
112120 SplitTwoOut (),
113121 test_data (),
@@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
122130 (Split .test_data | Split .test_data_list ),
123131)
124132def test_split_with_sizes_tosa_INT (test_data : input_t1 ):
125-
126133 pipeline = TosaPipelineINT [input_t1 ](
127134 Split (),
128135 test_data (),
@@ -161,7 +168,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1):
161168)
162169@common .XfailIfNoCorstone320
163170def test_split_with_sizes_u85_INT (test_data : input_t1 ):
164-
165171 pipeline = EthosU85PipelineINT [input_t1 ](
166172 Split (),
167173 test_data (),
@@ -190,7 +196,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1):
190196@common .parametrize ("test_data" , Split .test_data_list )
191197@common .SkipIfNoModelConverter
192198def test_split_with_sizes_vgf_FP_2 (test_data : input_t1 ):
193-
194199 pipeline = VgfPipeline [input_t1 ](
195200 SplitWithSizes (),
196201 test_data (),
@@ -207,7 +212,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
207212)
208213@common .SkipIfNoModelConverter
209214def test_split_with_sizes_vgf_FP_one_out (test_data : input_t1 ):
210-
211215 pipeline = VgfPipeline [input_t1 ](
212216 SplitSingleOut (),
213217 test_data (),
@@ -224,7 +228,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
224228)
225229@common .SkipIfNoModelConverter
226230def test_split_with_sizes_vgf_FP_two_out (test_data : input_t1 ):
227-
228231 pipeline = VgfPipeline [input_t1 ](
229232 SplitTwoOut (),
230233 test_data (),
@@ -241,7 +244,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
241244)
242245@common .SkipIfNoModelConverter
243246def test_split_with_sizes_vgf_INT (test_data : input_t1 ):
244-
245247 pipeline = VgfPipeline [input_t1 ](
246248 Split (),
247249 test_data (),
@@ -250,3 +252,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1):
250252 tosa_version = "TOSA-1.0+INT" ,
251253 )
252254 pipeline .run ()
255+
256+
257+ @common .parametrize ("test_data" , Split .test_data )
258+ def test_split_tensor_tosa_FP (test_data : Tuple ):
259+ pipeline = TosaPipelineFP [input_t1 ](
260+ SplitCopy (),
261+ test_data (),
262+ aten_op = SplitCopy .aten_op ,
263+ exir_op = SplitCopy .exir_op ,
264+ )
265+ pipeline .run ()
266+
267+
268+ @common .parametrize ("test_data" , Split .test_data )
269+ def test_split_tensor_tosa_INT (test_data : Tuple ):
270+ pipeline = TosaPipelineINT [input_t1 ](
271+ SplitCopy (),
272+ test_data (),
273+ aten_op = SplitCopy .aten_op ,
274+ exir_op = SplitCopy .exir_op ,
275+ )
276+ pipeline .run ()
277+
278+
279+ @common .XfailIfNoCorstone300
280+ @common .parametrize ("test_data" , Split .test_data )
281+ def test_split_tensor_u55_INT (test_data : Tuple ):
282+ pipeline = EthosU55PipelineINT [input_t1 ](
283+ SplitCopy (),
284+ test_data (),
285+ aten_ops = SplitCopy .aten_op ,
286+ exir_ops = SplitCopy .exir_op ,
287+ )
288+ pipeline .run ()
289+
290+
291+ @common .XfailIfNoCorstone320
292+ @common .parametrize ("test_data" , Split .test_data )
293+ def test_split_tensor_u85_INT (test_data : Tuple ):
294+ pipeline = EthosU85PipelineINT [input_t1 ](
295+ SplitCopy (),
296+ test_data (),
297+ aten_ops = SplitCopy .aten_op ,
298+ exir_ops = SplitCopy .exir_op ,
299+ )
300+ pipeline .run ()
301+
302+
303+ @common .parametrize ("test_data" , Split .test_data )
304+ @common .SkipIfNoModelConverter
305+ def test_split_tensor_vgf_FP (test_data : Tuple ):
306+ pipeline = VgfPipeline [input_t1 ](
307+ SplitCopy (),
308+ test_data (),
309+ aten_op = SplitCopy .aten_op ,
310+ exir_op = SplitCopy .exir_op ,
311+ tosa_version = "TOSA-1.0+FP" ,
312+ )
313+ pipeline .run ()
314+
315+
316+ @common .parametrize ("test_data" , Split .test_data )
317+ @common .SkipIfNoModelConverter
318+ def test_split_tensor_vgf_INT (test_data : Tuple ):
319+ pipeline = VgfPipeline [input_t1 ](
320+ SplitCopy (),
321+ test_data (),
322+ aten_op = SplitCopy .aten_op ,
323+ exir_op = SplitCopy .exir_op ,
324+ tosa_version = "TOSA-1.0+INT" ,
325+ )
326+ pipeline .run ()
0 commit comments