Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 125 additions & 16 deletions python/pplx_garden/kernels/p2p_all_to_all.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pickle
from dataclasses import dataclass
from typing import Optional, override
Expand All @@ -24,7 +25,93 @@
logger = logging_utils.get_logger(__name__)

_PAGE_SIZE = 4096

_CHECK_RECV_BUF_USAGE = os.environ.get("PPLX_CHECK_RECV_BUF_USAGE", "0") == "1"


def compute_max_recv_tokens(
*,
max_num_tokens: int,
num_experts: int,
num_experts_per_token: int,
num_local_experts: int,
num_dp_groups: int,
max_private_tokens: int,
expert_padding: int,
max_recv_tokens_override: Optional[int] = None,
recv_buffer_factor: Optional[float] = None,
) -> int:
"""Compute the recv buffer size for P2PAllToAll.

The recv buffer has two logical parts:
1. Private buffer — tokens from the local rank delivered via NVLink.
Sized by max_private_tokens per DP group.
2. RDMA send/recv buffer — tokens from all ranks delivered via the RDMA
all-to-all.

The worst-case size depends on the routing pattern.

Three sizing strategies are supported (in order of priority):
- max_recv_tokens_override: explicit override for the recv buffer token capacity.
Recommended when the routing behavior is well understood,
for best sizing and memory efficiency.
- recv_buffer_factor: multiplier on the balanced estimate.
- Default: worst-case upper bound.

Returns:
The effective_max_recv_tokens
"""
# Total tokens dispatched across all DP groups.
num_tokens = max_num_tokens * num_dp_groups

# Worst-case upper bound: assumes maximally imbalanced routing where
# all tokens land on this rank's experts. This is always safe,
# but can be many times larger than what actually requires and waste memory.
default_max_recv_tokens = max_private_tokens * num_dp_groups + round_up(
max(
min(
# Case 1: all token's top k experts are on this rank,
# plus padding per local expert
num_tokens * num_experts_per_token
+ num_local_experts * (expert_padding - 1),
# Case 2: every token is routed to every local expert on this rank
num_tokens * num_local_experts,
),
# Floor: local expert with padding
num_local_experts * expert_padding,
),
expert_padding,
)

balanced_recv_tokens = (
ceil_div(num_tokens * num_experts_per_token, num_experts)
* num_local_experts
+ max_private_tokens * num_dp_groups
)

if max_recv_tokens_override is not None:
# Explicit override, clamped to the worst-case bound
# allocating more than worst-case is pure waste
if max_recv_tokens_override > default_max_recv_tokens:
logger.warning(
"max_recv_tokens (%d) exceeds worst-case bound (%d); "
"clamping to worst-case",
max_recv_tokens_override,
default_max_recv_tokens,
)
max_recv_tokens = min(max_recv_tokens_override, default_max_recv_tokens)
elif recv_buffer_factor is not None:
# Scale the balanced estimate by a factor to accommodate routing imbalance.
# clamped to the worst-case bound.
max_recv_tokens = min(
round_up(
int(balanced_recv_tokens * recv_buffer_factor),
expert_padding,
),
default_max_recv_tokens,
)
else:
max_recv_tokens = default_max_recv_tokens
return max_recv_tokens

@dataclass
class _RdmaRankData:
Expand Down Expand Up @@ -66,6 +153,8 @@ def __init__(
dp_group: Optional[ParallelGroup],
node_group: Optional[ParallelGroup],
global_group: ParallelGroup,
max_recv_tokens: Optional[int] = None,
recv_buffer_factor: Optional[float] = None,
) -> None:
self._hidden_dim = hidden_dim
self._hidden_dim_scale = hidden_dim_scale
Expand All @@ -91,28 +180,31 @@ def __init__(
num_dp_groups = world_size // self._dp_size
self._num_local_experts = ceil_div(num_experts, world_size)

# Determine the size of the recv buffers.
# Average tokens per expert assuming perfect load balance with headroom
avg_tokens_per_expert = int(
ceil_div(max_num_tokens * num_experts_per_token, num_experts) * 1.2
)

# Default private token budget for the NVLink fast-path.
if max_private_tokens is None:
max_private_tokens = avg_tokens_per_expert * self._num_local_experts
assert max_private_tokens >= 0

num_tokens = max_num_tokens * num_dp_groups
max_recv_tokens = max_private_tokens * num_dp_groups + round_up(
max(
min(
num_tokens * num_experts_per_token
+ self._num_local_experts * (expert_padding - 1),
num_tokens * self._num_local_experts,
),
self._num_local_experts * expert_padding,
),
expert_padding,
# Compute the recv buffer size.
effective_max_recv_tokens = compute_max_recv_tokens(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
num_experts_per_token=num_experts_per_token,
num_local_experts=self._num_local_experts,
num_dp_groups=num_dp_groups,
max_private_tokens=max_private_tokens,
expert_padding=expert_padding,
max_recv_tokens_override=max_recv_tokens,
recv_buffer_factor=recv_buffer_factor,
)

self._max_recv_tokens = effective_max_recv_tokens

self._transfer_engine: Optional[TransferEngine] = None
self._all_to_all: Optional[AllToAllContext] = None

Expand Down Expand Up @@ -168,7 +260,7 @@ def __init__(
token_dim = max(token_dim_dispatch, token_dim_combine)

# Allocate a buffer to send data from.
send_buffer_bytes = round_up(max_recv_tokens * token_dim, _PAGE_SIZE)
send_buffer_bytes = round_up(effective_max_recv_tokens * token_dim, _PAGE_SIZE)
self._send_buffer_handle = CUMemAllocHandle(
send_buffer_bytes,
self._device,
Expand All @@ -183,7 +275,7 @@ def __init__(
)

# Allocate a buffer to receive into.
recv_buffer_bytes = round_up(max_recv_tokens * token_dim, _PAGE_SIZE)
recv_buffer_bytes = round_up(effective_max_recv_tokens * token_dim, _PAGE_SIZE)
self._recv_buffer_handle = CUMemAllocHandle(
recv_buffer_bytes,
self._device,
Expand Down Expand Up @@ -286,7 +378,7 @@ def __init__(
out_dtype=out_dtype,
scale_elemsize=scale_dtype.itemsize if scale_dtype else None,
max_num_tokens=max_num_tokens,
max_recv_tokens=max_recv_tokens,
max_recv_tokens=effective_max_recv_tokens,
max_private_tokens=max_private_tokens,
num_experts=num_experts,
expert_padding=expert_padding,
Expand Down Expand Up @@ -439,6 +531,23 @@ def dispatch(
stream=stream,
)

# recv buffer usage monitor. Enable with PPLX_CHECK_RECV_BUF_USAGE=1
# to log real buffer usage vs the allocated capacity.
# Useful for tuning max_recv_tokens / recv_buffer_factor.
# Requires a D2H sync, so this is off by default.
# Note: actual buffer overflow will panic in the Rust RDMA
# handler before this point; this is for tuning only.
if _CHECK_RECV_BUF_USAGE:
recv_total = out_expert_num_tokens.sum().item()
logger.info(
"P2PAllToAll rank %d recv buffer usage: "
"%d / %d tokens (%.1f%%)",
self._global_group.rank,
recv_total,
self._max_recv_tokens,
recv_total / max(self._max_recv_tokens, 1) * 100,
)

@override
def combine(
self,
Expand Down
Loading