diff --git a/p2p-all-to-all/a2a-kernels/src/a2a/a2a_dispatch_send.cu b/p2p-all-to-all/a2a-kernels/src/a2a/a2a_dispatch_send.cu index 806b46c..33d1dd6 100644 --- a/p2p-all-to-all/a2a-kernels/src/a2a/a2a_dispatch_send.cu +++ b/p2p-all-to-all/a2a-kernels/src/a2a/a2a_dispatch_send.cu @@ -190,55 +190,73 @@ __global__ __launch_bounds__(NUM_WARPS * WARP_SIZE, 1) void a2a_dispatch_send_ke } __syncthreads(); - // Find the start offset of each rank by computing a cumulative sum within tokens_per_rank. - // Compute sums within each warp and store the sums in shared memory. - const uint32_t i = threadIdx.x; - const uint32_t num_warps = ceil_div(num_experts, WARP_SIZE); - uint32_t *expert_sums = (uint32_t*)shared_memory; - + // Compute cumulative sum over tokens_per_expert to produce expert_offsets. + // Use a chunked two-level warp scan: each iteration processes up to num_threads experts, + // carrying a running total across chunks to support scaling up beyond block size. + uint32_t *expert_sums = (uint32_t*)(shared_memory + num_experts * sizeof(uint32_t)); + uint32_t *local_num_routed = num_routed + dp_group * num_experts; - uint32_t expert_offset = 0; - if (i < num_experts) { - expert_offset = tokens_per_expert[i]; - local_num_routed[i] = expert_offset; - } - __syncthreads(); + + __shared__ uint32_t s_running_total; if (threadIdx.x == 0) { - st_mmio_b8(dispatch_route_done, 1); - } - for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { - unsigned warp_sum_expert = __shfl_up_sync(0xFFFFFFFF, expert_offset, offset); - if (lane_id >= offset) { - expert_offset += warp_sum_expert; - } - } - if (lane_id == WARP_SIZE - 1) { - expert_sums[warp_id] = expert_offset; + s_running_total = 0; } __syncthreads(); - // Sum up the warp sums in the first warp. - if (warp_id == 0) { - uint32_t total_expert_sum = (lane_id < num_warps) ? expert_sums[lane_id] : 0; - for (int offset = 1; offset < num_warps; offset <<= 1) { - unsigned warp_sum = __shfl_up_sync(0xFFFFFFFF, total_expert_sum, offset); + for (uint32_t base_idx = 0; base_idx < num_experts; base_idx += blockDim.x) { + uint32_t i = base_idx + threadIdx.x; + uint32_t expert_count = 0; + + if (i < num_experts) { + expert_count = tokens_per_expert[i]; + local_num_routed[i] = expert_count; + } + __syncthreads(); + + // Compute sums within each warp + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + unsigned shfl_val = __shfl_up_sync(0xFFFFFFFF, expert_count, offset); if (lane_id >= offset) { - total_expert_sum += warp_sum; + expert_count += shfl_val; } } - if (lane_id < num_warps) { - expert_sums[lane_id] = total_expert_sum; + + // Store the sums in shared memory + if (lane_id == WARP_SIZE - 1) { + expert_sums[warp_id] = expert_count; } - } - __syncthreads(); + __syncthreads(); + + // Sum up the warp sums in the first warp + if (warp_id == 0) { + uint32_t total_expert_sum = (lane_id < NUM_WARPS) ? expert_sums[lane_id] : 0; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + unsigned shfl_val = __shfl_up_sync(0xFFFFFFFF, total_expert_sum, offset); + if (lane_id >= offset) { + total_expert_sum += shfl_val; + } + } + if (lane_id < NUM_WARPS) { + expert_sums[lane_id] = total_expert_sum; + } + } + __syncthreads(); - // Add the sums to the token counts to find the start offset of each expert. - if (i < num_experts) { - if (warp_id > 0) { - expert_offsets[i] = expert_sums[warp_id - 1] + expert_offset; - } else { - expert_offsets[i] = expert_offset; + // Compute global offset by adding the running total + if (i < num_experts) { + uint32_t chunk_inclusive_sum = (warp_id > 0 ? expert_sums[warp_id - 1] : 0) + expert_count; + expert_offsets[i] = s_running_total + chunk_inclusive_sum; } + + // Update and carry the running total across chunks + if (threadIdx.x == 0) { + s_running_total += expert_sums[NUM_WARPS - 1]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + st_mmio_b8(dispatch_route_done, 1); } } __syncthreads(); @@ -579,9 +597,7 @@ int a2a_kernels::a2a_dispatch_send( dim3 dimGrid(num_blocks, 1, 1); dim3 dimBlock(NUM_THREADS, 1, 1); - // There should be enough warps to do a horizontal reduction across ranks. assert(world_size <= NUM_THREADS); - assert(num_experts <= NUM_THREADS); const size_t token_dim = round_up(hidden_dim * x_elemsize, sizeof(int4)); const size_t token_scale_dim = round_up(hidden_dim_scale * x_scale_elemsize, sizeof(int4)); @@ -627,7 +643,8 @@ int a2a_kernels::a2a_dispatch_send( &recv_ptrs, }; - const size_t shared_memory_send = std::max(num_experts, NUM_WARPS) * sizeof(uint32_t); + // Requires tokens_per_expert[num_experts] + expert_sums[NUM_WARPS] + const size_t shared_memory_send = (num_experts + NUM_WARPS) * sizeof(uint32_t); nvtxRangePush("dispatch_send"); cudaError_t status; @@ -674,4 +691,4 @@ int a2a_kernels::a2a_dispatch_send( }); nvtxRangePop(); return status; -} +} \ No newline at end of file