diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a1859b9d7071b..faeee1abd07fc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -81,7 +81,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return true; } -bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { +bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) || !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { @@ -95,6 +95,10 @@ bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, con return false; } + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + return true; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1cda0c7965430..fe84f27a9105c 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3221,6 +3221,37 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 } +// Test skip removing edge when min/max come from DequantizeLinear nodes instead of initializers). +TEST(QDQTransformerTests, ClipQuantFusion_MultipleInputEdges) { + auto build_test_case = [&](ModelTestBuilder& builder) { + // Clip's min coming from another DQ node (creating 2 input edges to Clip) + auto* input_arg = builder.MakeInput({1, 2, 2, 2}, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* data_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, 0.04f, static_cast(0), data_dq); + auto* min_q = builder.MakeScalarInitializer(0); + auto* min_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(min_q, 0.04f, static_cast(0), min_dq); + auto* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {data_dq, min_dq}, {clip_output}); + auto* output_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(clip_output, 0.04f, static_cast(0), output_q); + auto* output_arg = builder.MakeOutput(); + builder.AddDequantizeLinearNode(output_q, 0.04f, static_cast(0), output_arg); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + // ClipQuantFusion should skip it due to CanRemoveNode check + EXPECT_EQ(op_to_count["Clip"], 1); + }; + + TransformerTester(build_test_case, check_graph, + TransformerLevel::Default, + TransformerLevel::Level2, + 18); // opset +} + template void TestWhereWithDqInput(bool is_dq_1, bool is_dq_2,