Skip to content

Commit fd4f5a5

Browse files
Add unit tests for DoubleQDQPairsRemover
1 parent ecd171c commit fd4f5a5

File tree

5 files changed

+241
-5
lines changed

5 files changed

+241
-5
lines changed

onnxruntime/core/optimizer/double_qdq_pairs_remover.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// Licensed under the MIT License.
33
#include "core/optimizer/double_qdq_pairs_remover.h"
44
#include <cassert>
5+
#include <string>
6+
#include <vector>
57

68
#include "core/common/gsl.h"
79
#include "core/graph/graph_utils.h"
@@ -88,7 +90,8 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons
8890
// After removing the middle two nodes, the zero point and scale of the final (outer) ops must be recomputed
8991
// for correctness.
9092
template <typename ZeroPointType>
91-
static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2, gsl::span<Node*> dq2s) {
93+
static bool RecomputeOuterQDQZeroPointAndScale(Graph& graph, Node& q1, const Node& dq1, const Node& q2,
94+
gsl::span<Node*> dq2s) {
9295
if (dq2s.empty()) {
9396
return false;
9497
}

onnxruntime/test/optimizer/graph_transform_test_builder.h

+82-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <type_traits>
77
#include <vector>
88

9+
#include "core/common/gsl.h"
910
#include "core/common/type_utils.h"
1011
#include "core/graph/graph.h"
1112
#include "core/framework/framework_common.h"
@@ -195,13 +196,14 @@ class ModelTestBuilder {
195196
return &graph_.GetOrCreateNodeArg(name, &type_proto);
196197
}
197198

198-
template <typename T>
199-
NodeArg* MakeInitializer(const std::vector<int64_t>& shape, const std::vector<T>& data) {
199+
// Makes an initializer from the provided shape, element type, and raw_data bytes.
200+
NodeArg* MakeInitializer(gsl::span<const int64_t> shape, ONNX_NAMESPACE::TensorProto_DataType elem_type,
201+
gsl::span<const std::byte> raw_data) {
200202
std::string name = graph_.GenerateNodeArgName("constant");
201203
ONNX_NAMESPACE::TensorProto tensor_proto;
202204
tensor_proto.set_name(name);
203-
tensor_proto.set_data_type(utils::ToTensorProtoElementType<T>());
204-
tensor_proto.set_raw_data(data.data(), data.size() * sizeof(T));
205+
tensor_proto.set_data_type(elem_type);
206+
tensor_proto.set_raw_data(raw_data.data(), raw_data.size());
205207

206208
for (auto& dim : shape) {
207209
tensor_proto.add_dims(dim);
@@ -212,6 +214,12 @@ class ModelTestBuilder {
212214
return &graph_.GetOrCreateNodeArg(name, nullptr);
213215
}
214216

217+
template <typename T>
218+
NodeArg* MakeInitializer(const std::vector<int64_t>& shape, const std::vector<T>& data) {
219+
gsl::span<const std::byte> raw_data = ReinterpretAsSpan<const std::byte, const T>(data);
220+
return MakeInitializer(shape, utils::ToTensorProtoElementType<T>(), raw_data);
221+
}
222+
215223
// Special handle for std::vector<bool>.
216224
NodeArg* MakeInitializerBool(const std::vector<int64_t>& shape, const std::vector<bool>& data) {
217225
std::string name = graph_.GenerateNodeArgName("constant");
@@ -342,6 +350,57 @@ class ModelTestBuilder {
342350
return AddNode("QuantizeLinear", input_args, {output_arg}, domain, attributes);
343351
}
344352

353+
static std::vector<std::byte> GetZeroPointBytes(int64_t zero_point, ONNX_NAMESPACE::TensorProto_DataType type) {
354+
switch (type) {
355+
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
356+
int8_t val = static_cast<int8_t>(zero_point);
357+
auto span = ReinterpretAsSpan<const std::byte, const int8_t>(gsl::make_span(&val, 1));
358+
return std::vector<std::byte>(span.begin(), span.end());
359+
}
360+
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
361+
uint8_t val = static_cast<uint8_t>(zero_point);
362+
auto span = ReinterpretAsSpan<const std::byte, const uint8_t>(gsl::make_span(&val, 1));
363+
return std::vector<std::byte>(span.begin(), span.end());
364+
}
365+
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
366+
int16_t val = static_cast<int16_t>(zero_point);
367+
auto span = ReinterpretAsSpan<const std::byte, const int16_t>(gsl::make_span(&val, 1));
368+
return std::vector<std::byte>(span.begin(), span.end());
369+
}
370+
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
371+
uint16_t val = static_cast<uint16_t>(zero_point);
372+
auto span = ReinterpretAsSpan<const std::byte, const uint16_t>(gsl::make_span(&val, 1));
373+
return std::vector<std::byte>(span.begin(), span.end());
374+
}
375+
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
376+
int32_t val = static_cast<int32_t>(zero_point);
377+
auto span = ReinterpretAsSpan<const std::byte, const int32_t>(gsl::make_span(&val, 1));
378+
return std::vector<std::byte>(span.begin(), span.end());
379+
}
380+
default:
381+
ORT_THROW("Unhandled zero-point type ", type, ".");
382+
}
383+
}
384+
385+
// Adds a Q node with a runtime configurable zero-point type.
386+
// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types.
387+
Node& AddQuantizeLinearNode(NodeArg* input_arg,
388+
float input_scale,
389+
int64_t input_zero_point,
390+
ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
391+
NodeArg* output_arg,
392+
bool use_ms_domain = false) {
393+
std::vector<NodeArg*> input_args;
394+
input_args.push_back(input_arg);
395+
input_args.push_back(MakeScalarInitializer<float>(input_scale));
396+
397+
std::vector<std::byte> zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type);
398+
input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes));
399+
400+
std::string domain = use_ms_domain ? kMSDomain : "";
401+
return AddNode("QuantizeLinear", input_args, {output_arg}, domain);
402+
}
403+
345404
template <typename T>
346405
typename std::enable_if<IsTypeDequantLinearCompatible<T>::value, Node&>::type
347406
AddDequantizeLinearNode(NodeArg* input_arg,
@@ -400,6 +459,25 @@ class ModelTestBuilder {
400459
return AddNode("DequantizeLinear", input_args, {output_arg}, domain, attributes);
401460
}
402461

462+
// Adds a DQ node with a runtime configurable zero-point type.
463+
// Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types.
464+
Node& AddDequantizeLinearNode(NodeArg* input_arg,
465+
float input_scale,
466+
int64_t input_zero_point,
467+
ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
468+
NodeArg* output_arg,
469+
bool use_ms_domain = false) {
470+
std::vector<NodeArg*> input_args;
471+
input_args.push_back(input_arg);
472+
input_args.push_back(MakeScalarInitializer<float>(input_scale));
473+
474+
std::vector<std::byte> zp_bytes = GetZeroPointBytes(input_zero_point, zero_point_type);
475+
input_args.push_back(MakeInitializer({}, zero_point_type, zp_bytes));
476+
477+
std::string domain = use_ms_domain ? kMSDomain : "";
478+
return AddNode("DequantizeLinear", input_args, {output_arg}, domain);
479+
}
480+
403481
template <typename TWeight>
404482
Node& AddQLinearConvNode(NodeArg* input_arg,
405483
float input_scale,

onnxruntime/test/optimizer/qdq_test_utils.cc

+43
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,48 @@ std::vector<std::string> GetNodeOpTypesInTopologicalOrder(const Graph& graph, bo
164164
return op_types;
165165
}
166166

167+
GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs(
168+
const std::vector<int64_t>& input_shape,
169+
const std::vector<float>& input_data,
170+
const std::vector<int64_t>& zero_points,
171+
const std::vector<ONNX_NAMESPACE::TensorProto_DataType>& zero_point_types,
172+
const std::vector<float>& scales,
173+
int graph_output_index,
174+
bool use_contrib_qdq) {
175+
const size_t num_nodes = zero_points.size();
176+
bool valid_inputs = (num_nodes >= 4) &&
177+
(zero_point_types.size() == num_nodes) &&
178+
(scales.size() == num_nodes) &&
179+
(graph_output_index >= 0 && graph_output_index < 4);
180+
if (!valid_inputs) {
181+
ORT_THROW("Invalid inputs for call to BuildDoubleQDQTestCaseWithDuplicateLastDQs()");
182+
}
183+
184+
return [=](ModelTestBuilder& builder) {
185+
auto* input_arg = builder.MakeInput<float>(input_shape, input_data);
186+
std::vector<NodeArg*> node_outputs(num_nodes);
187+
188+
for (size_t i = 0; i < num_nodes; i++) {
189+
if (i == graph_output_index || i >= 3) {
190+
node_outputs[i] = builder.MakeOutput();
191+
} else {
192+
node_outputs[i] = builder.MakeIntermediate();
193+
}
194+
}
195+
196+
builder.AddQuantizeLinearNode(input_arg, scales[0], zero_points[0], zero_point_types[0], node_outputs[0],
197+
use_contrib_qdq);
198+
builder.AddDequantizeLinearNode(node_outputs[0], scales[1], zero_points[1], zero_point_types[1], node_outputs[1],
199+
use_contrib_qdq);
200+
builder.AddQuantizeLinearNode(node_outputs[1], scales[2], zero_points[2], zero_point_types[2], node_outputs[2],
201+
use_contrib_qdq);
202+
203+
for (size_t i = 3; i < num_nodes; i++) {
204+
builder.AddDequantizeLinearNode(node_outputs[2], scales[i], zero_points[i], zero_point_types[i],
205+
node_outputs[i], use_contrib_qdq);
206+
}
207+
};
208+
}
209+
167210
} // namespace test
168211
} // namespace onnxruntime

