diff --git a/cub/test/bfloat16.h b/cub/test/bfloat16.h index dbd735db83..328fb644a5 100644 --- a/cub/test/bfloat16.h +++ b/cub/test/bfloat16.h @@ -32,11 +32,14 @@ * Utilities for interacting with the opaque CUDA __nv_bfloat16 type */ -#include +#include + #include -#include -#include +#include + +#include +#include #ifdef __GNUC__ // There's a ton of type-punning going on in this file. @@ -77,6 +80,16 @@ struct bfloat16_t *this = bfloat16_t(float(a)); } + /// Constructor from unsigned long long int + template < typename T, + typename = typename ::cuda::std::enable_if< + ::cuda::std::is_same::value + && (!::cuda::std::is_same::value)>::type> + __host__ __device__ __forceinline__ bfloat16_t(T a) + { + *this = bfloat16_t(float(a)); + } + /// Default constructor bfloat16_t() = default; diff --git a/cub/test/c2h/generators.cu b/cub/test/c2h/generators.cu index 9e0f318811..67bf81e558 100644 --- a/cub/test/c2h/generators.cu +++ b/cub/test/c2h/generators.cu @@ -259,16 +259,17 @@ void generator_t::operator()(seed_t seed, thrust::device_vector &data, T min, template struct count_to_item_t { - std::size_t n; + unsigned long long int n; - count_to_item_t(std::size_t n) + count_to_item_t(unsigned long long int n) : n(n) {} template __device__ T operator()(CounterT id) { - return static_cast(static_cast(id) % n); + // This has to be a type for which extended floating point types like __nv_fp8_e5m2 provide an overload + return static_cast(static_cast(id) % n); } }; diff --git a/cub/test/half.h b/cub/test/half.h index a009049cc7..74e507a57c 100644 --- a/cub/test/half.h +++ b/cub/test/half.h @@ -37,6 +37,8 @@ #include +#include + #include #include #include @@ -80,6 +82,16 @@ struct half_t *this = half_t(float(a)); } + /// Constructor from unsigned long long int + template < typename T, + typename = typename ::cuda::std::enable_if< + ::cuda::std::is_same::value + && (!::cuda::std::is_same::value)>::type> + __host__ __device__ __forceinline__ half_t(T a) + { + *this = half_t(float(a)); + } + /// Default constructor half_t() = default;