99from typing import Tuple
1010
1111import torch
12-
13- from executorch .backends .arm .test import common
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1417
1518from executorch .backends .arm .test .tester .test_pipeline import (
1619 EthosU55PipelineINT ,
1922 TosaPipelineINT ,
2023 VgfPipeline ,
2124)
22- from torchvision .ops import Permute
25+ from executorch .backends .arm .tosa import TosaSpecification
26+ from executorch .backends .xnnpack .test .tester import Quantize
2327
2428input_t1 = Tuple [torch .Tensor ] # Input x
2529
@@ -42,10 +46,10 @@ class SimplePermute(torch.nn.Module):
4246 def __init__ (self , dims : list [int ]):
4347 super ().__init__ ()
4448
45- self .permute = Permute ( dims = dims )
49+ self .dims = dims
4650
4751 def forward (self , x ):
48- return self .permute (x )
52+ return torch .permute (x , self . dims )
4953
5054
5155@common .parametrize ("test_data" , test_data_suite )
@@ -128,3 +132,96 @@ def test_permute_vgf_INT(test_data):
128132 tosa_version = "TOSA-1.0+INT" ,
129133 )
130134 pipeline .run ()
135+
136+
137+ def get_symmetric_a16w8_permute_quantizer (
138+ u55_config = False , per_channel_quantization = False
139+ ):
140+ tosa_version = conftest .get_option ("tosa_version" )
141+ tosa_profiles = {
142+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
143+ }
144+
145+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
146+ quantizer .set_global (
147+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
148+ )
149+
150+ return Quantize (
151+ quantizer ,
152+ get_symmetric_a16w8_quantization_config (
153+ is_per_channel = per_channel_quantization
154+ ),
155+ )
156+
157+
158+ @common .parametrize ("test_data" , test_data_suite )
159+ def test_permute_int16_tosa_INT (test_data : torch .Tensor ):
160+ """Test permute operation with int16 quantization"""
161+ test_data , dims = test_data ()
162+ pipeline = TosaPipelineINT [input_t1 ](
163+ SimplePermute (dims = dims ),
164+ (test_data ,),
165+ aten_op ,
166+ exir_op = [],
167+ per_channel_quantization = False ,
168+ use_to_edge_transform_and_lower = True ,
169+ tosa_extensions = ["int16" ],
170+ )
171+
172+ pipeline .change_args (
173+ "quantize" ,
174+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
175+ )
176+ # Run the pipeline
177+ pipeline .run ()
178+
179+
180+ test_data_suite_exact = {
181+ x : test_data_suite [x ] for x in test_data_suite if x != "rank_4_3"
182+ }
183+ @common .parametrize ("test_data" , test_data_suite_exact )
184+ @common .XfailIfNoCorstone300
185+ def test_permute_int16_u55_INT16 (test_data : torch .Tensor ):
186+ """Test permute operation with int16 quantization on U55"""
187+ test_data , dims = test_data ()
188+ pipeline = EthosU55PipelineINT [input_t1 ](
189+ SimplePermute (dims = dims ),
190+ (test_data ,),
191+ aten_op ,
192+ exir_ops = [],
193+ per_channel_quantization = True ,
194+ use_to_edge_transform_and_lower = True ,
195+ atol = 1e-02 ,
196+ rtol = 1e-02 ,
197+ run_on_fvp = True ,
198+ )
199+
200+ pipeline .change_args (
201+ "quantize" ,
202+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
203+ )
204+ pipeline .run ()
205+
206+
207+ @common .parametrize ("test_data" , test_data_suite )
208+ @common .XfailIfNoCorstone320
209+ def test_permute_int16_u85_INT16 (test_data : torch .Tensor ):
210+ """Test permute operation with int16 quantization on U85"""
211+ test_data , dims = test_data ()
212+ pipeline = EthosU85PipelineINT [input_t1 ](
213+ SimplePermute (dims = dims ),
214+ (test_data ,),
215+ aten_op ,
216+ exir_ops = [],
217+ use_to_edge_transform_and_lower = True ,
218+ atol = 1e-03 ,
219+ rtol = 1e-03 ,
220+ run_on_fvp = True ,
221+ )
222+
223+ pipeline .change_args (
224+ "quantize" ,
225+ get_symmetric_a16w8_permute_quantizer (per_channel_quantization = False ),
226+ )
227+ pipeline .run ()
0 commit comments