Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions flashinfer/comm/cuda_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import paddle
paddle.compat.enable_torch_proxy()
import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings.
# However, cuda-python's API is not stable yet, so we use ctypes bindings instead.
Expand Down Expand Up @@ -207,9 +208,14 @@ def create_shared_buffer(
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
handles = [None] * world_size
# handles = [None] * world_size
# dist.all_gather_object(handles, handle, group=group)
# handles = [None] * world_size
# dist.all_gather_object(handles, handle, group=group)

# The behavior of the paddle framework and torch framework is inconsistent,
# so the following code is used instead
handles = []
dist.all_gather_object(handles, handle, group=group)

pointers: List[int] = []
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/comm/nvshmem_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import torch
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from .nvshmem import get_nvshmem_module

Expand Down
12 changes: 10 additions & 2 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union

import paddle
paddle.compat.enable_torch_proxy()
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from ..jit.comm import gen_trtllm_comm_module
from ..utils import register_custom_op, round_up
Expand Down Expand Up @@ -602,8 +604,14 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}")

# Store workspace pointers in device tensor
# workspace_tensor = torch.tensor(
# workspace, dtype=torch.int64, device=torch.device("cuda")
# )

# There is a bug in the paddle framework when device="CUDA".
# Currently, the bug is being avoided by changing the source code.
workspace_tensor = torch.tensor(
workspace, dtype=torch.int64, device=torch.device("cuda")
workspace, dtype=torch.int64
)

dist.barrier(group=group) # must sync after create_workspace
Expand Down
160 changes: 160 additions & 0 deletions tests/comm/test_trtllm_allreduce_fusion_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import socket
import pytest

import flashinfer.comm as comm

import paddle
import paddle.distributed as dist_pp
paddle.compat.enable_torch_proxy()

import os
import numpy as np

# test parameters
token_num = 128
hidden_dim = 1024
dtype = paddle.float16
pattern_code = comm.AllReduceFusionPattern.kAllReduce
layout_code = comm.QuantizationSFLayout.LINEAR
launch_with_pdl = False
use_oneshot = True
trigger_completion_at_end = True
fp32_acc = False

def kernel(workspace_tensor, rank, world_size):
device = f"cuda:{rank}"
message_size = token_num * hidden_dim
dtype = paddle.float16
# Create input data
allreduce_in = paddle.randn(message_size, dtype=dtype, device=device)
# allreduce_in_clone = allreduce_in.clone()
all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device)

# Add missing required parameters
residual_in = paddle.randn(message_size, dtype=dtype, device=device)
residual_out = paddle.zeros(message_size, dtype=dtype, device=device)
norm_out = paddle.zeros(message_size, dtype=dtype, device=device)
quant_out = paddle.zeros(message_size, dtype=dtype, device=device)
scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device) # SF_VEC_SIZE = 16
rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device)
rms_eps = 1e-3
scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device)

# Run fusion operation
print("Running fusion operation...")
comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
world_size=world_size,
world_rank=rank,
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=workspace_tensor,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=all_reduce_out,
residual_in=residual_in,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
)

paddle.cuda.synchronize()

return allreduce_in, all_reduce_out

def _run_simple_worker(world_size, rank, distributed_init_port):

# Create workspace
# paddle.compat.enable_torch_proxy()
# Set all required environment variables
os.environ['FLAGS_selected_gpus'] = str(rank) # Key: set GPU ID
os.environ['PADDLE_TRAINER_ID'] = str(rank)
os.environ['PADDLE_TRAINERS_NUM'] = str(world_size)
os.environ['PADDLE_RANK_IN_NODE'] = str(rank)

# Build endpoint list
endpoints = ','.join([f'127.0.0.1:{distributed_init_port+i+10}' for i in range(world_size)])
os.environ['PADDLE_TRAINER_ENDPOINTS'] = endpoints
os.environ['PADDLE_CURRENT_ENDPOINT'] = f'127.0.0.1:{distributed_init_port+rank+10}'
# Set NCCL related environment variables (optional but recommended)
os.environ['FLAGS_sync_nccl_allreduce'] = '1'

# Set device
paddle.set_device(f"gpu:{rank}")

# Initialize distributed environment
dist_pp.init_parallel_env()
group_pp = dist_pp.get_group()

try:
# Create workspace
ipc_handles, workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
token_num,
hidden_dim,
group=group_pp,
use_fp32_lamport=False,
)
)

dist_pp.barrier(group=group_pp)

# Run fusion operation
allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size)

# # Calculate reference result
dist_pp.all_reduce(allreduce_in_clone, group=group_pp)
ref_allreduce_out = allreduce_in_clone.clone()

# # Verify results
tolerance = 8e-2
np.testing.assert_allclose(all_reduce_out.numpy(),
ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2)

print(f"Rank {rank}: Test passed!")

finally:
dist_pp.barrier(group=group_pp)
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp)
dist_pp.destroy_process_group(group=group_pp)


def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def test_trtllm_allreduce_fusion_simple():
# Fixed test parameters
world_size = 2

paddle.manual_seed(42)
paddle.cuda.manual_seed_all(42)

available_gpus = paddle.cuda.device_count()
if world_size > available_gpus:
pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available")

procs = []
distributed_init_port = get_open_port()
rank = dist_pp.get_rank()
_run_simple_worker(world_size, rank, distributed_init_port)

print("Simple allreduce fusion test: passed")


# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1
# ./test_torch_pp_launch.py
if __name__ == "__main__":
test_trtllm_allreduce_fusion_simple()
Loading