Skip to content

Commit f43dc02

Browse files
fix(unary-cuda): improve acos/reciprocal/tan numeric correctness
1 parent 3e09f2a commit f43dc02

File tree

3 files changed

+39
-40
lines changed

3 files changed

+39
-40
lines changed

src/infiniop/ops/acos/cuda/kernel.cuh

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,6 @@
77

88
namespace op::acos::cuda {
99

10-
// ----------------------
11-
// Fast acos approximation
12-
// ----------------------
13-
__device__ __forceinline__ float fast_acosf(float x) {
14-
// 高性能多项式近似 acos(x)
15-
float ax = fabsf(x);
16-
float t = sqrtf(1.0f - ax);
17-
float r = ((-0.0187293f * ax + 0.0742610f) * ax - 0.2121144f) * ax + 1.5707288f;
18-
return (x >= 0.0f ? t * r : 3.14159265358979323846f - t * r);
19-
}
20-
2110
// ----------------------
2211
// float kernel (F32)
2312
// ----------------------
@@ -26,39 +15,27 @@ __device__ __forceinline__ T acos_impl(T val);
2615

2716
template <>
2817
__device__ __forceinline__ float acos_impl<float>(float val) {
29-
return fast_acosf(val);
18+
return ::acosf(val);
3019
}
3120

3221
// ----------------------
3322
// half kernel (F16)
3423
// ----------------------
3524
template <>
3625
__device__ __forceinline__ half acos_impl<half>(half val) {
37-
#if (__CUDA_ARCH__ >= 530)
38-
float f = __half2float(val);
39-
return __float2half(fast_acosf(f));
40-
#else
4126
float f = __half2float(val);
42-
return __float2half(fast_acosf(f));
43-
#endif
27+
return __float2half(::acosf(f));
4428
}
4529

4630
// ----------------------
4731
// half2 kernel (F16x2 vectorized)
4832
// ----------------------
4933
template <>
5034
__device__ __forceinline__ half2 acos_impl<half2>(half2 val) {
51-
#if (__CUDA_ARCH__ >= 530)
5235
float2 f = __half22float2(val);
53-
f.x = fast_acosf(f.x);
54-
f.y = fast_acosf(f.y);
36+
f.x = ::acosf(f.x);
37+
f.y = ::acosf(f.y);
5538
return __float22half2_rn(f);
56-
#else
57-
float2 f = __half22float2(val);
58-
f.x = fast_acosf(f.x);
59-
f.y = fast_acosf(f.y);
60-
return __float22half2_rn(f);
61-
#endif
6239
}
6340

6441
// ----------------------
@@ -67,15 +44,20 @@ __device__ __forceinline__ half2 acos_impl<half2>(half2 val) {
6744
template <>
6845
__device__ __forceinline__ cuda_bfloat16 acos_impl<cuda_bfloat16>(cuda_bfloat16 val) {
6946
float f = __bfloat162float(val);
70-
return __float2bfloat16(fast_acosf(f));
47+
return __float2bfloat16(::acosf(f));
48+
}
49+
50+
template <>
51+
__device__ __forceinline__ double acos_impl<double>(double val) {
52+
return ::acos(val);
7153
}
7254

7355
// ----------------------
7456
// Fallback kernel
7557
// ----------------------
7658
template <typename T>
7759
__device__ __forceinline__ T acos_impl(T val) {
78-
return static_cast<T>(fast_acosf(static_cast<float>(val)));
60+
return static_cast<T>(::acos(static_cast<double>(val)));
7961
}
8062

8163
// ----------------------

src/infiniop/ops/reciprocal/cuda/kernel.cuh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
#ifndef __RECIPROCAL_CUDA_H__
22
#define __RECIPROCAL_CUDA_H__
33

4+
#include <type_traits>
5+
46
namespace op::reciprocal::cuda {
57
typedef struct ReciprocalOp {
68
public:
79
static constexpr size_t num_inputs = 1;
810
template <typename T>
911
__device__ __forceinline__ T operator()(const T &x) const {
1012
if constexpr (std::is_same_v<T, half2>) {
11-
return h2rcp(x);
13+
float2 vf = __half22float2(x);
14+
vf.x = 1.0f / vf.x;
15+
vf.y = 1.0f / vf.y;
16+
return __float22half2_rn(vf);
1217
} else if constexpr (std::is_same_v<T, half>) {
13-
return hrcp(x);
18+
return __float2half(1.0f / __half2float(x));
1419
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
15-
// bfloat16 does not have a direct hrcp intrinsic in some versions,
16-
// often handled by converting to float or using specific bf16 intrinsics
17-
return __float2bfloat16(1.0f / __bfloat162float(x));
20+
return __float2bfloat16_rn(1.0f / __bfloat162float(x));
1821
} else if constexpr (std::is_same_v<T, float>) {
19-
return __frcp_rd(x);
22+
return 1.0f / x;
2023
} else {
2124
return static_cast<T>(1) / x;
2225
}

src/infiniop/ops/tan/cuda/kernel.cuh

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,40 @@
11
#ifndef __TAN_CUDA_H__
22
#define __TAN_CUDA_H__
33

4+
#include <cmath>
5+
#include <type_traits>
6+
47
namespace op::tan::cuda {
58

69
typedef struct TanOp {
710
public:
811
static constexpr size_t num_inputs = 1;
912
template <typename T>
1013
__device__ __forceinline__ T operator()(const T &x) const {
11-
if constexpr (std::is_same_v<T, cuda_bfloat16>) {
14+
if constexpr (std::is_same_v<T, half2>) {
15+
float2 vf = __half22float2(x);
16+
vf.x = ::tanf(vf.x);
17+
vf.y = ::tanf(vf.y);
18+
return __float22half2_rn(vf);
19+
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
20+
float f0 = __bfloat162float(__low2bfloat16(x));
21+
float f1 = __bfloat162float(__high2bfloat16(x));
22+
return __floats2bfloat162_rn(::tanf(f0), ::tanf(f1));
23+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
1224
// BF16
1325
const float x_f = __bfloat162float(x);
14-
return __float2bfloat16(__tanf(x_f));
26+
return __float2bfloat16_rn(::tanf(x_f));
1527
} else if constexpr (std::is_same_v<T, half>) {
1628
// FP16
1729
const float x_f = __half2float(x);
18-
return __float2half(__tanf(x_f));
30+
return __float2half(::tanf(x_f));
1931
} else if constexpr (std::is_same_v<T, float>) {
2032
// FP32
21-
return __tanf(x);
33+
return ::tanf(x);
34+
} else if constexpr (std::is_same_v<T, double>) {
35+
return ::tan(x);
2236
} else {
23-
return __tanf(x);
37+
return static_cast<T>(::tan(static_cast<double>(x)));
2438
}
2539
}
2640
} TanOp;

0 commit comments

Comments
 (0)