@@ -3392,12 +3392,62 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33923392 raise NotImplementedError ()
33933393
33943394
3395+ @torch_op ("aten::fake_quantize_per_tensor_affine" , trace_only = True )
33953396def aten_fake_quantize_per_tensor_affine (
3396- self : TensorType , scale : float , zero_point : int , quant_min : int , quant_max : int
3397- ) -> TensorType :
3398- """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
3397+ self : TFloat ,
3398+ scale : TReal | float ,
3399+ zero_point : TReal | int ,
3400+ quant_min : int ,
3401+ quant_max : int ,
3402+ ) -> TFloat :
3403+ """fake_quantize_per_tensor_affine(Tensor self, Tensor | float scale, Tensor | int zero_point, int quant_min, int quant_max) -> Tensor"""
33993404
3400- raise NotImplementedError ()
3405+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3406+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3407+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3408+ raise NotImplementedError (
3409+ "For (quant_min, quant_max), ONNX allows only "
3410+ "(0, 127), (0, 255) and (-128, 127). "
3411+ f"Got ({ quant_min } , { quant_max } )" ,
3412+ )
3413+
3414+ if quant_min == 0 :
3415+ int_dtype = ir .DataType .UINT8
3416+ else :
3417+ int_dtype = ir .DataType .INT8
3418+
3419+ # TODO: When opset >= 19, remove this cast
3420+ orig_dtype = self .type .dtype
3421+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3422+ self = op .Cast (self , to = ir .DataType .FLOAT )
3423+
3424+ # TODO: When opset >= 19, relex the condition for this cast
3425+ if (
3426+ isinstance (scale , float ) or
3427+ scale .type .dtype != ir .DataType .FLOAT
3428+ ):
3429+ scale = op .Cast (scale , to = ir .DataType .FLOAT )
3430+
3431+ if (
3432+ isinstance (zero_point , int ) or
3433+ zero_point .type .dtype != int_dtype
3434+ ):
3435+ zero_point = op .Cast (zero_point , to = int_dtype )
3436+
3437+ quantized = op .QuantizeLinear (self , scale , zero_point )
3438+
3439+ # See comment about, PyTorch-specific (0, 127) handling
3440+ if (quant_min , quant_max ) == (0 , 127 ):
3441+ const_127 = op .Cast (127 , to = int_dtype )
3442+ quantized = op .Clip (quantized , max = const_127 )
3443+
3444+ output = op .DequantizeLinear (quantized , scale , zero_point )
3445+
3446+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3447+ if orig_dtype != ir .DataType .FLOAT :
3448+ output = op .Cast (output , to = orig_dtype )
3449+
3450+ return output
34013451
34023452
34033453def aten_fake_quantize_per_tensor_affine_cachemask (
0 commit comments