diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 875bbf471c688..fe7f8f546c15b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -723,7 +723,14 @@ bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 3)) { + // BatchNormalization has 5 inputs: x, scale, bias, mean, var. + // Require DQ on x and scale (indices 0,1). mean, var may optionally have DQ. + const int num_dq_nodes = gsl::narrow_cast(dq_nodes.size()); + if (num_dq_nodes < 3 || num_dq_nodes > 5) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, num_dq_nodes)) { return false; } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 51f6523559987..0fe7f6d6d7c20 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -12,6 +12,7 @@ namespace onnxruntime { namespace qnn { + class BatchNormOpBuilder : public BaseOpBuilder { public: BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} @@ -262,30 +263,57 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } + // Maybe de-quantizes a 1D BatchNorm parameter tensor to double values. + Status MaybeDequantizeParamTensor(const TensorInfo& info, + const uint8_t* raw_ptr, + const size_t raw_ptr_length, + const char* tensor_name, + std::vector& out) const { + uint32_t channel = info.shape[0]; + out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(info.qnn_data_type, channel, raw_ptr_length)); + + const bool is_quantized = info.quant_param.IsQuantized(); + const bool is_per_channel = info.quant_param.IsPerChannel(); + const Qnn_QuantizeParams_t& quant_param = info.quant_param.Get(); + if (is_per_channel) { + // Validate per-channel quantization parameters for 1D BatchNorm tensors. + // For 1D tensors, axis must be 0 and numScaleOffsets must be >= channel count. + ORT_RETURN_IF_NOT(quant_param.axisScaleOffsetEncoding.axis == 0, + "Per-channel quantization axis must be 0 for 1D ", tensor_name, " tensor, got ", + quant_param.axisScaleOffsetEncoding.axis); + ORT_RETURN_IF_NOT(quant_param.axisScaleOffsetEncoding.numScaleOffsets >= channel, + "Per-channel quantization scale/offset count (", + quant_param.axisScaleOffsetEncoding.numScaleOffsets, + ") is less than channel count (", channel, ") for ", tensor_name, " tensor."); + } + + int offset = 0; + for (uint32_t i = 0; i < channel; ++i) { + double value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(info.qnn_data_type, raw_ptr + offset, value, offset)); + // Dequantize if needed + if (is_quantized) { + if (is_per_channel) { + value = utils::Dequantize(quant_param.axisScaleOffsetEncoding.scaleOffset[i].offset, + quant_param.axisScaleOffsetEncoding.scaleOffset[i].scale, + value); + } else { + value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + value); + } + } + out[i] = value; + } + return Status::OK(); + } + Status PreprocessMean(const TensorInfo& mean_info, const uint8_t* mean_raw_ptr, const size_t mean_raw_ptr_length, std::vector& mean_out) const { - // tensor length (channel) - uint32_t channel = mean_info.shape[0]; - mean_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); - - const bool is_quantized = mean_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || mean_info.quant_param.IsPerTensor(), - "BatchNormalization's input_mean does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = mean_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double mean_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); - mean_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - mean_value) - : mean_value; - } - return Status::OK(); + return MaybeDequantizeParamTensor(mean_info, mean_raw_ptr, mean_raw_ptr_length, "mean", mean_out); } Status PreprocessStd(const TensorInfo& var_info, @@ -293,25 +321,12 @@ class BatchNormOpBuilder : public BaseOpBuilder { const size_t var_raw_ptr_length, const float epsilon, std::vector& std_out) const { - // tensor length (channel) - uint32_t channel = var_info.shape[0]; - std_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); - - const bool is_quantized = var_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || var_info.quant_param.IsPerTensor(), - "BatchNormalization's input_var does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = var_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double var_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); - std_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - var_value) - : var_value; - std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + std::vector var_dequantized; + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(var_info, var_raw_ptr, var_raw_ptr_length, "variance", var_dequantized)); + + std_out.resize(var_dequantized.size()); + for (size_t i = 0; i < var_dequantized.size(); ++i) { + std_out[i] = std::sqrt(var_dequantized[i] + static_cast(epsilon)); } return Status::OK(); } @@ -323,25 +338,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { double& rmax, double& rmin, std::vector& scale_out) const { - // tensor length (channel) - uint32_t channel = scale_info.shape[0]; - scale_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); - - const bool is_quantized = scale_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || scale_info.quant_param.IsPerTensor(), - "BatchNormalization's scale input does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = scale_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double scale_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); - scale_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - scale_value) - : scale_value; - scale_out[i] = scale_out[i] / std_double_tensor[i]; + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(scale_info, scale_raw_ptr, scale_raw_ptr_length, "scale", scale_out)); + + for (size_t i = 0; i < scale_out.size(); ++i) { + scale_out[i] /= std_double_tensor[i]; rmax = std::max(rmax, scale_out[i]); rmin = std::min(rmin, scale_out[i]); } @@ -356,25 +356,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { double& rmax, double& rmin, std::vector& bias_out) const { - // tensor length (channel) - uint32_t channel = bias_info.shape[0]; - bias_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); - - const bool is_quantized = bias_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || bias_info.quant_param.IsPerTensor(), - "BatchNormalization's bias input does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = bias_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double bias_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); - bias_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - bias_value) - : bias_value; - bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(bias_info, bias_raw_ptr, bias_raw_ptr_length, "bias", bias_out)); + + for (size_t i = 0; i < bias_out.size(); ++i) { + bias_out[i] -= mean_double_tensor[i] * scale_double_tensor[i]; rmax = std::max(rmax, bias_out[i]); rmin = std::min(rmin, bias_out[i]); } @@ -390,10 +375,15 @@ class BatchNormOpBuilder : public BaseOpBuilder { bool symmetric = false; if (info.quant_param.IsQuantized()) { size_t data_size = double_tensor.size(); - // QNN BatchNorm int32 bias requires symmetric quantizated + // QNN BatchNorm requires symmetric quantization (zero_point=0) for signed params if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { data_size *= sizeof(int32_t); symmetric = true; + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + data_size *= sizeof(int16_t); + symmetric = true; + } else if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + data_size *= sizeof(uint16_t); } raw_tensor.resize(data_size); float scale = 0.0f; @@ -406,7 +396,6 @@ class BatchNormOpBuilder : public BaseOpBuilder { symmetric)); quant_param = QnnQuantParamsWrapper(scale, zero_point); for (size_t i = 0; i < double_tensor.size(); ++i) { - // onnx only supports 8 bits quantization int quant_value_int = 0; ORT_RETURN_IF_ERROR(utils::Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { @@ -414,12 +403,19 @@ class BatchNormOpBuilder : public BaseOpBuilder { } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { int8_t quant_value = static_cast(quant_value_int); raw_tensor[i] = *reinterpret_cast(&quant_value); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + int16_t quant_value = static_cast(quant_value_int); + size_t pos = i * sizeof(int16_t); + std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(int16_t)); + } else if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + uint16_t quant_value = static_cast(quant_value_int); + size_t pos = i * sizeof(uint16_t); + std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(uint16_t)); } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { int32_t quant_value = static_cast(quant_value_int); size_t pos = i * sizeof(int32_t); std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(int32_t)); } else { - // TODO(adrianlizarraga): Should support 16-bit quantization as well. ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); } } @@ -437,6 +433,67 @@ class BatchNormOpBuilder : public BaseOpBuilder { const std::vector out_dtypes) const override ORT_MUST_USE_RESULT; }; +namespace { + +// Helper to check if a BatchNorm param is constant - either direct initializer or through a DQ node. +bool IsParamConstant(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& name) { + if (qnn_model_wrapper.IsConstantInput(name)) { + return true; + } + // Check if param comes through a DQ node with constant input + for (const Node* dq_node : node_unit.GetDQNodes()) { + if (dq_node->OutputDefs()[0]->Name() == name) { + return qnn_model_wrapper.IsConstantInput(dq_node->InputDefs()[0]->Name()); + } + } + return false; +} + +// Helper to resolve param initializer that may come from DQ nodes. +Status ResolveParamInitializer(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const NodeUnitIODef& param, + TensorInfo& tensor_info) { + const std::string& param_name = param.node_arg.Name(); + if (tensor_info.is_initializer) { + return Status::OK(); + } + for (const Node* dq_node : node_unit.GetDQNodes()) { + if (dq_node->OutputDefs()[0]->Name() == param_name) { + const std::string& init_name = dq_node->InputDefs()[0]->Name(); + tensor_info.initializer_tensor = qnn_model_wrapper.GetConstantTensor(init_name); + if (tensor_info.initializer_tensor != nullptr) { + tensor_info.is_initializer = true; + return Status::OK(); + } + } + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot resolve initializer for BatchNorm param: ", param_name); +} + +// Adjust BatchNorm param types for QNN HTP compatibility. +// Modifies scale/bias types in-place; quantization happens in Postprocess. +void MaybeQuantizeAndOverrideParamType(Qnn_DataType_t x_dtype, + Qnn_DataType_t& scale_dtype, + Qnn_DataType_t& bias_dtype, + bool is_scale_has_negative_values = true) { + // QNN HTP with UFIXED_POINT_16 input doesn't support SFIXED_POINT_8 scale + if (x_dtype == QNN_DATATYPE_UFIXED_POINT_16 && scale_dtype == QNN_DATATYPE_SFIXED_POINT_8) { + scale_dtype = is_scale_has_negative_values ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_UFIXED_POINT_8; + } + + // QNN HTP requires quantized bias for quantized ops + bool is_quantized = (x_dtype == QNN_DATATYPE_UFIXED_POINT_8 || x_dtype == QNN_DATATYPE_SFIXED_POINT_8 || + x_dtype == QNN_DATATYPE_UFIXED_POINT_16 || x_dtype == QNN_DATATYPE_SFIXED_POINT_16); + if (is_quantized && (bias_dtype == QNN_DATATYPE_FLOAT_32 || bias_dtype == QNN_DATATYPE_FLOAT_16)) { + bias_dtype = QNN_DATATYPE_SFIXED_POINT_32; + } +} + +} // namespace + // BatchNorm is sensitive with data layout, no special validation so far // The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW // The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC @@ -464,14 +521,14 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[1].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[1].node_arg.Name()), "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[2].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[2].node_arg.Name()), "QNN BatchNorm doesn't support dynamic bias."); ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, @@ -481,14 +538,14 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[3].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[4].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); @@ -528,11 +585,21 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[3], mean_info)); ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[4], var_info)); - // scale, bias, mean, and var must be initializers - ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); - ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); - ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); - ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + // Resolve initializers that may come through DQ nodes + ORT_RETURN_IF_ERROR(ResolveParamInitializer(qnn_model_wrapper, node_unit, inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(ResolveParamInitializer(qnn_model_wrapper, node_unit, inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(ResolveParamInitializer(qnn_model_wrapper, node_unit, inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(ResolveParamInitializer(qnn_model_wrapper, node_unit, inputs[4], var_info)); + + // Get input tensor info to determine if this is a quantized op + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info)); + const bool is_quantized_op = input_info.quant_param.IsQuantized(); + + // Check if bias needs conversion (will be done after preprocessing) + const bool bias_is_float = !bias_info.quant_param.IsQuantized() && + (bias_info.qnn_data_type == QNN_DATATYPE_FLOAT_32 || + bias_info.qnn_data_type == QNN_DATATYPE_FLOAT_16); std::vector scale_unpacked_tensor; std::vector bias_unpacked_tensor; @@ -582,6 +649,15 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bias_rmin, bias_double_tensor)); + // Apply QNN HTP type conversions + MaybeQuantizeAndOverrideParamType(input_info.qnn_data_type, + scale_info.qnn_data_type, + bias_info.qnn_data_type, + scale_rmin < 0.0); + if (is_quantized_op && bias_is_float) { + bias_info.quant_param = QnnQuantParamsWrapper(1.0f, 0); // Placeholder, computed in Postprocess + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { std::vector scale_raw_tensor; QnnQuantParamsWrapper scale_quant_param = scale_info.quant_param; @@ -650,10 +726,17 @@ Status BatchNormOpBuilder::CheckHtpDataTypes(const std::vector i const std::vector out_dtypes) const { bool is_supported_dtype = false; // in_dtypes: [X, scale, B, input_mean, input_var] - std::vector all_dtypes(in_dtypes.begin(), in_dtypes.begin() + 3); // out_dtypes: [Y, running_mean, running_var] - all_dtypes.insert(all_dtypes.end(), out_dtypes.begin(), out_dtypes.begin() + 1); - // FP16 + Qnn_DataType_t x_dtype = in_dtypes[0]; + Qnn_DataType_t scale_dtype = in_dtypes[1]; + Qnn_DataType_t bias_dtype = in_dtypes[2]; + Qnn_DataType_t y_dtype = out_dtypes[0]; + + // We likely need to re-quantize scale/bias for HTP compatibility, override dtypes before checking. + // Note: We conservatively assume scale may have negative values during validation. + MaybeQuantizeAndOverrideParamType(x_dtype, scale_dtype, bias_dtype); + std::vector all_dtypes{x_dtype, scale_dtype, bias_dtype, y_dtype}; + // FP16/FP32 if ( (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16}) || (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32})) { @@ -678,7 +761,7 @@ Status BatchNormOpBuilder::CheckHtpDataTypes(const std::vector i } ORT_RETURN_IF_NOT(is_supported_dtype, "QNN Batchnorm unsupported datatype on HTP."); return Status::OK(); -}; +} } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/batch_norm_test.cc b/onnxruntime/test/providers/qnn/batch_norm_test.cc index dc29b541d6dd6..c786e2f58d1be 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_test.cc @@ -3,8 +3,10 @@ #if !defined(ORT_MINIMAL_BUILD) +#include #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "core/common/float16.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -411,6 +413,104 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { ExpectedEPNodeAssignment::None); } +// Tests BatchNorm with Q->DQ structure commonly seen in quantized models +template +GetTestQDQModelFn BuildBatchNormQdqParamsTestCase(const TestInputDef& input_def, + const TestInputDef& scale_def, + const TestInputDef& bias_def) { + ORT_ENFORCE(input_def.IsRawData()); + ORT_ENFORCE(scale_def.IsRawData()); + + return [input_def, scale_def, bias_def](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const auto& input_shape = input_def.GetShape(); + const auto& input_data = input_def.GetRawData(); + const int64_t num_channels = input_shape[1]; + + // Input: float -> Q -> DQ (asymmetric for uint16 to match real models) + bool symmetric = sizeof(InputQType) == sizeof(uint16_t); + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def, symmetric); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + + NodeAttributes axis_0_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), axis_0_attrs); + + // Scale: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + const auto& scale_data = scale_def.GetRawData(); + std::vector scale_scales(num_channels); + std::vector scale_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(scale_data[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + scale_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + std::vector param_shape = {num_channels}; + NodeArg* scale_float_init = builder.MakeInitializer(param_shape, scale_data); + NodeArg* scale_qdq = AddQDQNodePair(builder, scale_float_init, scale_scales, scale_zero_points, + &axis_0_attrs, &axis_0_attrs); + + NodeArg* bias = builder.MakeInitializer(bias_def.GetShape(), bias_def.GetRawData()); + + // Compute mean and var from input data + std::vector mean_vals(num_channels); + std::vector var_vals(num_channels); + ComputeChannelMeanAndVar(input_data, input_shape, mean_vals, var_vals); + + // Mean: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + std::vector mean_scales(num_channels); + std::vector mean_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(mean_vals[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + mean_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + NodeArg* mean_float_init = builder.MakeInitializer(param_shape, mean_vals); + NodeArg* mean_qdq = AddQDQNodePair(builder, mean_float_init, mean_scales, mean_zero_points, + &axis_0_attrs, &axis_0_attrs); + + // Var: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + std::vector var_scales(num_channels); + std::vector var_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(var_vals[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + var_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + NodeArg* var_float_init = builder.MakeInitializer(param_shape, var_vals); + NodeArg* var_qdq = AddQDQNodePair(builder, var_float_init, var_scales, var_zero_points, + &axis_0_attrs, &axis_0_attrs); + + auto* batchnorm_output = builder.MakeIntermediate(); + builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias, mean_qdq, var_qdq}, + {batchnorm_output}); + + AddQDQNodePairWithOutputAsGraphOutput(builder, batchnorm_output, + output_qparams[0].scale, output_qparams[0].zero_point); + }; +} + +// Test BatchNorm with Q->DQ on input/scale/mean/var, float bias +TEST_F(QnnHTPBackendTests, BatchNorm2dQdqParams) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + TestInputDef input_def({2, num_channels, 2, 2}, false, input_data); + TestInputDef scale_def({num_channels}, true, {1.0f, 2.0f}); + TestInputDef bias_def({num_channels}, true, {1.1f, 2.1f}); + + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + TestQDQModelAccuracy(BuildBatchNormTestCase(input_def, scale_def, bias_def), + BuildBatchNormQdqParamsTestCase(input_def, scale_def, bias_def), + provider_options, + 21, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test