From b51069dd79bd1ca289355df43aa3df56d156d4be Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Wed, 19 Jul 2023 14:12:13 +0400 Subject: [PATCH] Eliminates Nop Broadcast/Tile and Slice Before GatherElements (#18614) --- .../common_optimizations/nop_elimination.hpp | 24 ++++++ .../include/transformations/utils/utils.hpp | 1 + .../common_optimizations/nop_elimination.cpp | 43 ++++++++++ .../src/transformations/utils/utils.cpp | 13 +++ .../common_optimizations/nop_elimination.cpp | 84 +++++++++++++++++++ 5 files changed, 165 insertions(+) diff --git a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp index 387d3eb9016368..3cca0736f5bd54 100644 --- a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp @@ -20,6 +20,8 @@ class TRANSFORMATIONS_API EliminateSplit; class TRANSFORMATIONS_API EliminateSplitConcat; class TRANSFORMATIONS_API EliminateSqueeze; class TRANSFORMATIONS_API EliminateTranspose; +class TRANSFORMATIONS_API EliminateNopBroadcast; +class TRANSFORMATIONS_API NopSliceBeforeGatherElements; class TRANSFORMATIONS_API NopElimination; } // namespace pass @@ -130,3 +132,25 @@ class ov::pass::EliminateSplitConcat : public ov::pass::MatcherPass { OPENVINO_RTTI("EliminateSplitConcat", "0"); EliminateSplitConcat(); }; + +/** + * @ingroup ie_transformation_comm on_api + * @brief EliminateNopBroadcast eliminates broadcast or tile with all ones on the second input + */ +class ov::pass::EliminateNopBroadcast : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("EliminateNopBroadcast", "0"); + EliminateNopBroadcast(); +}; + +/** + * @ingroup ie_transformation_comm on_api + * @brief NopSliceBeforeGatherElements eliminates slice before GElements if slicing from 0 + * It is valid since GatherElements doesn't support negative indices and Slice won't affect + * indexing of elements in the original tensor that GatherElements would like to take + */ +class ov::pass::NopSliceBeforeGatherElements : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("NopSliceBeforeGatherElements", "0"); + NopSliceBeforeGatherElements(); +}; diff --git a/src/common/transformations/include/transformations/utils/utils.hpp b/src/common/transformations/include/transformations/utils/utils.hpp index 9f4137d063a350..072f965816c336 100644 --- a/src/common/transformations/include/transformations/utils/utils.hpp +++ b/src/common/transformations/include/transformations/utils/utils.hpp @@ -221,6 +221,7 @@ TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output& node); TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr& eltwise, const Output& constant, const Output& non_constant_input); +TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output& output, const int64_t& v); } // namespace util } // namespace op } // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 126a9c996aab1d..33d7decc8eb1ac 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -788,6 +788,47 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() { this->register_matcher(m, callback); } +ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() { + MATCHER_SCOPE(EliminateNopBroadcast); + auto root = pattern::wrap_type( + pattern::op::as_value_predicate([](std::shared_ptr node) { + auto input_rank = node->get_input_partial_shape(0).rank(); + auto output_rank = node->get_output_partial_shape(0).rank(); + return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank; + })); + + ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) { + const auto& op = m.get_match_root(); + if (op::util::is_constant_and_all_values_equal_int(op->input_value(1), 1)) + return replace_output_update_name(op->output(0), op->input_value(0)); + return false; + }; + + auto m = std::make_shared(root, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::NopSliceBeforeGatherElements::NopSliceBeforeGatherElements() { + MATCHER_SCOPE(NopSliceBeforeGatherElements); + auto slice = pattern::wrap_type(); + auto gather = pattern::wrap_type({slice, pattern::any_input()}); + + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& pattern_to_node = m.get_pattern_map(); + const auto& slice_node = pattern_to_node.at(slice); + bool start_from_zero = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(1), 0); + bool step_is_one = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(3), 1); + if (!start_from_zero || !step_is_one) + return false; + const auto& gather_node = pattern_to_node.at(gather); + gather_node->input(0).replace_source_output(slice_node->input_value(0)); + return true; + }; + + auto m = std::make_shared(gather, matcher_name); + register_matcher(m, matcher_pass_callback); +} + ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { // shape-agnostic transformations ADD_MATCHER_FOR_THIS(EliminatePad) @@ -807,6 +848,8 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { ADD_MATCHER_FOR_THIS(EliminateSqueeze) ADD_MATCHER_FOR_THIS(EliminateUnsqueeze) ADD_MATCHER_FOR_THIS(EliminateBroadcast) + ADD_MATCHER_FOR_THIS(EliminateNopBroadcast) + ADD_MATCHER_FOR_THIS(NopSliceBeforeGatherElements) ADD_MATCHER_FOR_THIS(EliminateGather) } } diff --git a/src/common/transformations/src/transformations/utils/utils.cpp b/src/common/transformations/src/transformations/utils/utils.cpp index 6b0b8018b4095e..be3e39b79b4dda 100644 --- a/src/common/transformations/src/transformations/utils/utils.cpp +++ b/src/common/transformations/src/transformations/utils/utils.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -353,6 +354,18 @@ float cast_eps_to_float(double eps_d) { return eps_f; } +bool is_constant_and_all_values_equal_int(const Output& output, const int64_t& v) { + OPENVINO_SUPPRESS_DEPRECATED_START + if (const auto& constant = ov::get_constant_from_source(output)) { + OPENVINO_SUPPRESS_DEPRECATED_END + const auto& values = constant->cast_vector(); + return std::all_of(values.begin(), values.end(), [&](const int64_t& i) { + return i == v; + }); + } + return false; +} + } // namespace util } // namespace op } // namespace ov diff --git a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp index d320037b97b129..0e1adebc1d6fb7 100644 --- a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp @@ -1338,3 +1338,87 @@ TEST(nop_elimination, gather_to_squeeze) { run_and_check(func_axis_2); run_and_check(func_axis_3); } + +TEST_F(TransformationTestsF, Nopv1Broadcast) { + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1}); + auto broadcast = std::make_shared(data, broadcast_shape); + auto relu = std::make_shared(broadcast); + auto result = std::make_shared(relu); + model = std::make_shared(ResultVector{result}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto relu = std::make_shared(data); + auto result = std::make_shared(relu); + model_ref = std::make_shared(ResultVector{result}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, Nopv3Broadcast) { + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1}); + auto broadcast = std::make_shared(data, broadcast_shape); + auto relu = std::make_shared(broadcast); + auto result = std::make_shared(relu); + model = std::make_shared(ResultVector{result}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto relu = std::make_shared(data); + auto result = std::make_shared(relu); + model_ref = std::make_shared(ResultVector{result}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, NopTile) { + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto repeats = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1}); + auto tile = std::make_shared(data, repeats); + auto relu = std::make_shared(tile); + auto result = std::make_shared(relu); + model = std::make_shared(ResultVector{result}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto relu = std::make_shared(data); + auto result = std::make_shared(relu); + model_ref = std::make_shared(ResultVector{result}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, NopSliceBeforeGatherElements) { + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + + auto start = opset10::Constant::create(element::i32, Shape{1}, {0}); + auto stop = opset10::Constant::create(element::i32, Shape{1}, {2}); + auto step = opset10::Constant::create(element::i32, Shape{1}, {1}); + auto axis = opset10::Constant::create(element::i32, Shape{1}, {-1}); + auto slice = std::make_shared(data, start, stop, step, axis); + + auto indices = std::make_shared(element::i64, PartialShape{-1, -1, -1, -1}); + auto gather_elements = std::make_shared(slice, indices, 2); + + auto relu = std::make_shared(gather_elements); + auto result = std::make_shared(relu); + model = std::make_shared(ResultVector{result}, ParameterVector{data, indices}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto indices = std::make_shared(element::i64, PartialShape{-1, -1, -1, -1}); + + auto gather_elements = std::make_shared(data, indices, 2); + + auto relu = std::make_shared(gather_elements); + auto result = std::make_shared(relu); + model_ref = std::make_shared(ResultVector{result}, ParameterVector{data, indices}); + } +}