Skip to content

Commit 8aae19e

Browse files
Review comments: clean up names, use correct param type, use gsl::as_bytes, etc.
1 parent 2afc958 commit 8aae19e

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

onnxruntime/core/optimizer/double_qdq_pairs_remover.cc

+18-16
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ namespace onnxruntime {
1919
/// <param name="q_node">QuantizeLinear node</param>
2020
/// <param name="zp_data_type">Output parameter to store the zero-point data type</param>
2121
/// <returns>True if successfully extracted the zero-point data type</returns>
22-
static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, /*out*/ int64_t& zp_data_type) {
22+
static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node,
23+
/*out*/ ONNX_NAMESPACE::TensorProto_DataType& zp_data_type) {
2324
assert(q_node.OpType() == "QuantizeLinear");
2425
const auto input_defs = q_node.InputDefs();
2526

2627
if (QDQ::InputIndex::ZERO_POINT_ID >= input_defs.size() || !input_defs[QDQ::InputIndex::ZERO_POINT_ID]->Exists()) {
2728
// If a zero_point input is absent, get the type from the "output_dtype" attribute or default to uint8.
2829
// The "output_dtype" attribute was added in ONNX opset 21.
2930
const auto* attr = graph_utils::GetNodeAttribute(q_node, "output_dtype");
30-
zp_data_type = attr != nullptr ? attr->i() : static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_UINT8);
31+
zp_data_type = attr != nullptr ? static_cast<ONNX_NAMESPACE::TensorProto_DataType>(attr->i())
32+
: ONNX_NAMESPACE::TensorProto_DataType_UINT8;
3133
return true;
3234
}
3335

@@ -36,7 +38,7 @@ static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, /*out*
3638
return false;
3739
}
3840

39-
zp_data_type = zp_proto->data_type();
41+
zp_data_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(zp_proto->data_type());
4042
return true;
4143
}
4244

@@ -119,7 +121,7 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons
119121
// for correctness.
120122
template <typename ZeroPointType>
121123
static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2,
122-
gsl::span<Node*> dq2s) {
124+
gsl::span<gsl::not_null<Node*>> dq2s) {
123125
if (dq2s.empty()) {
124126
return false;
125127
}
@@ -137,7 +139,7 @@ static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Nod
137139
ApplyNewInputValue(graph, q1, QDQ::InputIndex::SCALE_ID, new_scale);
138140
ApplyNewInputValue(graph, q1, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
139141

140-
for (auto* dq2 : dq2s) {
142+
for (gsl::not_null<Node*> dq2 : dq2s) {
141143
ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::SCALE_ID, new_scale);
142144
ApplyNewInputValue(graph, *dq2, QDQ::InputIndex::ZERO_POINT_ID, new_zero_point);
143145
}
@@ -183,27 +185,27 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
183185
return false;
184186
}
185187

186-
int64_t quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
187-
if (!GetQNodeZeroPointType(graph, *q1, quant_type)) {
188+
auto q1_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
189+
if (!GetQNodeZeroPointType(graph, *q1, q1_quant_type)) {
188190
return false;
189191
}
190192

191193
// Ensure that q2 is a Q operator, its output is not a graph output, and that its zero-point quantization type
192194
// is equal to q1's.
193195
NodeIndex q2_index = dq1->OutputEdgesBegin()->GetNode().Index();
194196
const Node* q2 = graph.GetNode(q2_index);
195-
int64_t quant_type_2 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
197+
auto q2_quant_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
196198

197199
if (q2 == nullptr ||
198200
q2->OpType() != "QuantizeLinear" ||
199201
graph.NodeProducesGraphOutput(*q2) ||
200-
!GetQNodeZeroPointType(graph, *q2, quant_type_2) ||
201-
quant_type != quant_type_2) {
202+
!GetQNodeZeroPointType(graph, *q2, q2_quant_type) ||
203+
q1_quant_type != q2_quant_type) {
202204
return false;
203205
}
204206

205207
// All of q2's children should be DQ nodes with zero-point and scale values equal to those of q2.
206-
InlinedVector<Node*> dq2_nodes;
208+
InlinedVector<gsl::not_null<Node*>> dq2_nodes;
207209
dq2_nodes.reserve(q2->GetOutputEdgesCount());
208210

209211
for (auto it = q2->OutputEdgesBegin(); it != q2->OutputEdgesEnd(); it++) {
@@ -224,13 +226,13 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
224226
}
225227

226228
bool can_recompute = false;
227-
if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
229+
if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
228230
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint8_t>(graph, *q1, *dq1, *q2, dq2_nodes);
229-
} else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
231+
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
230232
can_recompute = RecomputeOuterQDQZeroPointAndScale<int8_t>(graph, *q1, *dq1, *q2, dq2_nodes);
231-
} else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
233+
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
232234
can_recompute = RecomputeOuterQDQZeroPointAndScale<uint16_t>(graph, *q1, *dq1, *q2, dq2_nodes);
233-
} else if (quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
235+
} else if (q1_quant_type == ONNX_NAMESPACE::TensorProto_DataType_INT16) {
234236
can_recompute = RecomputeOuterQDQZeroPointAndScale<int16_t>(graph, *q1, *dq1, *q2, dq2_nodes);
235237
}
236238

@@ -243,7 +245,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) {
243245

244246
// Disconnect Q2 --> DQ2(s)
245247
// Connect Q1 -> DQ2(s)
246-
for (auto* dq2 : dq2_nodes) {
248+
for (gsl::not_null<Node*> dq2 : dq2_nodes) {
247249
graph.RemoveEdge(q2_index, dq2->Index(), 0, 0);
248250
graph.AddEdge(q1_index, dq2->Index(), 0, 0);
249251
}

onnxruntime/test/optimizer/graph_transform_test_builder.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@ static InlinedVector<std::byte> GetZeroPointBytes(int64_t zero_point, ONNX_NAMES
2626
switch (type) {
2727
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
2828
int8_t val = static_cast<int8_t>(zero_point);
29-
auto span = ReinterpretAsSpan<const std::byte, const int8_t>(gsl::make_span(&val, 1));
29+
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
3030
return InlinedVector<std::byte>(span.begin(), span.end());
3131
}
3232
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
3333
uint8_t val = static_cast<uint8_t>(zero_point);
34-
auto span = ReinterpretAsSpan<const std::byte, const uint8_t>(gsl::make_span(&val, 1));
34+
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
3535
return InlinedVector<std::byte>(span.begin(), span.end());
3636
}
3737
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
3838
int16_t val = static_cast<int16_t>(zero_point);
39-
auto span = ReinterpretAsSpan<const std::byte, const int16_t>(gsl::make_span(&val, 1));
39+
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
4040
return InlinedVector<std::byte>(span.begin(), span.end());
4141
}
4242
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
4343
uint16_t val = static_cast<uint16_t>(zero_point);
44-
auto span = ReinterpretAsSpan<const std::byte, const uint16_t>(gsl::make_span(&val, 1));
44+
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
4545
return InlinedVector<std::byte>(span.begin(), span.end());
4646
}
4747
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
4848
int32_t val = static_cast<int32_t>(zero_point);
49-
auto span = ReinterpretAsSpan<const std::byte, const int32_t>(gsl::make_span(&val, 1));
49+
auto span = gsl::as_bytes(gsl::make_span(&val, 1));
5050
return InlinedVector<std::byte>(span.begin(), span.end());
5151
}
5252
default:

0 commit comments

Comments
 (0)