Skip to content

Commit f0cb337

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Backend-agnostic quantized_relu (pytorch#13932)
Summary: Continued support of custom Cadence ops Reviewed By: hsharma35 Differential Revision: D81646110
1 parent 68e9c5a commit f0cb337

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,40 @@ def quantized_conv_nhwc(
487487
)
488488

489489

490+
@impl(m, "quantized_relu")
491+
def quantized_relu(
492+
X: torch.Tensor,
493+
X_zero_point: torch.Tensor,
494+
out_zero_point: int,
495+
out_multiplier: torch.Tensor,
496+
out_shift: torch.Tensor,
497+
) -> torch.Tensor:
498+
"""
499+
Quantized ReLU operation followed by requantization.
500+
501+
Args:
502+
- X (Tensor): The input tensor
503+
- X_zero_point (Tensor): The quantized mapping of zero for the input
504+
- out_zero_point (int): The quantized mapping of zero for the output
505+
- out_multiplier (Tensor): The multiplier used to scale the output
506+
- out_shift (Tensor): The shift used to scale the output
507+
"""
508+
supported_dtypes = [torch.int8, torch.int16, torch.uint8, torch.uint16]
509+
if X.dtype not in supported_dtypes:
510+
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")
511+
512+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
513+
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
514+
return quantize_per_tensor(
515+
dequantized_X,
516+
out_scale,
517+
out_zero_point,
518+
torch.iinfo(X.dtype).min,
519+
torch.iinfo(X.dtype).max,
520+
X.dtype,
521+
)
522+
523+
490524
@impl(m, "requantize")
491525
def requantize(
492526
input: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
quantized_conv_nhwc,
2020
quantized_layer_norm_per_tensor,
2121
quantized_linear,
22+
quantized_relu,
2223
)
2324
from executorch.backends.cadence.aot.typing_stubs import expand
2425

@@ -744,3 +745,84 @@ def test_quantized_conv(
744745
torch.equal(output, expected_output),
745746
f"Output values don't match expected. Got {output}, expected {expected_output}",
746747
)
748+
749+
@expand(
750+
[
751+
# Test case 1: Basic int8 case with negative scale
752+
(
753+
"basic_int8",
754+
torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input
755+
torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast)
756+
0, # out_zero_point
757+
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
758+
torch.tensor([0]), # out_shift
759+
torch.int8, # dtype
760+
torch.tensor(
761+
[0, 0, 0, -2], dtype=torch.int8
762+
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
763+
),
764+
# Test case 2: uint8 with non-zero zero point
765+
(
766+
"uint8_with_zp",
767+
torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input
768+
torch.tensor([128], dtype=torch.uint8), # X_zero_point
769+
64, # out_zero_point
770+
torch.tensor([536870912]), # out_multiplier (0.25 * 2^31)
771+
torch.tensor([0]), # out_shift
772+
torch.uint8, # dtype
773+
torch.tensor(
774+
[64, 64, 64, 63], dtype=torch.uint8
775+
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
776+
),
777+
# Test case 3: All negative values (should all become zero after ReLU)
778+
(
779+
"all_negative_int8",
780+
torch.tensor([-5, -3, -1], dtype=torch.int8), # input
781+
torch.tensor([0], dtype=torch.int8), # X_zero_point
782+
10, # out_zero_point
783+
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
784+
torch.tensor([0]), # out_shift
785+
torch.int8, # dtype
786+
torch.tensor(
787+
[10, 10, 10], dtype=torch.int8
788+
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
789+
),
790+
# Test case 4: All positive values with shift (scale becomes -0.25)
791+
(
792+
"positive_with_shift",
793+
torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input
794+
torch.tensor([1], dtype=torch.int8), # X_zero_point
795+
5, # out_zero_point
796+
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
797+
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
798+
torch.int8, # dtype
799+
torch.tensor(
800+
[4, 2, 0, -2], dtype=torch.int8
801+
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
802+
),
803+
]
804+
)
805+
def test_quantized_relu(
806+
self,
807+
name: str,
808+
X: torch.Tensor,
809+
X_zero_point: torch.Tensor,
810+
out_zero_point: int,
811+
out_multiplier: torch.Tensor,
812+
out_shift: torch.Tensor,
813+
dtype: torch.dtype,
814+
expected_output: torch.Tensor,
815+
) -> None:
816+
output = quantized_relu(
817+
X, X_zero_point, out_zero_point, out_multiplier, out_shift
818+
)
819+
820+
# Verify output properties
821+
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
822+
self.assertEqual(output.shape, X.shape, "Output shape should match input shape")
823+
824+
# Verify output matches expected values
825+
self.assertTrue(
826+
torch.equal(output, expected_output),
827+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
828+
)

0 commit comments

Comments
 (0)