Skip to content

Commit

Permalink
Use the float flavors of the cmath functions in the extended floating…
Browse files Browse the repository at this point in the history
… point fallbacks (#2106)

Fixes #2078
  • Loading branch information
miscco authored Jul 31, 2024
1 parent a2a3824 commit bddcd20
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
18 changes: 9 additions & 9 deletions libcudacxx/include/cuda/std/__cuda/cmath_nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,47 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
// trigonometric functions
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sin(__bfloat162float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sinf(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
{
return __float2bfloat16(::sinh(__bfloat162float(__v)));
return __float2bfloat16(::sinhf(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cos(__bfloat162float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cosf(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
{
return __float2bfloat16(::cosh(__bfloat162float(__v)));
return __float2bfloat16(::coshf(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::exp(__bfloat162float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::expf(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __float2bfloat16(::hypot(__bfloat162float(__x), __bfloat162float(__y)));
return __float2bfloat16(::hypotf(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __float2bfloat16(::atan2(__bfloat162float(__x), __bfloat162float(__y)));
return __float2bfloat16(::atan2f(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::log(__bfloat162float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::logf(__bfloat162float(__x)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrt(__bfloat162float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrtf(__bfloat162float(__x)));))
}

// floating point helper
Expand Down
18 changes: 9 additions & 9 deletions libcudacxx/include/cuda/std/__cuda/cmath_nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half sin(__half __v)
{
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (return ::hsin(__v);), ({
float __vf = __half2float(__v);
__vf = ::sin(__vf);
__vf = ::sinf(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);

uint16_t __repr = __half_raw(__v).x;
Expand All @@ -61,7 +61,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half sin(__half __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __half sinh(__half __v)
{
return __float2half(::sinh(__half2float(__v)));
return __float2half(::sinhf(__half2float(__v)));
}

// clang-format off
Expand All @@ -72,7 +72,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v)
), (
{
float __vf = __half2float(__v);
__vf = ::cos(__vf);
__vf = ::cosf(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);

uint16_t __repr = __half_raw(__v).x;
Expand All @@ -94,7 +94,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __half cosh(__half __v)
{
return __float2half(::cosh(__half2float(__v)));
return __float2half(::coshf(__half2float(__v)));
}

// clang-format off
Expand All @@ -105,7 +105,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v)
), (
{
float __vf = __half2float(__v);
__vf = ::exp(__vf);
__vf = ::expf(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);

uint16_t __repr = __half_raw(__v).x;
Expand All @@ -127,12 +127,12 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __half hypot(__half __x, __half __y)
{
return __float2half(::hypot(__half2float(__x), __half2float(__y)));
return __float2half(::hypotf(__half2float(__x), __half2float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half atan2(__half __x, __half __y)
{
return __float2half(::atan2(__half2float(__x), __half2float(__y)));
return __float2half(::atan2f(__half2float(__x), __half2float(__y)));
}

// clang-format off
Expand All @@ -143,7 +143,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x)
), (
{
float __vf = __half2float(__x);
__vf = ::log(__vf);
__vf = ::logf(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);

uint16_t __repr = __half_raw(__x).x;
Expand All @@ -164,7 +164,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x)

inline _LIBCUDACXX_INLINE_VISIBILITY __half sqrt(__half __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(::sqrt(__half2float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(::sqrtf(__half2float(__x)));))
}

// floating point helper
Expand Down

0 comments on commit bddcd20

Please sign in to comment.