Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float&
return true;
}

bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) ||
!graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
Expand All @@ -95,6 +95,10 @@ bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, con
return false;
}

if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}

return true;
}

Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3221,6 +3221,37 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) {
test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128
}

// Test skip removing edge when min/max come from DequantizeLinear nodes instead of initializers).
TEST(QDQTransformerTests, ClipQuantFusion_MultipleInputEdges) {
auto build_test_case = [&](ModelTestBuilder& builder) {
// Clip's min coming from another DQ node (creating 2 input edges to Clip)
auto* input_arg = builder.MakeInput<uint8_t>({1, 2, 2, 2}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* data_dq = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.04f, static_cast<uint8_t>(0), data_dq);
auto* min_q = builder.MakeScalarInitializer<uint8_t>(0);
auto* min_dq = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(min_q, 0.04f, static_cast<uint8_t>(0), min_dq);
auto* clip_output = builder.MakeIntermediate();
builder.AddNode("Clip", {data_dq, min_dq}, {clip_output});
auto* output_q = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<uint8_t>(clip_output, 0.04f, static_cast<uint8_t>(0), output_q);
auto* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(output_q, 0.04f, static_cast<uint8_t>(0), output_arg);
};

auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
// ClipQuantFusion should skip it due to CanRemoveNode check
EXPECT_EQ(op_to_count["Clip"], 1);
};

TransformerTester(build_test_case, check_graph,
TransformerLevel::Default,
TransformerLevel::Level2,
18); // opset
}

template <typename ScaleType, typename ZpType>
void TestWhereWithDqInput(bool is_dq_1,
bool is_dq_2,
Expand Down
Loading