Skip to content

Commit f801c4f

Browse files
pytorchbotNinja91
andauthored
Add U55 and U85 16A8W linear tests (#14453)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14368 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/19/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/19/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/19/orig @diff-train-skip-merge Co-authored-by: Nitin Jain <[email protected]>
1 parent 685e795 commit f801c4f

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

backends/arm/test/ops/test_linear.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Tuple
1010

1111
import pytest
12+
1213
import torch
1314
from executorch.backends.arm.quantizer.arm_quantizer import (
1415
get_symmetric_a16w8_quantization_config,
@@ -308,3 +309,71 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
308309
)
309310
# Run the pipeline
310311
pipeline.run()
312+
313+
314+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
315+
@common.XfailIfNoCorstone300
316+
@pytest.mark.xfail(
317+
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
318+
strict=False,
319+
)
320+
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
321+
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
322+
test_data, out_features, has_bias, per_channel_quantization = test_data()
323+
in_features = test_data.shape[-1]
324+
325+
pipeline = EthosU55PipelineINT[input_t1](
326+
Linear(
327+
in_features=in_features,
328+
out_features=out_features,
329+
bias=has_bias,
330+
),
331+
(test_data,),
332+
aten_op,
333+
exir_ops=[],
334+
per_channel_quantization=per_channel_quantization,
335+
use_to_edge_transform_and_lower=True,
336+
run_on_fvp=True,
337+
)
338+
339+
pipeline.change_args(
340+
"quantize",
341+
get_symmetric_a16w8_linear_quantizer(
342+
per_channel_quantization=per_channel_quantization
343+
),
344+
)
345+
pipeline.run()
346+
347+
348+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
349+
@common.XfailIfNoCorstone320
350+
@pytest.mark.xfail(
351+
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
352+
strict=False,
353+
)
354+
def test_linear_16a8w_u85_INT16(test_data: torch.Tensor):
355+
"""Test linear operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
356+
test_data, out_features, has_bias, per_channel_quantization = test_data()
357+
in_features = test_data.shape[-1]
358+
359+
pipeline = EthosU85PipelineINT[input_t1](
360+
Linear(
361+
in_features=in_features,
362+
out_features=out_features,
363+
bias=has_bias,
364+
),
365+
(test_data,),
366+
aten_op,
367+
exir_ops=[],
368+
per_channel_quantization=per_channel_quantization,
369+
use_to_edge_transform_and_lower=True,
370+
run_on_fvp=True,
371+
)
372+
373+
pipeline.change_args(
374+
"quantize",
375+
get_symmetric_a16w8_linear_quantizer(
376+
per_channel_quantization=per_channel_quantization
377+
),
378+
)
379+
pipeline.run()

backends/arm/tosa/quant_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from executorch.backends.arm.tosa.mapping import TosaArg
2222
from torch.fx import Node
23+
2324
from tosa.RoundingMode import RoundingMode # type: ignore
2425

2526

@@ -318,6 +319,7 @@ def build_rescale(
318319
per_channel=False,
319320
):
320321
import serializer.tosa_serializer as ts # type: ignore
322+
321323
import tosa.Op as TosaOp # type: ignore
322324

323325
scaleWidth = 32

0 commit comments

Comments
 (0)