Skip to content

Commit fbf966a

Browse files
committed
add: detect Q/DQ with int16/uint16 initializers
1 parent ed9e425 commit fbf966a

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,61 @@ static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) {
387387
return false;
388388
}
389389

390+
// Check to see if the graph has Q/DQ nodes with int16 or uint16 quantization
391+
static bool IsQDQGraphWithUint16OrInt16(const onnxruntime::GraphViewer& graph_viewer) {
392+
std::unordered_set<std::string> qdq_ops = {"QuantizeLinear", "DequantizeLinear"};
393+
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
394+
395+
// Check if a NodeArg tensor is 16-bit quantized (UINT16 or INT16)
396+
auto is_16bit_tensor = [](const onnxruntime::NodeArg* node_arg) -> bool {
397+
if (!node_arg) return false;
398+
const auto* type_proto = node_arg->TypeAsProto();
399+
if (type_proto && type_proto->has_tensor_type()) {
400+
auto elem_type = type_proto->tensor_type().elem_type();
401+
return (elem_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16 ||
402+
elem_type == ONNX_NAMESPACE::TensorProto_DataType_INT16);
403+
}
404+
return false;
405+
};
406+
407+
for (size_t i = 0; i < node_indices.size(); i++) {
408+
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));
409+
410+
if (qdq_ops.find(node->OpType()) != qdq_ops.end()) {
411+
const auto& input_defs = node->InputDefs();
412+
413+
if (node->OpType() == "DequantizeLinear") {
414+
// DequantizeLinear: [quantized_input, scale, zero_point] -> [float_output]
415+
// The quantized input tensor (index 0) determines the quantization type
416+
if (is_16bit_tensor(input_defs.empty() ? nullptr : input_defs[0])) {
417+
return true;
418+
}
419+
420+
// Zero point (index 2) must match quantized tensor type per ONNX spec
421+
// It's optional - absent for INT32 and some float8 types
422+
if (input_defs.size() >= 3 && is_16bit_tensor(input_defs[2])) {
423+
return true;
424+
}
425+
}
426+
else if (node->OpType() == "QuantizeLinear") {
427+
// QuantizeLinear: [float_input, scale, zero_point] -> [quantized_output]
428+
// The quantized output tensor determines the quantization type
429+
const auto& output_defs = node->OutputDefs();
430+
if (is_16bit_tensor(output_defs.empty() ? nullptr : output_defs[0])) {
431+
return true;
432+
}
433+
434+
// Zero point (index 2) must match quantized tensor type per ONNX spec
435+
// It's optional - absent for INT32 and some float8 types
436+
if (input_defs.size() >= 3 && is_16bit_tensor(input_defs[2])) {
437+
return true;
438+
}
439+
}
440+
}
441+
}
442+
return false;
443+
}
444+
390445
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
391446
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
392447
[[maybe_unused]] const onnxruntime::Node& fused_node) {
@@ -445,6 +500,10 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
445500
}
446501
#endif
447502

503+
// Check if the graph is QDQ and has int16 or uint16 quantization
504+
// If so, we will apply the QDQ scales fix transformation (for GPU device only)
505+
bool is_qdq_graph_uint16_or_int16 = IsQDQGraphWithUint16OrInt16(subgraph);
506+
448507
const auto& onnx_model_path_name = subgraph.ModelPath();
449508
// QDQ stripping enabled only for the NPU and experimentally on the GPU
450509
if ((session_context_.device_type.find("NPU") != std::string::npos) &&
@@ -458,7 +517,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
458517
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
459518
return model_proto;
460519
} else if ((session_context_.device_type.find("GPU") != std::string::npos) &&
461-
enable_ovep_qdq_optimizer) {
520+
is_qdq_graph_uint16_or_int16) {
462521
// Create a copy of the model
463522
std::unique_ptr<onnxruntime::Model> model;
464523
Status status = qdq_scales_fix::Transform(subgraph, logger, model);

0 commit comments

Comments
 (0)