Skip to content

Commit c1d132f

Browse files
committed
Only use reciprocal functions on device
1 parent 0e5f524 commit c1d132f

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

include/kernel_float/meta.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
namespace kernel_float {
77

8+
using size_t = decltype(sizeof(int));
9+
810
template<size_t... Is>
911
struct index_sequence {
1012
static constexpr size_t size = sizeof...(Is);

include/kernel_float/unops.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,12 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
170170
KERNEL_FLOAT_DEFINE_UNARY_MATH(log1p)
171171

172172
KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
173-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
174173
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
175-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
176-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
177-
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
178174
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
179175
KERNEL_FLOAT_DEFINE_UNARY_MATH(tgamma)
180176

181177
KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
182-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt)
183178
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
184-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)
185179

186180
KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
187181
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
@@ -190,6 +184,13 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
190184
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
191185
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)
192186

187+
#if KERNEL_FLOAT_IS_DEVICE
188+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
189+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
190+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
191+
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
192+
#endif
193+
193194
// There are not support on HIP
194195
#if !KERNEL_FLOAT_IS_HIP
195196
KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan)
@@ -200,8 +201,12 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(isfinite)
200201
// CUDA offers special reciprocal functions (rcp), but only on the device.
201202
#if KERNEL_FLOAT_IS_DEVICE
202203
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, __drcp_rn(input), __frcp_rn(input))
204+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rsqrt, ::rsqrt(input), ::rsqrtf(input))
205+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcbrt, ::rcbrt(input), ::rcbrtf(input))
203206
#else
204207
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, 1.0 / input, 1.0f / input)
208+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rsqrt, 1.0 / ::sqrt(input), 1.0f / ::sqrtf(input))
209+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcbrt, 1.0 / ::cbrt(input), 1.0f / ::cbrtf(input))
205210
#endif
206211

207212
KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp)

single_include/kernel_float.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2025-09-15 16:14:44.345265
20-
// git hash: e824b62e2e7d40e70322cae48a0b652fbec3803c
19+
// date: 2025-10-14 16:18:28.846436
20+
// git hash: 0e5f52493c7b7027921243e813c434b6cd55e42b
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -122,6 +122,8 @@
122122

123123
namespace kernel_float {
124124

125+
using size_t = decltype(sizeof(int));
126+
125127
template<size_t... Is>
126128
struct index_sequence {
127129
static constexpr size_t size = sizeof...(Is);
@@ -1381,18 +1383,12 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(log10)
13811383
KERNEL_FLOAT_DEFINE_UNARY_MATH(log1p)
13821384

13831385
KERNEL_FLOAT_DEFINE_UNARY_MATH(erf)
1384-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
13851386
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfc)
1386-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
1387-
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
1388-
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
13891387
KERNEL_FLOAT_DEFINE_UNARY_MATH(lgamma)
13901388
KERNEL_FLOAT_DEFINE_UNARY_MATH(tgamma)
13911389

13921390
KERNEL_FLOAT_DEFINE_UNARY_MATH(sqrt)
1393-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt)
13941391
KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt)
1395-
KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt)
13961392

13971393
KERNEL_FLOAT_DEFINE_UNARY_MATH(abs)
13981394
KERNEL_FLOAT_DEFINE_UNARY_MATH(floor)
@@ -1401,6 +1397,13 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil)
14011397
KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc)
14021398
KERNEL_FLOAT_DEFINE_UNARY_MATH(rint)
14031399

1400+
#if KERNEL_FLOAT_IS_DEVICE
1401+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfinv)
1402+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcx)
1403+
KERNEL_FLOAT_DEFINE_UNARY_MATH(erfcinv)
1404+
KERNEL_FLOAT_DEFINE_UNARY_MATH(normcdf)
1405+
#endif
1406+
14041407
// There are not support on HIP
14051408
#if !KERNEL_FLOAT_IS_HIP
14061409
KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan)
@@ -1411,8 +1414,12 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(isfinite)
14111414
// CUDA offers special reciprocal functions (rcp), but only on the device.
14121415
#if KERNEL_FLOAT_IS_DEVICE
14131416
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, __drcp_rn(input), __frcp_rn(input))
1417+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rsqrt, ::rsqrt(input), ::rsqrtf(input))
1418+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcbrt, ::rcbrt(input), ::rcbrtf(input))
14141419
#else
14151420
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, 1.0 / input, 1.0f / input)
1421+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rsqrt, 1.0 / ::sqrt(input), 1.0f / ::sqrtf(input))
1422+
KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcbrt, 1.0 / ::cbrt(input), 1.0f / ::cbrtf(input))
14161423
#endif
14171424

14181425
KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp)

0 commit comments

Comments
 (0)