Skip to content

Commit f0f74b8

Browse files
committed
Add view_ptr that allows viewing a pointer U* through a different data type T
1 parent 4d0d49c commit f0f74b8

File tree

5 files changed

+129
-62
lines changed

5 files changed

+129
-62
lines changed

include/kernel_float/approx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) {
127127
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535;
128128
static constexpr double OFFSET = -2042.0;
129129

130+
// ws = (x / 2pi) - ((x / 2pi + OFFSET) - OFFSET)
130131
half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET);
131132
return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws);
132133
}

include/kernel_float/binops.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
118118
template<typename L, typename R, typename C = promote_t<L, vector_value_type<R>>, typename E> \
119119
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, vector<L, E>, R> operator OP( \
120120
const vector<L, E>& left, \
121-
const R& right) { \
121+
const R \
122+
& right) { \
122123
return zip_common(ops::NAME<C> {}, left, right); \
123124
} \
124125
template<typename L, typename R, typename C = promote_t<vector_value_type<L>, R>, typename E> \
125126
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, vector<R, E>> operator OP( \
126-
const L& left, \
127+
const L \
128+
& left, \
127129
const vector<R, E>& right) { \
128130
return zip_common(ops::NAME<C> {}, left, right); \
129131
}
@@ -164,16 +166,16 @@ static constexpr bool is_vector_assign_allowed =
164166
>;
165167
// clang-format on
166168

