@@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
1041
1041
def quantized_relu_asym8u_asym8u_per_tensor () -> torch .Tensor : ...
1042
1042
1043
1043
1044
- @impl (m , "requantize" )
1045
- def requantize (
1044
+ @impl (m , "requantize.per_tensor " )
1045
+ def requantize_per_tensor (
1046
1046
input : torch .Tensor ,
1047
- in_scale : torch . Tensor ,
1048
- in_zero_point : torch . Tensor ,
1049
- out_scale : torch . Tensor ,
1050
- out_zero_point : torch . Tensor ,
1047
+ in_scale : float ,
1048
+ in_zero_point : int ,
1049
+ out_scale : float ,
1050
+ out_zero_point : int ,
1051
1051
dtype : ScalarType ,
1052
1052
) -> torch .Tensor :
1053
1053
if dtype in qdtype_map :
@@ -1056,11 +1056,6 @@ def requantize(
1056
1056
torch .dequantize (input ), out_scale , out_zero_point , qdtype_map [dtype ]
1057
1057
)
1058
1058
1059
- # For in_scale or out_scale other than scalar, it requires quant/dequant
1060
- # per channel, but the channel dimension value is missing
1061
- if in_scale .numel () > 1 or out_scale .numel () > 1 :
1062
- raise NotImplementedError ("Only scalar scales are supported" )
1063
-
1064
1059
quant_min = torch .iinfo (input .dtype ).min
1065
1060
quant_max = torch .iinfo (input .dtype ).max
1066
1061
# pyre-fixme[6]: This dtype is actually the right one.
@@ -1070,14 +1065,14 @@ def requantize(
1070
1065
return torch .ops .quantized_decomposed .quantize_per_tensor (
1071
1066
torch .ops .quantized_decomposed .dequantize_per_tensor (
1072
1067
input ,
1073
- in_scale . flatten ()[ 0 ] ,
1074
- in_zero_point . flatten ()[ 0 ] ,
1068
+ in_scale ,
1069
+ in_zero_point ,
1075
1070
quant_min ,
1076
1071
quant_max ,
1077
1072
input .dtype ,
1078
1073
),
1079
- out_scale . flatten ()[ 0 ] ,
1080
- out_zero_point . flatten ()[ 0 ] ,
1074
+ out_scale ,
1075
+ out_zero_point ,
1081
1076
out_quant_min ,
1082
1077
out_quant_max ,
1083
1078
dtype ,
0 commit comments