From e406522f7039cac779b7241bb847750fb9439c48 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 19 Nov 2024 17:06:40 +0800 Subject: [PATCH] add activation to node unit --- onnxruntime/core/framework/node_unit.cc | 35 +- onnxruntime/core/framework/node_unit.h | 6 +- .../optimizer/qdq_transformer/qdq_util.cc | 137 +++++ .../core/optimizer/qdq_transformer/qdq_util.h | 5 + .../selectors_actions/qdq_selectors.cc | 136 ++++- .../selectors_actions/qdq_selectors.h | 36 +- .../selectors_actions/shared/utils.cc | 3 + .../core/providers/partitioning_utils.cc | 5 + .../qnn_node_group/conv_activation_fusion.cc | 480 ------------------ .../qnn_node_group/conv_activation_fusion.h | 63 --- .../builder/qnn_node_group/qnn_node_group.cc | 3 - 11 files changed, 317 insertions(+), 592 deletions(-) delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 850cb167a3ece..d97ecb0e74772 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -156,6 +156,7 @@ std::vector GetQDQIODefs(const Node& target_node, const QDQ::Node Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, const Node& target_node, + const Node* p_activation_node, gsl::span dq_nodes, gsl::span q_nodes) { // Within a QDQ node group, a target node input is the only consumer of each DQ. @@ -176,6 +177,22 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, dq_node->Name(), ", target node: ", target_node.Name()); } + // If activation node is present, currently we require target node has only one output edge, which is connected to + // the activation node. The activation node's output is consumed by the Q node that can be fused with itself. + if (p_activation_node) { + ORT_RETURN_IF_NOT(target_node.GetOutputEdgesCount() == 1 && + target_node.OutputEdgesBegin()->GetNode().Index() == p_activation_node->Index(), + "QDQ node group cannot have target node with more than one output edge if there is activation " + "node. target node: ", + target_node.Name()); + ORT_RETURN_IF_NOT(q_nodes.size() == 1 && p_activation_node->GetOutputEdgesCount() == 1 && + p_activation_node->OutputEdgesBegin()->GetNode().Index() == q_nodes[0]->Index(), + "QDQ node group cannot have activation node that doesn't have a single output edge to a Q node. " + "activation node: ", + p_activation_node->Name()); + return Status::OK(); + } + // an output from the target node can have either Q consumers or direct consumers. it cannot have both. // this must be checked on a per output basis. // e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ @@ -228,6 +245,7 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer, return Status::OK(); } + NodeUnit::NodeUnit(const Node& node) : target_node_(node), type_(Type::SingleNode), @@ -238,11 +256,15 @@ NodeUnit::NodeUnit(const Node& node) NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group) : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, target_node_(*graph_viewer.GetNode(node_group.target_node)), + p_activation_node_( + node_group.activation_node.has_value() ? graph_viewer.GetNode(node_group.activation_node.value()) : nullptr), q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, type_(Type::QDQGroup), inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, - outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} { - ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_)); + outputs_{ + GetQDQIODefs((p_activation_node_ ? *p_activation_node_ : target_node_), node_group, false /* is_input */)} { + ORT_THROW_IF_ERROR( + QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, p_activation_node_, dq_nodes_, q_nodes_)); input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); @@ -253,8 +275,10 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g // create output edges. each target node output either goes to Q node/s or non-Q node/s. // ValidateNodeGroupQDQNodes ensures this. - auto cur_edge = target_node_.OutputEdgesBegin(); - auto end_edge = target_node_.OutputEdgesEnd(); + // If activation node is present, the target node has only one output edge, which is connected to the activation node. + const Node& output_producer = p_activation_node_ ? *p_activation_node_ : target_node_; + auto cur_edge = output_producer.OutputEdgesBegin(); + auto end_edge = output_producer.OutputEdgesEnd(); for (; cur_edge != end_edge; ++cur_edge) { const Node& node = cur_edge->GetNode(); @@ -273,12 +297,13 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, const Node* p_activation_node, gsl::span q_nodes, Type unit_type, gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges) : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), target_node_(target_node), + p_activation_node_(p_activation_node), q_nodes_(q_nodes.begin(), q_nodes.end()), type_(unit_type), inputs_(inputs.begin(), inputs.end()), diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 50bd423d2f547..1aad282422bb7 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -27,12 +27,14 @@ struct NodeGroup { std::vector dq_nodes; std::vector q_nodes; NodeIndex target_node; + std::optional activation_node; // Validator to check if the set of nodes can form a valid QDQ NodeGroup. // Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to // be converted into a single node with a quantized operator. static Status CanCreateNodeGroup(const GraphViewer& graph_viewer, const Node& target_node, + const Node* p_activation_node, gsl::span dq_nodes, gsl::span q_nodes); }; @@ -68,7 +70,7 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - NodeUnit(gsl::span dq_nodes, const Node& target_node, + NodeUnit(gsl::span dq_nodes, const Node& target_node, const Node* p_activation_node, gsl::span q_nodes, Type unit_type, gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges); @@ -87,6 +89,7 @@ class NodeUnit { ProviderType GetExecutionProviderType() const noexcept; const Node& GetNode() const noexcept { return target_node_; } + const Node* GetActivationNode() const noexcept { return p_activation_node_; } const std::vector& GetDQNodes() const noexcept { return dq_nodes_; } const std::vector& GetQNodes() const noexcept { return q_nodes_; } std::vector GetAllNodesInGroup() const noexcept; @@ -106,6 +109,7 @@ class NodeUnit { const std::vector dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs const Node& target_node_; + const Node* p_activation_node_; // Optional activation node for the QDQ group, nullptr if not present. const std::vector q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs const Type type_; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index 7ef4ced1835f0..fe2d7c13406c5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -215,4 +215,141 @@ bool MatchDQNode(const Node& node) { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +namespace { + +bool GetDataTypeMinMax(int32_t data_type, int32_t& min, int32_t& max) { + switch (data_type) { + case ONNX_NAMESPACE::TensorProto::INT8: + min = static_cast(std::numeric_limits::min()); + max = static_cast(std::numeric_limits::max()); + break; + case ONNX_NAMESPACE::TensorProto::UINT8: + min = static_cast(std::numeric_limits::min()); + max = static_cast(std::numeric_limits::max()); + break; + case ONNX_NAMESPACE::TensorProto::INT16: + min = static_cast(std::numeric_limits::min()); + max = static_cast(std::numeric_limits::max()); + break; + case ONNX_NAMESPACE::TensorProto::UINT16: + min = static_cast(std::numeric_limits::min()); + max = static_cast(std::numeric_limits::max()); + break; + default: + return false; + } + return true; +} +bool GetQSalarScaleZp(const GraphViewer& graph_viewer, const Node& q_node, float& scale, int32_t& zp, + int32_t& data_type) { + assert(q_node.OpType() == QOpName); + const auto& q_input_defs = q_node.InputDefs(); + if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) { + return false; + } + + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[1]->Name(), true); + if (!scale_tensor_proto) { + return false; + } + + // Support scalar float scale only for now. Need to extend to other float types if needed. + Initializer scale_initializer(*scale_tensor_proto, graph_viewer.ModelPath()); + if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) { + return false; + } + scale = *scale_initializer.data(); + + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[2]->Name(), true); + if (!zp_tensor_proto) { + return false; + } + + Initializer zp_initializer(*zp_tensor_proto, graph_viewer.ModelPath()); + if (zp_initializer.dims().size() != 0) { + return false; + } + + data_type = zp_initializer.data_type(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto::INT8: + zp = static_cast(*zp_initializer.data()); + break; + case ONNX_NAMESPACE::TensorProto::UINT8: + zp = static_cast(*zp_initializer.data()); + break; + case ONNX_NAMESPACE::TensorProto::INT16: + zp = static_cast(*zp_initializer.data()); + break; + case ONNX_NAMESPACE::TensorProto::UINT16: + zp = static_cast(*zp_initializer.data()); + break; + default: + return false; + } + + return true; +} + +bool CanRemoveRelu(const GraphViewer& graph_viewer, const Node& q_node) { + float scale = 0.0f; + int32_t zp = 0; + int32_t data_type = 0; + if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) { + return false; + } + + int32_t data_type_min = 0; + int32_t data_type_max = 0; + if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) { + return false; + } + + // Relu can be removed if the zero-point is set to the smallest quantized value. + return zp == data_type_min; +} + +bool CanRemoveClip(const GraphViewer& graph_viewer, const Node& clip_node, const Node& q_node) { + float scale = 0.0f; + int32_t zp = 0; + int32_t data_type = 0; + if (!GetQSalarScaleZp(graph_viewer, q_node, scale, zp, data_type)) { + return false; + } + + float min = 0.0f; + float max = 0.0f; + if (!optimizer_utils::GetClipConstantMinMax(graph_viewer.GetGraph(), clip_node, min, max)) { + return false; + } + + int32_t q_clip_min = static_cast(::rint(min / scale)) + zp; + int32_t q_clip_max = static_cast(::rint(max / scale)) + zp; + + int32_t data_type_min = 0; + int32_t data_type_max = 0; + if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) { + return false; + } + + // The Clip can be removed if its range entirely overlaps the quantization range. + // QClip range: [------------------] + // Quant range: [-------------] + return q_clip_min <= data_type_min && q_clip_max >= data_type_max; +} + +} // namespace + +bool CanFuseActivationQ(const GraphViewer& graph_viewer, const Node& activation_node, const Node& q_node) { + const std::string& activation_op_type = activation_node.OpType(); + if (activation_op_type == "Relu") { + return CanRemoveRelu(graph_viewer, q_node); + } else if (activation_op_type == "Clip") { + return CanRemoveClip(graph_viewer, activation_node, q_node); + } + return false; +} + } // namespace onnxruntime::QDQ diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 008f9972a143b..caf207b813c02 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -15,6 +15,7 @@ namespace onnxruntime { class Node; class Path; +class GraphViewer; namespace QDQ { @@ -76,5 +77,9 @@ bool MatchQNode(const Node& node); // Check DQ node op type, version, and domain. bool MatchDQNode(const Node& node); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +// Check if an activation node can be fused with a Q node. +bool CanFuseActivationQ(const GraphViewer& graph_viewer, const Node& activation_node, const Node& q_node); + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 203aba2c3dd91..ad7f58fc07b4b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -49,7 +49,7 @@ std::vector FindQDQNodes(const GraphViewer& graph_viewer, const Nod } } // namespace -bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, +bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes, int num_dq_inputs, @@ -63,7 +63,8 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod return false; } - if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = + NodeGroup::CanCreateNodeGroup(graph_viewer, node, p_activation_node, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } @@ -80,8 +81,21 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod std::optional NodeGroupSelector::GetQDQSelection(const GraphViewer& graph_viewer, const Node& node) const { std::vector dq_nodes = FindQDQNodes(graph_viewer, node, true); - std::vector q_nodes = FindQDQNodes(graph_viewer, node, false); - if (!Check(graph_viewer, node, dq_nodes, q_nodes)) { + const Node* p_activation_node = nullptr; + if (node.GetOutputEdgesCount() == 1) { + const Node& activation_node = *node.OutputNodesBegin(); + if (activation_node.OpType() == "Relu" || activation_node.OpType() == "Clip") { + p_activation_node = &activation_node; + } + } + std::vector q_nodes = FindQDQNodes(graph_viewer, (p_activation_node ? *p_activation_node : node), false); + + if (p_activation_node && + (q_nodes.size() != 1 || !CanFuseActivationQ(graph_viewer, *p_activation_node, *q_nodes[0]))) { + return std::nullopt; + } + + if (!Check(graph_viewer, node, p_activation_node, dq_nodes, q_nodes)) { return std::nullopt; } @@ -89,6 +103,9 @@ std::optional NodeGroupSelector::GetQDQSelection(const GraphViewer& g node_group.dq_nodes.reserve(dq_nodes.size()); node_group.q_nodes.reserve(q_nodes.size()); node_group.target_node = node.Index(); + if (p_activation_node) { + node_group.activation_node = p_activation_node->Index(); + } auto get_node_idx = [&](const Node* n) { return n->Index(); }; std::transform(dq_nodes.begin(), dq_nodes.end(), std::back_inserter(node_group.dq_nodes), get_node_idx); std::transform(q_nodes.begin(), q_nodes.end(), std::back_inserter(node_group.q_nodes), get_node_idx); @@ -122,9 +139,14 @@ std::optional BaseSelector::Select(const GraphViewer& gr bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, 1)) { return false; } @@ -162,14 +184,19 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + constexpr int num_dq_inputs = 1; if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { return false; } - if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = NodeGroup::CanCreateNodeGroup(graph_viewer, node, nullptr, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } @@ -194,10 +221,14 @@ bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, return IsDQSupported(dq_node, get_const_initializer); } -bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, +bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, 1)) { return false; } @@ -222,9 +253,15 @@ bool UnaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + // Support Add+Activation only for now. Can add more binary ops if needed. + if (p_activation_node && node.OpType() != "Add") { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, p_activation_node, dq_nodes, q_nodes)) { return false; } @@ -251,9 +288,14 @@ bool BinaryNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool VariadicNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes)) { return false; } @@ -294,9 +336,14 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, 1)) { return false; } @@ -334,9 +381,10 @@ void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const bool ConvNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + if (!CheckQDQNodes(graph_viewer, node, p_activation_node, dq_nodes, q_nodes)) { return false; } @@ -379,8 +427,13 @@ void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + if (dq_nodes.size() != 2) { return false; } @@ -409,7 +462,7 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, if (qlinear) { // QLinearMatMul - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes)) { return false; } @@ -423,8 +476,13 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + // Should not have any Q nodes if (!q_nodes.empty()) { return false; @@ -508,9 +566,14 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, -1 /*num_dq_inputs*/, true /*is_empty_q_nodes_allowed*/)) { return false; } @@ -557,11 +620,15 @@ void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } -bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, +bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + // Where has 1 boolean input and 2 dq inputs - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 2)) { + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, 2)) { return false; } @@ -586,9 +653,13 @@ bool WhereNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& return true; } -bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, +bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + // Pad can have 1 or 2 dq input, the optional input constant_value can be quantized or non-quantized. // QNN supports data input quantized with constant_value input non-quantized. int num_dq_inputs = static_cast(dq_nodes.size()); @@ -596,7 +667,7 @@ bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& no return false; } - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, num_dq_inputs)) { + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, num_dq_inputs)) { return false; } @@ -613,9 +684,14 @@ bool PadNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& no bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes)) { return false; } @@ -637,9 +713,14 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 3)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, 3)) { return false; } @@ -661,9 +742,14 @@ bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, -1, true)) { + if (p_activation_node) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, nullptr, dq_nodes, q_nodes, -1, true)) { return false; } @@ -674,15 +760,21 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, + const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { + if (p_activation_node) { + return false; + } + constexpr int num_dq_inputs = 1; constexpr int num_q_outputs = 1; if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { return false; } - if (const auto qdq_validation_status = QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, dq_nodes, q_nodes); + if (const auto qdq_validation_status = + QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, node, nullptr, dq_nodes, q_nodes); !qdq_validation_status.IsOK()) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 0ba5436e69e81..bb371b8cc69a8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -26,7 +26,7 @@ class NodeGroupSelector { protected: // base check that we have the expected number of QDQ inputs/outputs, and `node` isn't producing a graph output. // num_dq_inputs defaults to the number of inputs `node` has if not explicitly specified - bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, + bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes, int num_dq_inputs = -1, @@ -34,7 +34,7 @@ class NodeGroupSelector { private: // derived classes should implement this check - bool virtual Check(const GraphViewer& graph_viewer, const Node& node, + bool virtual Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const = 0; }; @@ -53,7 +53,7 @@ class DropQDQNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -69,7 +69,7 @@ class DropDQNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -84,7 +84,7 @@ class UnaryNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -99,7 +99,7 @@ class BinaryNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -114,7 +114,7 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -131,7 +131,7 @@ class SplitNodeGroupSelector : public NodeGroupSelector { : req_equal_quant_params_(req_equal_quant_params), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -149,7 +149,7 @@ class ConvNodeGroupSelector : public NodeGroupSelector { : int8_allowed_(int8_allowed), allow_16bit_(allow_16bit), allow_4bit_weight_(allow_4bit_weight) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -164,7 +164,7 @@ class WhereNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -177,7 +177,7 @@ class PadNodeGroupSelector : public NodeGroupSelector { PadNodeGroupSelector() = default; private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; }; @@ -197,7 +197,7 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { } private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; bool int8_allowed_; @@ -209,7 +209,7 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { // Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" class DQMatMulNodeGroupSelector : public NodeGroupSelector { private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; }; @@ -222,7 +222,7 @@ class GemmNodeGroupSelector : public NodeGroupSelector { : allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -234,7 +234,7 @@ class GemmNodeGroupSelector : public NodeGroupSelector { // Output: Q node for output class InstanceAndLayerNormalizationNodeGroupSelector : public NodeGroupSelector { private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; }; @@ -246,7 +246,7 @@ class BatchNormalizationNodeGroupSelector : public NodeGroupSelector { BatchNormalizationNodeGroupSelector(bool int8_allowed = true) : int8_allowed_(int8_allowed) {} private: - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; @@ -256,7 +256,7 @@ class BatchNormalizationNodeGroupSelector : public NodeGroupSelector { // 2 DQ nodes providing input -> node with bool output tensor. // Example: Equal, Less, Greater. class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; }; @@ -264,7 +264,7 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { // TopK has 1 DQ input node and 1 Q output node. // Zero point and scale are constant scalars and must match class TopKNodeGroupSelector : public NodeGroupSelector { - bool Check(const GraphViewer& graph_viewer, const Node& node, + bool Check(const GraphViewer& graph_viewer, const Node& node, const Node* p_activation_node, const std::vector& dq_nodes, const std::vector& q_nodes) const override; }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d2240b5d50194..5fd95d95a891c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -351,6 +351,9 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) { add_node_unit_to_map(qdq_selection.dq_nodes, qdq_unit.get()); add_node_unit_to_map(qdq_selection.q_nodes, qdq_unit.get()); add_node_unit_to_map({qdq_selection.target_node}, qdq_unit.get()); + if (qdq_selection.activation_node.has_value()) { + add_node_unit_to_map({qdq_selection.activation_node.value()}, qdq_unit.get()); + } node_unit_holder.push_back(std::move(qdq_unit)); } diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index 83c08f3dbd25e..20f51558e7a1e 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -198,6 +198,11 @@ std::vector> CreateSupportedPartitionNodeGroups( } supported_group.push_back(&node); + const Node* p_activation_node = node_unit->GetActivationNode(); + if (p_activation_node) { + supported_group.push_back(p_activation_node); + supported_group_border.erase(p_activation_node); + } for (const auto& q : node_unit->GetQNodes()) { supported_group.push_back(q); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc deleted file mode 100644 index 813bba8a5952b..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ /dev/null @@ -1,480 +0,0 @@ -#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" - -#include -#include -#include -#include -#include -#include -#include "core/graph/graph_utils.h" -#include "core/framework/node_unit.h" -#include "core/providers/shared/utils/utils.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group/utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" - -namespace onnxruntime { -namespace qnn { - -// Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. -static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& q_node_unit, - /*out*/ float& scale, - /*out*/ int32_t& zero_point, - /*out*/ int32_t& zp_data_type) { - assert(q_node_unit.OpType() == QUANTIZE_LINEAR); - const auto& q_inputs = q_node_unit.GetNode().InputDefs(); - - // Require an explicit zero-point input for now. - if (q_inputs.size() != 3 || !q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Exists()) { - return false; - } - - std::vector zero_points; - Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Name(), - zero_points, zp_data_type); - - // Should only have one zero-point (per-tensor). - if (!status.IsOK() || zero_points.size() != 1) { - return false; - } - zero_point = -zero_points[0]; // QNN zero-points are negated. - - std::vector scales; - status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ_SCALE_INPUT_IDX]->Name(), scales); - - // Should only have one scale (per-tensor). - if (!status.IsOK() || scales.size() != 1) { - return false; - } - - scale = scales[0]; - return true; -} - -// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. -static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& q_node_unit, - /*out*/ float& rmin, - /*out*/ float& rmax) { - int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; - int32_t zero_point = 0; - float scale = 0.0f; - - if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { - return false; - } - - switch (zp_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { - rmin = scale * (std::numeric_limits::lowest() - zero_point); - rmax = scale * (std::numeric_limits::max() - zero_point); - break; - } - default: - return false; - } - - return true; -} - -// Returns true if the Clip in the sequence (Clip -> Q) can be removed because it is made redundant by the Q. -static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& clip_node_unit, - const NodeUnit& q_node_unit, - const logging::Logger& logger) { - assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QUANTIZE_LINEAR); - float rmin = 0.0f; - float rmax = 0.0f; - - if (!GetQRminRmax(qnn_model_wrapper, q_node_unit, rmin, rmax)) { - return false; - } - - float clip_min = std::numeric_limits::lowest(); - float clip_max = std::numeric_limits::max(); - - if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), - clip_min, clip_max, logger)) { - return false; - } - - // The clip range must entirely overlap the quantization range (quantization can be smaller). - // Clip range: [------------------] - // Quant range: [-------------] - constexpr float epsilon = std::numeric_limits::epsilon(); - if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { - return false; - } - - return true; -} - -// Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. -static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { - assert(q_node_unit.OpType() == QUANTIZE_LINEAR); - int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; - int32_t zero_point = 0; - float scale = 0.0f; - - if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { - return false; - } - - // Relu is redundant if the zero-point is set to the smallest quantized value. - switch (zp_data_type) { - case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: - return zero_point == static_cast(std::numeric_limits::lowest()); - case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: - return zero_point == static_cast(std::numeric_limits::lowest()); - case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: - return zero_point == static_cast(std::numeric_limits::lowest()); - case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: - return zero_point == static_cast(std::numeric_limits::lowest()); - default: - return false; - } -} - -// Returns true if the Clip/Relu in the sequence (Clip/Relu -> Q) can be removed because it is made redundant by the Q. -static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit, - const logging::Logger& logger) { - const std::string& activation_type = activation_node_unit.OpType(); - - if (activation_type == "Relu") { - return CanQRelaceRelu(qnn_model_wrapper, q_node_unit); - } - - if (activation_type == "Clip") { - return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit, logger); - } - - return false; -} - -// Returns the parent DQ nodes for a given node. -static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { - // Get all parent DQ nodes sorted by destination argument index. - std::vector parents(node.InputDefs().size(), nullptr); - for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { - if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { - parents[it->GetDstArgIndex()] = &(it->GetNode()); - } - } - - // Remove all the nodes which are not in the graph_viewer - parents.erase(std::remove_if(parents.begin(), parents.end(), - [&graph_viewer](const Node* _node) { - return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; - }), - parents.end()); - - return parents; -} - -// Gets the parent DQ nodes for the given Conv node. This fuction checks that the DQs are not a part of -// any other NodeUnit and that every Conv input comes from a parent DQ. -static bool GetConvDQs( - const GraphViewer& graph_viewer, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const Node& conv_node, - /*out*/ std::array& dq_node_units) { - if (conv_node.OpType() != "Conv" && conv_node.OpType() != "ConvTranspose") { - return false; - } - - // Count number of inputs to Conv node. - const auto& conv_inputs = conv_node.InputDefs(); - const size_t num_conv_inputs = std::count_if(conv_inputs.cbegin(), conv_inputs.cend(), - [](const NodeArg* input) { return input && input->Exists(); }); - - // Get the Conv's parent DQ nodes. - std::vector dq_nodes = FindParentDQNodes(graph_viewer, conv_node); - const size_t num_dqs = dq_nodes.size(); - - // Within a QDQ node group, a target node input is the only consumer of each DQ. - if ((num_conv_inputs != num_dqs) || (num_dqs > dq_node_units.size())) { - return false; - } - - dq_node_units.fill(nullptr); - for (size_t i = 0; i < num_dqs; i++) { - const Node* dq_node = dq_nodes[i]; - - // DQ must not produce a graph output. - if (!dq_node || graph_viewer.NodeProducesGraphOutput(*dq_node)) { - return false; - } - - // Conv should be the only consumer of a parent DQ. - const bool dq_has_single_output_edge_to_target = - dq_node->GetOutputEdgesCount() == 1 && - dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); - if (!dq_has_single_output_edge_to_target) { - return false; - } - - // DQ node must be part of a "standalone" NodeUnit. - const auto it = node_to_node_unit.find(dq_node); - if (it == node_to_node_unit.end()) { - return false; - } - const NodeUnit* dq_node_unit = it->second; - if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { - return false; - } - if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return false; - } - - dq_node_units[i] = dq_node_unit; - } - - return true; -} - -// Checks that the input and output data types are valid for a QDQ Conv. -static bool CheckQDQConvDataTypes(std::array& dq_node_units, - gsl::not_null q_node_unit) { - assert(q_node_unit->OpType() == QUANTIZE_LINEAR); - // input and output types need to be same - int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (dt_input != dt_output) { - return false; - } - - if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { - if (dt_weight != dt_input) { - return false; - } - } - - if (dq_node_units[2] != nullptr) { // has bias - int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { - return false; - } - } - - return true; -} - -// Utility function to either validate or create a quantized QNN Conv node. The function creates a temporary -// custom NodeUnit that excludes the Clip/Relu because it is redundant. This custom NodeUnit is passed to our -// existing Conv OpBuilder for creation or validation via QNN APIs. -#define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ - CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) -#define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ - CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), false) -static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, - gsl::span dq_node_units, - const NodeUnit* conv_node_unit, - const NodeUnit* q_node_unit, - const logging::Logger& logger, - bool validate) { - const size_t num_dqs = dq_node_units.size(); - constexpr size_t max_num_dqs = 3; - ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); - ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == QUANTIZE_LINEAR, - "Expected Conv/ConvTranspose and QuantizeLinear but got ", conv_node_unit->OpType(), " and ", - q_node_unit->OpType()); - - std::array dq_nodes_buf = {}; - for (size_t i = 0; i < num_dqs; i++) { - dq_nodes_buf[i] = &dq_node_units[i]->GetNode(); - } - gsl::span dq_nodes(dq_nodes_buf.data(), num_dqs); - - std::array q_nodes = {&q_node_unit->GetNode()}; - const Node& target_node = conv_node_unit->GetNode(); - - // Populate NodeUnit inputs - std::vector inputs; - inputs.reserve(num_dqs); - for (const Node* dq_node : dq_nodes) { - const auto dq_inputs = dq_node->InputDefs(); - const auto& dq_attrs = dq_node->GetAttributes(); - - std::optional axis; - if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { - axis = entry->second.i(); - } - - // quantization scale and zp are always the input[1, 2] - NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; - inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); - } - - // Populate NodeUnit outputs and output edges - std::vector outputs; - Node::EdgeSet output_edges; - for (const Node* q_node : q_nodes) { - const auto q_inputs = q_node->InputDefs(); - const auto& q_attrs = q_node->GetAttributes(); - const auto q_outputs = q_node->OutputDefs(); - - std::optional axis; - if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { - axis = entry->second.i(); - } - - // quantization scale and zp are always the input[1, 2] - NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; - outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); - - // Gather output edges out of the Q node. - auto q_cur_edge = q_node->OutputEdgesBegin(); - auto q_end_edge = q_node->OutputEdgesEnd(); - for (; q_cur_edge != q_end_edge; ++q_cur_edge) { - output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); - } - } - - NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, - inputs, outputs, num_dqs, output_edges); - const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); - if (conv_op_builder == nullptr) { - return Status::OK(); - } - - if (validate) { - return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); - } - - return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); -} - -// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. -// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. -std::unique_ptr ConvActivationFusion::TryFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - // Expect that this function is called with a standalone Conv or ConvTranspose. - const auto& conv_type = conv_node_unit.OpType(); - - if ((conv_type != "Conv" && conv_type != "ConvTranspose") || - (conv_node_unit.UnitType() != NodeUnit::Type::SingleNode)) { - return nullptr; - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - - // Conv must have a single Relu or Clip child. - const std::array activation_op_types = {"Relu", "Clip"}; - const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, - node_to_node_unit, node_unit_to_qnn_node_group); - if (activation_node_unit == nullptr) { - return nullptr; - } - - // Relu/Clip must have a single Q child. - const std::array q_op_types = {QUANTIZE_LINEAR}; - const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, - node_to_node_unit, node_unit_to_qnn_node_group); - - if (q_node_unit == nullptr) { - return nullptr; - } - - // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. - if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit, logger)) { - return nullptr; - } - - // Create a QDQ node group with DQ* -> Conv -> Q - const Node& conv_node = conv_node_unit.GetNode(); - std::array dq_node_units = {}; - if (!GetConvDQs(graph_viewer, - node_to_node_unit, - node_unit_to_qnn_node_group, - conv_node, dq_node_units)) { - return nullptr; - } - - if (!CheckQDQConvDataTypes(dq_node_units, q_node_unit)) { - return nullptr; - } - - return std::make_unique(*dq_node_units[0], - *dq_node_units[1], - dq_node_units[2], - conv_node_unit, - *activation_node_unit, - *q_node_unit); -} - -ConvActivationFusion::ConvActivationFusion(const NodeUnit& dq_node_unit_0, - const NodeUnit& dq_node_unit_1, - const NodeUnit* dq_node_unit_2, - const NodeUnit& conv_node_unit, - const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit) - : node_units_{} { - size_t i = 0; - node_units_[i++] = &dq_node_unit_0; - node_units_[i++] = &dq_node_unit_1; - if (dq_node_unit_2 != nullptr) { - node_units_[i++] = dq_node_unit_2; - } - node_units_[i++] = &conv_node_unit; - node_units_[i++] = &activation_node_unit; - node_units_[i++] = &q_node_unit; - assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); -} - -Status ConvActivationFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(node_units_.data(), num_dqs); - - return ValidateOnQnn(qmw, dq_node_units, - node_units_[num_dqs], // Conv - node_units_[num_dqs + 2], // Q - logger); -} - -Status ConvActivationFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(node_units_.data(), num_dqs); - - return CreateOnQnn(qmw, dq_node_units, - node_units_[num_dqs], // Conv - node_units_[num_dqs + 2], // Q - logger); -} - -gsl::span ConvActivationFusion::GetNodeUnits() const { - const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; - return gsl::make_span(node_units_.data(), num_node_units); -} - -const NodeUnit* ConvActivationFusion::GetTargetNodeUnit() const { - const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; - return node_units_[conv_index]; -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h deleted file mode 100644 index b604b25e943e6..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_node_group.h" - -namespace onnxruntime { -namespace qnn { - -class QnnModelWrapper; - -/// -/// Represents a fusion of a DQ* -> Conv -> Relu/Clip -> Q sequence where the Relu (or Clip) is redundant -/// due to the quantization effects of the Q. This sequence is translated to a quantized QNN Conv. -/// All contained NodeUnits are of type SingleNode since they are not a part of an existing QDQ node unit. -/// -class ConvActivationFusion : public IQnnNodeGroup { - public: - ConvActivationFusion(const NodeUnit& dq_node_unit_0, - const NodeUnit& dq_node_unit_1, - const NodeUnit* dq_node_unit_2, - const NodeUnit& conv_node_unit, - const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion); - - Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - gsl::span GetNodeUnits() const override; - const NodeUnit* GetTargetNodeUnit() const override; - std::string_view Type() const override { return "ConvActivationFusion"; } - - /// - /// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. - /// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. - /// - /// Used for validation and to traverse/query the graph - /// Conv node unit (type SingleNode) that be part of the sequence. - /// Maps a Node to a NodeUnit. - /// Maps a NodeUnit to a IQnnNodeGroup. - /// - /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise - static std::unique_ptr TryFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); - - private: - std::array node_units_; // Last elem is nullptr if the optional bias DQ is missing. -}; - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 9fb9e815321c0..c398d1fae5097 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -15,7 +15,6 @@ #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" @@ -90,8 +89,6 @@ static std::unique_ptr TryQnnFusions( static std::unordered_map fusions = { {"DequantizeLinear", DQQFusion::TryFusion}, {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, - {"Conv", ConvActivationFusion::TryFusion}, - {"ConvTranspose", ConvActivationFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes).