Skip to content

Commit

Permalink
Implementation of correlate
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Nov 19, 2024
1 parent b0dc412 commit 0d6a4e0
Show file tree
Hide file tree
Showing 27 changed files with 1,732 additions and 478 deletions.
1 change: 0 additions & 1 deletion dpnp/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ set(DPNP_SRC
kernels/dpnp_krnl_mathematical.cpp
kernels/dpnp_krnl_random.cpp
kernels/dpnp_krnl_sorting.cpp
kernels/dpnp_krnl_statistics.cpp
src/constants.cpp
src/dpnp_iface_fptr.cpp
src/memory_sycl.cpp
Expand Down
5 changes: 4 additions & 1 deletion dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@

set(python_module_name _statistics_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validation_utils.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
98 changes: 98 additions & 0 deletions dpnp/backend/extensions/statistics/dispatch_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ using DTypePair = std::pair<DType, DType>;
using SupportedDTypeList = std::vector<DType>;
using SupportedDTypeList2 = std::vector<DTypePair>;

template <typename FnT,
typename SupportedTypes,
template <typename>
typename Func>
struct TableBuilder
{
template <typename _FnT, typename T>
struct impl
{
static constexpr bool is_defined = one_of_v<T, SupportedTypes>;

_FnT get()
{
if constexpr (is_defined) {
return Func<T>::impl;
}
else {
return nullptr;
}
}
};

using type =
dpctl_td_ns::DispatchVectorBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT,
typename SupportedTypes,
template <typename, typename>
Expand Down Expand Up @@ -124,6 +150,78 @@ struct TableBuilder2
dpctl_td_ns::DispatchTableBuilder<FnT, impl, dpctl_td_ns::num_types>;
};

template <typename FnT>
class DispatchTable
{
public:
DispatchTable(std::string name) : name(name) {}

template <typename SupportedTypes, template <typename> typename Func>
void populate_dispatch_table()
{
using TBulder = typename TableBuilder<FnT, SupportedTypes, Func>::type;
TBulder builder;

builder.populate_dispatch_vector(table);
populate_supported_types();
}

FnT get_unsafe(int _typenum) const
{
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int type_id = array_types.typenum_to_lookup_id(_typenum);

return table[type_id];
}

FnT get(int _typenum) const
{
auto fn = get_unsafe(_typenum);

if (fn == nullptr) {
auto array_types = dpctl_td_ns::usm_ndarray_types();
const int _type_id = array_types.typenum_to_lookup_id(_typenum);

py::dtype _dtype = dtype_from_typenum(_type_id);
auto _type_pos = std::find(supported_types.begin(),
supported_types.end(), _dtype);
if (_type_pos == supported_types.end()) {
py::str types = py::str(py::cast(supported_types));
py::str dtype = py::str(_dtype);

py::str err_msg =
py::str("'" + name + "' has unsupported type '") + dtype +
py::str("'."
" Supported types are: ") +
types;

throw py::value_error(static_cast<std::string>(err_msg));
}
}

return fn;
}

const SupportedDTypeList &get_all_supported_types() const
{
return supported_types;
}

private:
void populate_supported_types()
{
for (int i = 0; i < dpctl_td_ns::num_types; ++i) {
if (table[i] != nullptr) {
supported_types.emplace_back(dtype_from_typenum(i));
}
}
}

std::string name;
SupportedDTypeList supported_types;
Table<FnT> table;
};

template <typename FnT>
class DispatchTable2
{
Expand Down
140 changes: 40 additions & 100 deletions dpnp/backend/extensions/statistics/histogram_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

#include "histogram_common.hpp"

#include "validation_utils.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
using dpctl::tensor::usm_ndarray;
using dpctl_td_ns::typenum_t;
Expand All @@ -46,6 +48,15 @@ namespace statistics
{
using common::CeilDiv;

using validation::array_names;
using validation::array_ptr;

using validation::check_max_dims;
using validation::check_num_dims;
using validation::check_size_at_least;
using validation::common_checks;
using validation::name_of;

namespace histogram
{

Expand All @@ -55,11 +66,9 @@ void validate(const usm_ndarray &sample,
const usm_ndarray &histogram)
{
auto exec_q = sample.get_queue();
using array_ptr = const usm_ndarray *;

std::vector<array_ptr> arrays{&sample, &histogram};
std::unordered_map<array_ptr, std::string> names = {
{arrays[0], "sample"}, {arrays[1], "histogram"}};
array_names names = {{arrays[0], "sample"}, {arrays[1], "histogram"}};

array_ptr bins_ptr = nullptr;

Expand All @@ -77,117 +86,48 @@ void validate(const usm_ndarray &sample,
names.insert({weights_ptr, "weights"});
}

auto get_name = [&](const array_ptr &arr) {
auto name_it = names.find(arr);
assert(name_it != names.end());

return "'" + name_it->second + "'";
};

dpctl::tensor::validation::CheckWritable::throw_if_not_writable(histogram);

auto unequal_queue =
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
return arr->get_queue() != exec_q;
});

if (unequal_queue != arrays.cend()) {
throw py::value_error(
get_name(*unequal_queue) +
" parameter has incompatible queue with parameter " +
get_name(&sample));
}

auto non_contig_array =
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
return !arr->is_c_contiguous();
});
common_checks({&sample, bins.has_value() ? &bins.value() : nullptr,
weights.has_value() ? &weights.value() : nullptr},
{&histogram}, names);

if (non_contig_array != arrays.cend()) {
throw py::value_error(get_name(*non_contig_array) +
" parameter is not c-contiguos");
}
check_size_at_least(bins_ptr, 2, names);

auto check_overlaping = [&](const array_ptr &first,
const array_ptr &second) {
if (first == nullptr || second == nullptr) {
return;
}

const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();

if (overlap(*first, *second)) {
throw py::value_error(get_name(first) +
" has overlapping memory segments with " +
get_name(second));
}
};

check_overlaping(&sample, &histogram);
check_overlaping(bins_ptr, &histogram);
check_overlaping(weights_ptr, &histogram);

if (bins_ptr && bins_ptr->get_size() < 2) {
throw py::value_error(get_name(bins_ptr) +
" parameter must have at least 2 elements");
}

if (histogram.get_size() < 1) {
throw py::value_error(get_name(&histogram) +
" parameter must have at least 1 element");
}

if (histogram.get_ndim() != 1) {
throw py::value_error(get_name(&histogram) +
" parameter must be 1d. Actual " +
std::to_string(histogram.get_ndim()) + "d");
}
check_size_at_least(&histogram, 1, names);
check_num_dims(&histogram, 1, names);

if (weights_ptr) {
if (weights_ptr->get_ndim() != 1) {
throw py::value_error(
get_name(weights_ptr) + " parameter must be 1d. Actual " +
std::to_string(weights_ptr->get_ndim()) + "d");
}
check_num_dims(weights_ptr, 1, names);

auto sample_size = sample.get_size();
auto weights_size = weights_ptr->get_size();
if (sample.get_size() != weights_ptr->get_size()) {
throw py::value_error(
get_name(&sample) + " size (" + std::to_string(sample_size) +
") and " + get_name(weights_ptr) + " size (" +
std::to_string(weights_size) + ")" + " must match");
throw py::value_error(name_of(&sample, names) + " size (" +
std::to_string(sample_size) + ") and " +
name_of(weights_ptr, names) + " size (" +
std::to_string(weights_size) + ")" +
" must match");
}
}

if (sample.get_ndim() > 2) {
throw py::value_error(
get_name(&sample) +
" parameter must have no more than 2 dimensions. Actual " +
std::to_string(sample.get_ndim()) + "d");
}
check_max_dims(&sample, 2, names);

if (sample.get_ndim() == 1) {
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
get_name(bins_ptr) + " is " +
std::to_string(bins_ptr->get_ndim()) + "d");
}
check_num_dims(bins_ptr, 1, names);
}
else if (sample.get_ndim() == 2) {
auto sample_count = sample.get_shape(0);
auto expected_dims = sample.get_shape(1);

if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
throw py::value_error(get_name(&sample) + " parameter has shape {" +
std::to_string(sample_count) + "x" +
std::to_string(expected_dims) + "}" +
", so " + get_name(bins_ptr) +
" parameter expected to be " +
std::to_string(expected_dims) +
"d. "
"Actual " +
std::to_string(bins->get_ndim()) + "d");
throw py::value_error(
name_of(&sample, names) + " parameter has shape {" +
std::to_string(sample_count) + "x" +
std::to_string(expected_dims) + "}" + ", so " +
name_of(bins_ptr, names) + " parameter expected to be " +
std::to_string(expected_dims) +
"d. "
"Actual " +
std::to_string(bins->get_ndim()) + "d");
}
}

