Skip to content

Commit

Permalink
Fixes compiler error for extended fp type data gen (#666)
Browse files Browse the repository at this point in the history
* fixes ctor compiler error for extended fp type data gen

* fixes msvc ctor invalid redeclaration
  • Loading branch information
elstehle authored Nov 7, 2023
1 parent a9d7d8e commit 591dc78
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
19 changes: 16 additions & 3 deletions cub/test/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
* Utilities for interacting with the opaque CUDA __nv_bfloat16 type
*/

#include <stdint.h>
#include <cub/util_type.cuh>

#include <cuda_bf16.h>
#include <iosfwd>

#include <cub/util_type.cuh>
#include <cuda/std/type_traits>

#include <cstdint>
#include <iosfwd>

#ifdef __GNUC__
// There's a ton of type-punning going on in this file.
Expand Down Expand Up @@ -77,6 +80,16 @@ struct bfloat16_t
*this = bfloat16_t(float(a));
}

/// Constructor from unsigned long long int
template < typename T,
typename = typename ::cuda::std::enable_if<
::cuda::std::is_same<T, unsigned long long int>::value
&& (!::cuda::std::is_same<std::size_t, unsigned long long int>::value)>::type>
__host__ __device__ __forceinline__ bfloat16_t(T a)
{
*this = bfloat16_t(float(a));
}

/// Default constructor
bfloat16_t() = default;

Expand Down
7 changes: 4 additions & 3 deletions cub/test/c2h/generators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,17 @@ void generator_t::operator()(seed_t seed, thrust::device_vector<T> &data, T min,
template <typename T>
struct count_to_item_t
{
std::size_t n;
unsigned long long int n;

count_to_item_t(std::size_t n)
count_to_item_t(unsigned long long int n)
: n(n)
{}

template <typename CounterT>
__device__ T operator()(CounterT id)
{
return static_cast<T>(static_cast<std::size_t>(id) % n);
// This has to be a type for which extended floating point types like __nv_fp8_e5m2 provide an overload
return static_cast<T>(static_cast<unsigned long long int>(id) % n);
}
};

Expand Down
12 changes: 12 additions & 0 deletions cub/test/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

#include <cuda_fp16.h>

#include <cuda/std/type_traits>

#include <cstdint>
#include <cstring>
#include <iosfwd>
Expand Down Expand Up @@ -80,6 +82,16 @@ struct half_t
*this = half_t(float(a));
}

/// Constructor from unsigned long long int
template < typename T,
typename = typename ::cuda::std::enable_if<
::cuda::std::is_same<T, unsigned long long int>::value
&& (!::cuda::std::is_same<std::size_t, unsigned long long int>::value)>::type>
__host__ __device__ __forceinline__ half_t(T a)
{
*this = half_t(float(a));
}

/// Default constructor
half_t() = default;

Expand Down

0 comments on commit 591dc78

Please sign in to comment.