diff --git a/libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h b/libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h index f44b860ff32..32f8ed6257c 100644 --- a/libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h +++ b/libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h @@ -83,11 +83,11 @@ _CCCL_DEVICE static inline void __shfl_sync_checks( [[maybe_unused]] uint32_t __lane_mask) { static_assert(sizeof(_Tp) == 4, "shfl.sync only accepts 4-byte data types"); + _CCCL_ASSERT(__lane_mask & (1u << ::cuda::ptx::get_sreg_laneid()), "lane_mask must contain the current lane"); if (__shfl_mode != __dot_shfl_mode::__idx) { _CCCL_ASSERT(__lane_idx_offset < 32, "the lane index or offset must be less than the warp size"); } - _CCCL_ASSERT(__lane_mask != 0, "lane_mask must be non-zero"); _CCCL_ASSERT((__clamp_segmask | 0b1111100011111) == 0b1111100011111, "clamp value + segmentation mask must use the bit positions [0:4] and [8:12]"); _CCCL_ASSERT(::cuda::ptx::__shfl_sync_dst_lane(__shfl_mode, __lane_idx_offset, __clamp_segmask) & __lane_mask, diff --git a/libcudacxx/test/libcudacxx/cuda/ptx/ptx.shfl.compile.pass.cpp b/libcudacxx/test/libcudacxx/cuda/ptx/ptx.shfl.compile.pass.cpp index 25d105acbc4..676a260cd70 100644 --- a/libcudacxx/test/libcudacxx/cuda/ptx/ptx.shfl.compile.pass.cpp +++ b/libcudacxx/test/libcudacxx/cuda/ptx/ptx.shfl.compile.pass.cpp @@ -45,7 +45,7 @@ __host__ __device__ void test_shfl_full_mask() } auto res4 = cuda::ptx::shfl_sync_bfly(data, pred4, 2 /*offset*/, 0b11111 /*clamp*/, FullMask); - assert(res4 == threadIdx.x ^ 2 && pred4); + assert(res4 == (threadIdx.x ^ 2) && pred4); #endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION() } @@ -78,7 +78,7 @@ __host__ __device__ void test_shfl_full_mask_no_pred() } auto res4 = cuda::ptx::shfl_sync_bfly(data, 2 /*offset*/, 0b11111 /*clamp*/, FullMask); - assert(res4 == threadIdx.x ^ 2); + assert(res4 == (threadIdx.x ^ 2)); #endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION() } @@ -136,7 +136,7 @@ __host__ __device__ void test_shfl_partial_warp() } auto res4 = cuda::ptx::shfl_sync_bfly(data, pred4, 2 /*offset*/, clamp_segmark, FullMask); - assert(res4 == threadIdx.x ^ 2 && pred4); + assert(res4 == (threadIdx.x ^ 2) && pred4); #endif // __cccl_ptx_isa >= 600 && _CCCL_DEVICE_COMPILATION() }