Skip to content

Commit

Permalink
Fix issues that came up with building cuDF with main (#1643)
Browse files Browse the repository at this point in the history
* Move `vsmem` helper into their own file

* Add missing includes of `cuda::std::min` and `cuda::std::max` to functional until we get `<algorithm>`

* Add missing include of `<cuda_runtime_api.h>` to `cuda_pinned_memory_resource`

* Mark `_CCCL_FORCEINLINE` as inline on host

* Avoid copying output iterators in `thrust::copy_if`

* Try to ensure that `thrust::tuple` and `thrust::pair` work with CTAD

* Add workaround for MSVC2017
  • Loading branch information
miscco committed May 6, 2024
1 parent 831f0e9 commit e485ff5
Show file tree
Hide file tree
Showing 21 changed files with 529 additions and 276 deletions.
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_namespace.cuh>
#include <cub/util_vsmem.cuh>

#include <thrust/detail/integer_math.h>
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include <cub/util_deprecated.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_vsmem.cuh>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

Expand Down
1 change: 1 addition & 0 deletions cub/cub/device/dispatch/dispatch_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <cub/util_deprecated.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_vsmem.cuh>

#include <iterator>

Expand Down
196 changes: 0 additions & 196 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@

#include <cub/detail/device_synchronize.cuh>
#include <cub/util_debug.cuh>
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
// for backward compatibility
#include <cub/util_temporary_storage.cuh>

#include <cuda/discard_memory>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

Expand All @@ -70,7 +68,6 @@ CUB_NAMESPACE_BEGIN

namespace detail
{

/**
* @brief Helper class template that allows overwriting the `BLOCK_THREAD` and `ITEMS_PER_THREAD`
* configurations of a given policy.
Expand All @@ -82,199 +79,6 @@ struct policy_wrapper_t : PolicyT
static constexpr int BLOCK_THREADS = BLOCK_THREADS_;
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
};

/**
* @brief Helper struct to wrap all the information needed to implement virtual shared memory that's passed to a kernel.
*
*/
struct vsmem_t
{
void* gmem_ptr;
};

// The maximum amount of static shared memory available per thread block
// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB
static constexpr std::size_t max_smem_per_block = 48 * 1024;

/**
* @brief Class template that helps to prevent exceeding the available shared memory per thread block.
*
* @tparam AgentT The agent for which we check whether per-thread block shared memory is sufficient or whether virtual
* shared memory is needed.
*/
template <typename AgentT>
class vsmem_helper_impl
{
private:
// Per-block virtual shared memory may be padded to make sure vsmem is an integer multiple of `line_size`
static constexpr std::size_t line_size = 128;

// The amount of shared memory or virtual shared memory required by the algorithm's agent
static constexpr std::size_t required_smem = sizeof(typename AgentT::TempStorage);

// Whether we need to allocate global memory-backed virtual shared memory
static constexpr bool needs_vsmem = required_smem > max_smem_per_block;

// Padding bytes to an integer multiple of `line_size`. Only applies to virtual shared memory
static constexpr std::size_t padding_bytes =
(required_smem % line_size == 0) ? 0 : (line_size - (required_smem % line_size));

public:
// Type alias to be used for static temporary storage declaration within the algorithm's kernel
using static_temp_storage_t = cub::detail::conditional_t<needs_vsmem, cub::NullType, typename AgentT::TempStorage>;

// The amount of global memory-backed virtual shared memory needed, padded to an integer multiple of 128 bytes
static constexpr std::size_t vsmem_per_block = needs_vsmem ? (required_smem + padding_bytes) : 0;

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we can use native shared memory as temporary
* storage.
*/
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&)
{
return static_temp_storage;
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we can use native shared memory as temporary
* storage and taking a linear block id.
*/
static __device__ __forceinline__ typename AgentT::TempStorage&
get_temp_storage(typename AgentT::TempStorage& static_temp_storage, vsmem_t&, std::size_t)
{
return static_temp_storage;
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we have to use global memory-backed
* virtual shared memory as temporary storage.
*/
static _CCCL_DEVICE _CCCL_FORCEINLINE typename AgentT::TempStorage&
get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem)
{
return *reinterpret_cast<typename AgentT::TempStorage*>(
static_cast<char*>(vsmem.gmem_ptr) + (vsmem_per_block * blockIdx.x));
}

/**
* @brief Used from within the device algorithm's kernel to get the temporary storage that can be
* passed to the agent, specialized for the case when we have to use global memory-backed
* virtual shared memory as temporary storage and taking a linear block id.
*/
static __device__ __forceinline__ typename AgentT::TempStorage&
get_temp_storage(cub::NullType& static_temp_storage, vsmem_t& vsmem, std::size_t linear_block_id)
{
return *reinterpret_cast<typename AgentT::TempStorage*>(
static_cast<char*>(vsmem.gmem_ptr) + (vsmem_per_block * linear_block_id));
}

/**
* @brief Hints to discard modified cache lines of the used virtual shared memory.
* modified cache lines.
*
* @note Needs to be followed by `__syncthreads()` if the function returns true and the virtual shared memory is
* supposed to be reused after this function call.
*/
template <bool needs_vsmem_ = needs_vsmem, typename ::cuda::std::enable_if<!needs_vsmem_, int>::type = 0>
static _CCCL_DEVICE _CCCL_FORCEINLINE bool discard_temp_storage(typename AgentT::TempStorage& temp_storage)
{
return false;
}

/**
* @brief Hints to discard modified cache lines of the used virtual shared memory.
* modified cache lines.
*
* @note Needs to be followed by `__syncthreads()` if the function returns true and the virtual shared memory is
* supposed to be reused after this function call.
*/
template <bool needs_vsmem_ = needs_vsmem, typename ::cuda::std::enable_if<needs_vsmem_, int>::type = 0>
static _CCCL_DEVICE _CCCL_FORCEINLINE bool discard_temp_storage(typename AgentT::TempStorage& temp_storage)
{
// Ensure all threads finished using temporary storage
CTA_SYNC();

const std::size_t linear_tid = threadIdx.x;
const std::size_t block_stride = line_size * blockDim.x;

char* ptr = reinterpret_cast<char*>(&temp_storage);
auto ptr_end = ptr + vsmem_per_block;

// 128 byte-aligned virtual shared memory discard
for (auto thread_ptr = ptr + (linear_tid * line_size); thread_ptr < ptr_end; thread_ptr += block_stride)
{
cuda::discard_memory(thread_ptr, line_size);
}

return true;
}
};

template <class DefaultAgentT, class FallbackAgentT>
constexpr bool use_fallback_agent()
{
return (sizeof(typename DefaultAgentT::TempStorage) > max_smem_per_block)
&& (sizeof(typename FallbackAgentT::TempStorage) <= max_smem_per_block);
}

/**
* @brief Class template that helps to prevent exceeding the available shared memory per thread block with two measures:
* (1) If an agent's `TempStorage` declaration exceeds the maximum amount of shared memory per thread block, we check
* whether using a fallback policy, e.g., with a smaller tile size, would fit into shared memory.
* (2) If the fallback still doesn't fit into shared memory, we make use of virtual shared memory that is backed by
* global memory.
*
* @tparam DefaultAgentPolicyT The default tuning policy that is used if the default agent's shared memory requirements
* fall within the bounds of `max_smem_per_block` or when virtual shared memory is needed
* @tparam DefaultAgentT The default agent, instantiated with the given default tuning policy
* @tparam FallbackAgentPolicyT A fallback tuning policy that may exhibit lower shared memory requirements, e.g., by
* using a smaller tile size, than the default. This fallback policy is used if and only if the shared memory
* requirements of the default agent exceed `max_smem_per_block`, yet the shared memory requirements of the fallback
* agent falls within the bounds of `max_smem_per_block`.
* @tparam FallbackAgentT The fallback agent, instantiated with the given fallback tuning policy
*/
template <typename DefaultAgentPolicyT,
typename DefaultAgentT,
typename FallbackAgentPolicyT = DefaultAgentPolicyT,
typename FallbackAgentT = DefaultAgentT,
bool UseFallbackPolicy = use_fallback_agent<DefaultAgentT, FallbackAgentT>()>
struct vsmem_helper_with_fallback_impl : public vsmem_helper_impl<DefaultAgentT>
{
using agent_t = DefaultAgentT;
using agent_policy_t = DefaultAgentPolicyT;
};
template <typename DefaultAgentPolicyT, typename DefaultAgentT, typename FallbackAgentPolicyT, typename FallbackAgentT>
struct vsmem_helper_with_fallback_impl<DefaultAgentPolicyT, DefaultAgentT, FallbackAgentPolicyT, FallbackAgentT, true>
: public vsmem_helper_impl<FallbackAgentT>
{
using agent_t = FallbackAgentT;
using agent_policy_t = FallbackAgentPolicyT;
};

/**
* @brief Alias template for the `vsmem_helper_with_fallback_impl` that instantiates the given AgentT template with the
* respective policy as first template parameter, followed by the parameters captured by the `AgentParamsT` template
* parameter pack.
*/
template <typename DefaultPolicyT, typename FallbackPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
using vsmem_helper_fallback_policy_t =
vsmem_helper_with_fallback_impl<DefaultPolicyT,
AgentT<DefaultPolicyT, AgentParamsT...>,
FallbackPolicyT,
AgentT<FallbackPolicyT, AgentParamsT...>>;

/**
* @brief Alias template for the `vsmem_helper_t` by using a simple fallback policy that uses `DefaultPolicyT` as basis,
* overwriting `64` threads per block and `1` item per thread.
*/
template <typename DefaultPolicyT, template <typename...> class AgentT, typename... AgentParamsT>
using vsmem_helper_default_fallback_policy_t =
vsmem_helper_fallback_policy_t<DefaultPolicyT, policy_wrapper_t<DefaultPolicyT, 64, 1>, AgentT, AgentParamsT...>;

} // namespace detail

/**
Expand Down
Loading

0 comments on commit e485ff5

Please sign in to comment.