Expand All @@ -199,17 +139,17 @@ void validate(const usm_ndarray &sample,

if (histogram.get_size() != expected_hist_size) {
throw py::value_error(
get_name(&histogram) + " and " + get_name(bins_ptr) +
" shape mismatch. " + get_name(&histogram) +
" expected to have size = " +
name_of(&histogram, names) + " and " +
name_of(bins_ptr, names) + " shape mismatch. " +
name_of(&histogram, names) + " expected to have size = " +
std::to_string(expected_hist_size) + ". Actual " +
std::to_string(histogram.get_size()));
}
}

int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
if (histogram.get_size() > max_hist_size) {
throw py::value_error(get_name(&histogram) +
throw py::value_error(name_of(&histogram, names) +
" parameter size expected to be less than " +
std::to_string(max_hist_size) + ". Actual " +
std::to_string(histogram.get_size()));
Expand All @@ -225,7 +165,7 @@ void validate(const usm_ndarray &sample,
if (!_64bit_atomics) {
auto device_name = device.get_info<sycl::info::device::name>();
throw py::value_error(
get_name(&histogram) +
name_of(&histogram, names) +
" parameter has 64-bit type, but 64-bit atomics " +
" are not supported for " + device_name);
}
Expand Down
Loading

0 comments on commit 0d6a4e0

Please sign in to comment.