Skip to content

Commit

Permalink
Apply review feedback
Browse files Browse the repository at this point in the history
Cleaned up order of annotations and constexpr, removed __device__.
Moved to absolute includes.
Added no discard on all functions in detail namespace
Added c++17 ifdef.
Changed header guards to a new format applicable after I move some files
in a future change.
A couple of _LIBCUDACXX_UNREACHABLE and other fixes
  • Loading branch information
pciolkosz committed May 17, 2024
1 parent 84e9a83 commit f6e09b0
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 105 deletions.
26 changes: 14 additions & 12 deletions cudax/include/cuda/experimental/detail/dimensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX_DETAIL_DIMENSIONS
#define _CUDAX_DETAIL_DIMENSIONS
#ifndef _CUDAX__HIERARCHY_DIMENSIONS
#define _CUDAX__HIERARCHY_DIMENSIONS

#include <cuda/std/mdspan>

#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
{

Expand Down Expand Up @@ -66,7 +67,7 @@ struct hierarchy_query_result : public dimensions<T, Extents...>
const T y = Dims::rank() > 1 ? Dims::extent(1) : 1;
const T z = Dims::rank() > 2 ? Dims::extent(2) : 1;

constexpr _CCCL_HOST_DEVICE operator dim3() const
_CCCL_HOST_DEVICE constexpr operator dim3() const
{
return dim3(static_cast<uint32_t>(x), static_cast<uint32_t>(y), static_cast<uint32_t>(z));
}
Expand All @@ -75,7 +76,7 @@ struct hierarchy_query_result : public dimensions<T, Extents...>
namespace detail
{
template <typename OpType>
_CCCL_HOST_DEVICE constexpr size_t merge_extents(size_t e1, size_t e2)
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr size_t merge_extents(size_t e1, size_t e2)
{
if (e1 == ::cuda::std::dynamic_extent || e2 == ::cuda::std::dynamic_extent)
{
Expand All @@ -89,7 +90,7 @@ _CCCL_HOST_DEVICE constexpr size_t merge_extents(size_t e1, size_t e2)
}

template <typename DstType, typename OpType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_HOST_DEVICE constexpr auto
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_op(const OpType& op, const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
// For now target only 3 dim extents
Expand All @@ -101,26 +102,26 @@ dims_op(const OpType& op, const dimensions<T1, Extents1...>& h1, const dimension
}

template <typename DstType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_HOST_DEVICE constexpr auto
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_product(const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
return dims_op<DstType>(::cuda::std::multiplies(), h1, h2);
}

template <typename DstType, typename T1, size_t... Extents1, typename T2, size_t... Extents2>
_CCCL_HOST_DEVICE constexpr auto
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto
dims_sum(const dimensions<T1, Extents1...>& h1, const dimensions<T2, Extents2...>& h2) noexcept
{
return dims_op<DstType>(::cuda::std::plus(), h1, h2);
}

template <typename T, size_t... Extents>
_CCCL_HOST_DEVICE constexpr auto convert_to_query_result(const dimensions<T, Extents...>& result)
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto convert_to_query_result(const dimensions<T, Extents...>& result)
{
return hierarchy_query_result<T, Extents...>(result);
}

_CCCL_HOST_DEVICE constexpr auto dim3_to_dims(const dim3& dims)
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto dim3_to_dims(const dim3& dims)
{
return dimensions<dimensions_index_type,
::cuda::std::dynamic_extent,
Expand All @@ -129,13 +130,14 @@ _CCCL_HOST_DEVICE constexpr auto dim3_to_dims(const dim3& dims)
}

template <typename TyTrunc, typename Index, typename Dims>
__device__ constexpr auto index_to_linear(const Index& index, const Dims& dims)
_CCCL_NODISCARD _CCCL_HOST_DEVICE constexpr auto index_to_linear(const Index& index, const Dims& dims)
{
static_assert(Dims::rank() == 3);

return (index.extent(2) * dims.extent(1) + index.extent(1)) * dims.extent(0) + index.extent(0);
return (static_cast<TyTrunc>(index.extent(2)) * dims.extent(1) + index.extent(1)) * dims.extent(0) + index.extent(0);
}

} // namespace detail
} // namespace cuda::experimental
#endif
#endif // _CCCL_STD_VER >= 2017
#endif // _CUDAX__HIERARCHY_DIMENSIONS
Loading

0 comments on commit f6e09b0

Please sign in to comment.