diff --git a/.gitignore b/.gitignore
index c24ecf89b..506ef150c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -100,4 +100,5 @@ ep/figs
ep/deep_ep_wrapper/deep_ep.egg-info/
*.json
+!.devcontainer/devcontainer.json
*result.jsonl
\ No newline at end of file
diff --git a/ep/.gitignore b/ep/.gitignore
index 12a320ce6..ec02ea881 100644
--- a/ep/.gitignore
+++ b/ep/.gitignore
@@ -1,4 +1,5 @@
benchmark
*.hip
*_hip.*
-*.egg-info
\ No newline at end of file
+*.egg-info
+.lam_dev_state
\ No newline at end of file
diff --git a/ep/Makefile b/ep/Makefile
index 93aa112a6..28c121871 100644
--- a/ep/Makefile
+++ b/ep/Makefile
@@ -55,8 +55,29 @@ LDFLAGS := -lpthread -lglog -libverbs -lnl-3 -lnl-route-3 -lnuma -Xlinker -rpath
NVCCFLAGS := -O3 -std=c++17 -Xcompiler "-Wall -pthread -fPIC -fvisibility=hidden" -ccbin /usr/bin/g++ --expt-relaxed-constexpr
INCLUDES := -Iinclude -I$(CUDA_PATH)/include -I/usr/include -I../include
-CXXFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)
-NVCCFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)
+# LAM_DEV: Enable Lam's development/debug code (set LAM_DEV=1 to enable)
+# Auto-detects changes and only rebuilds affected files
+LAM_DEV ?= 0
+LAM_DEV_STATE_FILE := .lam_dev_state
+LAM_DEV_CURRENT := $(shell cat $(LAM_DEV_STATE_FILE) 2>/dev/null || echo 0)
+# Files affected by LAM_DEV: internode_ll.cu, ep_config.hpp (used by rdma.cpp, uccl_ep.cc)
+LAM_DEV_AFFECTED_OBJS := src/internode_ll.o src/rdma.o src/uccl_ep.o
+
+ifneq ($(LAM_DEV),$(LAM_DEV_CURRENT))
+ $(info LAM_DEV changed from $(LAM_DEV_CURRENT) to $(LAM_DEV), will rebuild affected files)
+ $(shell rm -f $(LAM_DEV_AFFECTED_OBJS))
+ $(shell echo $(LAM_DEV) > $(LAM_DEV_STATE_FILE))
+endif
+
+ifeq ($(LAM_DEV),1)
+ LAM_DEV_FLAGS := -DLAM_DEV
+ $(info LAM_DEV enabled)
+else
+ LAM_DEV_FLAGS :=
+endif
+
+CXXFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) $(LAM_DEV_FLAGS)
+NVCCFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) $(LAM_DEV_FLAGS)
LDFLAGS += $(EFA_LDFLAGS)
INCLUDES += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)
@@ -106,7 +127,7 @@ install: $(EP_EXT)
# Clean all generated files
clean:
rm -f $(OBJ_CPP) $(OBJ_CC) $(OBJ_CU) \
- $(EP_EXT) \
+ $(EP_EXT) $(LAM_DEV_STATE_FILE) \
*.d src/*.d bench/*.d src/*.o **/*.hip **/*_hip.*
# Automatically include dependency files if they exist
diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py
index b49a528f9..e2e5c65a1 100644
--- a/ep/bench/buffer.py
+++ b/ep/bench/buffer.py
@@ -281,6 +281,24 @@ def low_latency_dispatch(
hidden=x.shape[1],
num_experts=num_experts,
)
+
+ # DEBUG: Print data before CUDA kernel launch
+ if os.environ.get("DEBUG_DISPATCH", "0") == "1":
+ print(f"\n{'='*60}", flush=True)
+ print(f"[DEBUG] low_latency_dispatch - BEFORE CUDA kernel", flush=True)
+ print(f"{'='*60}", flush=True)
+ print(f" rank: {self.rank}", flush=True)
+ print(f" x.shape: {x.shape}, dtype: {x.dtype}", flush=True)
+ print(f" x.data_ptr: {hex(x.data_ptr())}", flush=True)
+ print(f" x[:3, :8]:\n{x[:3, :8]}", flush=True)
+ print(f" topk_idx.shape: {topk_idx.shape}, dtype: {topk_idx.dtype}", flush=True)
+ print(f" topk_idx[:5]:\n{topk_idx[:5]}", flush=True)
+ print(f" num_max_dispatch_tokens_per_rank: {num_max_dispatch_tokens_per_rank}", flush=True)
+ print(f" num_experts: {num_experts}", flush=True)
+ print(f" use_fp8: {use_fp8}, round_scale: {round_scale}, use_ue8m0: {use_ue8m0}", flush=True)
+ print(f" async_finish: {async_finish}, return_recv_hook: {return_recv_hook}", flush=True)
+ print(f"{'='*60}\n", flush=True)
+
(
packed_recv_x,
packed_recv_x_scales,
@@ -309,6 +327,22 @@ def low_latency_dispatch(
x.size(1),
num_experts,
)
+
+ # DEBUG: Print data after CUDA kernel returns
+ if os.environ.get("DEBUG_DISPATCH", "0") == "1":
+ import torch
+ torch.cuda.synchronize() # Wait for kernel to complete
+ print(f"\n{'='*60}", flush=True)
+ print(f"[DEBUG] low_latency_dispatch - AFTER CUDA kernel", flush=True)
+ print(f"{'='*60}", flush=True)
+ print(f" packed_recv_x.shape: {packed_recv_x.shape}, dtype: {packed_recv_x.dtype}", flush=True)
+ print(f" packed_recv_count: {packed_recv_count}", flush=True)
+ print(f" packed_recv_src_info.shape: {packed_recv_src_info.shape}", flush=True)
+ print(f" packed_recv_layout_range.shape: {packed_recv_layout_range.shape}", flush=True)
+ if packed_recv_x_scales is not None:
+ print(f" packed_recv_x_scales.shape: {packed_recv_x_scales.shape}", flush=True)
+ print(f"{'='*60}\n", flush=True)
+
tensors_to_record = (
x,
topk_idx,
diff --git a/ep/bench/launch_distributed_ll.sh b/ep/bench/launch_distributed_ll.sh
new file mode 100755
index 000000000..3c7ed4232
--- /dev/null
+++ b/ep/bench/launch_distributed_ll.sh
@@ -0,0 +1,298 @@
+#!/bin/bash
+#
+# Launch Low Latency Test Script (Distributed)
+#
+# Launches test_low_latency.py across multiple nodes via SSH
+# Supports both Docker and Conda execution modes
+#
+# Usage:
+# ./launch_distributed_ll.sh [--docker | --conda] # Launch the test (default: conda)
+# ./launch_distributed_ll.sh kill # Kill running processes
+#
+# Options:
+# --docker Run inside Docker container (default container: lam_rocm)
+# --conda Run inside conda environment (default: /home/ubuntu/lam/uccl_lam_local)
+#
+# Environment variables:
+# MASTER_PORT - Master port (default: random 29500-30499)
+# CONTAINER - Docker container name (default: lam_rocm, only for --docker)
+# CONDA_ENV - Conda environment path (default: /home/ubuntu/lam/uccl_lam_local, only for --conda)
+# NUM_TOKENS - Number of tokens (default: 128)
+# HIDDEN - Hidden size (default: 7168)
+# NUM_TOPK - Top-k value (default: 8)
+# NUM_EXPERTS - Number of experts (default: 288)
+#
+# This script passes --stop-after-first to test_low_latency.py so only the first
+# experiment (return_recv_hook=False, dispatch_use_fp8=False, ...) is run.
+#
+# Examples:
+# ./launch_distributed_ll.sh --conda
+# ./launch_distributed_ll.sh --docker
+# ./launch_distributed_ll.sh --conda NUM_TOKENS=256
+# ./launch_distributed_ll.sh --docker CONTAINER=my_container kill
+#
+
+# Parse command line options
+RUN_MODE="conda" # default mode
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --docker)
+ RUN_MODE="docker"
+ shift
+ ;;
+ --conda)
+ RUN_MODE="conda"
+ shift
+ ;;
+ *)
+ break
+ ;;
+ esac
+done
+
+# Node IPs (in order of node_rank)
+NODES=(
+ "172.31.5.88" # master (ip-172-31-5-88)
+ "172.31.6.215" # ip-172-31-6-215
+)
+
+# Configuration - MASTER_ADDR automatically set to first node
+MASTER_ADDR="${NODES[0]}"
+MASTER_PORT="${MASTER_PORT:-$((29500 + RANDOM % 1000))}" # Random port 29500-30499 if not specified
+NNODES=2
+NPROC_PER_NODE=8
+
+# Docker container name
+CONTAINER="${CONTAINER:-lam_rocm}"
+
+# Conda environment path
+CONDA_ENV="${CONDA_ENV:-/home/ubuntu/lam/uccl_lam_local}"
+
+# Kill command - kills processes based on mode
+if [ "$1" == "kill" ]; then
+ echo "Scanning for processes on all nodes (mode: ${RUN_MODE})..."
+
+ # Declare associative array to store PIDs per node
+ declare -A NODE_PIDS
+ TOTAL_PIDS=0
+
+ # Stage 1: Collect PIDs from all nodes
+ for node_ip in "${NODES[@]}"; do
+ echo " Checking ${node_ip}..."
+ if [ "${RUN_MODE}" == "docker" ]; then
+ pids=$(ssh -o StrictHostKeyChecking=no "${node_ip}" "docker exec ${CONTAINER} pgrep -f 'python.*torch'" 2>/dev/null | tr '\n' ' ')
+ else
+ pids=$(ssh -o StrictHostKeyChecking=no "${node_ip}" "pgrep -f 'python.*test_low_latency'" 2>/dev/null | tr '\n' ' ')
+ fi
+ pids=$(echo "$pids" | xargs) # trim whitespace
+ if [ -n "$pids" ]; then
+ NODE_PIDS["$node_ip"]="$pids"
+ pid_count=$(echo "$pids" | wc -w)
+ TOTAL_PIDS=$((TOTAL_PIDS + pid_count))
+ echo " Found PIDs: $pids"
+ else
+ echo " No processes found"
+ fi
+ done
+
+ # Check if any PIDs found
+ if [ "$TOTAL_PIDS" -eq 0 ]; then
+ echo ""
+ echo "No processes found on any node."
+ exit 0
+ fi
+
+ # Stage 2: Ask for confirmation
+ echo ""
+ echo "=========================================="
+ echo "Summary: $TOTAL_PIDS process(es) to kill"
+ echo "=========================================="
+ for node_ip in "${!NODE_PIDS[@]}"; do
+ echo " ${node_ip}: ${NODE_PIDS[$node_ip]}"
+ done
+ echo ""
+ read -p "Kill these processes? [y/N]: " confirm
+
+ if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then
+ echo "Aborted."
+ exit 0
+ fi
+
+ # Stage 3: Kill confirmed PIDs
+ echo ""
+ echo "Killing processes..."
+ for node_ip in "${!NODE_PIDS[@]}"; do
+ pids="${NODE_PIDS[$node_ip]}"
+ echo " Killing on ${node_ip}: $pids"
+ if [ "${RUN_MODE}" == "docker" ]; then
+ ssh -o StrictHostKeyChecking=no "${node_ip}" "for pid in $pids; do docker exec ${CONTAINER} kill -9 \$pid 2>/dev/null; done"
+ else
+ ssh -o StrictHostKeyChecking=no "${node_ip}" "for pid in $pids; do kill -9 \$pid 2>/dev/null; done"
+ fi
+ done
+ echo "Done."
+ exit 0
+fi
+
+# Default benchmark parameters
+NUM_TOKENS="${NUM_TOKENS:-128}"
+HIDDEN="${HIDDEN:-7168}"
+NUM_TOPK="${NUM_TOPK:-8}"
+NUM_EXPERTS="${NUM_EXPERTS:-288}"
+
+# Log directory (shared filesystem)
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+LOG_DIR="${SCRIPT_DIR}/logs"
+RUN_LOG_DIR="${LOG_DIR}/latest"
+
+# Create log directory
+mkdir -p "${RUN_LOG_DIR}"
+
+# Clean old logs before starting
+echo "Cleaning old logs in ${RUN_LOG_DIR}..."
+rm -rf "${RUN_LOG_DIR:?}"/*
+
+echo "=============================================="
+echo "Launching distributed test_low_latency.py"
+echo "=============================================="
+echo "Master: ${MASTER_ADDR}:${MASTER_PORT}"
+echo "Nodes: ${NNODES}, GPUs per node: ${NPROC_PER_NODE}"
+echo "Total GPUs: $((NNODES * NPROC_PER_NODE))"
+echo "Mode: ${RUN_MODE}"
+if [ "${RUN_MODE}" == "docker" ]; then
+ echo "Container: ${CONTAINER}"
+else
+ echo "Conda env: ${CONDA_ENV}"
+fi
+echo "Parameters: tokens=${NUM_TOKENS}, hidden=${HIDDEN}, topk=${NUM_TOPK}, experts=${NUM_EXPERTS}"
+echo "Logs: ${RUN_LOG_DIR}/"
+echo " rank{0..$((NNODES * NPROC_PER_NODE - 1))}.log + node{0..$((NNODES - 1))}.log"
+echo "=============================================="
+
+# Function to launch on a node
+launch_node() {
+ local node_rank=$1
+ local node_ip=${NODES[$node_rank]}
+
+ echo "[$(date '+%H:%M:%S')] Launching on node ${node_rank} (${node_ip})..."
+
+ # Build command based on mode
+ # Use /dev/null 2>&1 &"
+ else
+ # Conda mode
+ local conda_cmd="export UCCL_LOG_DIR=${RUN_LOG_DIR} && \
+ cd ${SCRIPT_DIR} && ${CONDA_ENV}/bin/python -m torch.distributed.run \
+ --nnodes=${NNODES} \
+ --nproc_per_node=${NPROC_PER_NODE} \
+ --node_rank=${node_rank} \
+ --master_addr=${MASTER_ADDR} \
+ --master_port=${MASTER_PORT} \
+ ${SCRIPT_DIR}/test_low_latency.py \
+ --num-tokens=${NUM_TOKENS} \
+ --hidden=${HIDDEN} \
+ --num-topk=${NUM_TOPK} \
+ --num-experts=${NUM_EXPERTS} \
+ --stop-after-first"
+ echo "[$(date '+%H:%M:%S')] Command to launch on node ${node_rank}:"
+ echo "${conda_cmd}"
+ ssh -o StrictHostKeyChecking=no \
+ -o ServerAliveInterval=30 \
+ -o ServerAliveCountMax=10 \
+ "${node_ip}" "nohup bash -c '${conda_cmd}' /dev/null 2>&1 &"
+ fi
+
+ echo "[$(date '+%H:%M:%S')] Node ${node_rank} launched"
+}
+
+# Launch all nodes
+for rank in $(seq 0 $((NNODES - 1))); do
+ launch_node ${rank}
+done
+
+echo ""
+echo "All nodes launched."
+echo "Waiting for completion (checking marker files)..."
+echo ""
+
+TOTAL_GPUS=$((NNODES * NPROC_PER_NODE))
+
+# Poll until all ranks are done (check for .done.{rank} marker files)
+POLL_INTERVAL=3
+while true; do
+ done_count=0
+ for rank in $(seq 0 $((TOTAL_GPUS - 1))); do
+ if [ -f "${RUN_LOG_DIR}/.done.${rank}" ]; then
+ done_count=$((done_count + 1))
+ fi
+ done
+
+ if [ "$done_count" -eq "$TOTAL_GPUS" ]; then
+ break
+ fi
+
+ echo "[$(date '+%H:%M:%S')] Progress: ${done_count}/${TOTAL_GPUS} ranks done"
+ sleep ${POLL_INTERVAL}
+done
+
+# Clean up marker files
+rm -f "${RUN_LOG_DIR}"/.done.*
+
+# Print summary
+echo ""
+echo "=============================================="
+echo "Execution completed"
+echo "=============================================="
+for i in $(seq 0 $((NNODES - 1))); do
+ echo "Node ${i} (${NODES[$i]}): DONE"
+done
+
+# Create per-node combined logs
+echo ""
+echo "Creating per-node combined logs..."
+for node_rank in $(seq 0 $((NNODES - 1))); do
+ node_log="${RUN_LOG_DIR}/node${node_rank}.log"
+ > "${node_log}" # Clear/create file
+ for local_rank in $(seq 0 $((NPROC_PER_NODE - 1))); do
+ global_rank=$((node_rank * NPROC_PER_NODE + local_rank))
+ rank_log="${RUN_LOG_DIR}/rank${global_rank}.log"
+ if [ -f "${rank_log}" ]; then
+ echo "=== rank${global_rank} ===" >> "${node_log}"
+ cat "${rank_log}" >> "${node_log}"
+ echo "" >> "${node_log}"
+ fi
+ done
+done
+
+echo ""
+echo "Logs saved to: ${RUN_LOG_DIR}"
+echo " Per-rank: rank0.log ... rank$((TOTAL_GPUS - 1)).log (${TOTAL_GPUS} files)"
+echo " Per-node: node0.log ... node$((NNODES - 1)).log (${NNODES} files)"
+echo " Total: $((TOTAL_GPUS + NNODES)) files"
diff --git a/ep/bench/plots/dispatch_bytes_per_message_breakdown.svg b/ep/bench/plots/dispatch_bytes_per_message_breakdown.svg
new file mode 100644
index 000000000..8199baa87
--- /dev/null
+++ b/ep/bench/plots/dispatch_bytes_per_message_breakdown.svg
@@ -0,0 +1,2762 @@
+
+
+
diff --git a/ep/bench/plots/plot_dispatch_message_sizes.py b/ep/bench/plots/plot_dispatch_message_sizes.py
new file mode 100644
index 000000000..8cbfc14d9
--- /dev/null
+++ b/ep/bench/plots/plot_dispatch_message_sizes.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+import argparse
+import math
+import re
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+MSG_RE = re.compile(
+ r"\[RDMA MSG dispatch\]\s+th=(?P
\d+)\s+src=(?P\d+)\s+dst=(?P\d+)\s+msg_bytes=(?P\d+)"
+)
+POST_RE = re.compile(
+ r"\[RDMA POST dispatch-only\]\s+th=(?P| \d+)\s+src=(?P\d+)\s+dst=(?P\d+)\s+dispatch_wrs=(?P\d+)\s+dispatch_bytes=(?P\d+)"
+)
+
+
+def parse_message_sizes(path: Path):
+ data = []
+ source_mode = None
+
+ for line in path.read_text().splitlines():
+ m = MSG_RE.search(line)
+ if m:
+ source_mode = "msg_bytes"
+ data.append((int(m.group("dst")), int(m.group("msg"))))
+
+ if data:
+ return data, source_mode
+
+ # Fallback for older logs: use per-post average when per-message logs are absent.
+ for line in path.read_text().splitlines():
+ m = POST_RE.search(line)
+ if not m:
+ continue
+ wrs = int(m.group("wrs"))
+ if wrs <= 0:
+ continue
+ source_mode = "dispatch_bytes/dispatch_wrs"
+ msg_size = int(round(int(m.group("bytes")) / wrs))
+ data.append((int(m.group("dst")), msg_size))
+
+ return data, source_mode
+
+
+def percentile(values, p):
+ arr = np.asarray(values, dtype=float)
+ return float(np.percentile(arr, p))
+
+
+def format_stats(values):
+ arr = np.asarray(values, dtype=float)
+ return {
+ "n": arr.size,
+ "min": int(arr.min()),
+ "p50": int(percentile(arr, 50)),
+ "p90": int(percentile(arr, 90)),
+ "max": int(arr.max()),
+ "mean": float(arr.mean()),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Plot RDMA dispatch message-size breakdown from logs."
+ )
+ parser.add_argument(
+ "--w-batch",
+ default="ep/bench/logs/w_batch_data.txt",
+ help="Path to with-batch log file.",
+ )
+ parser.add_argument(
+ "--wo-batch",
+ default="ep/bench/logs/wo_batch_data.txt",
+ help="Path to without-batch log file.",
+ )
+ parser.add_argument(
+ "--out",
+ default="ep/bench/logs/dispatch_message_size_breakdown.png",
+ help="Output figure path.",
+ )
+ parser.add_argument(
+ "--bucket-size",
+ type=int,
+ default=14352,
+ help="Bucket width in bytes for message-size histogram.",
+ )
+ args = parser.parse_args()
+
+ w_path = Path(args.w_batch)
+ wo_path = Path(args.wo_batch)
+ out_path = Path(args.out)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+
+ series = {}
+ modes = {}
+ by_dst = {}
+ for name, path in [("w_batch", w_path), ("wo_batch", wo_path)]:
+ parsed, mode = parse_message_sizes(path)
+ if not parsed:
+ raise ValueError(f"No parseable dispatch entries found in {path}")
+ modes[name] = mode
+ sizes = [v for _, v in parsed]
+ series[name] = sizes
+ dst_map = defaultdict(list)
+ for dst, v in parsed:
+ dst_map[dst].append(v)
+ by_dst[name] = {d: float(np.mean(vs)) for d, vs in sorted(dst_map.items())}
+
+ # Bucketize by fixed byte range and render side-by-side bars.
+ bucket_size = max(1, int(args.bucket_size))
+ max_size = max(max(series["w_batch"]), max(series["wo_batch"]))
+ num_buckets = int(math.ceil(max_size / bucket_size))
+ bins = np.arange(0, (num_buckets + 1) * bucket_size + 1, bucket_size)
+ hist_wo, _ = np.histogram(series["wo_batch"], bins=bins)
+ hist_w, _ = np.histogram(series["w_batch"], bins=bins)
+ bucket_indices = np.arange(len(bins) - 1)
+ bar_width = 0.82
+ bucket_labels = [f"{int(bins[i])//1024}-{int(bins[i+1])//1024}KB" for i in range(len(bins) - 1)]
+
+ fig = plt.figure(figsize=(16, 10), dpi=160, constrained_layout=True)
+ gs = fig.add_gridspec(2, 2, height_ratios=[1, 1.1], hspace=0.35, wspace=0.2)
+ ax_top_left = fig.add_subplot(gs[0, 0])
+ ax_top_right = fig.add_subplot(gs[0, 1])
+ ax_bottom = fig.add_subplot(gs[1, :])
+
+ ax_top_left.bar(
+ bucket_indices,
+ hist_wo,
+ width=bar_width,
+ color="#ff7f0e",
+ alpha=0.85,
+ )
+ ax_top_left.set_title(f"without batch message-size buckets (total WR={len(series['wo_batch'])})")
+ ax_top_left.set_xlabel("message-size bucket")
+ ax_top_left.set_ylabel("count")
+ ax_top_left.set_xticks(bucket_indices)
+ ax_top_left.set_xticklabels(bucket_labels, rotation=30, ha="right")
+ ax_top_left.grid(axis="y", alpha=0.25)
+ ax_top_left.set_xlim(-0.5, len(bucket_labels) - 0.5)
+ ax_top_left.text(
+ 0.98,
+ 0.95,
+ f"Total WR: {len(series['wo_batch'])}",
+ transform=ax_top_left.transAxes,
+ ha="right",
+ va="top",
+ fontsize=14,
+ bbox={"boxstyle": "round,pad=0.25", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"},
+ )
+
+ ax_top_right.bar(
+ bucket_indices, hist_w, width=bar_width, color="#1f77b4", alpha=0.85
+ )
+ ax_top_right.set_title(f"with batch message-size buckets (total WR={len(series['w_batch'])})")
+ ax_top_right.set_xlabel("message-size bucket")
+ ax_top_right.set_ylabel("count")
+ ax_top_right.set_xticks(bucket_indices)
+ ax_top_right.set_xticklabels(bucket_labels, rotation=30, ha="right")
+ ax_top_right.grid(axis="y", alpha=0.25)
+ ax_top_right.set_xlim(-0.5, len(bucket_labels) - 0.5)
+ ax_top_right.text(
+ 0.98,
+ 0.95,
+ f"Total WR: {len(series['w_batch'])}",
+ transform=ax_top_right.transAxes,
+ ha="right",
+ va="top",
+ fontsize=14,
+ bbox={"boxstyle": "round,pad=0.25", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"},
+ )
+
+ all_dsts = sorted(set(by_dst["w_batch"]) | set(by_dst["wo_batch"]))
+ x = np.arange(len(all_dsts))
+ width = 0.38
+ vals_w = [by_dst["w_batch"].get(d, np.nan) for d in all_dsts]
+ vals_wo = [by_dst["wo_batch"].get(d, np.nan) for d in all_dsts]
+ ax_bottom.bar(x - width / 2, vals_wo, width=width, label="without_batch", color="#ff7f0e")
+ ax_bottom.bar(x + width / 2, vals_w, width=width, label="with_batch", color="#1f77b4")
+ ax_bottom.set_xticks(x)
+ ax_bottom.set_xticklabels([str(d) for d in all_dsts])
+ ax_bottom.set_title("Average message size by dst rank")
+ ax_bottom.set_xlabel("dst rank")
+ ax_bottom.set_ylabel("avg bytes/message")
+ ax_bottom.grid(axis="y", alpha=0.25)
+ ax_bottom.legend()
+
+ fig.suptitle(
+ f"Dispatch message-size breakdown (without batch WR={len(series['wo_batch'])}, with batch WR={len(series['w_batch'])})"
+ )
+ fig.savefig(out_path)
+
+ print(f"Saved figure: {out_path}")
+ for name in ["w_batch", "wo_batch"]:
+ st = format_stats(series[name])
+ print(
+ f"{name}: mode={modes[name]}, n={st['n']}, min={st['min']}, "
+ f"p50={st['p50']}, p90={st['p90']}, max={st['max']}, mean={st['mean']:.1f}"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ep/bench/plots/w_batch_data.txt b/ep/bench/plots/w_batch_data.txt
new file mode 100644
index 000000000..046331948
--- /dev/null
+++ b/ep/bench/plots/w_batch_data.txt
@@ -0,0 +1,142 @@
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=100464
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=71760
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=57408
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=57408
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=71760
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=57408
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=100464
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=43056
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=71760
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=71760
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=71760
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=71760
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=100464
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=71760
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=86112
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=86112
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=114816
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=129168
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=57408
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=57408
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=86112
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=86112
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=71760
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=71760
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=86112
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=43056
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=57408
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=57408
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=114816
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=43056
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=28704
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=71760
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=86112
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=100464
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=100464
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=57408
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=57408
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=43056
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=86112
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=28704
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=86112
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=28704
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=28704
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=28704
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=86112
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=86112
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=28704
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=43056
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=86112
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=86112
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=100464
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=71760
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=28704
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=43056
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=57408
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=28704
\ No newline at end of file
diff --git a/ep/bench/plots/wo_batch_data.txt b/ep/bench/plots/wo_batch_data.txt
new file mode 100644
index 000000000..780992cfe
--- /dev/null
+++ b/ep/bench/plots/wo_batch_data.txt
@@ -0,0 +1,505 @@
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=8 msg_bytes=14352
+[RDMA MSG dispatch] th=0 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=11 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=2 src=0 dst=14 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=3 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=15 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=13 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=12 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=10 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=9 msg_bytes=14352
+[RDMA MSG dispatch] th=1 src=0 dst=14 msg_bytes=14352
\ No newline at end of file
diff --git a/ep/bench/test_low_latency.py b/ep/bench/test_low_latency.py
index 45667be72..31c19da04 100644
--- a/ep/bench/test_low_latency.py
+++ b/ep/bench/test_low_latency.py
@@ -17,18 +17,75 @@
import random
import time
import os
+import sys
import torch
import torch.distributed as dist
import numpy as np
-from functools import partial
from typing import Optional
+
+class TeeWriter:
+ """Write to multiple file handles simultaneously."""
+ def __init__(self, *files):
+ self.files = files
+ self._fileno = files[0].fileno() if files else None
+
+ def write(self, data):
+ for f in self.files:
+ f.write(data)
+ f.flush()
+
+ def flush(self):
+ for f in self.files:
+ f.flush()
+
+ def fileno(self):
+ """Return fileno of the first file (needed for some redirections)."""
+ return self._fileno
+
+
+def setup_rank_logging():
+ """Redirect stdout/stderr to rank{N}.log if UCCL_LOG_DIR is set.
+
+ This does fd-level redirection so C/CUDA code output is also captured.
+ """
+ log_dir = os.environ.get("UCCL_LOG_DIR")
+ if log_dir is None:
+ return
+
+ rank = int(os.environ.get("RANK", 0))
+
+ # Per-rank log file
+ rank_log = os.path.join(log_dir, f"rank{rank}.log")
+ rank_file = open(rank_log, "w", buffering=1)
+
+ # Redirect fd 1 (stdout) and fd 2 (stderr) to the rank log file
+ # This captures C/CUDA printf output as well
+ os.dup2(rank_file.fileno(), 1) # stdout
+ os.dup2(rank_file.fileno(), 2) # stderr
+
+ # Also update Python's sys.stdout/sys.stderr to use the file
+ sys.stdout = rank_file
+ sys.stderr = rank_file
+
+
+def mark_done():
+ """Create a marker file to indicate this rank is done."""
+ log_dir = os.environ.get("UCCL_LOG_DIR")
+ if log_dir is None:
+ return
+ rank = int(os.environ.get("RANK", 0))
+ marker_file = os.path.join(log_dir, f".done.{rank}")
+ open(marker_file, "w").close()
+
+
+# Setup logging as early as possible
+setup_rank_logging()
+
from buffer import Buffer
from utils import (
init_dist,
init_dist_under_torchrun,
- bench,
- bench_kineto,
calc_diff,
hash_tensor,
per_token_cast_back,
@@ -85,9 +142,13 @@ def test_main(
group: dist.ProcessGroup,
buffer: Buffer,
use_logfmt: bool = False,
+ dispatch_use_fp8: bool = True,
seed: int = 0,
skip_benchmark: bool = False,
debug_hash: bool = False,
+ stop_after_first: bool = False,
+ num_warmup: int = 10000,
+ num_repeats: int = 10000,
):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
@@ -150,25 +211,41 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
# Preserve the XOR aggregation behavior at per-label granularity.
hash_details[label] = hash_details.get(label, 0) ^ hv
+ stop_now = False
for current_x in x_list:
+ if stop_now:
+ break
for return_recv_hook in (False, True):
- for dispatch_use_fp8 in (False, True):
+ if stop_now:
+ break
+ for dispatch_use_fp8_case in (False, True):
+ if stop_now:
+ break
for round_scale in (False,):
- for round_scale in (False, True) if dispatch_use_fp8 else (False,):
+ if stop_now:
+ break
+ for round_scale in (False, True) if dispatch_use_fp8_case else (False,):
+ if stop_now:
+ break
for use_ue8m0 in (False, True) if round_scale else (False,):
print(
"Start experiment with settings:"
f" return_recv_hook={return_recv_hook}"
- f" dispatch_use_fp8={dispatch_use_fp8}"
+ f" dispatch_use_fp8={dispatch_use_fp8_case}"
f" round_scale={round_scale}"
f" use_ue8m0={use_ue8m0}",
flush=True,
)
num_times += 1
- for i in range((num_times % 2) + 1):
+ running_count = (num_times % 2) + 1
+ # Lam: If stop_after_first, just run once to check the code
+ if stop_after_first:
+ running_count = 1
+ for i in range(running_count):
cumulative_local_expert_recv_stats = torch.zeros(
(num_local_experts,), dtype=torch.int, device="cuda"
)
+ # print(f"[python] dispatch called", flush=True)
(
packed_recv_x,
packed_recv_count,
@@ -180,7 +257,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
topk_idx,
num_tokens,
num_experts,
- use_fp8=dispatch_use_fp8,
+ use_fp8=dispatch_use_fp8_case,
round_scale=round_scale,
use_ue8m0=use_ue8m0,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
@@ -194,7 +271,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
)
packed_recv_x = (
(packed_recv_x[0], packed_recv_x[1].contiguous())
- if dispatch_use_fp8
+ if dispatch_use_fp8_case
else packed_recv_x
)
simulated_gemm_x = (
@@ -202,7 +279,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
packed_recv_x[0].view(-1, hidden),
packed_recv_x[1].view(-1, hidden // 128),
).view(packed_recv_x[0].shape)
- if dispatch_use_fp8
+ if dispatch_use_fp8_case
else packed_recv_x.clone()
)
all_topk_idx = torch.empty(
@@ -219,7 +296,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
per_token_cast_back(
packed_recv_x[0][i], packed_recv_x[1][i]
)
- if dispatch_use_fp8
+ if dispatch_use_fp8_case
else packed_recv_x[i]
)
recv_count, recv_src_info, recv_layout_range = (
@@ -284,11 +361,11 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
- j
+ rank_offset
).sum().item() == 0
- if dispatch_use_fp8:
+ if dispatch_use_fp8_case:
tag = (
f"x={'x' if current_x is x else 'rand'}"
f"|hook={return_recv_hook}"
- f"|fp8={dispatch_use_fp8}"
+ f"|fp8={dispatch_use_fp8_case}"
f"|rs={round_scale}"
f"|ue={use_ue8m0}"
f"|le={i}"
@@ -306,7 +383,7 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
tag = (
f"x={'x' if current_x is x else 'rand'}"
f"|hook={return_recv_hook}"
- f"|fp8={dispatch_use_fp8}"
+ f"|fp8={dispatch_use_fp8_case}"
f"|rs={round_scale}"
f"|ue={use_ue8m0}"
f"|le={i}"
@@ -368,52 +445,34 @@ def _record_hash(label: str, t: torch.Tensor, include_in_overall: bool = True):
)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (
- 9e-4 if dispatch_use_fp8 else 1e-5
- ), f"Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}"
+ 9e-4 if dispatch_use_fp8_case else 1e-5
+ ), f"Error: {diff=}, {dispatch_use_fp8_case=}, {zero_copy=}"
tag = (
f"x={'x' if current_x is x else 'rand'}"
f"|hook={return_recv_hook}"
- f"|fp8={dispatch_use_fp8}"
+ f"|fp8={dispatch_use_fp8_case}"
f"|rs={round_scale}"
f"|ue={use_ue8m0}"
f"|zc={zero_copy}"
f"|logfmt={use_logfmt}"
)
_record_hash(f"combine_out|{tag}", combined_x)
-
- # noinspection PyShadowingNames
- def large_gemm_with_hook(hook):
- mat_0 = torch.randn((8192, 8192), dtype=torch.float)
- mat_1 = torch.randn((8192, 8192), dtype=torch.float)
- mat_0 @ mat_1
- hook()
-
- # noinspection PyShadowingNames
- def test_func(return_recv_hook: bool):
- recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch(
- current_x,
- topk_idx,
- num_tokens,
- num_experts,
- cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
- use_fp8=True,
- async_finish=False,
- return_recv_hook=return_recv_hook,
- )
- large_gemm_with_hook(hook) if return_recv_hook else None
- combined_x, event, hook = buffer.low_latency_combine(
- simulated_gemm_x,
- topk_idx,
- topk_weights,
- handle,
- use_logfmt=use_logfmt,
- return_recv_hook=return_recv_hook,
- )
- large_gemm_with_hook(hook) if return_recv_hook else None
+ if stop_after_first:
+ stop_now = True
+ break
+ if stop_now:
+ break
+ if stop_now:
+ break
+ if stop_now:
+ break
+ if stop_now:
+ break
print("✓ All correctness tests passed!", flush=True)
- if skip_benchmark:
+ # Lam: If stop_after_first, just run once to check the code
+ if skip_benchmark or stop_after_first:
return (hash_value, hash_details) if debug_hash else hash_value
# Calculate bandwidth
@@ -422,42 +481,298 @@ def test_func(return_recv_hook: bool):
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
- num_dispatch_comm_bytes += num_fp8_bytes * num_selections
+ num_dispatch_comm_bytes += (
+ num_fp8_bytes if dispatch_use_fp8 else num_bf16_bytes
+ ) * num_selections
num_combine_comm_bytes += (
num_logfmt10_bytes if use_logfmt else num_bf16_bytes
) * num_selections
- # Dispatch + combine testing
- avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
- print(
- f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, "
- f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us",
- flush=True,
+ # Benchmark with the same timing structure as pplx/benchmarks/bench_all_to_all.py
+ out_dummy = torch.empty((1,), dtype=torch.float32, device="cuda")
+ gemm = torch.empty(
+ (2048, 2048) if num_tokens <= 128 else (8192, 8192),
+ dtype=torch.float32,
+ device="cuda",
)
- # Separate profiling
- for return_recv_hook in (False, True):
- dispatch_t, combine_t = bench_kineto(
- partial(test_func, return_recv_hook=return_recv_hook),
- kernel_names=("dispatch", "combine"),
- barrier_comm_profiling=True,
- suppress_kineto_output=True,
- num_kernels_per_period=2 if return_recv_hook else 1,
+ rng = torch.Generator(device="cuda")
+ rng.manual_seed(rank + seed + 123)
+
+ pending_dispatch_hook = None
+ pending_combine_hook = None
+ pending_recv_x = None
+ pending_handle = None
+ bench_topk_idx = topk_idx
+
+ def wait():
+ # Same "wait" structure as bench_all_to_all.py
+ dist.all_reduce(out_dummy, group=group)
+ _ = gemm @ gemm
+ dist.all_reduce(out_dummy, group=group)
+
+ def _rand_topk_idx() -> torch.Tensor:
+ scores = torch.randn(
+ (num_tokens, num_experts),
+ dtype=torch.float32,
+ device="cuda",
+ generator=rng,
)
- # kineto profiling failed.
- if dispatch_t == 0 or combine_t == 0:
- continue
- if not return_recv_hook:
- print(
- f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | "
- f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us",
- flush=True,
+ scores = scores.abs() + 1
+ return torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
+
+ def dispatch(do_send: bool, do_recv: bool):
+ nonlocal pending_dispatch_hook
+ nonlocal pending_recv_x
+ nonlocal pending_handle
+ if do_send:
+ recv_x, _, handle, _, hook = buffer.low_latency_dispatch(
+ current_x,
+ bench_topk_idx,
+ num_tokens,
+ num_experts,
+ cumulative_local_expert_recv_stats=None,
+ use_fp8=dispatch_use_fp8,
+ async_finish=False,
+ return_recv_hook=not do_recv,
)
- else:
+ if do_recv:
+ return recv_x, handle
+ pending_dispatch_hook = hook
+ pending_recv_x = recv_x
+ pending_handle = handle
+ return None, None
+ assert do_recv, "Invalid dispatch mode"
+ assert pending_dispatch_hook is not None
+ pending_dispatch_hook()
+ out = (pending_recv_x, pending_handle)
+ pending_dispatch_hook = None
+ pending_recv_x = None
+ pending_handle = None
+ return out
+
+ def materialize_for_combine(recv_x):
+ if dispatch_use_fp8:
+ return per_token_cast_back(
+ recv_x[0].view(-1, hidden),
+ recv_x[1].contiguous().view(-1, hidden // 128),
+ ).view(recv_x[0].shape)
+ return recv_x
+
+ def combine(simulated_x, handle, do_send: bool, do_recv: bool):
+ nonlocal pending_combine_hook
+ if do_send:
+ _, _, hook = buffer.low_latency_combine(
+ simulated_x,
+ bench_topk_idx,
+ topk_weights,
+ handle,
+ use_logfmt=use_logfmt,
+ return_recv_hook=not do_recv,
+ )
+ if not do_recv:
+ pending_combine_hook = hook
+ return
+ assert do_recv, "Invalid combine mode"
+ assert pending_combine_hook is not None
+ pending_combine_hook()
+ pending_combine_hook = None
+
+ events = []
+ for _ in range(num_warmup + num_repeats):
+ dispatch_start = torch.cuda.Event(enable_timing=True)
+ dispatch_end = torch.cuda.Event(enable_timing=True)
+ combine_start = torch.cuda.Event(enable_timing=True)
+ combine_end = torch.cuda.Event(enable_timing=True)
+ dispatch_send_start = torch.cuda.Event(enable_timing=True)
+ dispatch_send_end = torch.cuda.Event(enable_timing=True)
+ dispatch_recv_start = torch.cuda.Event(enable_timing=True)
+ dispatch_recv_end = torch.cuda.Event(enable_timing=True)
+ combine_send_start = torch.cuda.Event(enable_timing=True)
+ combine_send_end = torch.cuda.Event(enable_timing=True)
+ combine_recv_start = torch.cuda.Event(enable_timing=True)
+ combine_recv_end = torch.cuda.Event(enable_timing=True)
+ dispatch_start.record()
+ dispatch_end.record()
+ combine_start.record()
+ combine_end.record()
+ dispatch_send_start.record()
+ dispatch_send_end.record()
+ dispatch_recv_start.record()
+ dispatch_recv_end.record()
+ combine_send_start.record()
+ combine_send_end.record()
+ combine_recv_start.record()
+ combine_recv_end.record()
+ events.append(
+ (
+ dispatch_start,
+ dispatch_end,
+ combine_start,
+ combine_end,
+ dispatch_send_start,
+ dispatch_send_end,
+ dispatch_recv_start,
+ dispatch_recv_end,
+ combine_send_start,
+ combine_send_end,
+ combine_recv_start,
+ combine_recv_end,
+ )
+ )
+
+ last_report_time = time.time()
+ profiler_started = False
+ for i in range(num_warmup + num_repeats):
+ if i + 1 == num_warmup and num_warmup > 0:
+ torch.cuda.profiler.start()
+ profiler_started = True
+ now = time.time()
+ if rank == 0 and (
+ now - last_report_time > 1 or i + 1 == num_warmup + num_repeats
+ ):
print(
- f"[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | "
- f"Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us",
+ f"[rank 0] Iteration {i + 1}/{num_warmup + num_repeats}",
flush=True,
)
+ last_report_time = now
+
+ (
+ dispatch_start,
+ dispatch_end,
+ combine_start,
+ combine_end,
+ dispatch_send_start,
+ dispatch_send_end,
+ dispatch_recv_start,
+ dispatch_recv_end,
+ combine_send_start,
+ combine_send_end,
+ combine_recv_start,
+ combine_recv_end,
+ ) = events[i]
+
+ bench_topk_idx = _rand_topk_idx()
+
+ # Send + recv back-to-back
+ wait()
+ dispatch_start.record()
+ recv_x, handle = dispatch(do_send=True, do_recv=True)
+ dispatch_end.record()
+ simulated_x = materialize_for_combine(recv_x)
+
+ wait()
+ combine_start.record()
+ combine(simulated_x, handle, do_send=True, do_recv=True)
+ combine_end.record()
+
+ # Send and recv split by long kernels
+ wait()
+ dispatch_send_start.record()
+ dispatch(do_send=True, do_recv=False)
+ dispatch_send_end.record()
+
+ wait()
+ dispatch_recv_start.record()
+ recv_x, handle = dispatch(do_send=False, do_recv=True)
+ dispatch_recv_end.record()
+ simulated_x = materialize_for_combine(recv_x)
+
+ wait()
+ combine_send_start.record()
+ combine(simulated_x, handle, do_send=True, do_recv=False)
+ combine_send_end.record()
+
+ wait()
+ combine_recv_start.record()
+ combine(None, None, do_send=False, do_recv=True)
+ combine_recv_end.record()
+
+ torch.cuda.synchronize()
+ if profiler_started:
+ torch.cuda.profiler.stop()
+
+ dispatch_times_us = []
+ dispatch_send_times_us = []
+ dispatch_recv_times_us = []
+ combine_times_us = []
+ combine_send_times_us = []
+ combine_recv_times_us = []
+ for (
+ dispatch_st,
+ dispatch_en,
+ combine_st,
+ combine_en,
+ dispatch_send_st,
+ dispatch_send_en,
+ dispatch_recv_st,
+ dispatch_recv_en,
+ combine_send_st,
+ combine_send_en,
+ combine_recv_st,
+ combine_recv_en,
+ ) in events[num_warmup:]:
+ dispatch_times_us.append(dispatch_st.elapsed_time(dispatch_en) * 1000.0)
+ combine_times_us.append(combine_st.elapsed_time(combine_en) * 1000.0)
+ dispatch_send_times_us.append(
+ dispatch_send_st.elapsed_time(dispatch_send_en) * 1000.0
+ )
+ dispatch_recv_times_us.append(
+ dispatch_recv_st.elapsed_time(dispatch_recv_en) * 1000.0
+ )
+ combine_send_times_us.append(
+ combine_send_st.elapsed_time(combine_send_en) * 1000.0
+ )
+ combine_recv_times_us.append(
+ combine_recv_st.elapsed_time(combine_recv_en) * 1000.0
+ )
+
+ gathered = [None for _ in range(num_ranks)]
+ dist.all_gather_object(gathered, dispatch_times_us, group=group)
+ dispatch_times_us = [v for per_rank in gathered for v in per_rank]
+ dist.all_gather_object(gathered, dispatch_send_times_us, group=group)
+ dispatch_send_times_us = [v for per_rank in gathered for v in per_rank]
+ dist.all_gather_object(gathered, dispatch_recv_times_us, group=group)
+ dispatch_recv_times_us = [v for per_rank in gathered for v in per_rank]
+ dist.all_gather_object(gathered, combine_times_us, group=group)
+ combine_times_us = [v for per_rank in gathered for v in per_rank]
+ dist.all_gather_object(gathered, combine_send_times_us, group=group)
+ combine_send_times_us = [v for per_rank in gathered for v in per_rank]
+ dist.all_gather_object(gathered, combine_recv_times_us, group=group)
+ combine_recv_times_us = [v for per_rank in gathered for v in per_rank]
+
+ def _p50(values):
+ return float(np.percentile(np.asarray(values), 50))
+
+ if rank == 0:
+ dispatch_p50_s = _p50(dispatch_times_us) / 1e6
+ combine_p50_s = _p50(combine_times_us) / 1e6
+ dispatch_bw = num_dispatch_comm_bytes / 1e9 / dispatch_p50_s
+ combine_bw = num_combine_comm_bytes / 1e9 / combine_p50_s
+
+ print(
+ f"[rank 0] Dispatch both p50: {_p50(dispatch_times_us):.2f} us, {dispatch_bw:.2f} GB/s",
+ flush=True,
+ )
+ print(
+ f"[rank 0] Dispatch send p50: {_p50(dispatch_send_times_us):.2f} us",
+ flush=True,
+ )
+ print(
+ f"[rank 0] Dispatch recv p50: {_p50(dispatch_recv_times_us):.2f} us",
+ flush=True,
+ )
+ print(
+ f"[rank 0] Combine both p50: {_p50(combine_times_us):.2f} us, {combine_bw:.2f} GB/s",
+ flush=True,
+ )
+ print(
+ f"[rank 0] Combine send p50: {_p50(combine_send_times_us):.2f} us",
+ flush=True,
+ )
+ print(
+ f"[rank 0] Combine recv p50: {_p50(combine_recv_times_us):.2f} us",
+ flush=True,
+ )
return (hash_value, hash_details) if debug_hash else hash_value
@@ -494,9 +809,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
+ dispatch_use_fp8=args.dispatch_use_fp8,
seed=seed,
skip_benchmark=args.pressure_test_mode == 1,
debug_hash=args.debug_hash,
+ stop_after_first=args.stop_after_first,
+ num_warmup=args.num_warmup,
+ num_repeats=args.num_repeats,
)
if args.debug_hash:
ref_hash, ref_hash_details = ref_out
@@ -521,9 +840,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
+ dispatch_use_fp8=args.dispatch_use_fp8,
seed=seed,
skip_benchmark=args.pressure_test_mode == 1,
debug_hash=args.debug_hash,
+ stop_after_first=args.stop_after_first,
+ num_warmup=args.num_warmup,
+ num_repeats=args.num_repeats,
)
if args.debug_hash:
current_hash, current_hash_details = cur_out
@@ -568,6 +891,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
+
+ # Mark this rank as done
+ mark_done()
if __name__ == "__main__":
@@ -601,6 +927,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
parser.add_argument(
"--use-logfmt", action="store_true", help="Whether to test LogFMT combine"
)
+ parser.add_argument(
+ "--dispatch-use-fp8",
+ type=bool,
+ default=True,
+ action=argparse.BooleanOptionalAction,
+ help="Whether dispatch path uses FP8 casting (default: true).",
+ )
parser.add_argument(
"--pressure-test-mode",
type=int,
@@ -612,6 +945,23 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
action="store_true",
help="Print per-tensor hash breakdown when non-determinism is detected.",
)
+ parser.add_argument(
+ "--stop-after-first",
+ action="store_true",
+ help="Stop after the first experiment (return_recv_hook=False, dispatch_use_fp8=False, round_scale=False, use_ue8m0=False).",
+ )
+ parser.add_argument(
+ "--num-warmup",
+ type=int,
+ default=200,
+ help="Number of warmup iterations for event timing benchmark.",
+ )
+ parser.add_argument(
+ "--num-repeats",
+ type=int,
+ default=500,
+ help="Number of measured iterations for event timing benchmark.",
+ )
args = parser.parse_args()
num_processes = args.num_processes
diff --git a/ep/include/ep_config.hpp b/ep/include/ep_config.hpp
index 451b3d27b..aeb4b2880 100644
--- a/ep/include/ep_config.hpp
+++ b/ep/include/ep_config.hpp
@@ -196,8 +196,25 @@ struct LowLatencyLayout {
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16);
// Send buffer
+#ifdef LAM_DEV
+ constexpr size_t kNumMaxTopK = 9;
+ // Lam: Buffer layout for batched RDMA sends:
+ // ┌─────────────────────────────────┬──────────────────────────────────────────────┬──────────────────────────────────────────────┐
+ // │ Temp buffer (offset 0) │ Expert batch buffer │ Rank batch buffer │
+ // │ rdma_x[token_idx] │ rdma_x[num_max_tokens + expert*max + slot] │ rdma_x[... + rank*rank_cap + rank_slot] │
+ // │ Size: num_max_tokens * msg_size │ Size: num_experts * num_max_tokens * msg_size│ Size: num_ranks*max*kNumMaxTopK*msg_size │
+ // └─────────────────────────────────┴──────────────────────────────────────────────┴──────────────────────────────────────────────┘
+ // Flow: FP8 cast -> temp -> expert batch (legacy path) + rank batch (new path)
+ size_t rank_batch_capacity_per_rank =
+ num_max_dispatch_tokens_per_rank * kNumMaxTopK;
+ size_t dispatch_send_buffer_bytes =
+ ((1 + num_experts) * num_max_dispatch_tokens_per_rank +
+ num_ranks * rank_batch_capacity_per_rank) *
+ num_bytes_per_dispatch_msg;
+#else
size_t dispatch_send_buffer_bytes =
num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
+#endif
size_t combine_send_buffer_bytes = num_experts *
num_max_dispatch_tokens_per_rank *
num_bytes_per_combine_msg;
diff --git a/ep/include/uccl_ibgda.cuh b/ep/include/uccl_ibgda.cuh
index 41b48d1cd..753a5e63f 100644
--- a/ep/include/uccl_ibgda.cuh
+++ b/ep/include/uccl_ibgda.cuh
@@ -29,7 +29,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
int expert_idx, int lane_id, int message_idx,
uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs,
bool is_combine, int low_latency_buffer_idx = 0, uint64_t atomic_offset = 0,
- uint64_t atomic_val = 0) {
+ uint64_t atomic_val = 0, int num_tokens = 1) {
// NOTE(MaoZiming): different from the nvshmemi_ibgda_put_nbi_warp in
// ibgda_device.cuh, we don't do warp-cooperation.
if (lane_id != 0) return;
@@ -67,6 +67,10 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
cmd.atomic_val = atomic_val;
} else {
cmd.expert_idx = expert_idx;
+ // Low-latency WRITE: use atomic_val byte for num_tokens (1..255).
+ cmd.atomic_val = (num_tokens <= 0 || num_tokens > 255)
+ ? 1
+ : static_cast(num_tokens);
}
h->atomic_set_and_commit(cmd, &slot);
}
@@ -115,6 +119,10 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp(
cmd.atomic_val = atomic_val;
} else {
cmd.expert_idx = expert_idx;
+ // Low-latency WRITE: use atomic_val byte for num_tokens (1..255).
+ cmd.atomic_val = (num_tokens <= 0 || num_tokens > 255)
+ ? 1
+ : static_cast(num_tokens);
}
h->atomic_set_and_commit(cmd, &slot);
break;
diff --git a/ep/lam/docs/data.md b/ep/lam/docs/data.md
new file mode 100644
index 000000000..db856551f
--- /dev/null
+++ b/ep/lam/docs/data.md
@@ -0,0 +1,21 @@
+# The original EFA performance
+
+[rank 0] Dispatch bandwidth: 2.23 GB/s, avg_t=3373.00 us | Combine bandwidth: 19.58 GB/s, avg_t=742.49 us
+[rank 1] Dispatch bandwidth: 2.14 GB/s, avg_t=3509.00 us | Combine bandwidth: 23.94 GB/s, avg_t=607.13 us
+[rank 2] Dispatch bandwidth: 2.26 GB/s, avg_t=3324.00 us | Combine bandwidth: 18.35 GB/s, avg_t=792.12 us
+[rank 3] Dispatch bandwidth: 2.22 GB/s, avg_t=3381.00 us | Combine bandwidth: 19.78 GB/s, avg_t=734.92 us
+[rank 4] Dispatch bandwidth: 2.17 GB/s, avg_t=3468.00 us | Combine bandwidth: 22.40 GB/s, avg_t=648.84 us
+[rank 5] Dispatch bandwidth: 2.14 GB/s, avg_t=3517.00 us | Combine bandwidth: 24.33 GB/s, avg_t=597.45 us
+[rank 6] Dispatch bandwidth: 2.10 GB/s, avg_t=3575.00 us | Combine bandwidth: 26.78 GB/s, avg_t=542.78 us
+[rank 7] Dispatch bandwidth: 2.09 GB/s, avg_t=3590.00 us | Combine bandwidth: 27.53 GB/s, avg_t=528.00 us
+
+# Lam
+
+[rank 0] Dispatch bandwidth: 2.32 GB/s, avg_t=3233.00 us | Combine bandwidth: 17.23 GB/s, avg_t=843.55 us
+[rank 1] Dispatch bandwidth: 2.19 GB/s, avg_t=3423.00 us | Combine bandwidth: 22.30 GB/s, avg_t=651.91 us
+[rank 2] Dispatch bandwidth: 2.31 GB/s, avg_t=3246.00 us | Combine bandwidth: 17.51 GB/s, avg_t=830.24 us
+[rank 3] Dispatch bandwidth: 2.22 GB/s, avg_t=3377.00 us | Combine bandwidth: 20.75 GB/s, avg_t=700.61 us
+[rank 4] Dispatch bandwidth: 2.22 GB/s, avg_t=3380.00 us | Combine bandwidth: 20.81 GB/s, avg_t=698.59 us
+[rank 5] Dispatch bandwidth: 2.26 GB/s, avg_t=3325.00 us | Combine bandwidth: 19.31 GB/s, avg_t=752.92 us
+[rank 6] Dispatch bandwidth: 2.14 GB/s, avg_t=3509.00 us | Combine bandwidth: 25.64 GB/s, avg_t=567.05 us
+[rank 7] Dispatch bandwidth: 2.12 GB/s, avg_t=3545.00 us | Combine bandwidth: 27.31 GB/s, avg_t=532.25 us
\ No newline at end of file
diff --git a/ep/lam/docs/dispatch_buffer_trace.md b/ep/lam/docs/dispatch_buffer_trace.md
new file mode 100644
index 000000000..c60932e1f
--- /dev/null
+++ b/ep/lam/docs/dispatch_buffer_trace.md
@@ -0,0 +1,150 @@
+# Trace: `x` and `rdma_channel_data.send_buffer` (high-throughput dispatch)
+
+## 1. Trace: `x` (original memory — read source)
+
+### With code lines
+
+| 層級 | 變數 / 來源 | File | Line |
+|------|-------------|------|------|
+| **Python** | 使用者傳入的 `x`(`torch.Tensor`,shape `[num_tokens, hidden]`) | `buffer.py` | 906-907 (參數 `x`) |
+| **Python** | `x, x_scales = x if isinstance(x, tuple) else (x, None)` | `buffer.py` | 935 |
+| **Python** | `self.runtime.internode_dispatch(x, x_scales, ...)` | `buffer.py` | 954, 1007 |
+| **C++** | `Buffer::internode_dispatch(torch::Tensor const& x, ...)` | `uccl_ep.cc` | 401 (宣告), 416 (參數 `x`) |
+| **C++** | 傳給 kernel 的 pointer:`x.data_ptr()` | `uccl_ep.cc` | 697 |
+| **C++ (host)** | `dispatch(..., void const* x, ...)` | `internode.cu` | 1516 (參數 `x`) |
+| **C++ (host)** | launch 時轉型:`reinterpret_cast(x)` | `internode.cu` | 1565 |
+| **C++ (kernel)** | 參數 `int4 const* x` | `internode.cu` | 482 |
+| **C++ (kernel)** | 讀取位址 `x + token_idx * hidden_int4` | `internode.cu` | 854 |
+
+### In kernel (device)
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Dispatch kernel param | `x` (type: `int4 const*`) | `internode.cu:482` |
+| Copy source | `x + token_idx * hidden_int4` | `internode.cu:854` |
+
+### Host launch (C++)
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Kernel launch | `reinterpret_cast(x)` | `internode.cu:1565` |
+| Host `dispatch()` param | `x` (type: `void const*`) | `internode.cu:1516` |
+| Caller | `x.data_ptr()` | `uccl_ep.cc:697` |
+| Caller method | `Buffer::internode_dispatch(torch::Tensor const& x, ...)` | `uccl_ep.cc:401, 416` |
+
+### Python
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| `Buffer.internode_dispatch` arg | `x` (type: `Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]`) | `buffer.py:906-907` |
+| Unpack if tuple | `x, x_scales = x if isinstance(x, tuple) else (x, None)` | `buffer.py:935` |
+| Passed to C++ | `self.runtime.internode_dispatch(x, x_scales, ...)` | `buffer.py:954, 1007` |
+
+### Chain summary (x)
+
+```
+User / caller
+ → Python: x (torch.Tensor, shape [num_tokens, hidden])
+ → buffer.py: internode_dispatch(self, x, ...) → self.runtime.internode_dispatch(x, ...)
+ → uccl_ep.cc: Buffer::internode_dispatch(torch::Tensor const& x, ...) → x.data_ptr()
+ → internode.cu: dispatch(..., void const* x, ...) → reinterpret_cast(x)
+ → kernel: int4 const* x → read at x + token_idx * hidden_int4
+```
+
+So **`x`** is the **input activation tensor** from the MoE layer (user-provided), in GPU memory; the kernel reads from it.
+
+---
+
+## 2. Trace: `rdma_channel_data.send_buffer` (new contiguous memory — write destination)
+
+### With code lines
+
+| 層級 | 變數 / 來源 | File | Line |
+|------|-------------|------|------|
+| **Python** | `num_rdma_bytes` 傳入 `Buffer(..., num_rdma_bytes=...)` | `buffer.py` | 59 (參數) |
+| **Python** | `self.scratch = torch.zeros(num_rdma_bytes, dtype=torch.uint8, device=...)` | `buffer.py` | 92-94 |
+| **Python** | (ROCm) `self.scratch = ep.get_rdma_buffer(num_rdma_bytes, device_index)` | `buffer.py` | 96 |
+| **Python** | `rdma_buffer_ptr = self.scratch.data_ptr()` | `buffer.py` | 98 |
+| **Python** | `self.runtime.set_rdma_buffer_raw(rdma_buffer_ptr)` | `buffer.py` | 128 |
+| **C++** | Python binding 呼叫 `self.set_rdma_buffer_raw(addr)` → `Buffer::set_rdma_buffer_raw(void* ptr)` | `uccl_ep.cc` | 2059-2063 |
+| **C++** | `void Buffer::set_rdma_buffer_raw(void* ptr)` → `rdma_buffer_ptr = ptr` | `uccl_ep.cc` | 1865-1869 |
+| **C++** | Member:`void* Buffer::rdma_buffer_ptr` | `uccl_ep.cc` | 1912 |
+| **C++** | `uccl::internode::dispatch(..., rdma_buffer_ptr, ...)` 傳入的參數 | `uccl_ep.cc` | 710 |
+| **C++ (host)** | `void dispatch(..., void* rdma_buffer_ptr, ...)` | `internode.cu` | 1523 (參數) |
+| **C++ (host)** | Kernel launch:`dispatch_func(..., rdma_buffer_ptr, ...)` | `internode.cu` | 1569-1570 |
+| **C++ (kernel)** | 參數 `void* rdma_buffer_ptr` | `internode.cu` | 494 |
+| **C++ (kernel)** | `auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, ...)` | `internode.cu` | 552-554 |
+| **C++ (kernel)** | RDMASender 用:`send_buffer` = `rdma_channel_data.send_buffer(lane_id)` 或 `.recv_buffer(lane_id)` | `internode.cu` | 725-727 |
+| **C++ (kernel)** | Coordinator put 用:`rdma_channel_data.send_buffer(dst_rdma_rank)` | `internode.cu` | 1034 |
+| **C++ (buffer.cuh)** | `SymBuffer::send_ptr` = `gbl_ptr + per_channel_bytes * sm_id` | `buffer.cuh` | 125 |
+| **C++ (buffer.cuh)** | `dtype_t* send_buffer(int idx)` → `send_ptr + num_bytes * idx` | `buffer.cuh` | 132-135 |
+
+### In kernel (device)
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Send buffer (per lane) | `send_buffer` = `rdma_channel_data.recv_buffer(lane_id)` or `.send_buffer(lane_id)` | `internode.cu:725-727` |
+| Per-rank send buffer | `rdma_channel_data.send_buffer(dst_rdma_rank)` | `internode.cu:1034` |
+| `rdma_channel_data` | `SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels)` | `internode.cu:552-554` |
+| Kernel param | `rdma_buffer_ptr` (type: `void*`) | `internode.cu:494` |
+
+### SymBuffer: where `send_buffer` points (buffer.cuh)
+
+| Member | Meaning |
+|--------|--------|
+| `SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id, int num_sms)` | Carves a region from `gbl_ptr`; **advances** `gbl_ptr` by `total_bytes`. |
+| `send_ptr` | `gbl_ptr + per_channel_bytes * sm_id` (per-channel offset). |
+| `send_buffer(int idx)` | `send_ptr + num_bytes * idx` → one contiguous region per rank index. |
+
+So **`rdma_channel_data.send_buffer(idx)`** is a **subregion of the memory pointed to by `rdma_buffer_ptr`** when the kernel was launched.
+
+### Host launch (C++)
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Kernel launch | `rdma_buffer_ptr` passed to `dispatch(...)` | `internode.cu:1569` |
+| Host `dispatch()` param | `void* rdma_buffer_ptr` | `internode.cu:1523` |
+| Caller | `rdma_buffer_ptr` (no cast) | `uccl_ep.cc:710` |
+| Source | `Buffer` member: `rdma_buffer_ptr` | `uccl_ep.cc:1912` |
+
+### Where `Buffer::rdma_buffer_ptr` is set
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Setter | `void Buffer::set_rdma_buffer_raw(void* ptr)` → `rdma_buffer_ptr = ptr` | `uccl_ep.cc:1865-1869` |
+| Called from | Python binding / C++ init code that holds the actual allocation | `uccl_ep.cc:2060-2064` |
+
+### Python
+
+| Layer | Variable / Expression | File:Line |
+|-------|----------------------|-----------|
+| Allocation (CUDA) | `self.scratch = torch.zeros(num_rdma_bytes, dtype=torch.uint8, device=f"cuda:{device_index}")` | `buffer.py:92-94` |
+| Allocation (ROCm) | `self.scratch = ep.get_rdma_buffer(num_rdma_bytes, device_index)` | `buffer.py:96` |
+| Pointer | `rdma_buffer_ptr = self.scratch.data_ptr()` | `buffer.py:98` |
+| Passed to C++ | `initialize_uccl(rdma_buffer_ptr, num_rdma_bytes, ...)` and later `self.runtime.set_rdma_buffer_raw(rdma_buffer_ptr)` | `buffer.py:100, 128` |
+| `num_rdma_bytes` | Passed into `Buffer.__init__(..., num_rdma_bytes=...)`; typically from `config.get_rdma_buffer_size_hint(...)` at call sites | `buffer.py:59` |
+
+### Chain summary (rdma_channel_data.send_buffer)
+
+```
+User / config
+ → Python: num_rdma_bytes (e.g. from config.get_rdma_buffer_size_hint(hidden_bytes, num_ranks))
+ → buffer.py: Buffer(..., num_rdma_bytes) → self.scratch = torch.zeros(num_rdma_bytes, ...) → rdma_buffer_ptr = self.scratch.data_ptr()
+ → buffer.py: self.runtime.set_rdma_buffer_raw(rdma_buffer_ptr)
+ → uccl_ep.cc: Buffer::rdma_buffer_ptr (member)
+ → uccl_ep.cc: internode_dispatch(..., rdma_buffer_ptr, ...) → uccl::internode::dispatch(..., rdma_buffer_ptr, ...)
+ → internode.cu: dispatch(..., void* rdma_buffer_ptr, ...) → kernel param rdma_buffer_ptr
+ → kernel: rdma_channel_data = SymBuffer(rdma_buffer_ptr, ...) → send_ptr/send_buffer(idx) are offsets into that block
+ → rdma_channel_data.send_buffer(dst_rdma_rank) = base of the contiguous send region for that rank
+```
+
+So **`rdma_channel_data.send_buffer`** is the **contiguous send staging buffer** for each destination rank, carved out of the **RDMA buffer** that was allocated in Python (`self.scratch`) and set on the C++ `Buffer` via **`set_rdma_buffer_raw`**.
+
+---
+
+## 3. Quick reference
+
+| What | Origin |
+|------|--------|
+| **x** | User’s input tensor (Python `x` → `x.data_ptr()` → kernel `x`). |
+| **rdma_channel_data.send_buffer** | A region inside the RDMA buffer whose base pointer is set by Python (`scratch.data_ptr()` → `set_rdma_buffer_raw`) and passed into the kernel as `rdma_buffer_ptr`; layout is built by `SymBuffer` in the kernel. |
diff --git a/ep/lam/docs/lam_understanding.md b/ep/lam/docs/lam_understanding.md
new file mode 100644
index 000000000..98dc37cd4
--- /dev/null
+++ b/ep/lam/docs/lam_understanding.md
@@ -0,0 +1,497 @@
+# rdma_recv_x Layout and Usage
+
+`rdma_recv_x` is the receive data buffer used in the low-latency dispatch phase (`ep/src/internode_ll.cu`). Each rank has its **own** buffer in its **own** GPU memory; when rank A sends to an expert on rank B, A does an RDMA put **into B’s** `rdma_recv_x`.
+
+---
+
+## Dispatch send flow (sender side)
+
+- **Input** goes into **`rdma_x`** (send buffer): each token’s hidden is read from input `x`, cast/copied into `rdma_x[token_idx]` (one message per token, `num_bytes_per_msg` each).
+- **Multiple SMs**: the grid runs many blocks (SMs). Tokens are split across SMs with:
+ ```cpp
+ for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms)
+ ```
+ So each SM handles every `num_sms`-th token (e.g. SM 0 → token 0, num_sms, 2*num_sms, …).
+- **Per token, one warp per top-k**: for each token, `num_topk` warps (warp_id 0 .. num_topk-1) each take one top-k destination. Each warp reads `dst_expert_idx = topk_idx[token_idx][warp_id]`, allocates a `slot_idx` for that expert, then puts one message from `rdma_x[token_idx]` to the receiver’s `rdma_recv_x` at the slot given by the layout below.
+- **Write to dst**: the remote address is computed with the **rdma_recv_x layout** (see next section): `dst_ptr = rdma_recv_x + dst_expert_local_idx * ... + rank * ... + slot_idx * num_bytes_per_msg`, and the sender does an RDMA put (or IPC copy) of that one message to `dst_ptr`.
+
+---
+
+## Count send (after token sends)
+
+After all token puts to each expert are done, the sender issues a **count send** for that expert.
+
+- **Not a broadcast to all ranks**: one count is sent **per expert**, to the **rank that owns that expert** (`dst_rank = expert_idx / num_local_experts`). So we send one count to the owner of expert 0, one to the owner of expert 1, etc. Each rank only receives counts for its own local experts (from each source rank).
+- **Content**: “I (this rank) sent you (dst_rank) this many tokens for expert `dst_expert_local_idx`” — i.e. the number of tokens this rank routed to that expert (`num_tokens_sent`), encoded as `-num_tokens_sent - 1` and written into the receiver’s `rdma_recv_count[dst_expert_local_idx][rank]` (or `rdma_recv_count_internode` for cross-node). The receiver polls until non-zero, then decodes `num_tokens = -value - 1` and uses that to know how many messages to read from `rdma_recv_x` for that (expert, source_rank).
+
+---
+
+## Layout
+
+**Logical 3D layout** (on the rank that owns the buffer):
+
+```
+rdma_recv_x[expert_local_idx][source_rank][slot_idx]
+```
+
+| Dimension | Index | Size | Meaning |
+|-----------|--------|------|---------|
+| 0 | `expert_local_idx` | `num_local_experts` | Which local expert on this rank |
+| 1 | `source_rank` | `num_ranks` | Which rank sent the data |
+| 2 | `slot_idx` | `num_max_dispatch_tokens_per_rank` | Which message slot from that source |
+
+- **One slot** = `num_bytes_per_msg` bytes (control `int4` + hidden + scales).
+- **Total size per GPU:** `num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg`.
+
+Declared in `ep/include/ep_config.hpp` as `LowLatencyBuffer::dispatch_rdma_recv_data_buffer`; allocated from `rdma_buffer` in `LowLatencyLayout`, passed to the kernel as `rdma_recv_x`.
+
+---
+
+## Usage
+
+**Sender (writing):** The sender computes the **remote** address on the receiver’s `rdma_recv_x` with:
+
+```cpp
+dst_ptr = rdma_recv_x
+ + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg
+ + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg
+ + slot_idx * num_bytes_per_msg;
+```
+
+So the sender uses `(dst_expert_local_idx, rank, slot_idx)` to pick one slot in the receiver’s buffer and does an RDMA put (or IPC copy) of one message there.
+
+**Receiver (reading):** The receiver waits for the send count for each `(local_expert_idx, src_rank)`, then reads that many messages from the corresponding block:
+
+```cpp
+rdma_recv_x_uint8 = rdma_recv_x
+ + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg
+ + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
+// then read slots 0 .. num_recv_tokens - 1 (each num_bytes_per_msg)
+```
+
+So the receiver uses the same layout: for each of its local experts and each source rank, it reads `num_recv_tokens` contiguous slots from that block and packs them into `packed_recv_x` etc.
+
+---
+
+## Dispatch Receive Flow (receiver side)
+
+### Overview
+
+The receiver side of dispatch:
+1. **Polls** for incoming tokens from each `(local_expert, src_rank)` pair
+2. **Copies** tokens from `rdma_recv_x` (RDMA buffer, may have gaps) to `packed_recv_x` (contiguous output buffer)
+3. Records metadata for the combine phase
+
+### Input / Output
+
+| Buffer | Type | Description |
+|--------|------|-------------|
+| **Input:** `rdma_recv_x` | RDMA recv buffer | Organized by `[local_expert][src_rank][slot]`, may have unused slots |
+| **Input:** `rdma_recv_count` | Atomic counters | Token counts per `(local_expert, src_rank)`, set by senders |
+| **Output:** `packed_recv_x` | Packed buffer | Contiguous tokens per expert, ready for Expert Forward |
+| **Output:** `packed_recv_src_info` | int array | Original token index at source rank (for combine phase) |
+| **Output:** `packed_recv_count` | int array | Total tokens received per local expert |
+| **Output:** `packed_recv_layout_range` | int64 array | `(num_tokens, begin_idx)` per `(local_expert, src_rank)` |
+
+### Parallelization Hierarchy
+
+```
+Block (SM)
+└── Warp Group (handles one (local_expert, src_rank) pair)
+ └── Sub-warp (= 1 warp = 32 threads, handles multiple tokens in strided fashion)
+ └── Lane (= 1 thread, handles part of one token's hidden data)
+```
+
+> **Note:** "Sub-warp" and "warp" are the **same thing** (32 threads). The naming depends on perspective:
+> - From **block** perspective: called "warp" (`warp_id = thread_id / 32`)
+> - From **warp group** perspective: called "sub-warp" (`sub_warp_id = warp_id % num_warps_per_group`)
+>
+> A warp group contains multiple warps; each warp within the group is called a "sub-warp" to indicate it's a sub-unit of the group.
+
+### ASCII Diagram: Block Structure
+
+```
+┌─────────────────────────────────────────────────────────────────────────────┐
+│ Block 0 (SM 0) │
+│ blockDim.x = 1024 threads │
+│ │
+│ ┌─────────────────────────────────────┐ ┌─────────────────────────────────┐│
+│ │ Warp Group 0 │ │ Warp Group 1 ││
+│ │ (handles Expert Pair 0) │ │ (handles Expert Pair 1) ││
+│ │ num_warps_per_group = 4 │ │ num_warps_per_group = 4 ││
+│ │ │ │ ││
+│ │ ┌─────────────────────────────┐ │ │ ┌─────────────────────────────┐││
+│ │ │ Sub-warp 0 (warp_id=0) │ │ │ │ Sub-warp 0 (warp_id=4) │││
+│ │ │ ┌──┬──┬──┬──┬─────┬──┬──┐ │ │ │ │ 32 threads │││
+│ │ │ │T0│T1│T2│T3│ ... │T30│T31│ │ │ │ └─────────────────────────────┘││
+│ │ │ └──┴──┴──┴──┴─────┴──┴──┘ │ │ │ ┌─────────────────────────────┐││
+│ │ │ 32 threads (lanes) │ │ │ │ Sub-warp 1 (warp_id=5) │││
+│ │ └─────────────────────────────┘ │ │ │ 32 threads (waiter+copier) │││
+│ │ ┌─────────────────────────────┐ │ │ └─────────────────────────────┘││
+│ │ │ Sub-warp 1 (warp_id=1) │ │ │ ┌─────────────────────────────┐││
+│ │ │ 32 threads │ │ │ │ Sub-warp 2 (warp_id=6) │││
+│ │ │ ★ waiter + copy tokens │ │ │ │ 32 threads │││
+│ │ └─────────────────────────────┘ │ │ └─────────────────────────────┘││
+│ │ ┌─────────────────────────────┐ │ │ ┌─────────────────────────────┐││
+│ │ │ Sub-warp 2 (warp_id=2) │ │ │ │ Sub-warp 3 (warp_id=7) │││
+│ │ │ 32 threads │ │ │ │ 32 threads │││
+│ │ └─────────────────────────────┘ │ │ └─────────────────────────────┘││
+│ │ ┌─────────────────────────────┐ │ │ ││
+│ │ │ Sub-warp 3 (warp_id=3) │ │ │ ││
+│ │ │ 32 threads │ │ │ ││
+│ │ └─────────────────────────────┘ │ │ ││
+│ └─────────────────────────────────────┘ └─────────────────────────────────┘│
+└─────────────────────────────────────────────────────────────────────────────┘
+
+Index calculations:
+┌────────────────────────────────────────────────────────────────┐
+│ thread_id = threadIdx.x (0 ~ 1023) │
+│ warp_id = thread_id / 32 (0 ~ 31) │
+│ lane_id = thread_id % 32 (0 ~ 31) │
+│ │
+│ warp_group_id = warp_id / num_warps_per_group (which group) │
+│ sub_warp_id = warp_id % num_warps_per_group (which in group) │
+│ │
+│ responsible_expert_idx = sm_id * num_warp_groups + warp_group_id│
+└────────────────────────────────────────────────────────────────┘
+```
+
+### ASCII Diagram: (local_expert, src_rank) Distribution
+
+Each `responsible_expert_idx` maps to one `(local_expert, src_rank)` pair:
+
+```
+responsible_expert_idx ∈ [0, num_experts - 1]
+
+src_rank = responsible_expert_idx / num_local_experts
+local_expert_idx = responsible_expert_idx % num_local_experts
+
+Example: num_local_experts = 9, num_ranks = 32, num_experts = 288
+
+┌───────────────────────────────────────────────────────────────┐
+│ idx │ src_rank │ local_expert │ Meaning │
+├─────┼──────────┼──────────────┼───────────────────────────────┤
+│ 0 │ 0 │ 0 │ Recv from Rank0 for Expert0 │
+│ 1 │ 0 │ 1 │ Recv from Rank0 for Expert1 │
+│ ... │ ... │ ... │ │
+│ 8 │ 0 │ 8 │ Recv from Rank0 for Expert8 │
+│ 9 │ 1 │ 0 │ Recv from Rank1 for Expert0 │
+│ 10 │ 1 │ 1 │ Recv from Rank1 for Expert1 │
+│ ... │ ... │ ... │ │
+│ 287 │ 31 │ 8 │ Recv from Rank31 for Expert8 │
+└───────────────────────────────────────────────────────────────┘
+
+Total 288 pairs = 288 warp groups needed
+```
+
+### ASCII Diagram: Receive Matrix (on one rank)
+
+```
+Rank 0 as Receiver, has 9 local experts (Expert 0~8)
+Needs to receive from 32 src_ranks (including itself)
+
+┌─────────────────────────────────────────────────────────────────────────┐
+│ Rank 0's Receive Matrix │
+│ │
+│ src_rank │
+│ 0 1 2 3 ... 31 │
+│ ┌────┬────┬────┬────┬─────┬────┐ │
+│ 0 │ WG │ WG │ WG │ WG │ ... │ WG │ ← recv from each rank for Exp0 │
+│ ├────┼────┼────┼────┼─────┼────┤ │
+│ 1 │ WG │ WG │ WG │ WG │ ... │ WG │ ← recv from each rank for Exp1 │
+│ l ├────┼────┼────┼────┼─────┼────┤ │
+│ o 2 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ c ├────┼────┼────┼────┼─────┼────┤ │
+│ a 3 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ l ├────┼────┼────┼────┼─────┼────┤ │
+│ 4 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ e ├────┼────┼────┼────┼─────┼────┤ │
+│ x 5 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ p ├────┼────┼────┼────┼─────┼────┤ │
+│ e 6 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ r ├────┼────┼────┼────┼─────┼────┤ │
+│ t 7 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ ├────┼────┼────┼────┼─────┼────┤ │
+│ 8 │ WG │ WG │ WG │ WG │ ... │ WG │ │
+│ └────┴────┴────┴────┴─────┴────┘ │
+│ │
+│ Each cell = 1 Warp Group = 1 (local_expert, src_rank) pair │
+│ Total: 9 × 32 = 288 Warp Groups │
+│ │
+└─────────────────────────────────────────────────────────────────────────┘
+```
+
+### ASCII Diagram: Sub-warp Token Distribution
+
+Within a warp group, sub-warps divide tokens by index:
+
+```
+Warp Group for (local_expert=2, src_rank=5) receives 30 tokens
+num_warps_per_group = 10
+
+Token distribution (for loop: i = sub_warp_id; i < num_recv_tokens; i += 10):
+┌─────────────────────────────────────────────────────────────────────┐
+│ │
+│ Sub-warp 0: processes token 0, 10, 20 │
+│ Sub-warp 1: processes token 1, 11, 21 (also polls for arrival) │
+│ Sub-warp 2: processes token 2, 12, 22 │
+│ Sub-warp 3: processes token 3, 13, 23 │
+│ Sub-warp 4: processes token 4, 14, 24 │
+│ Sub-warp 5: processes token 5, 15, 25 │
+│ Sub-warp 6: processes token 6, 16, 26 │
+│ Sub-warp 7: processes token 7, 17, 27 │
+│ Sub-warp 8: processes token 8, 18, 28 │
+│ Sub-warp 9: processes token 9, 19, 29 │
+│ │
+│ All 10 sub-warps process 10 different tokens in PARALLEL! │
+│ │
+└─────────────────────────────────────────────────────────────────────┘
+```
+
+### ASCII Diagram: Single Token Copy (within one sub-warp)
+
+```
+┌─────────────────────────────────────────────────────────────────────┐
+│ Sub-warp processes 1 Token │
+│ │
+│ Token structure: [src_info: 16B] [hidden: 7168B] [scales: ~56B] │
+│ │
+│ 32 threads copy in parallel: │
+│ ┌────────────────────────────────────────────────────────────┐ │
+│ │ Lane 0: src_info + hidden[0:16] │ │
+│ │ Lane 1: hidden[16:32] │ │
+│ │ Lane 2: hidden[32:48] │ │
+│ │ ... │ │
+│ │ Lane 31: hidden[496:512] │ │
+│ │ │ │
+│ │ (then loop: Lane 0 copies hidden[512:528], ...) │ │
+│ │ │ │
+│ │ UNROLLED_WARP_COPY: 32 lanes copy in parallel │ │
+│ └────────────────────────────────────────────────────────────┘ │
+│ │
+│ 7168B / 32 threads / 16B per load ≈ 14 iterations │
+│ │
+└─────────────────────────────────────────────────────────────────────┘
+```
+
+### Understanding token_idx: Local Index on Sender
+
+**IMPORTANT:** `token_idx` in the recv logs is the **sender's local token index**, NOT a global index.
+
+```
+Each rank has its OWN set of input tokens (e.g., 128 tokens per rank):
+
+┌─────────────────────────────────────────────────────────────────────┐
+│ Rank 0: token 0, 1, 2, ..., 83, ..., 127 (local to Rank 0) │
+│ Rank 1: token 0, 1, 2, ..., 83, ..., 127 (local to Rank 1) │
+│ Rank 2: token 0, 1, 2, ..., 83, ..., 127 (local to Rank 2) │
+│ ... │
+│ Rank 8: token 0, 1, 2, ..., 83, ..., 127 (local to Rank 8) │
+│ │
+│ These are DIFFERENT tokens with the same local index! │
+└─────────────────────────────────────────────────────────────────────┘
+
+Example log interpretation:
+ [RECV] rank=0 expert=156 src_rank=8 slot=3 token_idx=83 ...
+
+ This means:
+ - Rank 0 received a token
+ - The token came from Rank 8
+ - On Rank 8, this token was token #83 (out of Rank 8's 128 tokens)
+ - token_idx=83 is Rank 8's LOCAL index, not global
+```
+
+**Why multiple logs show the same token_idx (e.g., token_idx=83)?**
+
+```
+If you see 11 recv logs all with token_idx=83, they are 11 DIFFERENT tokens:
+ - Rank 0's token #83, Rank 2's token #83, Rank 6's token #83, etc.
+ - They happen to share the same local index on their respective ranks
+ - They are routed to different experts on Rank 0 based on top-k routing
+
+Example (top-k=2, 128 tokens per rank):
+┌─────────────────────────────────────────────────────────────────────┐
+│ Rank 6's token #83: │
+│ → top-k routing sends to Expert 3 (on Rank 0) → expert=111 │
+│ → top-k routing sends to Expert 10 (on Rank 0) → expert=118 │
+│ │
+│ Rank 8's token #83: │
+│ → top-k routing sends to Expert 5 (on Rank 0) → expert=149 │
+│ → top-k routing sends to Expert 12 (on Rank 0) → expert=156 │
+│ │
+│ These are 4 different recv operations for 2 different tokens │
+│ (each token sent to 2 experts due to top-k=2) │
+└─────────────────────────────────────────────────────────────────────┘
+```
+
+### Complete Data Flow Example
+
+```
+Rank 5 (Sender) has 128 local tokens (token 0 ~ 127):
+┌─────────────────────────────────────────────────────────────────────┐
+│ Input tokens: T0, T1, T2, T3, T4, T5, ..., T83, ..., T127 │
+│ (these are Rank 5's LOCAL tokens) │
+│ │
+│ Top-k routing results (top-k=2, each token → 2 experts): │
+│ T0 → Expert 2 (on Rank 0), Expert 15 (on Rank 1) │
+│ T1 → Expert 7 (on Rank 0), Expert 22 (on Rank 2) │
+│ T2 → Expert 2 (on Rank 0), Expert 31 (on Rank 3) │
+│ T83 → Expert 4 (on Rank 0), Expert 19 (on Rank 1) │
+│ ... │
+│ │
+│ Sent to Rank 0: T0, T1, T2, T83, ... (with their local indices) │
+└─────────────────────────────────────────────────────────────────────┘
+ │
+ ▼
+Rank 0 (Receiver):
+┌─────────────────────────────────────────────────────────────────────┐
+│ Warp Group for (local_expert=2, src_rank=5) │
+│ │
+│ Received tokens from Rank 5 for Expert 2: │
+│ slot[0] → token_idx=0 (Rank 5's T0) │
+│ slot[1] → token_idx=2 (Rank 5's T2) │
+│ ... │
+│ │
+│ token_idx is preserved so combine phase can send results back! │
+│ │
+│ Sub-warps divide work (assuming 10 sub-warps): │
+│ Sub-warp 0: slot[0], slot[10], ... │
+│ Sub-warp 1: slot[1], slot[11], ... │
+│ ... │
+└─────────────────────────────────────────────────────────────────────┘
+```
+
+**Key formulas:**
+```cpp
+warp_group_id = warp_id / num_warps_per_group;
+sub_warp_id = warp_id % num_warps_per_group;
+responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
+src_rank = responsible_expert_idx / num_local_experts;
+local_expert_idx = responsible_expert_idx % num_local_experts;
+```
+
+**Configuration (computed at launch):**
+```cpp
+num_warp_groups = ceil_div(num_experts, num_device_sms); // warp groups per block
+num_warps_per_group = kNumMaxWarpGroups / num_warp_groups; // sub-warps per warp group
+num_sms = ceil_div(num_experts, num_warp_groups); // blocks to launch
+// kNumMaxWarpGroups = 32 (NVIDIA) or 16 (AMD)
+```
+
+**Example (256 experts, H100 with 132 SMs):**
+- `num_warp_groups = ceil(256/132) = 2` warp groups per block
+- `num_warps_per_group = 32/2 = 16` sub-warps per warp group
+- `num_sms = ceil(256/2) = 128` blocks launched
+- Total: 128 blocks × 2 warp groups = 256 responsible_expert_idx values
+
+### What Each Level Does
+
+| Level | Handles | Description |
+|-------|---------|-------------|
+| **Warp Group** | One `(local_expert, src_rank)` pair | Polls for tokens, then copies all tokens from that src_rank to that expert |
+| **Sub-warp** | Multiple tokens (strided) | Each sub-warp copies tokens `i, i+n, i+2n, ...` where `n = num_warps_per_group` |
+| **Lane** | Part of one token's hidden | 32 lanes cooperatively copy one token's ~7KB hidden data |
+
+### Key Insight: Warp Group = (local_expert, src_rank)
+
+**Each warp group processes exactly one `(local_expert, src_rank)` pair:**
+
+1. **One warp group** is assigned one `responsible_expert_idx` which maps to one `(local_expert, src_rank)` pair
+2. **Many tokens** may arrive from this `src_rank` to this `local_expert` (stored in `num_recv_tokens`)
+3. **All sub-warps in this warp group** cooperatively process these tokens in parallel
+
+```
+Code verification (internode_ll.cu):
+
+Line 527-529: Map responsible_expert_idx → (local_expert, src_rank)
+┌─────────────────────────────────────────────────────────────────────┐
+│ if (responsible_expert_idx < num_experts) { │
+│ src_rank = responsible_expert_idx / num_local_experts; │
+│ local_expert_idx = responsible_expert_idx % num_local_experts; │
+│ } │
+└─────────────────────────────────────────────────────────────────────┘
+
+Line 620: Count tokens from this (local_expert, src_rank) pair
+┌─────────────────────────────────────────────────────────────────────┐
+│ num_recv_tokens = num_recv_tokens_internode + num_recv_tokens_ipc; │
+│ // This is the total tokens from src_rank to local_expert │
+└─────────────────────────────────────────────────────────────────────┘
+
+Line 656: All sub-warps in warp group process these tokens together
+┌─────────────────────────────────────────────────────────────────────┐
+│ for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) │
+│ // sub_warp 0: token 0, 10, 20, ... │
+│ // sub_warp 1: token 1, 11, 21, ... │
+│ // ... │
+│ } │
+└─────────────────────────────────────────────────────────────────────┘
+```
+
+**Summary:**
+- **Warp Group scope:** one `(local_expert, src_rank)` pair
+- **What it receives:** all tokens that `src_rank` sent to `local_expert`
+- **How it processes:** all sub-warps (warps) in the group work together, dividing tokens by index
+
+### recv_token_begin_idx: Dynamic Slot Allocation
+
+Multiple warp groups (different `src_rank`) write to the same expert's output buffer. They use **atomicAdd** for thread-safe slot allocation:
+
+```cpp
+// Line 444-445 in internode_ll.cu
+recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
+```
+
+**How it works:**
+1. `packed_recv_count[local_expert_idx]` is an atomic counter, initialized to 0
+2. `atomicAdd` returns the value **before** adding, and atomically increments the counter
+3. The returned value is the starting index for this warp group's tokens
+
+**Example (Expert 0 receives from 4 src_ranks, in arbitrary order):**
+
+| Order | Warp Group | src_rank | num_tokens | atomicAdd returns | counter after | Write slots |
+|-------|------------|----------|------------|-------------------|---------------|-------------|
+| 1st | B | 2 | 8 | 0 | 8 | [0..7] |
+| 2nd | A | 0 | 5 | 8 | 13 | [8..12] |
+| 3rd | D | 5 | 3 | 13 | 16 | [13..15] |
+| 4th | C | 1 | 4 | 16 | 20 | [16..19] |
+
+**Key points:**
+- **No ordering required**: First-come, first-served. Order depends on network latency and GPU scheduling.
+- **Correctness**: `recv_src_info[i]` records the original token index, so combine phase can send results back correctly.
+- **recv_range**: `recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx)` records where each src_rank's tokens ended up.
+
+### Receive Flow Detail
+
+```
+1. Grid Sync
+ └── Wait for send phase to complete (makes packed_recv_count visible)
+
+2. Compute buffer pointers
+ └── Each warp group calculates its slice of rdma_recv_x based on (local_expert, src_rank)
+
+3. Poll for token count (sub_warp_id == 1, lane_id == 0 only)
+ └── Spin on rdma_recv_count[local_expert][src_rank] until non-zero
+ └── Decode: num_recv_tokens = -value - 1
+
+4. Allocate output slots (sub_warp_id == 1, lane_id == 0 only)
+ └── recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens)
+ └── Store to shared memory for other sub-warps
+
+5. Warp group barrier
+ └── All sub-warps now have num_recv_tokens and recv_token_begin_idx
+
+6. Copy tokens (all sub-warps in parallel)
+ └── for (i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group):
+ ├── Lane 0: copy src_info (4 bytes)
+ ├── All 32 lanes: copy hidden data (~7KB, strided by lane_id)
+ └── All 32 lanes: copy FP8 scales (if applicable)
+```
+
+### Why Pack?
+
+| rdma_recv_x (Source) | packed_recv_x (Destination) |
+|----------------------|-----------------------------|
+| Organized by `[expert][src_rank][slot]` | Organized by `[expert][token]` |
+| Fixed slot size, may have unused slots | Contiguous, no gaps |
+| Scattered across src_ranks | All tokens for one expert together |
+| Not suitable for batched compute | Ready for Expert Forward (batched MLP) |
diff --git a/ep/lam/setup_scripts/setup_lam_conda_local.sh b/ep/lam/setup_scripts/setup_lam_conda_local.sh
new file mode 100755
index 000000000..3760a19fe
--- /dev/null
+++ b/ep/lam/setup_scripts/setup_lam_conda_local.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+set -e
+
+LAM_DIR="/home/ubuntu/lam"
+ENV_PATH="${LAM_DIR}/uccl_lam_local"
+
+echo "Creating ${LAM_DIR} if it doesn't exist..."
+mkdir -p "$LAM_DIR"
+
+echo "Creating conda env at ${ENV_PATH} with Python 3.10..."
+conda create -p "$ENV_PATH" python=3.10 -y
+
+echo "Activating environment and starting shell..."
+# shellcheck source=/dev/null
+source "$(conda info --base)/etc/profile.d/conda.sh"
+conda activate "$ENV_PATH"
+exec "$SHELL"
diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu
index 7cb551bba..16ba7d5ae 100644
--- a/ep/src/internode_ll.cu
+++ b/ep/src/internode_ll.cu
@@ -12,6 +12,15 @@ namespace cg = cooperative_groups;
namespace uccl {
namespace internode_ll {
+#ifdef LAM_DEV
+// Lam: Global lock for debug printing (ensures printf calls don't interleave)
+__device__ int g_print_lock = 0;
+// Lam: Helper macro for conditional kernel arguments
+#define LAM_DEV_ARG(x) x,
+#else
+#define LAM_DEV_ARG(x)
+#endif
+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
constexpr int kNumMaxWarpGroups = 16;
#else
@@ -53,6 +62,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x,
int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
+#ifdef LAM_DEV
+ int* atomic_send_counter_per_expert,
+ int* atomic_send_counter_per_rank, int* rank_send_prefix,
+#endif
int* next_clean, int64_t* next_clean_second, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk,
int num_experts, int rank, int num_ranks, int num_warp_groups,
@@ -63,6 +76,11 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
void* atomic_buffer_ptr = nullptr,
int64_t* rdma_recv_count_internode = nullptr,
int* grid_sync_barrier_ptr = nullptr) {
+// #ifdef LAM_DEV
+ // if (blockIdx.x == 0 && threadIdx.x == 0) {
+ // printf("[LAM_DEV] dispatch called\n");
+ // }
+// #endif
auto const sm_id = static_cast(blockIdx.x);
auto const thread_id = static_cast(threadIdx.x);
auto const warp_id = thread_id / WARP_SIZE, lane_id = get_lane_id();
@@ -98,6 +116,13 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];
+#ifdef LAM_DEV
+ // Lam: Send slots for each topk destination (for batched send buffer layout)
+ constexpr int kNumMaxTopK = 9;
+ __shared__ int shared_send_slots[kNumMaxTopK];
+ __shared__ int shared_dst_experts[kNumMaxTopK];
+#endif
+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
// initialize barrier
amd::barrier_init(1);
@@ -136,6 +161,23 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
: -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
+#ifdef LAM_DEV
+ // Lam: Allocate send slots for each topk destination
+ // Each warp (warp_id < num_topk) allocates a slot for its destination expert
+ if (warp_id < num_topk && lane_id == 0) {
+ shared_dst_experts[warp_id] = dst_expert_idx;
+ if (dst_expert_idx >= 0) {
+ shared_send_slots[warp_id] =
+ atomicAdd(atomic_send_counter_per_expert + dst_expert_idx, 1);
+ } else {
+ shared_send_slots[warp_id] = -1;
+ }
+ }
+ // Sync to make shared_send_slots visible to all threads
+ sync_barrier_1((num_warps - 1) * WARP_SIZE);
+
+#endif // LAM_DEV (slot allocation + debug print)
+
// FP8 cast
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
@@ -211,6 +253,82 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
dst_rank, max_nvl_peers, 0)
: 0;
if (dst_p2p_ptr == 0) {
+#ifdef LAM_DEV
+ int lam_rank_slot = lane_id == 0
+ ? atomicAdd(atomic_send_counter_per_rank +
+ dst_rank,
+ 1)
+ : 0;
+ lam_rank_slot = __shfl_sync(WARP_MASK, lam_rank_slot, 0);
+ // Lam: IBGDA -> copy temp to rdma_batch_buffer, batch send later
+ auto const lam_slot = shared_send_slots[warp_id];
+ auto const rank_batch_capacity =
+ num_max_dispatch_tokens_per_rank * kNumMaxTopK;
+ EP_DEVICE_ASSERT(lam_rank_slot >= 0);
+ EP_DEVICE_ASSERT(lam_rank_slot < rank_batch_capacity);
+ auto const batch_buf_offset =
+ num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
+ auto const rank_batch_buf_offset =
+ (1 + num_experts) * num_max_dispatch_tokens_per_rank *
+ num_bytes_per_msg;
+ auto const batch_buf_ptr =
+ static_cast(rdma_x) + batch_buf_offset +
+ (dst_expert_idx * num_max_dispatch_tokens_per_rank + lam_slot) *
+ num_bytes_per_msg;
+ auto const rank_batch_buf_ptr =
+ static_cast(rdma_x) + rank_batch_buf_offset +
+ (dst_rank * rank_batch_capacity + lam_rank_slot) *
+ num_bytes_per_msg;
+ auto const* src_int4_ptr = reinterpret_cast(rdma_x_src_idx);
+ auto* batch_buf_int4_ptr = reinterpret_cast(batch_buf_ptr);
+ auto* rank_batch_buf_int4_ptr =
+ reinterpret_cast(rank_batch_buf_ptr);
+ UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, batch_buf_int4_ptr,
+ src_int4_ptr, ld_nc_global, st_na_global);
+ // New path: write into per-dst-rank contiguous staging buffer.
+ UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, rank_batch_buf_int4_ptr,
+ src_int4_ptr, ld_nc_global, st_na_global);
+ if (lane_id == 0) {
+ // Populate reserved header fields for future rank-batched recv
+ // demux:
+ // [0] src_token_idx (already written),
+ // [1] dst_local_expert_idx,
+ // [2] dst_rank,
+ // [3] rank_slot.
+ auto* batch_msg_header = reinterpret_cast(batch_buf_ptr);
+ batch_msg_header[1] = dst_expert_local_idx;
+ batch_msg_header[2] = dst_rank;
+ batch_msg_header[3] = lam_rank_slot;
+ auto* rank_batch_msg_header = reinterpret_cast(rank_batch_buf_ptr);
+ rank_batch_msg_header[1] = dst_expert_local_idx;
+ rank_batch_msg_header[2] = dst_rank;
+ rank_batch_msg_header[3] = lam_rank_slot;
+
+ // Full mapping print (rank0 only): print every inter-node mapped
+ // entry at write time, so we can inspect exact slot assignments.
+ if (rank == 0) {
+ while (atomicCAS(&g_print_lock, 0, 1) != 0)
+#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
+ __builtin_amdgcn_s_sleep(1);
+#else
+ ;
+#endif
+ printf(
+ "[LAM_DEV][dispatch] write src_idx=%d dst_expert=%d "
+ "dst_local=%d dst_rank=%d expert_slot=%d local_rank_slot=%d "
+ "expert_buf_index=%d rank_buf_index=%d "
+ "dst_rank_expert_begin=%d dst_rank_expert_end=%d\n",
+ batch_msg_header[0], dst_expert_idx, batch_msg_header[1],
+ batch_msg_header[2], lam_slot, batch_msg_header[3],
+ dst_expert_idx * num_max_dispatch_tokens_per_rank + lam_slot,
+ dst_rank * rank_batch_capacity + lam_rank_slot,
+ dst_rank * num_local_experts,
+ min((dst_rank + 1) * num_local_experts, num_experts) - 1);
+ __threadfence_system();
+ atomicExch(&g_print_lock, 0);
+ }
+ }
+#else
__threadfence_system();
uccl::nvshmemi_ibgda_put_nbi_warp(
dst_ptr - reinterpret_cast(rdma_buffer_ptr),
@@ -220,6 +338,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
// rb.
lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs,
false, low_latency_buffer_idx);
+#endif
} else {
// Intra-node: use direct memory copy via IPC
auto const* src_int4_ptr = reinterpret_cast(src_ptr);
@@ -278,6 +397,117 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
}
}
__syncthreads();
+
+#ifdef LAM_DEV
+ // Lam: Grid-wide sync before batch send phase.
+ // __syncthreads() only syncs within a single thread block (SM).
+ // The token loop distributes tokens round-robin across SMs
+ // (token_idx = sm_id, stepping by num_sms). When num_tokens < num_sms,
+ // most SMs skip the token loop and pass __syncthreads() immediately,
+ // while the SMs processing tokens are still writing to the batch buffer
+ // and incrementing atomic_send_counter_per_expert.
+ // Without grid sync, the batch send phase can read a stale/partial
+ // counter and send fewer tokens than actually produced, causing the
+ // receiver to hang waiting for data that never arrives.
+#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
+ amd::grid_sync(grid_sync_barrier_ptr, num_sms);
+#else
+ cg::this_grid().sync();
+#endif
+
+ // Lam: Batch RDMA send phase - send one rank-contiguous buffer per dst rank.
+ // Each warp group handles one destination rank (only first sub_warp sends).
+ if (responsible_expert_idx < num_ranks && sub_warp_id == 0) {
+ auto const dst_rank = responsible_expert_idx;
+ auto const rank_batch_capacity =
+ num_max_dispatch_tokens_per_rank * kNumMaxTopK;
+
+ // Check if this destination is inter-node (needs IBGDA batch send)
+ // IPC destinations were already sent in the token loop.
+ auto const test_dst_ptr = reinterpret_cast(rdma_recv_x);
+ auto const dst_p2p_ptr =
+ ipc_rdma_base_ptrs
+ ? uccl::get_ipc_p2p_ptr(test_dst_ptr, ipc_rdma_base_ptrs, rank,
+ dst_rank, max_nvl_peers, 0)
+ : 0;
+
+ if (dst_p2p_ptr == 0) {
+ // Inter-node: batch send ALL tokens for this destination rank in ONE call.
+ auto const num_tokens_to_send = atomic_send_counter_per_rank[dst_rank];
+ if (num_tokens_to_send > 0) {
+ auto const rank_batch_buf_offset =
+ (1 + num_experts) * num_max_dispatch_tokens_per_rank *
+ num_bytes_per_msg;
+ // Source: start of this dst-rank's rank-batch buffer (contiguous).
+ auto const rank_batch_buf_ptr =
+ static_cast(rdma_x) + rank_batch_buf_offset +
+ dst_rank * rank_batch_capacity * num_bytes_per_msg;
+ auto const src_ptr = reinterpret_cast(rank_batch_buf_ptr);
+ auto const rank_batch_recv_offset =
+ num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
+ // Destination: start of this src-rank's slot in remote rank-batch area.
+ auto const dst_ptr =
+ reinterpret_cast(rdma_recv_x) +
+ rank_batch_recv_offset +
+ rank * rank_batch_capacity * num_bytes_per_msg;
+ auto const total_bytes = num_tokens_to_send * num_bytes_per_msg;
+
+#ifdef LAM_DEV
+ if (rank == 0 && lane_id == 0) {
+ while (atomicCAS(&g_print_lock, 0, 1) != 0)
+#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
+ __builtin_amdgcn_s_sleep(1);
+#else
+ ;
+#endif
+ printf(
+ "[LAM_DEV][dispatch] send_ready src_rank=%d dst_rank=%d "
+ "tokens=%d bytes=%llu rank_capacity=%d src_rank_base_idx=%d\n",
+ rank, dst_rank, num_tokens_to_send,
+ static_cast(total_bytes), rank_batch_capacity,
+ dst_rank * rank_batch_capacity);
+ __threadfence_system();
+ atomicExch(&g_print_lock, 0);
+ }
+#endif
+
+ __threadfence_system();
+
+ uccl::nvshmemi_ibgda_put_nbi_warp(
+ dst_ptr - reinterpret_cast(rdma_buffer_ptr),
+ src_ptr - reinterpret_cast(rdma_buffer_ptr),
+ total_bytes, dst_rank,
+ /*warp_id=*/
+ dst_rank % ((num_local_experts > 0) ? num_local_experts : 1),
+ lane_id, /*slot=*/0, d2h_channel_addrs, num_d2h_channel_addrs,
+ false, low_latency_buffer_idx, 0, 0, num_tokens_to_send);
+
+#ifdef LAM_DEV
+ if (rank == 0 && lane_id == 0) {
+ while (atomicCAS(&g_print_lock, 0, 1) != 0)
+#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
+ __builtin_amdgcn_s_sleep(1);
+#else
+ ;
+#endif
+ printf(
+ "[LAM_DEV][dispatch] send_done src_rank=%d dst_rank=%d "
+ "tokens=%d bytes=%llu\n",
+ rank, dst_rank, num_tokens_to_send,
+ static_cast(total_bytes));
+ __threadfence_system();
+ atomicExch(&g_print_lock, 0);
+ }
+#endif
+ }
+ }
+ // IPC: already sent in the token loop, nothing to do here.
+ }
+
+ __threadfence_system(); // Ensure batch sends are visible before count sends
+
+#endif // LAM_DEV batch send
+
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and
lane_id == 0) {
@@ -319,10 +549,18 @@ __global__ __launch_bounds__(1024, 1) void dispatch(
st_release_sys_global(reinterpret_cast(dst_p2p_ptr),
-num_tokens_sent - 1);
}
+
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
+#ifdef LAM_DEV
+ atomic_send_counter_per_expert[responsible_expert_idx] = 0;
+ if (responsible_expert_idx < num_ranks) {
+ atomic_send_counter_per_rank[responsible_expert_idx] = 0;
+ rank_send_prefix[responsible_expert_idx] = 0;
+ }
+#endif
// Clean `packed_recv_count`
if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0;
}
@@ -430,13 +668,6 @@ LOW_LATENCY_DISPATCH_RECV:
num_recv_tokens_internode != 0 ? -num_recv_tokens_internode - 1 : 0;
num_recv_tokens_ipc =
num_recv_tokens_ipc != 0 ? -num_recv_tokens_ipc - 1 : 0;
- // printf(
- // "num_recv_tokens_internode: %d, num_recv_tokens_ipc: %d, src_rank:"
- // "%d, rank: %d, max_nvl_peers: %d, responsible_expert_idx: %d,"
- // "num_experts: %d, num_local_experts: %d\n",
- // num_recv_tokens_internode, num_recv_tokens_ipc, src_rank, rank,
- // max_nvl_peers, responsible_expert_idx, num_experts,
- // num_local_experts);
num_recv_tokens = num_recv_tokens_internode + num_recv_tokens_ipc;
recv_token_begin_idx =
atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
@@ -464,10 +695,40 @@ LOW_LATENCY_DISPATCH_RECV:
// Copy tokens
EP_DEVICE_ASSERT(num_scales <= 64);
+ auto const rank_batch_capacity =
+ num_max_dispatch_tokens_per_rank * kNumMaxTopK;
+ auto const rank_batch_recv_offset =
+ num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
+ auto const rank_batch_recv_x_uint8 =
+ static_cast(rdma_recv_x) +
+ rank_batch_recv_offset +
+ src_rank * rank_batch_capacity * num_bytes_per_msg;
+ int inter_rank_scan_cursor = 0;
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// Copy source info
- auto const src_src_idx =
- reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg);
+ int matched_rank_slot = i;
+ if (src_rank / max_nvl_peers != rank / max_nvl_peers) {
+ if (lane_id == 0) {
+ matched_rank_slot = -1;
+ for (int s = inter_rank_scan_cursor; s < rank_batch_capacity; ++s) {
+ auto const* hdr = reinterpret_cast(
+ rank_batch_recv_x_uint8 + s * num_bytes_per_msg);
+ auto const dst_local_expert = ld_nc_global(hdr + 1);
+ if (dst_local_expert == local_expert_idx) {
+ matched_rank_slot = s;
+ inter_rank_scan_cursor = s + 1;
+ break;
+ }
+ }
+ EP_DEVICE_ASSERT(matched_rank_slot >= 0);
+ }
+ matched_rank_slot = __shfl_sync(WARP_MASK, matched_rank_slot, 0);
+ }
+ auto const src_src_idx = reinterpret_cast(
+ (src_rank / max_nvl_peers == rank / max_nvl_peers)
+ ? (rdma_recv_x_uint8 + i * num_bytes_per_msg)
+ : (rank_batch_recv_x_uint8 +
+ matched_rank_slot * num_bytes_per_msg));
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
__syncwarp();
@@ -512,8 +773,6 @@ LOW_LATENCY_DISPATCH_RECV:
}
}
}
- // if (blockIdx.x == 0 && threadIdx.x == 0)
- // printf("[dispatch] RECV finished\n");
}
}
@@ -546,8 +805,19 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto atomic_counter_per_expert = static_cast(workspace);
auto atomic_finish_counter_per_expert =
atomic_counter_per_expert + num_experts;
+#ifdef LAM_DEV
+ auto atomic_send_counter_per_expert =
+ atomic_finish_counter_per_expert + num_experts;
+ auto atomic_send_counter_per_rank =
+ atomic_send_counter_per_expert + num_experts;
+ auto rank_send_prefix = atomic_send_counter_per_rank + num_ranks;
+ auto grid_sync_barrier_ptr = rank_send_prefix + num_ranks;
+ EP_HOST_ASSERT((num_experts * 3 + num_ranks * 2 + 1) * sizeof(int) <=
+ NUM_WORKSPACE_BYTES);
+#else
auto grid_sync_barrier_ptr = atomic_finish_counter_per_expert + num_experts;
EP_HOST_ASSERT((num_experts * 2 + 1) * sizeof(int) <= NUM_WORKSPACE_BYTES);
+#endif
// FP8 checks
if (use_ue8m0)
@@ -565,6 +835,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
cumulative_local_expert_recv_stats, dispatch_wait_recv_cost_stats, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
+ LAM_DEV_ARG(atomic_send_counter_per_expert) \
+ LAM_DEV_ARG(atomic_send_counter_per_rank) \
+ LAM_DEV_ARG(rank_send_prefix) \
next_clean, next_clean_second, num_next_clean_int, num_tokens, \
num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, \
num_ranks, num_warp_groups, num_warps_per_group, round_scale, phases, \
@@ -936,8 +1209,6 @@ __global__ __launch_bounds__(1024, 1) void combine(
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0) {
- // if (blockIdx.x == 0 && threadIdx.x == 0)
- // printf("[combine] SEND finished\n");
return;
}
// Wait all ranks to arrive
@@ -1050,8 +1321,6 @@ LOW_LATENCY_COMBINE_RECV:
token_idx * hidden_bf16_int4)[hidden_idx] = combined_int4;
}
- // if (blockIdx.x == 0 && threadIdx.x == 0)
- // printf("[combine] RECV finished\n");
}
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
@@ -1093,7 +1362,6 @@ void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag,
constexpr int kNumTMABytesPerWarp = 12 * (512 + 16);
int const smem_size = kNumTMABytesPerWarp * num_warps;
- // printf("Combine launched\n");
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp
index 8874509ef..679481305 100644
--- a/ep/src/rdma.cpp
+++ b/ep/src/rdma.cpp
@@ -38,6 +38,10 @@
#include
#include
+#ifndef UCCL_RDMA_POST_DEBUG
+#define UCCL_RDMA_POST_DEBUG 0
+#endif
+
void recv_connection_info_as_server(int my_rank, int* actual_peer,
int listen_fd,
RDMAConnectionInfo* remote_array) {
@@ -732,6 +736,67 @@ void post_receive_buffer_for_imm(ProxyCtx& S) {
}
}
+static inline void log_dispatch_rdma_post(int thread_idx, int src_rank,
+ int dst_rank, size_t dispatch_wrs,
+ uint64_t dispatch_bytes,
+ size_t total_wrs,
+ uint64_t total_bytes) {
+#if UCCL_RDMA_POST_DEBUG
+ if (dispatch_wrs == 0) return;
+
+ static std::atomic g_dispatch_post_count{0};
+ static std::atomic g_dispatch_wr_count{0};
+ static std::atomic g_dispatch_bytes_sum{0};
+ static std::once_flag g_dispatch_summary_once;
+ std::call_once(g_dispatch_summary_once, []() {
+ std::atexit([]() {
+ uint64_t post_count = g_dispatch_post_count.load(std::memory_order_relaxed);
+ uint64_t wr_count = g_dispatch_wr_count.load(std::memory_order_relaxed);
+ uint64_t bytes_sum = g_dispatch_bytes_sum.load(std::memory_order_relaxed);
+ if (post_count == 0) return;
+ fprintf(stderr,
+ "[RDMA POST dispatch summary] posts=%llu dispatch_wrs=%llu "
+ "dispatch_bytes=%llu\n",
+ (unsigned long long)post_count, (unsigned long long)wr_count,
+ (unsigned long long)bytes_sum);
+ });
+ });
+
+ g_dispatch_post_count.fetch_add(1, std::memory_order_relaxed);
+ g_dispatch_wr_count.fetch_add(dispatch_wrs, std::memory_order_relaxed);
+ g_dispatch_bytes_sum.fetch_add(dispatch_bytes, std::memory_order_relaxed);
+ fprintf(stderr,
+ "[RDMA POST dispatch-only] th=%d src=%d dst=%d dispatch_wrs=%zu "
+ "dispatch_bytes=%llu total_wrs=%zu total_bytes=%llu\n",
+ thread_idx, src_rank, dst_rank, dispatch_wrs,
+ (unsigned long long)dispatch_bytes, total_wrs,
+ (unsigned long long)total_bytes);
+#else
+ (void)thread_idx;
+ (void)src_rank;
+ (void)dst_rank;
+ (void)dispatch_wrs;
+ (void)dispatch_bytes;
+ (void)total_wrs;
+ (void)total_bytes;
+#endif
+}
+
+static inline void log_dispatch_rdma_message_size(int thread_idx, int src_rank,
+ int dst_rank,
+ uint32_t msg_bytes) {
+#if UCCL_RDMA_POST_DEBUG
+ fprintf(stderr,
+ "[RDMA MSG dispatch] th=%d src=%d dst=%d msg_bytes=%u\n", thread_idx,
+ src_rank, dst_rank, msg_bytes);
+#else
+ (void)thread_idx;
+ (void)src_rank;
+ (void)dst_rank;
+ (void)msg_bytes;
+#endif
+}
+
// Normal mode implementation
static void post_rdma_async_batched_normal_mode(
ProxyCtx& S, void* buf, size_t num_wrs,
@@ -873,6 +938,22 @@ static void post_rdma_async_batched_normal_mode(
ring_wrids.push_back(wrs_to_post[i]);
}
+#if UCCL_RDMA_POST_DEBUG
+ uint64_t total_bytes = 0;
+ uint64_t dispatch_bytes = 0;
+ size_t dispatch_wrs = 0;
+ for (size_t i : idxs) {
+ total_bytes += cmds_to_post[i].bytes;
+ if (!get_is_combine(cmds_to_post[i].cmd_type)) {
+ dispatch_bytes += cmds_to_post[i].bytes;
+ ++dispatch_wrs;
+ log_dispatch_rdma_message_size(thread_idx, my_rank, dst_rank,
+ cmds_to_post[i].bytes);
+ }
+ }
+ log_dispatch_rdma_post(thread_idx, my_rank, dst_rank, dispatch_wrs,
+ dispatch_bytes, idxs.size(), total_bytes);
+#endif
int ret = ibv_wr_complete(qpx);
if (ret) {
fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n",
@@ -974,6 +1055,22 @@ static void post_rdma_async_batched_normal_mode(
}
// Post the chain
+#if UCCL_RDMA_POST_DEBUG
+ uint64_t total_bytes = 0;
+ uint64_t dispatch_bytes = 0;
+ size_t dispatch_wrs = 0;
+ for (size_t i : idxs) {
+ total_bytes += cmds_to_post[i].bytes;
+ if (!get_is_combine(cmds_to_post[i].cmd_type)) {
+ dispatch_bytes += cmds_to_post[i].bytes;
+ ++dispatch_wrs;
+ log_dispatch_rdma_message_size(thread_idx, my_rank, dst_rank,
+ cmds_to_post[i].bytes);
+ }
+ }
+ log_dispatch_rdma_post(thread_idx, my_rank, dst_rank, dispatch_wrs,
+ dispatch_bytes, kgroup, total_bytes);
+#endif
ibv_send_wr* bad = nullptr;
int ret = ibv_post_send(qp, &wrs[0], &bad);
if (ret) {
@@ -1083,6 +1180,22 @@ static void post_rdma_async_batched_normal_mode(
}
// Post the chain
+#if UCCL_RDMA_POST_DEBUG
+ uint64_t total_bytes = 0;
+ uint64_t dispatch_bytes = 0;
+ size_t dispatch_wrs = 0;
+ for (size_t i : idxs) {
+ total_bytes += cmds_to_post[i].bytes;
+ if (!get_is_combine(cmds_to_post[i].cmd_type)) {
+ dispatch_bytes += cmds_to_post[i].bytes;
+ ++dispatch_wrs;
+ log_dispatch_rdma_message_size(thread_idx, my_rank, dst_rank,
+ cmds_to_post[i].bytes);
+ }
+ }
+ log_dispatch_rdma_post(thread_idx, my_rank, dst_rank, dispatch_wrs,
+ dispatch_bytes, kgroup, total_bytes);
+#endif
ibv_send_wr* bad = nullptr;
int ret = ibv_post_send(qp, &wrs[0], &bad);
if (ret) {
@@ -1148,6 +1261,20 @@ static void post_rdma_async_batched_fast_mode(
std::abort();
}
size_t const k = wr_ids.size();
+#if UCCL_RDMA_POST_DEBUG
+ uint64_t dst_total_bytes = 0;
+ for (size_t i : wr_ids) dst_total_bytes += cmds_to_post[i].bytes;
+ uint64_t dst_dispatch_bytes = 0;
+ size_t dst_dispatch_wrs = 0;
+ for (size_t i : wr_ids) {
+ if (!get_is_combine(cmds_to_post[i].cmd_type)) {
+ dst_dispatch_bytes += cmds_to_post[i].bytes;
+ ++dst_dispatch_wrs;
+ log_dispatch_rdma_message_size(thread_idx, my_rank, dst_rank,
+ cmds_to_post[i].bytes);
+ }
+ }
+#endif
#ifdef EFA
struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp;
ibv_wr_start(qpx);
@@ -1207,9 +1334,16 @@ static void post_rdma_async_batched_fast_mode(
get_low_latency(cmd.cmd_type)};
#endif
#ifdef USE_RECEIVER_BARRIER
+ // Lam: Low-latency: num_tokens from cmd.atomic_val (GPU sets it); else 1.
+ // Use atomic_val whenever GPU set it (non-zero), not only when
+ // get_low_latency(cmd_type); dispatch can also set atomic_val.
+ uint32_t num_tokens_imm =
+ cmd.atomic_val ? static_cast(cmd.atomic_val) : 1u;
+ // get_is_combine(cmd.cmd_type) ? printf("Receiving this combine imm? num_tokens_imm: %d, cmd.atomic_val: %d\n", num_tokens_imm, cmd.atomic_val) : printf("Receiving this dispatch imm? num_tokens_imm: %d, cmd.atomic_val: %d\n", num_tokens_imm, cmd.atomic_val);
+ // fflush(stdout);
uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type),
get_low_latency(cmd.cmd_type),
- cmd.expert_idx, 1, my_rank)
+ cmd.expert_idx, num_tokens_imm, my_rank)
.GetImmData();
ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm));
#else
@@ -1259,6 +1393,10 @@ static void post_rdma_async_batched_fast_mode(
}
#endif
+#if UCCL_RDMA_POST_DEBUG
+ log_dispatch_rdma_post(thread_idx, my_rank, dst_rank, dst_dispatch_wrs,
+ dst_dispatch_bytes, k, dst_total_bytes);
+#endif
int ret = ibv_wr_complete(qpx);
if (ret) {
fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n",
@@ -1310,6 +1448,10 @@ static void post_rdma_async_batched_fast_mode(
wrs[last].send_flags |= IBV_SEND_SIGNALED;
wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wrs[last].imm_data = htonl(static_cast(batch_tail_wr));
+#if UCCL_RDMA_POST_DEBUG
+ log_dispatch_rdma_post(thread_idx, my_rank, dst_rank, dst_dispatch_wrs,
+ dst_dispatch_bytes, k, dst_total_bytes);
+#endif
ibv_send_wr* bad = nullptr;
int ret = ibv_post_send(ctx->qp, &wrs[0], &bad);
if (ret) {
| |