Skip to content

Commit

Permalink
Optimize Transpose around QLinearSoftmax (#22849)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

- Improved Transpose around QLinearSoftmax in Level 3 NHWC Transformer.
- Removed redundant code HandleQLinearConcat, HandleQLinearBinaryOp.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

By merging and eliminating redundant transpose , the Image Segmentation
i8 model (MobileNetv2 + DeepLabv3) achieves a 2.34X speedup.
  • Loading branch information
yihonglyu authored Nov 18, 2024
1 parent 135d8b2 commit 02a0be3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1654,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) {

constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit};

static bool HandleConcat(HandlerArgs& args) {
bool HandleConcat(HandlerArgs& args) {
return HandleSimpleNodeWithAxis(args);
}

constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat};

// Handles Softmax, Hardmax, and LogSoftmax
static bool HandleSoftHardMax(HandlerArgs& args) {
bool HandleSoftHardMax(HandlerArgs& args) {
if (args.ctx.opset >= 13) {
return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args);
// Transposes all inputs and all outputs. Updates axis attribute.
bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis = std::nullopt);

bool HandleConcat(HandlerArgs& args);
bool HandleSoftHardMax(HandlerArgs& args);

// base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed.
bool HandleReduceOps(HandlerArgs& args);
bool HandleResize([[maybe_unused]] HandlerArgs& args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) {

constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize};

static bool HandleQLinearConcat(HandlerArgs& args) {
return HandleSimpleNodeWithAxis(args);
}

std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
(void)ctx;
std::vector<size_t> indices;
Expand All @@ -48,19 +44,15 @@ std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
return indices;
}

constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat};

static bool HandleQLinearBinaryOp(HandlerArgs& args) {
return HandleSimpleNodeBroadcast(args);
}
constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat};

std::vector<size_t> QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) {
// Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point],
// we want [A, B].
return {0, 3};
}

constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp};
constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast};

static bool HandleQLinearPoolOp(HandlerArgs& args) {
// Swap between channel first/last variants. Only works for applicable values of perm.
Expand Down Expand Up @@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool};

constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode};
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};
constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax};
constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput,
&HandleContribQuantizeDequantizeLinear};

Expand All @@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() {
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
{"com.microsoft.QLinearSoftmax", soft_hard_max_handler},
};

return map;
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
#include "test/util/include/asserts.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test_utils.h"

Expand Down Expand Up @@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) {
/*opset_version*/ {15, 18});
}

TEST(TransposeOptimizerTests, TestQLinearSoftmax) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<uint8_t>(builder, std::nullopt, {1, 384, 384, 21}, 0, 255);
auto* transpose_1_out_0 = builder.MakeIntermediate();
auto* input_x_scale = builder.MakeScalarInitializer<float>(0.5086354613304138);
auto* input_x_zero_point = builder.MakeScalarInitializer<uint8_t>(74);
auto* input_y_scale = builder.MakeScalarInitializer<float>(0.003921568859368563);
auto* input_y_zero_point = builder.MakeScalarInitializer<uint8_t>(0);
auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate();
auto* transpose_2_out_0 = builder.MakeOutput();

auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0});
transpose_1.AddAttribute("perm", std::vector<int64_t>{0, 3, 1, 2});
auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax",
{transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point},
{qlinearsoftmax_1_out_0}, kMSDomain);
qlinearsoftmax_1.AddAttribute("axis", static_cast<int64_t>(1));
qlinearsoftmax_1.AddAttribute("opset", static_cast<int64_t>(13));
auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0});
transpose_2.AddAttribute("perm", std::vector<int64_t>{0, 2, 3, 1});
};

auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) {
int transpose_cost = EstimateTransposeCost(session.GetGraph());
EXPECT_EQ(transpose_cost, 0);
};

TransformerTester(build_test_case_1,
check_optimized_graph_1,
TransformerLevel::Level2,
TransformerLevel::Level3,
/*opset_version*/ 13,
/*per_sample_tolerance*/ 0.0,
/*relative_per_sample_tolerance*/ 0.0,
/*transformer*/ nullptr,
/*add_session_options*/ {},
/*disabled_optimizers*/ {},
/*ep*/ DefaultCpuExecutionProvider());
}

TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<float>(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0);
Expand Down

0 comments on commit 02a0be3

Please sign in to comment.