Skip to content

Commit

Permalink
alpaka::AtomicCas add floating point support
Browse files Browse the repository at this point in the history
  • Loading branch information
psychocoderHPC committed Jul 27, 2022
1 parent 7df99cd commit 883e629
Showing 1 changed file with 39 additions and 4 deletions.
43 changes: 39 additions & 4 deletions include/alpaka/atomic/Op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <alpaka/core/Common.hpp>

#include <algorithm>
#include <type_traits>
#include <variant>


namespace alpaka
{
Expand Down Expand Up @@ -112,7 +115,7 @@ namespace alpaka
{
auto const old = *addr;
auto& ref = *addr;
ref = ((old >= value) ? 0 : static_cast<T>(old + 1));
ref = ((old >= value) ? static_cast<T>(0) : static_cast<T>(old + static_cast<T>(1)));
return old;
}
};
Expand All @@ -128,7 +131,7 @@ namespace alpaka
{
auto const old = *addr;
auto& ref = *addr;
ref = (((old == 0) || (old > value)) ? value : static_cast<T>(old - 1));
ref = (((old == static_cast<T>(0)) || (old > value)) ? value : static_cast<T>(old - static_cast<T>(1)));
return old;
}
};
Expand Down Expand Up @@ -177,9 +180,14 @@ namespace alpaka
//! The compare and swap function object.
struct AtomicCas
{
//! \return The old value of addr.
ALPAKA_NO_HOST_ACC_WARNING
// allow reinterpreting any type up to 64bit in size for bitwise interpretation.
template<typename T>
using reinterpret = std::variant<T, std::conditional_t<sizeof(T) == 4u, unsigned int, unsigned long long>>;

//! AtomicCas for non floating point values
// \return The old value of addr.
ALPAKA_NO_HOST_ACC_WARNING
template<typename T, std::enable_if_t<!std::is_floating_point_v<T>, bool> = true>
ALPAKA_FN_HOST_ACC auto operator()(T* addr, T const& compare, T const& value) const -> T
{
auto const old = *addr;
Expand All @@ -191,9 +199,36 @@ namespace alpaka
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wstrict-overflow"
#endif
// check if values are bit-wise equal
ref = ((old == compare) ? value : old);
#if BOOST_COMP_GNUC && (BOOST_COMP_GNUC == BOOST_VERSION_NUMBER(7, 4, 0))
# pragma GCC diagnostic pop
#endif
return old;
}
//! AtomicCas for floating point values
// \return The old value of addr.
ALPAKA_NO_HOST_ACC_WARNING
template<typename T, std::enable_if_t<std::is_floating_point_v<T>, bool> = true>
ALPAKA_FN_HOST_ACC auto operator()(T* addr, T const& compare, T const& value) const -> T
{
static_assert(sizeof(T) == 4u || sizeof(T) == 8u, "AtomicCas is supporting only 32bit and 64bit values!");

auto const old = *addr;
auto& ref = *addr;

// gcc-7.4.0 assumes for an optimization that a signed overflow does not occur here.
// That's fine, so ignore that warning.
#if BOOST_COMP_GNUC && (BOOST_COMP_GNUC == BOOST_VERSION_NUMBER(7, 4, 0))
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wstrict-overflow"
#endif
// check if values are bit-wise equal
auto const o = std::get<1>(reinterpret<T>(std::in_place_index<1>, old));
auto const c = std::get<1>(reinterpret<T>(std::in_place_index<1>, compare));
ref = ((o == c) ? value : old);
#if BOOST_COMP_GNUC && (BOOST_COMP_GNUC == BOOST_VERSION_NUMBER(7, 4, 0))
# pragma GCC diagnostic pop
#endif
return old;
}
Expand Down

0 comments on commit 883e629

Please sign in to comment.