onnxruntime/test/optimizer/qdq_test_utils.h

+21
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,27 @@ GetQDQTestCaseFn BuildDoubleQDQTestCases(Type1 zp_1, Type2 zp_2, Type3 zp_3, Typ
460460
};
461461
}
462462

463+
/// <summary>
464+
/// Returns a function that builds a model with a double QDQ sequence (Q1 -> DQ1 -> Q2 -> DQ2*),
465+
/// where DQ2 can be repeated. Must provide at least 4 zero-point and scale values.
466+
/// </summary>
467+
/// <param name="input_shape">Shape of input float data.</param>
468+
/// <param name="input_data">Input float data.</param>
469+
/// <param name="zero_points">Ordered list of zero-point values for each node in the sequence.</param>
470+
/// <param name="zero_point_types">Ordered list of zero-point types for each node in the sequence.</param>
471+
/// <param name="zero_points">Ordered list of scale values for each node in the sequence.</param>
472+
/// <param name="graph_output_index">Index of the node that provides a graph output.</param>
473+
/// <param name="use_contrib_qdq">Set to true to use the 'com.microsoft' domain for Q and DQ ops.</param>
474+
/// <returns>A function for building the model</returns>
475+
GetQDQTestCaseFn BuildDoubleQDQTestCaseWithDuplicateLastDQs(
476+
const std::vector<int64_t>& input_shape,
477+
const std::vector<float>& input_data,
478+
const std::vector<int64_t>& zero_points,
479+
const std::vector<ONNX_NAMESPACE::TensorProto_DataType>& zero_point_types,
480+
const std::vector<float>& scales,
481+
int graph_output_index,
482+
bool use_contrib_qdq = false);
483+
463484
template <typename T>
464485
GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_contrib_qdq = false) {
465486
return [=](ModelTestBuilder& builder) {

onnxruntime/test/optimizer/qdq_transformer_test.cc

+91
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "core/graph/model.h"
88
#include "core/graph/onnx_protobuf.h"
99
#include "core/mlas/inc/mlas.h"
10+
#include "core/optimizer/double_qdq_pairs_remover.h"
1011
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
1112
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
1213
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
@@ -1233,6 +1234,96 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) {
12331234
RunDoubleQDQWithoutLastNodeBeingOutput<uint16_t>(3, 1, 1, !use_ms_qdq, 21);
12341235
}
12351236

1237+
static void RunDoubleQDQWithDuplicateLastDQs(int expected_Q_count, int expected_DQ_count,
1238+
const std::vector<int64_t>& input_shape,
1239+
const std::vector<float>& input_data,
1240+
const std::vector<int64_t>& zero_points,
1241+
const std::vector<ONNX_NAMESPACE::TensorProto_DataType>& zero_point_types,
1242+
const std::vector<float>& scales,
1243+
int graph_output_index,
1244+
bool use_contrib_qdq = false,
1245+
int opset = 19) {
1246+
auto graph_checker = [&](InferenceSessionWrapper& session) {
1247+
auto op_to_count = CountOpsInGraph(session.GetGraph());
1248+
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
1249+
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_Q_count);
1250+
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_DQ_count);
1251+
};
1252+
1253+
auto model_build_fn = BuildDoubleQDQTestCaseWithDuplicateLastDQs(input_shape, input_data, zero_points,
1254+
zero_point_types, scales, graph_output_index,
1255+
use_contrib_qdq);
1256+
TransformerTester(model_build_fn,
1257+
graph_checker,
1258+
TransformerLevel::Default,
1259+
TransformerLevel::Level1,
1260+
opset,
1261+
/*per_sample_tolerance*/ 0.0,
1262+
/*relative_per_sample_tolerance*/ 0.0,
1263+
std::make_unique<DoubleQDQPairsRemover>());
1264+
}
1265+
1266+
// Test QDQDoublePairsRemover when the sequence ends with duplicate DQs.
1267+
TEST(QDQTransformerTests, DoubleQDQPairsRemover_DuplicateLastDQs) {
1268+
std::vector<int64_t> shape = {1, 2, 2, 2};
1269+
std::vector<float> input_data = {-3.0f, -2.0f, -1.0f, 0.0f, 0.5f, 1.0f, 2.0f, 3.0f};
1270+
1271+
constexpr auto int8_type = ONNX_NAMESPACE::TensorProto_DataType_INT8;
1272+
constexpr auto uint8_type = ONNX_NAMESPACE::TensorProto_DataType_UINT8;
1273+
constexpr auto int16_type = ONNX_NAMESPACE::TensorProto_DataType_INT16;
1274+
constexpr auto uint16_type = ONNX_NAMESPACE::TensorProto_DataType_UINT16;
1275+
std::vector<ONNX_NAMESPACE::TensorProto_DataType> quant_types = {int8_type, uint8_type, int16_type, uint16_type};
1276+
1277+
// Input graph:
1278+
// input -> Q1 -> DQ1 -> Q2 --+--> DQ2 -> output0
1279+
// |
1280+
// ...
1281+
// |
1282+
// +--> DQ2'' -> outputN
1283+
// Expected graph after DoubleQDQPairsRemover:
1284+
// input -> Q1 --+--> DQ2 -> output0
1285+
// |
1286+
// ...
1287+
// |
1288+
// +--> DQ2'' -> outputN
1289+
for (auto quant_type : quant_types) {
1290+
for (size_t num_dq2s = 1; num_dq2s <= 1; num_dq2s++) {
1291+
const size_t num_nodes = 3 + num_dq2s;
1292+
std::vector<int64_t> zp_vals(num_nodes, 1);
1293+
std::vector<ONNX_NAMESPACE::TensorProto_DataType> zp_types(num_nodes, quant_type);
1294+
std::vector<float> scale_vals(num_nodes, 0.1f);
1295+
1296+
const int expected_q_nodes = 1;
1297+
const int expected_dq_nodes = static_cast<int>(num_dq2s);
1298+
RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types,
1299+
scale_vals, 3, false, 21);
1300+
RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types,
1301+
scale_vals, 3, quant_type == int16_type || quant_type == uint16_type, 19);
1302+
}
1303+
}
1304+
1305+
// Should not remove QDQ pair because the middle nodes produce a graph output.
1306+
for (auto quant_type : quant_types) {
1307+
for (int output_index = 0; output_index < 3; output_index++) {
1308+
for (size_t num_dq2s = 1; num_dq2s <= 1; num_dq2s++) {
1309+
const size_t num_nodes = 3 + num_dq2s;
1310+
std::vector<int64_t> zp_vals(num_nodes, 1);
1311+
std::vector<ONNX_NAMESPACE::TensorProto_DataType> zp_types(num_nodes, quant_type);
1312+
std::vector<float> scale_vals(num_nodes, 0.1f);
1313+
1314+
const int expected_q_nodes = 2;
1315+
int expected_dq_nodes = 1 + static_cast<int>(num_dq2s);
1316+
if (output_index == 1) {
1317+
// EnsureUniqueDQ pass will create a duplicate DQ if it produces a graph output.
1318+
expected_dq_nodes += 1;
1319+
}
1320+
RunDoubleQDQWithDuplicateLastDQs(expected_q_nodes, expected_dq_nodes, shape, input_data, zp_vals, zp_types,
1321+
scale_vals, output_index, false, 21);
1322+
}
1323+
}
1324+
}
1325+
}
1326+
12361327
// Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split.
12371328
template <typename QuantType>
12381329
static void RunDropSplitQDQTestCase(const std::vector<int64_t>& input_shape, int64_t axis,

0 commit comments

Comments
 (0)