Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize Bias for Conv/Gemm on Quantized Model #22889

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#ifdef MLAS_TARGET_AMD64_IX86
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
#endif
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
Expand Down Expand Up @@ -243,6 +244,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
transformers.emplace_back(std::make_unique<BiasQuantization>());

// EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input
// condition that it ensures is important for the partitioning that happens after Level1 optimizers are run.
Expand Down
127 changes: 127 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/qdq_transformer/bias_quantization.h"

#include "core/common/common.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"

namespace onnxruntime {

Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const GraphViewer graph_viewer{graph};
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
auto* node_ptr = graph.GetNode(node_idx);
if (!node_ptr) {
continue;
}

Node& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

const auto& input_defs = node.InputDefs();

// It's Conv/Gemm node with an initializer bias.
if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 ||
!graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) {
continue;
}

auto bias_shape = input_defs[2]->Shape();
if (!bias_shape || bias_shape->dim_size() != 1) {
continue;
}
int64_t bias_size = bias_shape->dim(0).dim_value();

// input_0 and input_1 are outputs of DequantizeLinear nodes.
const Node* p_dq_0 = graph.GetProducerNode(input_defs[0]->Name());
const Node* p_dq_1 = graph.GetProducerNode(input_defs[1]->Name());
if (!p_dq_0 || !p_dq_1 || p_dq_0->OpType() != QDQ::DQOpName || p_dq_1->OpType() != QDQ::DQOpName) {
continue;
}

Node& dq_0 = *graph.GetNode(p_dq_0->Index());
Node& dq_1 = *graph.GetNode(p_dq_1->Index());
// Currently we require input_0 is per-tensor scale.
if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) {
continue;
}

// For input_1, it's either per-tensor scale or per-channel scale on specific axis.
bool is_per_tensor_scale = true;
if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) {
is_per_tensor_scale = false;
auto weight_scale_shape = dq_1.InputDefs()[1]->Shape();
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
weight_scale_shape->dim(0).dim_value() != bias_size) {
continue;
}
const auto& dq_attrs = dq_1.GetAttributes();
if (dq_attrs.find("block_size") != dq_attrs.end()) {
continue;
}
int64_t axis = 1;
if (dq_attrs.find("axis") != dq_attrs.end()) {
axis = dq_attrs.at("axis").i();
}
int64_t expected_axis = 0;
if (node.OpType() == "Gemm") {
int64_t transB = 0;
if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) {
transB = attr->second.i();
}
expected_axis = transB == 0 ? 1 : 0;
}
if (axis != expected_axis) {
continue;
}
}

// Bias is quantized to int32.
ONNX_NAMESPACE::TypeProto int32_type_proto;
int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
auto dq_type = dq_0.InputDefs()[1]->TypeAsProto();
NodeArg& scale_node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), dq_type);
Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node",
{dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr,
node.Domain());
NodeArg& bias_div_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), dq_type);
Node& div_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
{node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain());
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
NodeArg& bias_div_round_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), dq_type);
Node& round_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
{&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain());
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
NodeArg& bias_int32_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto);
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node",
{&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain());
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
NodeArg& bias_dq_node_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), dq_type);
// Bias DQ node has scale = input_scale_0 * input_scale_1, zp = 0.
Node& dp_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
{&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain());
if (!is_per_tensor_scale) {
dp_node.AddAttribute("axis", static_cast<int64_t>(0));
}
graph.AddEdge(cast_node.Index(), dp_node.Index(), 0, 0);
graph.AddEdge(mul_node.Index(), dp_node.Index(), 0, 1);
node.MutableInputDefs()[2] = &bias_dq_node_arg;
graph.AddEdge(dp_node.Index(), node.Index(), 0, 2);
modified = true;
}

return Status::OK();
}

} // namespace onnxruntime
27 changes: 27 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

/**
* @class BiasQuantization
*
* Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias
* with scale = scale_input_0 * scale_input_1 and zero_point = 0.
*
* Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed
* by a DequantizeLinear node.
*/
class BiasQuantization : public GraphTransformer {
public:
BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {}

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
Loading