diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 81a842eb87db1..10cb6eb97bdd6 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1654,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) { constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit}; -static bool HandleConcat(HandlerArgs& args) { +bool HandleConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat}; // Handles Softmax, Hardmax, and LogSoftmax -static bool HandleSoftHardMax(HandlerArgs& args) { +bool HandleSoftHardMax(HandlerArgs& args) { if (args.ctx.opset >= 13) { return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 0095ead75f0c8..f65bd6aa82fbb 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args); // Transposes all inputs and all outputs. Updates axis attribute. bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_axis = std::nullopt); +bool HandleConcat(HandlerArgs& args); +bool HandleSoftHardMax(HandlerArgs& args); + // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); bool HandleResize([[maybe_unused]] HandlerArgs& args); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index 8eaac3d34c3af..824ab20a84668 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) { constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; -static bool HandleQLinearConcat(HandlerArgs& args) { - return HandleSimpleNodeWithAxis(args); -} - std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { (void)ctx; std::vector indices; @@ -48,11 +44,7 @@ std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { return indices; } -constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat}; - -static bool HandleQLinearBinaryOp(HandlerArgs& args) { - return HandleSimpleNodeBroadcast(args); -} +constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat}; std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { // Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point], @@ -60,7 +52,7 @@ std::vector QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) { return {0, 3}; } -constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp}; +constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast}; static bool HandleQLinearPoolOp(HandlerArgs& args) { // Swap between channel first/last variants. Only works for applicable values of perm. @@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; +constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, &HandleContribQuantizeDequantizeLinear}; @@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() { {"com.microsoft.QLinearMul", q_linear_binary_op_handler}, {"com.microsoft.QLinearReduceMean", reduce_op_handler}, {"com.microsoft.QLinearSigmoid", node_1_inp_handler}, + {"com.microsoft.QLinearSoftmax", soft_hard_max_handler}, }; return map; diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 35ba1a3369597..f6fce37322c10 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -22,6 +22,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_utils.h" @@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) { /*opset_version*/ {15, 18}); } +TEST(TransposeOptimizerTests, TestQLinearSoftmax) { + auto build_test_case_1 = [&](ModelTestBuilder& builder) { + auto* input0_arg = MakeInput(builder, std::nullopt, {1, 384, 384, 21}, 0, 255); + auto* transpose_1_out_0 = builder.MakeIntermediate(); + auto* input_x_scale = builder.MakeScalarInitializer(0.5086354613304138); + auto* input_x_zero_point = builder.MakeScalarInitializer(74); + auto* input_y_scale = builder.MakeScalarInitializer(0.003921568859368563); + auto* input_y_zero_point = builder.MakeScalarInitializer(0); + auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate(); + auto* transpose_2_out_0 = builder.MakeOutput(); + + auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0}); + transpose_1.AddAttribute("perm", std::vector{0, 3, 1, 2}); + auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax", + {transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point}, + {qlinearsoftmax_1_out_0}, kMSDomain); + qlinearsoftmax_1.AddAttribute("axis", static_cast(1)); + qlinearsoftmax_1.AddAttribute("opset", static_cast(13)); + auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0}); + transpose_2.AddAttribute("perm", std::vector{0, 2, 3, 1}); + }; + + auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) { + int transpose_cost = EstimateTransposeCost(session.GetGraph()); + EXPECT_EQ(transpose_cost, 0); + }; + + TransformerTester(build_test_case_1, + check_optimized_graph_1, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 13, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + /*transformer*/ nullptr, + /*add_session_options*/ {}, + /*disabled_optimizers*/ {}, + /*ep*/ DefaultCpuExecutionProvider()); +} + TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0);