Skip to content

Commit 25c44ab

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update requantize to requantize_per_tensor since we don't have a non-per-tensor variant (pytorch#14482)
Summary: As titled Reviewed By: hsharma35 Differential Revision: D82995376
1 parent 3862537 commit 25c44ab

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
10411041
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...
10421042

10431043

1044-
@impl(m, "requantize")
1045-
def requantize(
1044+
@impl(m, "requantize.per_tensor")
1045+
def requantize_per_tensor(
10461046
input: torch.Tensor,
1047-
in_scale: torch.Tensor,
1048-
in_zero_point: torch.Tensor,
1049-
out_scale: torch.Tensor,
1050-
out_zero_point: torch.Tensor,
1047+
in_scale: float,
1048+
in_zero_point: int,
1049+
out_scale: float,
1050+
out_zero_point: int,
10511051
dtype: ScalarType,
10521052
) -> torch.Tensor:
10531053
if dtype in qdtype_map:
@@ -1056,11 +1056,6 @@ def requantize(
10561056
torch.dequantize(input), out_scale, out_zero_point, qdtype_map[dtype]
10571057
)
10581058

1059-
# For in_scale or out_scale other than scalar, it requires quant/dequant
1060-
# per channel, but the channel dimension value is missing
1061-
if in_scale.numel() > 1 or out_scale.numel() > 1:
1062-
raise NotImplementedError("Only scalar scales are supported")
1063-
10641059
quant_min = torch.iinfo(input.dtype).min
10651060
quant_max = torch.iinfo(input.dtype).max
10661061
# pyre-fixme[6]: This dtype is actually the right one.
@@ -1070,14 +1065,14 @@ def requantize(
10701065
return torch.ops.quantized_decomposed.quantize_per_tensor(
10711066
torch.ops.quantized_decomposed.dequantize_per_tensor(
10721067
input,
1073-
in_scale.flatten()[0],
1074-
in_zero_point.flatten()[0],
1068+
in_scale,
1069+
in_zero_point,
10751070
quant_min,
10761071
quant_max,
10771072
input.dtype,
10781073
),
1079-
out_scale.flatten()[0],
1080-
out_zero_point.flatten()[0],
1074+
out_scale,
1075+
out_zero_point,
10811076
out_quant_min,
10821077
out_quant_max,
10831078
dtype,

0 commit comments

Comments
 (0)