diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index e3d591a0d5bb4..625cab25b9c46 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -235,7 +235,9 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims) packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name) self.model.initializer_extend([packed_bias_initializer]) - bias_scale_data = np.asarray(bias_scale, dtype=np.float32).reshape(-1) + + # Bias's scale dtype should match the original bias data's unquantized type (float32 or float16). + bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1) node_type = "DequantizeLinear" node_qtype = self.weight_qType diff --git a/onnxruntime/test/python/quantization/test_op_conv_transpose.py b/onnxruntime/test/python/quantization/test_op_conv_transpose.py index 118278d91c094..6f8d5f7b4dfd2 100644 --- a/onnxruntime/test/python/quantization/test_op_conv_transpose.py +++ b/onnxruntime/test/python/quantization/test_op_conv_transpose.py @@ -149,7 +149,6 @@ def quantize_conv_transpose_u8u8(self, onnx_type, opset, ir_version): def test_quantize_conv_transpose_u8u8(self): self.quantize_conv_transpose_u8u8(TensorProto.FLOAT, 13, 7) - @unittest.skip(reason="Shape inference bug, see onnx PR #5709") def test_quantize_conv_transpose_u8u8_fp16(self): self.quantize_conv_transpose_u8u8(TensorProto.FLOAT16, 19, 9) @@ -160,7 +159,7 @@ def quantize_conv_transpose_s8s8(self, onnx_type, opset, ir_version): np.random.seed(1) model_fp32_path = "conv_transpose_fp32.onnx" - self.construct_model(model_fp32_path) + self.construct_model(model_fp32_path, onnx_type, opset, ir_version) dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_type) data_reader = self.input_feeds(1, {"input": [1, 1, 7, 7]}, dtype) @@ -175,7 +174,6 @@ def quantize_conv_transpose_s8s8(self, onnx_type, opset, ir_version): def test_quantize_conv_transpose_s8s8(self): self.quantize_conv_transpose_s8s8(TensorProto.FLOAT, 13, 7) - @unittest.skip(reason="Shape inference bug, see onnx PR #5709") def test_quantize_conv_transpose_s8s8_fp16(self): self.quantize_conv_transpose_s8s8(TensorProto.FLOAT16, 19, 9)