Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 59 additions & 42 deletions p2p-all-to-all/a2a-kernels/src/a2a/a2a_dispatch_send.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,55 +190,73 @@ __global__ __launch_bounds__(NUM_WARPS * WARP_SIZE, 1) void a2a_dispatch_send_ke
}
__syncthreads();

// Find the start offset of each rank by computing a cumulative sum within tokens_per_rank.
// Compute sums within each warp and store the sums in shared memory.
const uint32_t i = threadIdx.x;
const uint32_t num_warps = ceil_div<size_t>(num_experts, WARP_SIZE);
uint32_t *expert_sums = (uint32_t*)shared_memory;

// Compute cumulative sum over tokens_per_expert to produce expert_offsets.
// Use a chunked two-level warp scan: each iteration processes up to num_threads experts,
// carrying a running total across chunks to support scaling up beyond block size.
uint32_t *expert_sums = (uint32_t*)(shared_memory + num_experts * sizeof(uint32_t));

uint32_t *local_num_routed = num_routed + dp_group * num_experts;
uint32_t expert_offset = 0;
if (i < num_experts) {
expert_offset = tokens_per_expert[i];
local_num_routed[i] = expert_offset;
}
__syncthreads();

__shared__ uint32_t s_running_total;
if (threadIdx.x == 0) {
st_mmio_b8(dispatch_route_done, 1);
}
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
unsigned warp_sum_expert = __shfl_up_sync(0xFFFFFFFF, expert_offset, offset);
if (lane_id >= offset) {
expert_offset += warp_sum_expert;
}
}
if (lane_id == WARP_SIZE - 1) {
expert_sums[warp_id] = expert_offset;
s_running_total = 0;
}
__syncthreads();

// Sum up the warp sums in the first warp.
if (warp_id == 0) {
uint32_t total_expert_sum = (lane_id < num_warps) ? expert_sums[lane_id] : 0;
for (int offset = 1; offset < num_warps; offset <<= 1) {
unsigned warp_sum = __shfl_up_sync(0xFFFFFFFF, total_expert_sum, offset);
for (uint32_t base_idx = 0; base_idx < num_experts; base_idx += blockDim.x) {
uint32_t i = base_idx + threadIdx.x;
uint32_t expert_count = 0;

if (i < num_experts) {
expert_count = tokens_per_expert[i];
local_num_routed[i] = expert_count;
}
__syncthreads();

// Compute sums within each warp
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
unsigned shfl_val = __shfl_up_sync(0xFFFFFFFF, expert_count, offset);
if (lane_id >= offset) {
total_expert_sum += warp_sum;
expert_count += shfl_val;
}
}
if (lane_id < num_warps) {
expert_sums[lane_id] = total_expert_sum;

// Store the sums in shared memory
if (lane_id == WARP_SIZE - 1) {
expert_sums[warp_id] = expert_count;
}
}
__syncthreads();
__syncthreads();

// Sum up the warp sums in the first warp
if (warp_id == 0) {
uint32_t total_expert_sum = (lane_id < NUM_WARPS) ? expert_sums[lane_id] : 0;
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
unsigned shfl_val = __shfl_up_sync(0xFFFFFFFF, total_expert_sum, offset);
if (lane_id >= offset) {
total_expert_sum += shfl_val;
}
}
if (lane_id < NUM_WARPS) {
expert_sums[lane_id] = total_expert_sum;
}
}
__syncthreads();

// Add the sums to the token counts to find the start offset of each expert.
if (i < num_experts) {
if (warp_id > 0) {
expert_offsets[i] = expert_sums[warp_id - 1] + expert_offset;
} else {
expert_offsets[i] = expert_offset;
// Compute global offset by adding the running total
if (i < num_experts) {
uint32_t chunk_inclusive_sum = (warp_id > 0 ? expert_sums[warp_id - 1] : 0) + expert_count;
expert_offsets[i] = s_running_total + chunk_inclusive_sum;
}

// Update and carry the running total across chunks
if (threadIdx.x == 0) {
s_running_total += expert_sums[NUM_WARPS - 1];
}
__syncthreads();
}

if (threadIdx.x == 0) {
st_mmio_b8(dispatch_route_done, 1);
}
}
__syncthreads();
Expand Down Expand Up @@ -579,9 +597,7 @@ int a2a_kernels::a2a_dispatch_send(
dim3 dimGrid(num_blocks, 1, 1);
dim3 dimBlock(NUM_THREADS, 1, 1);

// There should be enough warps to do a horizontal reduction across ranks.
assert(world_size <= NUM_THREADS);
assert(num_experts <= NUM_THREADS);

const size_t token_dim = round_up<size_t>(hidden_dim * x_elemsize, sizeof(int4));
const size_t token_scale_dim = round_up<size_t>(hidden_dim_scale * x_scale_elemsize, sizeof(int4));
Expand Down Expand Up @@ -627,7 +643,8 @@ int a2a_kernels::a2a_dispatch_send(
&recv_ptrs,
};

const size_t shared_memory_send = std::max(num_experts, NUM_WARPS) * sizeof(uint32_t);
// Requires tokens_per_expert[num_experts] + expert_sums[NUM_WARPS]
const size_t shared_memory_send = (num_experts + NUM_WARPS) * sizeof(uint32_t);

nvtxRangePush("dispatch_send");
cudaError_t status;
Expand Down Expand Up @@ -674,4 +691,4 @@ int a2a_kernels::a2a_dispatch_send(
});
nvtxRangePop();
return status;
}
}