Skip to content

Commit

Permalink
Updated lib.cpp.general.bincount to use reduction rather than atomics.
Browse files Browse the repository at this point in the history
  • Loading branch information
carljohnsen committed Oct 7, 2024
1 parent a0bfc2b commit 0f04b3c
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/lib/cpp/include/general.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "datatypes.hh"
#include "boilerplate.hh"

#include <omp.h>

namespace NS {

/**
Expand All @@ -21,15 +23,37 @@ namespace NS {
*
* @param src The input array containing the non-negative integers.
* @param dst The output array containing the counts.
* @tparam T The datatype of the input array.
*/
inline void bincount(const input_ndarray<uint64_t> &src, output_ndarray<uint64_t> &dst) {
template <typename T, typename U>
inline void bincount(const input_ndarray<T> &src, output_ndarray<U> &dst) {
UNPACK_NUMPY(src);
UNPACK_NUMPY(dst);

PRAGMA(PARALLEL_TERM)
for (int64_t flat_index = 0; flat_index < src_length; flat_index++) {
ATOMIC()
dst.data[src.data[flat_index]]++;
const T *src_data = src.data;
U *dst_data = dst.data;

U *local_dsts[omp_get_max_threads()];

#pragma omp parallel
{
U *local_dst = (U *) calloc(dst_length, sizeof(U));
local_dsts[omp_get_thread_num()] = local_dst;

#pragma omp for schedule(static)
for (int64_t flat_index = 0; flat_index < src_length; flat_index++) {
local_dst[src_data[flat_index]]++;
}
}

for (int64_t i = 0; i < dst_length; i++) {
for (int64_t j = 0; j < omp_get_max_threads(); j++) {
dst_data[i] += local_dsts[j][i];
}
}

for (int64_t i = 0; i < omp_get_max_threads(); i++) {
free(local_dsts[i]);
}
}

Expand Down

0 comments on commit 0f04b3c

Please sign in to comment.