diff --git a/cub/agent/single_pass_scan_operators.cuh b/cub/agent/single_pass_scan_operators.cuh index e20462648..360168f0e 100644 --- a/cub/agent/single_pass_scan_operators.cuh +++ b/cub/agent/single_pass_scan_operators.cuh @@ -499,6 +499,28 @@ struct ClusterTilePrefixCallbackOp } } + __device__ __forceinline__ void + BroadcastInclusiveAggregate(T block_aggregate, ScanTileStatus status) + { + const unsigned int cta_rank = cooperative_groups::cluster_group::block_rank(); + const unsigned int dst_cta = cta_rank + 1 + threadIdx.x; + + // Notify last CTA first + for (int dst_cta = CUB_DETAIL_CLUSTER_SIZE - 1; dst_cta > 0; dst_cta--) + { + TxnWord * dsmem = cooperative_groups::cluster_group::map_shared_rank(temp_storage.dsmem, dst_cta); + + TileDescriptor tile_descriptor; + tile_descriptor.status = status; + tile_descriptor.value = block_aggregate; + + TxnWord alias; + *reinterpret_cast(&alias) = tile_descriptor; + + dsmem_st_relaxed(dsmem + cta_rank, alias); + } + } + __device__ __forceinline__ T Reduce(unsigned int cta_rank, unsigned int src_cta, T value) { @@ -543,8 +565,16 @@ struct ClusterTilePrefixCallbackOp exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); } - T inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); - BroadcastBlockAggregate(inclusive_prefix, SCAN_TILE_INCLUSIVE); + // T inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); + // TODO Different values!!! + // inclusive_prefix = __shfl_sync(CUB_DETAIL_CLUSTER_WARP_MASK, inclusive_prefix, 0, CUB_DETAIL_CLUSTER_SIZE); + // BroadcastBlockAggregate(inclusive_prefix, SCAN_TILE_INCLUSIVE); + + if (threadIdx.x == 0) + { + T inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); + BroadcastInclusiveAggregate(inclusive_prefix, SCAN_TILE_INCLUSIVE); + } } else {