Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ _CCCL_DEVICE static inline void __shfl_sync_checks(
[[maybe_unused]] uint32_t __lane_mask)
{
static_assert(sizeof(_Tp) == 4, "shfl.sync only accepts 4-byte data types");
_CCCL_ASSERT(__lane_mask & (1u << ::cuda::ptx::get_sreg_laneid()), "lane_mask must contain the current lane");
if (__shfl_mode != __dot_shfl_mode::__idx)
{
_CCCL_ASSERT(__lane_idx_offset < 32, "the lane index or offset must be less than the warp size");
}
_CCCL_ASSERT(__lane_mask != 0, "lane_mask must be non-zero");
_CCCL_ASSERT((__clamp_segmask | 0b1111100011111) == 0b1111100011111,
"clamp value + segmentation mask must use the bit positions [0:4] and [8:12]");
_CCCL_ASSERT(::cuda::ptx::__shfl_sync_dst_lane(__shfl_mode, __lane_idx_offset, __clamp_segmask) & __lane_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ __host__ __device__ void test_shfl_full_mask()
}

auto res4 = cuda::ptx::shfl_sync_bfly(data, pred4, 2 /*offset*/, 0b11111 /*clamp*/, FullMask);
assert(res4 == threadIdx.x ^ 2 && pred4);
assert(res4 == (threadIdx.x ^ 2) && pred4);
#endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION()
}

Expand Down Expand Up @@ -78,7 +78,7 @@ __host__ __device__ void test_shfl_full_mask_no_pred()
}

auto res4 = cuda::ptx::shfl_sync_bfly(data, 2 /*offset*/, 0b11111 /*clamp*/, FullMask);
assert(res4 == threadIdx.x ^ 2);
assert(res4 == (threadIdx.x ^ 2));
#endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION()
}

Expand Down Expand Up @@ -136,7 +136,7 @@ __host__ __device__ void test_shfl_partial_warp()
}

auto res4 = cuda::ptx::shfl_sync_bfly(data, pred4, 2 /*offset*/, clamp_segmark, FullMask);
assert(res4 == threadIdx.x ^ 2 && pred4);
assert(res4 == (threadIdx.x ^ 2) && pred4);
#endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION()
}

Expand Down