Skip to content

Commit

Permalink
[core] Optimize ScatterElementsUpdate reference implementation binary…
Browse files Browse the repository at this point in the history
… size (openvinotoolkit#23146)

### Details:
- Adds tests into `ov_template_func_tests` for ScatterElementsUpdate
version 12.
- Removes Indices type from template parameters for internal template
function - all works on `int64_t`.
 - Uses `std::memcpy` instead of type dependent assignment.

### Tickets:
 - CVS-119213
  • Loading branch information
t-jankowski authored Mar 5, 2024
1 parent 4f5c2a7 commit d093c74
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "openvino/core/except.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/reference/utils/coordinate_index.hpp"
#include "openvino/reference/utils/coordinate_transform.hpp"

namespace ov {
Expand All @@ -26,43 +27,16 @@ size_t normalize_index(const T idx, const size_t dim_value) {
}
}

template <typename DataType, typename IndicesType>
void scatter_elem_update_with_reduction(const DataType* input_data,
const IndicesType* indices,
const DataType* updates,
const int64_t axis,
DataType* out_buf,
const Shape& data_shape,
const Shape& indices_shape,
const ov::op::v12::ScatterElementsUpdate::Reduction reduction_type,
const bool use_init_val);

template <typename DataType, typename IndicesType>
void scatter_elem_update(const DataType* input_data,
const IndicesType* indices,
const DataType* updates,
const int64_t axis,
DataType* out_buf,
const Shape& data_shape,
const Shape& indices_shape,
const Reduction reduction_type = Reduction::NONE,
const bool use_init_val = true) {
// Copy inputs to out
std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape));

if (reduction_type != Reduction::NONE) {
scatter_elem_update_with_reduction(input_data,
indices,
updates,
axis,
out_buf,
data_shape,
indices_shape,
reduction_type,
use_init_val);
return;
}

namespace {
void scatter_elem_update_no_reduction(const size_t data_elem_size,
const int64_t* indices,
const char* updates,
const int64_t axis,
char* out_buf,
const Shape& data_shape,
const Shape& indices_shape,
const Reduction reduction_type,
const bool use_init_val) {
// 3D example
// output[indices[i][j][k]][j][k] = updates[i][j][k] if axis = 0,
// output[i][indices[i][j][k]][k] = updates[i][j][k] if axis = 1,
Expand All @@ -78,10 +52,11 @@ void scatter_elem_update(const DataType* input_data,
std::inner_product(indices_cord.begin(), indices_cord.end(), indices_strides.begin(), uint64_t(0));
Coordinate out_cord(indices_cord);
out_cord.at(axis) = normalize_index(indices[indices_idx], data_shape[axis]);
const auto out_idx = std::inner_product(out_cord.begin(), out_cord.end(), data_strides.begin(), uint64_t(0));
out_buf[out_idx] = updates[indices_idx];
const size_t out_idx = ov::coordinate_offset(out_cord, data_strides);
std::memcpy(out_buf + out_idx * data_elem_size, updates + indices_idx * data_elem_size, data_elem_size);
}
}
} // namespace

template <typename T>
T reduction_neutral_value(const Reduction reduction_type) {
Expand All @@ -97,7 +72,6 @@ T reduction_neutral_value(const Reduction reduction_type) {
return T{0};
default:
OPENVINO_THROW("Neutral value not available for this type of reduction");
return 0;
}
}

Expand All @@ -119,7 +93,6 @@ std::function<T(const T, const T)> reduction_functor_for(const Reduction reducti
return std::plus<T>{};
default:
OPENVINO_THROW("No functor available for this type of reduction");
return 0;
}
}

Expand All @@ -144,7 +117,6 @@ std::function<char(const char, const char)> reduction_functor_for<char>(const Re
};
default:
OPENVINO_THROW("No functor available for this type of reduction");
return 0;
}
}

Expand Down Expand Up @@ -180,9 +152,8 @@ struct RoundingDirectionGuard {
decltype(std::fegetround()) m_original_mode;
};

template <typename DataType, typename IndicesType>
void scatter_elem_update_with_reduction(const DataType* input_data,
const IndicesType* indices,
template <typename DataType>
void scatter_elem_update_with_reduction(const int64_t* indices,
const DataType* updates,
const int64_t axis,
DataType* out_buf,
Expand Down Expand Up @@ -247,5 +218,53 @@ void scatter_elem_update_with_reduction(const DataType* input_data,
}
}
}

template <typename InType, typename OutType>
const OutType* convert_indices(const InType* indices, const size_t indices_count, std::vector<OutType>& buffer) {
if (std::is_same<typename std::decay<InType>::type, OutType>::value)
return reinterpret_cast<const OutType*>(indices);

buffer.resize(indices_count);
for (auto i = indices_count; i-- > 0;)
buffer[i] = indices[i];
return buffer.data();
}

template <typename DataType, typename IndicesType>
void scatter_elem_update(const DataType* input_data,
const IndicesType* indices,
const DataType* updates,
const int64_t axis,
DataType* out_buf,
const Shape& data_shape,
const Shape& indices_shape,
const Reduction reduction_type = Reduction::NONE,
const bool use_init_val = true) {
std::memcpy(out_buf, input_data, sizeof(DataType) * shape_size(data_shape));

std::vector<int64_t> buffer;
const auto indices_i64 = convert_indices(indices, shape_size(indices_shape), buffer);

if (reduction_type != Reduction::NONE) {
scatter_elem_update_with_reduction(indices_i64,
updates,
axis,
out_buf,
data_shape,
indices_shape,
reduction_type,
use_init_val);
} else {
scatter_elem_update_no_reduction(sizeof(DataType),
indices_i64,
reinterpret_cast<const char*>(updates),
axis,
reinterpret_cast<char*>(out_buf),
data_shape,
indices_shape,
reduction_type,
use_init_val);
}
}
} // namespace reference
} // namespace ov
Loading

0 comments on commit d093c74

Please sign in to comment.