Skip to content

Commit

Permalink
Eliminates Nop Broadcast/Tile and Slice Before GatherElements (openvi…
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgenya Stepyreva authored Jul 19, 2023
1 parent 53fe969 commit b51069d
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class TRANSFORMATIONS_API EliminateSplit;
class TRANSFORMATIONS_API EliminateSplitConcat;
class TRANSFORMATIONS_API EliminateSqueeze;
class TRANSFORMATIONS_API EliminateTranspose;
class TRANSFORMATIONS_API EliminateNopBroadcast;
class TRANSFORMATIONS_API NopSliceBeforeGatherElements;
class TRANSFORMATIONS_API NopElimination;

} // namespace pass
Expand Down Expand Up @@ -130,3 +132,25 @@ class ov::pass::EliminateSplitConcat : public ov::pass::MatcherPass {
OPENVINO_RTTI("EliminateSplitConcat", "0");
EliminateSplitConcat();
};

/**
* @ingroup ie_transformation_comm on_api
* @brief EliminateNopBroadcast eliminates broadcast or tile with all ones on the second input
*/
class ov::pass::EliminateNopBroadcast : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("EliminateNopBroadcast", "0");
EliminateNopBroadcast();
};

/**
* @ingroup ie_transformation_comm on_api
* @brief NopSliceBeforeGatherElements eliminates slice before GElements if slicing from 0
* It is valid since GatherElements doesn't support negative indices and Slice won't affect
* indexing of elements in the original tensor that GatherElements would like to take
*/
class ov::pass::NopSliceBeforeGatherElements : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("NopSliceBeforeGatherElements", "0");
NopSliceBeforeGatherElements();
};
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);
TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise,
const Output<Node>& constant,
const Output<Node>& non_constant_input);
TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v);
} // namespace util
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,47 @@ pass::EliminateScatterUpdate::EliminateScatterUpdate() {
this->register_matcher(m, callback);
}

ov::pass::EliminateNopBroadcast::EliminateNopBroadcast() {
MATCHER_SCOPE(EliminateNopBroadcast);
auto root = pattern::wrap_type<op::v1::Broadcast, op::v3::Broadcast, op::v0::Tile>(
pattern::op::as_value_predicate([](std::shared_ptr<Node> node) {
auto input_rank = node->get_input_partial_shape(0).rank();
auto output_rank = node->get_output_partial_shape(0).rank();
return input_rank.is_static() && output_rank.is_static() && input_rank == output_rank;
}));

ov::matcher_pass_callback matcher_pass_callback = [](pattern::Matcher& m) {
const auto& op = m.get_match_root();
if (op::util::is_constant_and_all_values_equal_int(op->input_value(1), 1))
return replace_output_update_name(op->output(0), op->input_value(0));
return false;
};

auto m = std::make_shared<pattern::Matcher>(root, matcher_name);
register_matcher(m, matcher_pass_callback);
}

ov::pass::NopSliceBeforeGatherElements::NopSliceBeforeGatherElements() {
MATCHER_SCOPE(NopSliceBeforeGatherElements);
auto slice = pattern::wrap_type<op::v8::Slice>();
auto gather = pattern::wrap_type<op::v6::GatherElements>({slice, pattern::any_input()});

ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();
const auto& slice_node = pattern_to_node.at(slice);
bool start_from_zero = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(1), 0);
bool step_is_one = op::util::is_constant_and_all_values_equal_int(slice_node->input_value(3), 1);
if (!start_from_zero || !step_is_one)
return false;
const auto& gather_node = pattern_to_node.at(gather);
gather_node->input(0).replace_source_output(slice_node->input_value(0));
return true;
};

auto m = std::make_shared<pattern::Matcher>(gather, matcher_name);
register_matcher(m, matcher_pass_callback);
}

ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
// shape-agnostic transformations
ADD_MATCHER_FOR_THIS(EliminatePad)
Expand All @@ -807,6 +848,8 @@ ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
ADD_MATCHER_FOR_THIS(EliminateSqueeze)
ADD_MATCHER_FOR_THIS(EliminateUnsqueeze)
ADD_MATCHER_FOR_THIS(EliminateBroadcast)
ADD_MATCHER_FOR_THIS(EliminateNopBroadcast)
ADD_MATCHER_FOR_THIS(NopSliceBeforeGatherElements)
ADD_MATCHER_FOR_THIS(EliminateGather)
}
}
13 changes: 13 additions & 0 deletions src/common/transformations/src/transformations/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <functional>
#include <memory>
#include <ngraph/op/util/op_annotations.hpp>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/gather.hpp>
Expand Down Expand Up @@ -353,6 +354,18 @@ float cast_eps_to_float(double eps_d) {
return eps_f;
}

bool is_constant_and_all_values_equal_int(const Output<Node>& output, const int64_t& v) {
OPENVINO_SUPPRESS_DEPRECATED_START
if (const auto& constant = ov::get_constant_from_source(output)) {
OPENVINO_SUPPRESS_DEPRECATED_END
const auto& values = constant->cast_vector<int64_t>();
return std::all_of(values.begin(), values.end(), [&](const int64_t& i) {
return i == v;
});
}
return false;
}

} // namespace util
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -1338,3 +1338,87 @@ TEST(nop_elimination, gather_to_squeeze) {
run_and_check(func_axis_2);
run_and_check(func_axis_3);
}

TEST_F(TransformationTestsF, Nopv1Broadcast) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto broadcast = std::make_shared<op::v1::Broadcast>(data, broadcast_shape);
auto relu = std::make_shared<op::v0::Relu>(broadcast);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}

TEST_F(TransformationTestsF, Nopv3Broadcast) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto broadcast_shape = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto broadcast = std::make_shared<op::v3::Broadcast>(data, broadcast_shape);
auto relu = std::make_shared<op::v0::Relu>(broadcast);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}

TEST_F(TransformationTestsF, NopTile) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto repeats = opset10::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1});
auto tile = std::make_shared<op::v0::Tile>(data, repeats);
auto relu = std::make_shared<op::v0::Relu>(tile);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<ov::pass::EliminateNopBroadcast>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto relu = std::make_shared<op::v0::Relu>(data);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data});
}
}

TEST_F(TransformationTestsF, NopSliceBeforeGatherElements) {
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});

auto start = opset10::Constant::create(element::i32, Shape{1}, {0});
auto stop = opset10::Constant::create(element::i32, Shape{1}, {2});
auto step = opset10::Constant::create(element::i32, Shape{1}, {1});
auto axis = opset10::Constant::create(element::i32, Shape{1}, {-1});
auto slice = std::make_shared<op::v8::Slice>(data, start, stop, step, axis);

auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});
auto gather_elements = std::make_shared<op::v6::GatherElements>(slice, indices, 2);

auto relu = std::make_shared<op::v0::Relu>(gather_elements);
auto result = std::make_shared<opset10::Result>(relu);
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
manager.register_pass<ov::pass::NopSliceBeforeGatherElements>();
}
{
auto data = std::make_shared<opset10::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto indices = std::make_shared<opset10::Parameter>(element::i64, PartialShape{-1, -1, -1, -1});

auto gather_elements = std::make_shared<op::v6::GatherElements>(data, indices, 2);

auto relu = std::make_shared<op::v0::Relu>(gather_elements);
auto result = std::make_shared<opset10::Result>(relu);
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{data, indices});
}
}

0 comments on commit b51069d

Please sign in to comment.