@@ -19,15 +19,17 @@ namespace onnxruntime {
19
19
// / <param name="q_node">QuantizeLinear node</param>
20
20
// / <param name="zp_data_type">Output parameter to store the zero-point data type</param>
21
21
// / <returns>True if successfully extracted the zero-point data type</returns>
22
- static bool GetQNodeZeroPointType (const Graph& graph, const Node& q_node, /* out*/ int64_t & zp_data_type) {
22
+ static bool GetQNodeZeroPointType (const Graph& graph, const Node& q_node,
23
+ /* out*/ ONNX_NAMESPACE::TensorProto_DataType& zp_data_type) {
23
24
assert (q_node.OpType () == " QuantizeLinear" );
24
25
const auto input_defs = q_node.InputDefs ();
25
26
26
27
if (QDQ::InputIndex::ZERO_POINT_ID >= input_defs.size () || !input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Exists ()) {
27
28
// If a zero_point input is absent, get the type from the "output_dtype" attribute or default to uint8.
28
29
// The "output_dtype" attribute was added in ONNX opset 21.
29
30
const auto * attr = graph_utils::GetNodeAttribute (q_node, " output_dtype" );
30
- zp_data_type = attr != nullptr ? attr->i () : static_cast <int64_t >(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
31
+ zp_data_type = attr != nullptr ? static_cast <ONNX_NAMESPACE::TensorProto_DataType>(attr->i ())
32
+ : ONNX_NAMESPACE::TensorProto_DataType_UINT8;
31
33
return true ;
32
34
}
33
35
@@ -36,7 +38,7 @@ static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, /*out*
36
38
return false ;
37
39
}
38
40
39
- zp_data_type = zp_proto->data_type ();
41
+ zp_data_type = static_cast <ONNX_NAMESPACE::TensorProto_DataType>( zp_proto->data_type () );
40
42
return true ;
41
43
}
42
44
@@ -119,7 +121,7 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons
119
121
// for correctness.
120
122
template <typename ZeroPointType>
121
123
static bool RecomputeOuterQDQZeroPointAndScale (Graph& graph, Node& q1, const Node& dq1, const Node& q2,
122
- gsl::span<Node*> dq2s) {
124
+ gsl::span<gsl::not_null< Node*> > dq2s) {
123
125
if (dq2s.empty ()) {
124
126
return false ;
125
127
}
@@ -137,7 +139,7 @@ static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Nod
137
139
ApplyNewInputValue (graph, q1, QDQ::InputIndex::SCALE_ID, new_scale);
138
140
ApplyNewInputValue (graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
139
141
140
- for (auto * dq2 : dq2s) {
142
+ for (gsl::not_null<Node*> dq2 : dq2s) {
141
143
ApplyNewInputValue (graph, *dq2, QDQ::InputIndex::SCALE_ID, new_scale);
142
144
ApplyNewInputValue (graph, *dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
143
145
}
@@ -183,27 +185,27 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
183
185
return false ;
184
186
}
185
187
186
- int64_t quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
187
- if (!GetQNodeZeroPointType (graph, *q1, quant_type )) {
188
+ auto q1_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
189
+ if (!GetQNodeZeroPointType (graph, *q1, q1_quant_type )) {
188
190
return false ;
189
191
}
190
192
191
193
// Ensure that q2 is a Q operator, its output is not a graph output, and that its zero-point quantization type
192
194
// is equal to q1's.
193
195
NodeIndex q2_index = dq1->OutputEdgesBegin ()->GetNode ().Index ();
194
196
const Node* q2 = graph.GetNode (q2_index);
195
- int64_t quant_type_2 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
197
+ auto q2_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
196
198
197
199
if (q2 == nullptr ||
198
200
q2->OpType () != " QuantizeLinear" ||
199
201
graph.NodeProducesGraphOutput (*q2) ||
200
- !GetQNodeZeroPointType (graph, *q2, quant_type_2 ) ||
201
- quant_type != quant_type_2 ) {
202
+ !GetQNodeZeroPointType (graph, *q2, q2_quant_type ) ||
203
+ q1_quant_type != q2_quant_type ) {
202
204
return false ;
203
205
}
204
206
205
207
// All of q2's children should be DQ nodes with zero-point and scale values equal to those of q2.
206
- InlinedVector<Node*> dq2_nodes;
208
+ InlinedVector<gsl::not_null< Node*> > dq2_nodes;
207
209
dq2_nodes.reserve (q2->GetOutputEdgesCount ());
208
210
209
211
for (auto it = q2->OutputEdgesBegin (); it != q2->OutputEdgesEnd (); it++) {
@@ -224,13 +226,13 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
224
226
}
225
227
226
228
bool can_recompute = false ;
227
- if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
229
+ if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
228
230
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint8_t >(graph, *q1, *dq1, *q2, dq2_nodes);
229
- } else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
231
+ } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
230
232
can_recompute = RecomputeOuterQDQZeroPointAndScale<int8_t >(graph, *q1, *dq1, *q2, dq2_nodes);
231
- } else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
233
+ } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
232
234
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint16_t >(graph, *q1, *dq1, *q2, dq2_nodes);
233
- } else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
235
+ } else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
234
236
can_recompute = RecomputeOuterQDQZeroPointAndScale<int16_t >(graph, *q1, *dq1, *q2, dq2_nodes);
235
237
}
236
238
@@ -243,7 +245,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
243
245
244
246
// Disconnect Q2 --> DQ2(s)
245
247
// Connect Q1 -> DQ2(s)
246
- for (auto * dq2 : dq2_nodes) {
248
+ for (gsl::not_null<Node*> dq2 : dq2_nodes) {
247
249
graph.RemoveEdge (q2_index, dq2->Index (), 0 , 0 );
248
250
graph.AddEdge (q1_index, dq2->Index (), 0 , 0 );
249
251
}
0 commit comments