167-
#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \
168-
template< \
169-
typename T, \
170-
typename E, \
171-
typename R, \
172-
typename = enable_if_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
173-
KERNEL_FLOAT_INLINE vector<T, E>& operator OP(vector<T, E>& lhs, const R& rhs) { \
174-
using F = ops::NAME<T>; \
175-
lhs = zip_common(F {}, lhs, rhs); \
176-
return lhs; \
169+
#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \
170+
template< \
171+
typename T, \
172+
typename E, \
173+
typename R, \
174+
typename = enable_if_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
175+
KERNEL_FLOAT_INLINE vector<T, E>& operator OP(vector<T, E>& lhs, const R & rhs) { \
176+
using F = ops::NAME<T>; \
177+
lhs = zip_common(F {}, lhs, rhs); \
178+
return lhs; \
177179
}
178180

179181
KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(add, +=)

include/kernel_float/memory.h

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ struct vector_ref<T, N, const U, Alignment> {
414414
template<typename T, size_t N, typename U, size_t Alignment, typename V> \
415415
KERNEL_FLOAT_INLINE vector_ref<T, N, U, Alignment> operator OP_ASSIGN( \
416416
vector_ref<T, N, U, Alignment> ptr, \
417-
const V& value) { \
417+
const V \
418+
& value) { \
418419
ptr.write(ptr.read() OP value); \
419420
return ptr; \
420421
}
@@ -462,19 +463,27 @@ struct vector_ptr {
462463
vector_ptr() = default;
463464

464465
/**
465-
* Constructor from a given pointer. It is up to the user to assert that the pointer is aligned to `Align` elements.
466+
* Constructor from a given pointer. It is up to the user to assert that the pointer is aligned to `Alignment`.
466467
*/
468+
template<typename V = U, enable_if_t<Alignment != alignof(V), int> = 0>
467469
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
468470

471+
/**
472+
* Constructor from a given pointer. This assumes that the alignment of the pointer equals `Alignment`.
473+
*/
474+
template<typename V = U, enable_if_t<Alignment == alignof(V), int> = 0>
475+
KERNEL_FLOAT_INLINE vector_ptr(pointer_type p) : data_(p) {}
476+
469477
/**
470478
* Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor
471479
* only allows conversion if the alignment of the source is greater than or equal to the alignment of the target.
472480
*/
473-
template<typename T2, size_t N2, size_t A2>
474-
KERNEL_FLOAT_INLINE vector_ptr(
475-
vector_ptr<T2, N2, U, A2> p,
476-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
477-
data_(p.get()) {}
481+
template<
482+
typename T2,
483+
size_t N2,
484+
size_t A2,
485+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
486+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
478487

479488
/**
480489
* Shorthand for `at(0)`.
@@ -548,19 +557,25 @@ struct vector_ptr<T, N, const U, Alignment> {
548557

549558
vector_ptr() = default;
550559

560+
template<typename V = U, enable_if_t<Alignment != alignof(V), int> = 0>
551561
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
552562

553-
template<typename T2, size_t N2, size_t A2>
554-
KERNEL_FLOAT_INLINE vector_ptr(
555-
vector_ptr<T2, N2, const U, A2> p,
556-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
557-
data_(p.get()) {}
563+
template<typename V = U, enable_if_t<Alignment == alignof(V), int> = 0>
564+
KERNEL_FLOAT_INLINE vector_ptr(pointer_type p) : data_(p) {}
558565

559-
template<typename T2, size_t N2, size_t A2>
560-
KERNEL_FLOAT_INLINE vector_ptr(
561-
vector_ptr<T2, N2, U, A2> p,
562-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
563-
data_(p.get()) {}
566+
template<
567+
typename T2,
568+
size_t N2,
569+
size_t A2,
570+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
571+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, const U, A2> p) : data_(p.get()) {}
572+
573+
template<
574+
typename T2,
575+
size_t N2,
576+
size_t A2,
577+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
578+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
564579

565580
KERNEL_FLOAT_INLINE vector_ref<value_type, N, const U, Alignment> operator*() const {
566581
return vector_ref<value_type, N, const U, Alignment> {data_};
@@ -614,7 +629,7 @@ KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p
614629
/**
615630
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
616631
*
617-
* @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT.
632+
* @tparam N The alignment constraint for the vector_ptr.
618633
* @tparam T The type of the elements pointed to by the raw pointer.
619634
*/
620635
template<size_t N, typename T>
@@ -638,6 +653,9 @@ KERNEL_FLOAT_INLINE vector_ptr<T, 1, T, KERNEL_FLOAT_MAX_ALIGNMENT> assert_align
638653
template<typename T, size_t N = 1, typename U = T, size_t Align = N>
639654
using vec_ptr = vector_ptr<T, N, U, Align * sizeof(U)>;
640655

656+
template<typename T, typename U = T>
657+
using view_ptr = vector_ptr<T, 1, U, alignof(U)>;
658+
641659
#if defined(__cpp_deduction_guides)
642660
template<typename T>
643661
vector_ptr(T*) -> vector_ptr<T, 1, T>;

single_include/kernel_float.h

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2025-08-12 13:55:51.042675
20-
// git hash: 714ca6b5fd63ef3497d80ef018cb9a9460c91391
19+
// date: 2025-08-21 10:13:04.148230
20+
// git hash: 4d0d49cad7962d3f9ba4f2a0abfa2faea3ec7efa
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -1823,12 +1823,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
18231823
template<typename L, typename R, typename C = promote_t<L, vector_value_type<R>>, typename E> \
18241824
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, vector<L, E>, R> operator OP( \
18251825
const vector<L, E>& left, \
1826-
const R& right) { \
1826+
const R \
1827+
& right) { \
18271828
return zip_common(ops::NAME<C> {}, left, right); \
18281829
} \
18291830
template<typename L, typename R, typename C = promote_t<vector_value_type<L>, R>, typename E> \
18301831
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, vector<R, E>> operator OP( \
1831-
const L& left, \
1832+
const L \
1833+
& left, \
18321834
const vector<R, E>& right) { \
18331835
return zip_common(ops::NAME<C> {}, left, right); \
18341836
}
@@ -1869,16 +1871,16 @@ static constexpr bool is_vector_assign_allowed =
18691871
>;
18701872
// clang-format on
18711873

1872-
#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \
1873-
template< \
1874-
typename T, \
1875-
typename E, \
1876-
typename R, \
1877-
typename = enable_if_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
1878-
KERNEL_FLOAT_INLINE vector<T, E>& operator OP(vector<T, E>& lhs, const R& rhs) { \
1879-
using F = ops::NAME<T>; \
1880-
lhs = zip_common(F {}, lhs, rhs); \
1881-
return lhs; \
1874+
#define KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(NAME, OP) \
1875+
template< \
1876+
typename T, \
1877+
typename E, \
1878+
typename R, \
1879+
typename = enable_if_t<is_vector_assign_allowed<ops::NAME, T, E, R>>> \
1880+
KERNEL_FLOAT_INLINE vector<T, E>& operator OP(vector<T, E>& lhs, const R & rhs) { \
1881+
using F = ops::NAME<T>; \
1882+
lhs = zip_common(F {}, lhs, rhs); \
1883+
return lhs; \
18821884
}
18831885

18841886
KERNEL_FLOAT_DEFINE_BINARY_ASSIGN_OP(add, +=)
@@ -2975,7 +2977,8 @@ struct vector_ref<T, N, const U, Alignment> {
29752977
template<typename T, size_t N, typename U, size_t Alignment, typename V> \
29762978
KERNEL_FLOAT_INLINE vector_ref<T, N, U, Alignment> operator OP_ASSIGN( \
29772979
vector_ref<T, N, U, Alignment> ptr, \
2978-
const V& value) { \
2980+
const V \
2981+
& value) { \
29792982
ptr.write(ptr.read() OP value); \
29802983
return ptr; \
29812984
}
@@ -3023,19 +3026,27 @@ struct vector_ptr {
30233026
vector_ptr() = default;
30243027

30253028
/**
3026-
* Constructor from a given pointer. It is up to the user to assert that the pointer is aligned to `Align` elements.
3029+
* Constructor from a given pointer. It is up to the user to assert that the pointer is aligned to `Alignment`.
30273030
*/
3031+
template<typename V = U, enable_if_t<Alignment != alignof(V), int> = 0>
30283032
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
30293033

3034+
/**
3035+
* Constructor from a given pointer. This assumes that the alignment of the pointer equals `Alignment`.
3036+
*/
3037+
template<typename V = U, enable_if_t<Alignment == alignof(V), int> = 0>
3038+
KERNEL_FLOAT_INLINE vector_ptr(pointer_type p) : data_(p) {}
3039+
30303040
/**
30313041
* Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor
30323042
* only allows conversion if the alignment of the source is greater than or equal to the alignment of the target.
30333043
*/
3034-
template<typename T2, size_t N2, size_t A2>
3035-
KERNEL_FLOAT_INLINE vector_ptr(
3036-
vector_ptr<T2, N2, U, A2> p,
3037-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
3038-
data_(p.get()) {}
3044+
template<
3045+
typename T2,
3046+
size_t N2,
3047+
size_t A2,
3048+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
3049+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
30393050

30403051
/**
30413052
* Shorthand for `at(0)`.
@@ -3109,19 +3120,25 @@ struct vector_ptr<T, N, const U, Alignment> {
31093120

31103121
vector_ptr() = default;
31113122

3123+
template<typename V = U, enable_if_t<Alignment != alignof(V), int> = 0>
31123124
KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {}
31133125

3114-
template<typename T2, size_t N2, size_t A2>
3115-
KERNEL_FLOAT_INLINE vector_ptr(
3116-
vector_ptr<T2, N2, const U, A2> p,
3117-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
3118-
data_(p.get()) {}
3126+
template<typename V = U, enable_if_t<Alignment == alignof(V), int> = 0>
3127+
KERNEL_FLOAT_INLINE vector_ptr(pointer_type p) : data_(p) {}
31193128

3120-
template<typename T2, size_t N2, size_t A2>
3121-
KERNEL_FLOAT_INLINE vector_ptr(
3122-
vector_ptr<T2, N2, U, A2> p,
3123-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = {}) :
3124-
data_(p.get()) {}
3129+
template<
3130+
typename T2,
3131+
size_t N2,
3132+
size_t A2,
3133+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
3134+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, const U, A2> p) : data_(p.get()) {}
3135+
3136+
template<
3137+
typename T2,
3138+
size_t N2,
3139+
size_t A2,
3140+
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
3141+
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
31253142

31263143
KERNEL_FLOAT_INLINE vector_ref<value_type, N, const U, Alignment> operator*() const {
31273144
return vector_ref<value_type, N, const U, Alignment> {data_};
@@ -3175,7 +3192,7 @@ KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p
31753192
/**
31763193
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
31773194
*
3178-
* @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT.
3195+
* @tparam N The alignment constraint for the vector_ptr.
31793196
* @tparam T The type of the elements pointed to by the raw pointer.
31803197
*/
31813198
template<size_t N, typename T>
@@ -3199,6 +3216,9 @@ KERNEL_FLOAT_INLINE vector_ptr<T, 1, T, KERNEL_FLOAT_MAX_ALIGNMENT> assert_align
31993216
template<typename T, size_t N = 1, typename U = T, size_t Align = N>
32003217
using vec_ptr = vector_ptr<T, N, U, Align * sizeof(U)>;
32013218

3219+
template<typename T, typename U = T>
3220+
using view_ptr = vector_ptr<T, 1, U, alignof(U)>;
3221+
32023222
#if defined(__cpp_deduction_guides)
32033223
template<typename T>
32043224
vector_ptr(T*) -> vector_ptr<T, 1, T>;
@@ -4749,6 +4769,7 @@ KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) {
47494769
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535;
47504770
static constexpr double OFFSET = -2042.0;
47514771

4772+
// ws = (x / 2pi) - ((x / 2pi + OFFSET) - OFFSET)
47524773
half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET);
47534774
return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws);
47544775
}

tests/memory.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,31 @@ struct vector_ptr_test {
208208
kf::vec<T, N> h = ptr[1];
209209
ASSERT_EQ_ALL(h[I], T(3.14));
210210
}
211+
212+
{
213+
// This does *not* require an explicit constructor (N == 1)
214+
kf::vector_ptr<T, 1, U> a1_ptr = storage.data;
215+
kf::vector_ptr<const T, 1, U> a2_ptr = storage.data;
216+
kf::vector_ptr<T, 1, const U> a3_ptr = storage.data;
217+
kf::vector_ptr<const T, 1, const U> a4_ptr = storage.data;
218+
219+
ASSERT_EQ(a1_ptr.get(), static_cast<U*>(storage.data));
220+
ASSERT_EQ(a2_ptr.get(), static_cast<U*>(storage.data));
221+
ASSERT_EQ(a3_ptr.get(), static_cast<const U*>(storage.data));
222+
ASSERT_EQ(a4_ptr.get(), static_cast<const U*>(storage.data));
223+
224+
// This *does* require an explicit constructor (N > 1)
225+
kf::vector_ptr<T, 2, U> b1_ptr = kf::vector_ptr<T, 2, U>(storage.data);
226+
kf::vector_ptr<const T, 2, U> b2_ptr = kf::vector_ptr<const T, 2, U>(storage.data);
227+
kf::vector_ptr<T, 2, const U> b3_ptr = kf::vector_ptr<T, 2, const U>(storage.data);
228+
kf::vector_ptr<const T, 2, const U> b4_ptr =
229+
kf::vector_ptr<const T, 2, const U>(storage.data);
230+
231+
ASSERT_EQ(b1_ptr.get(), static_cast<U*>(storage.data));
232+
ASSERT_EQ(b2_ptr.get(), static_cast<U*>(storage.data));
233+
ASSERT_EQ(b3_ptr.get(), static_cast<const U*>(storage.data));
234+
ASSERT_EQ(b4_ptr.get(), static_cast<const U*>(storage.data));
235+
}
211236
}
212237
};
213238

0 commit comments

Comments
 (0)