From d093c7465f8ebd8ab4f4d586bb9f4ffd3aaeef08 Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Tue, 5 Mar 2024 13:24:47 +0100 Subject: [PATCH] [core] Optimize ScatterElementsUpdate reference implementation binary size (#23146) ### Details: - Adds tests into `ov_template_func_tests` for ScatterElementsUpdate version 12. - Removes Indices type from template parameters for internal template function - all works on `int64_t`. - Uses `std::memcpy` instead of type dependent assignment. ### Tickets: - CVS-119213 --- .../reference/scatter_elements_update.hpp | 109 ++++---- .../op_reference/scatter_elements_update.cpp | 246 ++++++++++++++++-- 2 files changed, 284 insertions(+), 71 deletions(-) diff --git a/src/core/reference/include/openvino/reference/scatter_elements_update.hpp b/src/core/reference/include/openvino/reference/scatter_elements_update.hpp index 6e10914a5f5311..1cf8cd62a97196 100644 --- a/src/core/reference/include/openvino/reference/scatter_elements_update.hpp +++ b/src/core/reference/include/openvino/reference/scatter_elements_update.hpp @@ -11,6 +11,7 @@ #include "openvino/core/except.hpp" #include "openvino/core/shape.hpp" #include "openvino/op/scatter_elements_update.hpp" +#include "openvino/reference/utils/coordinate_index.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" namespace ov { @@ -26,43 +27,16 @@ size_t normalize_index(const T idx, const size_t dim_value) { } } -template -void scatter_elem_update_with_reduction(const DataType* input_data, - const IndicesType* indices, - const DataType* updates, - const int64_t axis, - DataType* out_buf, - const Shape& data_shape, - const Shape& indices_shape, - const ov::op::v12::ScatterElementsUpdate::Reduction reduction_type, - const bool use_init_val); - -template -void scatter_elem_update(const DataType* input_data, - const IndicesType* indices, - const DataType* updates, - const int64_t axis, - DataType* out_buf, - const Shape& data_shape, - const Shape& indices_shape, - const Reduction reduction_type = Reduction::NONE, - const bool use_init_val = true) { - // Copy inputs to out - std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); - - if (reduction_type != Reduction::NONE) { - scatter_elem_update_with_reduction(input_data, - indices, - updates, - axis, - out_buf, - data_shape, - indices_shape, - reduction_type, - use_init_val); - return; - } - +namespace { +void scatter_elem_update_no_reduction(const size_t data_elem_size, + const int64_t* indices, + const char* updates, + const int64_t axis, + char* out_buf, + const Shape& data_shape, + const Shape& indices_shape, + const Reduction reduction_type, + const bool use_init_val) { // 3D example // output[indices[i][j][k]][j][k] = updates[i][j][k] if axis = 0, // output[i][indices[i][j][k]][k] = updates[i][j][k] if axis = 1, @@ -78,10 +52,11 @@ void scatter_elem_update(const DataType* input_data, std::inner_product(indices_cord.begin(), indices_cord.end(), indices_strides.begin(), uint64_t(0)); Coordinate out_cord(indices_cord); out_cord.at(axis) = normalize_index(indices[indices_idx], data_shape[axis]); - const auto out_idx = std::inner_product(out_cord.begin(), out_cord.end(), data_strides.begin(), uint64_t(0)); - out_buf[out_idx] = updates[indices_idx]; + const size_t out_idx = ov::coordinate_offset(out_cord, data_strides); + std::memcpy(out_buf + out_idx * data_elem_size, updates + indices_idx * data_elem_size, data_elem_size); } } +} // namespace template T reduction_neutral_value(const Reduction reduction_type) { @@ -97,7 +72,6 @@ T reduction_neutral_value(const Reduction reduction_type) { return T{0}; default: OPENVINO_THROW("Neutral value not available for this type of reduction"); - return 0; } } @@ -119,7 +93,6 @@ std::function reduction_functor_for(const Reduction reducti return std::plus{}; default: OPENVINO_THROW("No functor available for this type of reduction"); - return 0; } } @@ -144,7 +117,6 @@ std::function reduction_functor_for(const Re }; default: OPENVINO_THROW("No functor available for this type of reduction"); - return 0; } } @@ -180,9 +152,8 @@ struct RoundingDirectionGuard { decltype(std::fegetround()) m_original_mode; }; -template -void scatter_elem_update_with_reduction(const DataType* input_data, - const IndicesType* indices, +template +void scatter_elem_update_with_reduction(const int64_t* indices, const DataType* updates, const int64_t axis, DataType* out_buf, @@ -247,5 +218,53 @@ void scatter_elem_update_with_reduction(const DataType* input_data, } } } + +template +const OutType* convert_indices(const InType* indices, const size_t indices_count, std::vector& buffer) { + if (std::is_same::type, OutType>::value) + return reinterpret_cast(indices); + + buffer.resize(indices_count); + for (auto i = indices_count; i-- > 0;) + buffer[i] = indices[i]; + return buffer.data(); +} + +template +void scatter_elem_update(const DataType* input_data, + const IndicesType* indices, + const DataType* updates, + const int64_t axis, + DataType* out_buf, + const Shape& data_shape, + const Shape& indices_shape, + const Reduction reduction_type = Reduction::NONE, + const bool use_init_val = true) { + std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape)); + + std::vector buffer; + const auto indices_i64 = convert_indices(indices, shape_size(indices_shape), buffer); + + if (reduction_type != Reduction::NONE) { + scatter_elem_update_with_reduction(indices_i64, + updates, + axis, + out_buf, + data_shape, + indices_shape, + reduction_type, + use_init_val); + } else { + scatter_elem_update_no_reduction(sizeof(DataType), + indices_i64, + reinterpret_cast(updates), + axis, + reinterpret_cast(out_buf), + data_shape, + indices_shape, + reduction_type, + use_init_val); + } +} } // namespace reference } // namespace ov diff --git a/src/plugins/template/tests/functional/op_reference/scatter_elements_update.cpp b/src/plugins/template/tests/functional/op_reference/scatter_elements_update.cpp index 48db3024d0b45f..ca908f55f5f7a1 100644 --- a/src/plugins/template/tests/functional/op_reference/scatter_elements_update.cpp +++ b/src/plugins/template/tests/functional/op_reference/scatter_elements_update.cpp @@ -12,36 +12,44 @@ using namespace reference_tests; using namespace ov; namespace { +using Reduction = ov::op::v12::ScatterElementsUpdate::Reduction; + struct ScatterElementsUpdateParams { - ScatterElementsUpdateParams(const reference_tests::Tensor& paramData, - const reference_tests::Tensor& paramIndices, - const reference_tests::Tensor& paramUpdates, - const reference_tests::Tensor& paramAxis, - const reference_tests::Tensor& paramExpected) - : input(paramData), - indices(paramIndices), - updates(paramUpdates), - axis(paramAxis), - expected(paramExpected) {} - - reference_tests::Tensor input; - reference_tests::Tensor indices; - reference_tests::Tensor updates; - reference_tests::Tensor axis; - reference_tests::Tensor expected; + ScatterElementsUpdateParams(reference_tests::Tensor paramData, + reference_tests::Tensor paramIndices, + reference_tests::Tensor paramUpdates, + reference_tests::Tensor paramAxis, + reference_tests::Tensor paramExpected, + const Reduction paramReduction = Reduction::NONE, + const bool paramUseInitValue = true) + : input{std::move(paramData)}, + indices{std::move(paramIndices)}, + updates{std::move(paramUpdates)}, + axis{std::move(paramAxis)}, + expected{std::move(paramExpected)}, + reduction{paramReduction}, + use_init_value{paramUseInitValue} {} + + const reference_tests::Tensor input; + const reference_tests::Tensor indices; + const reference_tests::Tensor updates; + const reference_tests::Tensor axis; + const reference_tests::Tensor expected; + const Reduction reduction; + const bool use_init_value; }; -class ReferenceScatterElementsUpdateLayerTest : public testing::TestWithParam, - public CommonReferenceTest { +class ReferenceScatterElementsUpdateV3LayerTest : public testing::TestWithParam, + public CommonReferenceTest { public: void SetUp() override { - auto params = GetParam(); + const auto& params = GetParam(); function = CreateFunction(params); inputData = {params.input.data, params.indices.data, params.updates.data, params.axis.data}; refOutData = {params.expected.data}; } static std::string getTestCaseName(const testing::TestParamInfo& obj) { - auto param = obj.param; + const auto& param = obj.param; std::ostringstream result; result << "data_sh=" << param.input.shape; result << "_data_pr=" << param.input.type; @@ -65,7 +73,54 @@ class ReferenceScatterElementsUpdateLayerTest : public testing::TestWithParam, + public CommonReferenceTest { +public: + void SetUp() override { + const auto& params = GetParam(); + function = CreateFunction(params); + inputData = {params.input.data, params.indices.data, params.updates.data, params.axis.data}; + refOutData = {params.expected.data}; + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + static std::map reduction_as_string = { + {Reduction::NONE, "none"}, + {Reduction::SUM, "sum"}, + {Reduction::PROD, "prod"}, + {Reduction::MIN, "min"}, + {Reduction::MAX, "max"}, + {Reduction::MEAN, "mean"}, + }; + const auto& param = obj.param; + std::ostringstream result; + result << ReferenceScatterElementsUpdateV3LayerTest::getTestCaseName(obj); + result << "_reduction=" << reduction_as_string[param.reduction]; + result << "_use_init_value=" << std::boolalpha << param.use_init_value; + return result.str(); + } + +private: + static std::shared_ptr CreateFunction(const ScatterElementsUpdateParams& params) { + const auto data = std::make_shared(params.input.type, params.input.shape); + const auto indices = std::make_shared(params.indices.type, params.indices.shape); + const auto updates = std::make_shared(params.updates.type, params.updates.shape); + const auto axis = std::make_shared(params.axis.type, params.axis.shape); + auto scatter_eu = std::make_shared(data, + indices, + updates, + axis, + params.reduction, + params.use_init_value); + return std::make_shared(NodeVector{scatter_eu}, ParameterVector{data, indices, updates, axis}); + } +}; + +TEST_P(ReferenceScatterElementsUpdateV3LayerTest, CompareWithHardcodedRefs) { + Exec(); +} + +TEST_P(ReferenceScatterElementsUpdateV12LayerTest, CompareWithHardcodedRefs) { Exec(); } @@ -159,13 +214,152 @@ std::vector generateScatterCombinedParams() { generateScatterParams(), }; std::vector combinedParams; - for (const auto& params : scatterTypeParams) { - combinedParams.insert(combinedParams.end(), params.begin(), params.end()); + for (const auto& param : scatterTypeParams) { + std::move(param.begin(), param.end(), std::back_inserter(combinedParams)); } return combinedParams; } -INSTANTIATE_TEST_SUITE_P(smoke_ScatterEltsUpdate_With_Hardcoded_Refs, - ReferenceScatterElementsUpdateLayerTest, + +template ::value>::type* = nullptr> +Indices_t norm(int i, int d) { + return static_cast(i); +} +template ::value>::type* = nullptr> +Indices_t norm(int i, int d) { + return static_cast(i < 0 ? i + d : i); +} + +template +std::vector generate_scatter_eu_v12_params() { + using Data_t = typename element_type_traits::value_type; + using Indices_t = typename element_type_traits::value_type; + return { + {{Shape{3, 2}, element::Type(DATA_ET), std::vector{11, 12, 13, 14, 15, 16}}, // data + {Shape{1, 2}, element::Type(INDICES_ET), std::vector{norm(-1, 3), 1}}, // indices + {Shape{1, 2}, element::Type(DATA_ET), std::vector{5, 24}}, // updates + {Shape{1}, element::Type(INDICES_ET), std::vector{0}}, // axis + {Shape{3, 2}, element::Type(DATA_ET), std::vector{11, 12, 13, 24, 15, 16}}, // expected + Reduction::MAX, + true}, + {{Shape{2, 3}, element::Type(DATA_ET), std::vector{11, 12, 13, 14, 15, 16}}, + {Shape{2, 2}, element::Type(INDICES_ET), std::vector{norm(-3, 3), 1, 0, 2}}, + {Shape{2, 2}, element::Type(DATA_ET), std::vector{1, 22, 24, 6}}, + {Shape{1}, element::Type(INDICES_ET), std::vector{1}}, + {Shape{2, 3}, element::Type(DATA_ET), std::vector{1, 22, 13, 24, 15, 6}}, + Reduction::MIN, + false}, + {{Shape{1, 2, 3}, element::Type(DATA_ET), std::vector{11, 12, 13, 14, 15, 16}}, + {Shape{1, 1, 4}, element::Type(INDICES_ET), std::vector{0, 1, 0, 2}}, + {Shape{1, 1, 4}, element::Type(DATA_ET), std::vector{23, 38, 32, 7}}, + {Shape{1}, element::Type(INDICES_ET), std::vector{2}}, + {Shape{1, 2, 3}, element::Type(DATA_ET), std::vector{22, 25, 10, 14, 15, 16}}, + Reduction::MEAN, + true}, + {{Shape{1, 2, 3}, element::Type(DATA_ET), std::vector{11, 12, 13, 14, 15, 16}}, + {Shape{1, 1, 4}, element::Type(INDICES_ET), std::vector{0, 1, 0, 0}}, + {Shape{1, 1, 4}, element::Type(DATA_ET), std::vector{20, 33, 26, 29}}, + {Shape{1}, element::Type(INDICES_ET), std::vector{2}}, + {Shape{1, 2, 3}, element::Type(DATA_ET), std::vector{25, 33, 13, 14, 15, 16}}, + Reduction::MEAN, + false}, + {{Shape{2, 2, 1}, element::Type(DATA_ET), std::vector{1, 2, 3, 4}}, + {Shape{1, 5, 1}, element::Type(INDICES_ET), std::vector{0, 0, 1, 1, 1}}, + {Shape{1, 5, 1}, element::Type(DATA_ET), std::vector{50, 51, 10, 20, 30}}, + {Shape{1}, element::Type(INDICES_ET), std::vector{1}}, + {Shape{2, 2, 1}, element::Type(DATA_ET), std::vector{101, 60, 3, 4}}, + Reduction::SUM, + false}, + {{Shape{3, 2}, element::Type(DATA_ET), std::vector{1, 2, 3, 4, 5, 6}}, + {Shape{4, 1}, element::Type(INDICES_ET), std::vector{0, 0, 1, 2}}, + {Shape{4, 1}, element::Type(DATA_ET), std::vector{7, 7, 10, 5}}, + {Shape{1}, element::Type(INDICES_ET), std::vector{0}}, + {Shape{3, 2}, element::Type(DATA_ET), std::vector{49, 2, 30, 4, 25, 6}}, + Reduction::PROD, + true}, + }; +} + +std::vector collect_scatter_eu_v12_params() { + const std::vector> params{ + // i16 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // i32 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // i64 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // u32 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // u64 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // f16 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + // f32 + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + generate_scatter_eu_v12_params(), + }; + + auto combined_params = generateScatterCombinedParams(); + for (const auto& param : params) { + std::move(param.begin(), param.end(), std::back_inserter(combined_params)); + } + return combined_params; +} + +INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate, + ReferenceScatterElementsUpdateV3LayerTest, ::testing::ValuesIn(generateScatterCombinedParams()), - ReferenceScatterElementsUpdateLayerTest::getTestCaseName); + ReferenceScatterElementsUpdateV3LayerTest::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_ScatterElementsUpdate, + ReferenceScatterElementsUpdateV12LayerTest, + ::testing::ValuesIn(collect_scatter_eu_v12_params()), + ReferenceScatterElementsUpdateV12LayerTest::getTestCaseName); } // namespace