Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,5 @@ ep/figs

ep/deep_ep_wrapper/deep_ep.egg-info/
*.json
!.devcontainer/devcontainer.json
*result.jsonl
3 changes: 2 additions & 1 deletion ep/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
benchmark
*.hip
*_hip.*
*.egg-info
*.egg-info
.lam_dev_state
27 changes: 24 additions & 3 deletions ep/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,29 @@ LDFLAGS := -lpthread -lglog -libverbs -lnl-3 -lnl-route-3 -lnuma -Xlinker -rpath
NVCCFLAGS := -O3 -std=c++17 -Xcompiler "-Wall -pthread -fPIC -fvisibility=hidden" -ccbin /usr/bin/g++ --expt-relaxed-constexpr
INCLUDES := -Iinclude -I$(CUDA_PATH)/include -I/usr/include -I../include

CXXFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)
NVCCFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)
# LAM_DEV: Enable Lam's development/debug code (set LAM_DEV=1 to enable)
# Auto-detects changes and only rebuilds affected files
LAM_DEV ?= 0
LAM_DEV_STATE_FILE := .lam_dev_state
LAM_DEV_CURRENT := $(shell cat $(LAM_DEV_STATE_FILE) 2>/dev/null || echo 0)
# Files affected by LAM_DEV: internode_ll.cu, ep_config.hpp (used by rdma.cpp, uccl_ep.cc)
LAM_DEV_AFFECTED_OBJS := src/internode_ll.o src/rdma.o src/uccl_ep.o

ifneq ($(LAM_DEV),$(LAM_DEV_CURRENT))
$(info LAM_DEV changed from $(LAM_DEV_CURRENT) to $(LAM_DEV), will rebuild affected files)
$(shell rm -f $(LAM_DEV_AFFECTED_OBJS))
$(shell echo $(LAM_DEV) > $(LAM_DEV_STATE_FILE))
endif

ifeq ($(LAM_DEV),1)
LAM_DEV_FLAGS := -DLAM_DEV
$(info LAM_DEV enabled)
else
LAM_DEV_FLAGS :=
endif

CXXFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) $(LAM_DEV_FLAGS)
NVCCFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) $(LAM_DEV_FLAGS)
LDFLAGS += $(EFA_LDFLAGS)
INCLUDES += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS)

Expand Down Expand Up @@ -106,7 +127,7 @@ install: $(EP_EXT)
# Clean all generated files
clean:
rm -f $(OBJ_CPP) $(OBJ_CC) $(OBJ_CU) \
$(EP_EXT) \
$(EP_EXT) $(LAM_DEV_STATE_FILE) \
*.d src/*.d bench/*.d src/*.o **/*.hip **/*_hip.*

# Automatically include dependency files if they exist
Expand Down
34 changes: 34 additions & 0 deletions ep/bench/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ def low_latency_dispatch(
hidden=x.shape[1],
num_experts=num_experts,
)

# DEBUG: Print data before CUDA kernel launch
if os.environ.get("DEBUG_DISPATCH", "0") == "1":
print(f"\n{'='*60}", flush=True)
print(f"[DEBUG] low_latency_dispatch - BEFORE CUDA kernel", flush=True)
print(f"{'='*60}", flush=True)
print(f" rank: {self.rank}", flush=True)
print(f" x.shape: {x.shape}, dtype: {x.dtype}", flush=True)
print(f" x.data_ptr: {hex(x.data_ptr())}", flush=True)
print(f" x[:3, :8]:\n{x[:3, :8]}", flush=True)
print(f" topk_idx.shape: {topk_idx.shape}, dtype: {topk_idx.dtype}", flush=True)
print(f" topk_idx[:5]:\n{topk_idx[:5]}", flush=True)
print(f" num_max_dispatch_tokens_per_rank: {num_max_dispatch_tokens_per_rank}", flush=True)
print(f" num_experts: {num_experts}", flush=True)
print(f" use_fp8: {use_fp8}, round_scale: {round_scale}, use_ue8m0: {use_ue8m0}", flush=True)
print(f" async_finish: {async_finish}, return_recv_hook: {return_recv_hook}", flush=True)
print(f"{'='*60}\n", flush=True)

(
packed_recv_x,
packed_recv_x_scales,
Expand Down Expand Up @@ -309,6 +327,22 @@ def low_latency_dispatch(
x.size(1),
num_experts,
)

# DEBUG: Print data after CUDA kernel returns
if os.environ.get("DEBUG_DISPATCH", "0") == "1":
import torch
torch.cuda.synchronize() # Wait for kernel to complete
print(f"\n{'='*60}", flush=True)
print(f"[DEBUG] low_latency_dispatch - AFTER CUDA kernel", flush=True)
print(f"{'='*60}", flush=True)
print(f" packed_recv_x.shape: {packed_recv_x.shape}, dtype: {packed_recv_x.dtype}", flush=True)
print(f" packed_recv_count: {packed_recv_count}", flush=True)
print(f" packed_recv_src_info.shape: {packed_recv_src_info.shape}", flush=True)
print(f" packed_recv_layout_range.shape: {packed_recv_layout_range.shape}", flush=True)
if packed_recv_x_scales is not None:
print(f" packed_recv_x_scales.shape: {packed_recv_x_scales.shape}", flush=True)
print(f"{'='*60}\n", flush=True)

tensors_to_record = (
x,
topk_idx,
Expand Down
Loading