Skip to content

Commit 8f39507

Browse files
committed
Rename wrap_ptr and assert_aligned to make_vec_ptr
1 parent aa2f07b commit 8f39507

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

examples/vector_add/main.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ void run_kernel(int n) {
5454
int grid_size = (n + items_per_block - 1) / items_per_block;
5555
my_kernel<items_per_thread><<<grid_size, block_size>>>(
5656
n,
57-
kf::assert_aligned(input_dev),
57+
kf::make_vec_ptr(input_dev),
5858
constant,
59-
kf::assert_aligned(output_dev));
59+
kf::make_vec_ptr(output_dev));
6060

6161
// Copy results back
6262
cuda_check(cudaMemcpy(output_dev, output_result.data(), sizeof(float) * n, cudaMemcpyDefault));

include/kernel_float/memory.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ constexpr size_t gcd(size_t a, size_t b) {
119119
* Returns true if a pointer having alignment of `a` bytes also has an alignment of `b` bytes. Returns false otherwise.
120120
*/
121121
KERNEL_FLOAT_INLINE
122-
constexpr size_t alignment_divisible(size_t a, size_t b) {
122+
constexpr bool is_alignment_divisible(size_t a, size_t b) {
123123
return gcd(a, KERNEL_FLOAT_MAX_ALIGNMENT) % gcd(b, KERNEL_FLOAT_MAX_ALIGNMENT) == 0;
124124
}
125125

@@ -482,7 +482,7 @@ struct vector_ptr {
482482
typename T2,
483483
size_t N2,
484484
size_t A2,
485-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
485+
enable_if_t<detail::is_alignment_divisible(A2, Alignment), int> = 0>
486486
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
487487

488488
/**
@@ -567,14 +567,14 @@ struct vector_ptr<T, N, const U, Alignment> {
567567
typename T2,
568568
size_t N2,
569569
size_t A2,
570-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
570+
enable_if_t<detail::is_alignment_divisible(A2, Alignment), int> = 0>
571571
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, const U, A2> p) : data_(p.get()) {}
572572

573573
template<
574574
typename T2,
575575
size_t N2,
576576
size_t A2,
577-
enable_if_t<detail::alignment_divisible(A2, Alignment), int> = 0>
577+
enable_if_t<detail::is_alignment_divisible(A2, Alignment), int> = 0>
578578
KERNEL_FLOAT_INLINE vector_ptr(vector_ptr<T2, N2, U, A2> p) : data_(p.get()) {}
579579

580580
KERNEL_FLOAT_INLINE vector_ref<value_type, N, const U, Alignment> operator*() const {
@@ -621,7 +621,7 @@ template<
621621
size_t N,
622622
typename U,
623623
size_t A,
624-
typename = enable_if_t<(N * sizeof(U)) % A == 0>>
624+
typename = enable_if_t<detail::is_alignment_divisible(N * sizeof(U), A)>>
625625
KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p, size_t i) {
626626
return p = p + i;
627627
}
@@ -634,30 +634,30 @@ KERNEL_FLOAT_INLINE vector_ptr<T, N, U, A>& operator+=(vector_ptr<T, N, U, A>& p
634634
* @tparam U The type of the elements pointed to by the raw pointer.
635635
*/
636636
template<typename T, size_t N = 1, typename U>
637-
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> wrap_ptr(U* ptr) {
637+
KERNEL_FLOAT_INLINE vector_ptr<T, N, U> make_vec_ptr(U* ptr) {
638638
return vector_ptr<T, N, U> {ptr};
639639
}
640640

641+
// Doxygen cannot deal with the `make_vec_ptr` being defined multiple times, we ignore the second definition.
642+
/// @cond IGNORE
641643
/**
642644
* Creates a `vector_ptr<T, N>` from a raw pointer `T*` by asserting a specific alignment `N`.
643645
*
644646
* @tparam N The alignment constraint for the vector_ptr.
645647
* @tparam T The type of the elements pointed to by the raw pointer.
646648
*/
647649
template<size_t N, typename T>
648-
KERNEL_FLOAT_INLINE vector_ptr<T, N> assert_aligned(T* ptr) {
650+
KERNEL_FLOAT_INLINE vector_ptr<T, N> make_vec_ptr(T* ptr) {
649651
return vector_ptr<T, N> {ptr};
650652
}
651653

652-
// Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition.
653-
/// @cond IGNORE
654654
/**
655655
* Creates a `vector_ptr<T, 1>` from a raw pointer `T*`. The alignment is assumed to be KERNEL_FLOAT_MAX_ALIGNMENT.
656656
*
657657
* @tparam T The type of the elements pointed to by the raw pointer.
658658
*/
659-
template<typename T>
660-
KERNEL_FLOAT_INLINE vector_ptr<T, 1, T, KERNEL_FLOAT_MAX_ALIGNMENT> assert_aligned(T* ptr) {
659+
template<decltype(nullptr) = nullptr, typename T>
660+
KERNEL_FLOAT_INLINE vector_ptr<T, 1, T, KERNEL_FLOAT_MAX_ALIGNMENT> make_vec_ptr(T* ptr) {
661661
return vector_ptr<T, 1, T, KERNEL_FLOAT_MAX_ALIGNMENT> {ptr};
662662
}
663663
/// @endcond

tests/memory.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ struct vector_ptr_test {
156156
};
157157

158158
{
159-
kf::vector_ptr<const U, N> storage_ptr = kf::assert_aligned(storage.data);
159+
kf::vector_ptr<const U, N> storage_ptr = kf::make_vec_ptr(storage.data);
160160
kf::vector_ptr<T, N, const U> ptr = storage_ptr;
161161
ASSERT_EQ(ptr.get(), static_cast<const U*>(storage.data));
162162

@@ -175,7 +175,7 @@ struct vector_ptr_test {
175175
}
176176

177177
{
178-
kf::vector_ptr<U, N> storage_ptr = kf::assert_aligned(storage.data);
178+
kf::vector_ptr<U, N> storage_ptr = kf::make_vec_ptr(storage.data);
179179
kf::vector_ptr<T, N, U> ptr = storage_ptr;
180180
ASSERT_EQ(ptr.get(), static_cast<U*>(storage.data));
181181

@@ -237,18 +237,18 @@ struct vector_ptr_test {
237237
{
238238
U* ptr = nullptr;
239239

240-
auto a1 = kf::wrap_ptr<T>(ptr);
240+
auto a1 = kf::make_vec_ptr<T>(ptr);
241241
ASSERT(std::is_same<decltype(a1), kf::vector_ptr<T, 1, U>>::value);
242242

243-
auto a2 = kf::wrap_ptr<T, 2>(ptr);
243+
auto a2 = kf::make_vec_ptr<T, 2>(ptr);
244244
ASSERT(std::is_same<decltype(a2), kf::vector_ptr<T, 2, U>>::value);
245245

246-
auto a3 = kf::assert_aligned(ptr);
246+
auto a3 = kf::make_vec_ptr(ptr);
247247
ASSERT(
248248
std::is_same<decltype(a3), kf::vector_ptr<U, 1, U, KERNEL_FLOAT_MAX_ALIGNMENT>>::
249249
value);
250250

251-
auto a4 = kf::assert_aligned<2>(ptr);
251+
auto a4 = kf::make_vec_ptr<2>(ptr);
252252
ASSERT(std::is_same<decltype(a4), kf::vector_ptr<U, 2>>::value);
253253
}
254254
}

0 commit comments

Comments
 (0)