diff --git a/src/lib/cpp/include/general.hh b/src/lib/cpp/include/general.hh index a4990e4..ee1accd 100644 --- a/src/lib/cpp/include/general.hh +++ b/src/lib/cpp/include/general.hh @@ -13,6 +13,8 @@ #include "datatypes.hh" #include "boilerplate.hh" +#include + namespace NS { /** @@ -21,15 +23,37 @@ namespace NS { * * @param src The input array containing the non-negative integers. * @param dst The output array containing the counts. + * @tparam T The datatype of the input array. */ - inline void bincount(const input_ndarray &src, output_ndarray &dst) { + template + inline void bincount(const input_ndarray &src, output_ndarray &dst) { UNPACK_NUMPY(src); UNPACK_NUMPY(dst); - PRAGMA(PARALLEL_TERM) - for (int64_t flat_index = 0; flat_index < src_length; flat_index++) { - ATOMIC() - dst.data[src.data[flat_index]]++; + const T *src_data = src.data; + U *dst_data = dst.data; + + U *local_dsts[omp_get_max_threads()]; + + #pragma omp parallel + { + U *local_dst = (U *) calloc(dst_length, sizeof(U)); + local_dsts[omp_get_thread_num()] = local_dst; + + #pragma omp for schedule(static) + for (int64_t flat_index = 0; flat_index < src_length; flat_index++) { + local_dst[src_data[flat_index]]++; + } + } + + for (int64_t i = 0; i < dst_length; i++) { + for (int64_t j = 0; j < omp_get_max_threads(); j++) { + dst_data[i] += local_dsts[j][i]; + } + } + + for (int64_t i = 0; i < omp_get_max_threads(); i++) { + free(local_dsts[i]); } }