Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 10, 2023
1 parent 474a084 commit cc6eded
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
1 change: 0 additions & 1 deletion cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ struct AgentScan

// Wait for all threads in the cluster to finish loading / dsmem initialization
cooperative_groups::cluster_group::barrier_wait(std::move(token));
CTA_SYNC(); // What, this sync fixes the race

// Perform tile scan
if (tile_idx == 0)
Expand Down
67 changes: 57 additions & 10 deletions cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ struct ScanClusterTileState<T, true>
}
};

static __device__ __forceinline__ uint4 dsmem_ld_relaxed(uint4 const *ptr)
static __device__ __forceinline__ uint4 lsmem_ld_relaxed(uint4 const *ptr)
{
uint4 retval;
asm volatile("ld.relaxed.shared.cluster.v4.u32 {%0, %1, %2, %3}, [%4];"
Expand All @@ -314,7 +314,7 @@ static __device__ __forceinline__ uint4 dsmem_ld_relaxed(uint4 const *ptr)
return retval;
}

static __device__ __forceinline__ ulonglong2 dsmem_ld_relaxed(ulonglong2 const *ptr)
static __device__ __forceinline__ ulonglong2 lsmem_ld_relaxed(ulonglong2 const *ptr)
{
ulonglong2 retval;
asm volatile("ld.relaxed.shared.cluster.v2.u64 {%0, %1}, [%2];"
Expand All @@ -324,7 +324,7 @@ static __device__ __forceinline__ ulonglong2 dsmem_ld_relaxed(ulonglong2 const *
return retval;
}

static __device__ __forceinline__ ushort4 dsmem_ld_relaxed(ushort4 const *ptr)
static __device__ __forceinline__ ushort4 lsmem_ld_relaxed(ushort4 const *ptr)
{
ushort4 retval;
asm volatile("ld.relaxed.shared.cluster.v4.u16 {%0, %1, %2, %3}, [%4];"
Expand All @@ -334,7 +334,7 @@ static __device__ __forceinline__ ushort4 dsmem_ld_relaxed(ushort4 const *ptr)
return retval;
}

static __device__ __forceinline__ uint2 dsmem_ld_relaxed(uint2 const *ptr)
static __device__ __forceinline__ uint2 lsmem_ld_relaxed(uint2 const *ptr)
{
uint2 retval;
asm volatile("ld.relaxed.shared.cluster.v2.u32 {%0, %1}, [%2];"
Expand All @@ -344,7 +344,7 @@ static __device__ __forceinline__ uint2 dsmem_ld_relaxed(uint2 const *ptr)
return retval;
}

static __device__ __forceinline__ unsigned long long dsmem_ld_relaxed(unsigned long long const *ptr)
static __device__ __forceinline__ unsigned long long lsmem_ld_relaxed(unsigned long long const *ptr)
{
unsigned long long retval;
asm volatile("ld.relaxed.shared.cluster.u64 %0, [%1];"
Expand All @@ -354,7 +354,7 @@ static __device__ __forceinline__ unsigned long long dsmem_ld_relaxed(unsigned l
return retval;
}

static __device__ __forceinline__ unsigned int dsmem_ld_relaxed(unsigned int const *ptr)
static __device__ __forceinline__ unsigned int lsmem_ld_relaxed(unsigned int const *ptr)
{
unsigned int retval;
asm volatile("ld.relaxed.shared.cluster.u32 %0, [%1];"
Expand Down Expand Up @@ -413,6 +413,55 @@ static __device__ __forceinline__ void dsmem_st_relaxed(unsigned int *ptr, unsig
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(uint4 *ptr, uint4 val)
{
asm volatile("st.relaxed.shared.cluster.v4.u32 [%0], {%1, %2, %3, %4};"
:
: _CUB_ASM_PTR_(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(ulonglong2 *ptr, ulonglong2 val)
{
asm volatile("st.relaxed.shared.cluster.v2.u64 [%0], {%1, %2};"
:
: _CUB_ASM_PTR_(ptr), "l"(val.x), "l"(val.y)
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(ushort4 *ptr, ushort4 val)
{
asm volatile("st.relaxed.shared.cluster.v4.u16 [%0], {%1, %2, %3, %4};"
:
: _CUB_ASM_PTR_(ptr), "h"(val.x), "h"(val.y), "h"(val.z), "h"(val.w)
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(uint2 *ptr, uint2 val)
{
asm volatile("st.relaxed.shared.cluster.v2.u32 [%0], {%1, %2};"
:
: _CUB_ASM_PTR_(ptr), "r"(val.x), "r"(val.y)
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(unsigned long long *ptr,
unsigned long long val)
{
asm volatile("st.relaxed.shared.cluster.u64 [%0], %1;"
:
: _CUB_ASM_PTR_(ptr), "l"(val)
: "memory");
}

static __device__ __forceinline__ void lsmem_st_relaxed(unsigned int *ptr, unsigned int val)
{
asm volatile("st.relaxed.shared.cluster.u32 [%0], %1;"
:
: _CUB_ASM_PTR_(ptr), "r"(val)
: "memory");
}

template <
typename T,
typename ScanOpT,
Expand Down Expand Up @@ -530,7 +579,7 @@ struct ClusterTilePrefixCallbackOp
__device__ __forceinline__
void LoadTileDescriptor(unsigned int src_cta, TileDescriptor &tile_descriptor)
{
TxnWord alias = dsmem_ld_relaxed(temp_storage.dsmem + src_cta);
TxnWord alias = lsmem_ld_relaxed(temp_storage.dsmem + src_cta);
tile_descriptor = reinterpret_cast<TileDescriptor &>(alias);
}

Expand Down Expand Up @@ -593,8 +642,6 @@ struct ClusterTilePrefixCallbackOp
}
exclusive_prefix = Reduce(cta_rank, src_cta, tile_descriptor.value);

// second thread of block 31 reads 2399 instead of 1920

if (__shfl_sync(CUB_DETAIL_CLUSTER_WARP_MASK,
tile_descriptor.status == SCAN_TILE_PARTIAL,
0,
Expand Down Expand Up @@ -667,7 +714,7 @@ struct ClusterTilePrefixCallbackOp
descriptor->status = StatusWord(SCAN_TILE_INVALID);

if (threadIdx.x < CUB_DETAIL_CLUSTER_SIZE) {
temp_storage.dsmem[threadIdx.x] = val;
lsmem_st_relaxed(temp_storage.dsmem + threadIdx.x, val);
}
}
};
Expand Down

0 comments on commit cc6eded

Please sign in to comment.