Skip to content

Commit aa916f4

Browse files
authored
[Feat] Support internode copy with intranode copy (#26)
* [Feat]support internode copy with intranode copy * [Feat]format code * format code
1 parent 190723c commit aa916f4

File tree

5 files changed

+164
-7
lines changed

5 files changed

+164
-7
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
import torch.distributed as dist
3+
import pynvshmem
4+
import tilelang
5+
import tilelang.language as T
6+
import os
7+
from tilelang.distributed.utils import init_distributed
8+
from tilelang.env import env
9+
from packaging import version
10+
import importlib.metadata
11+
12+
cuda_python_version = importlib.metadata.version("cuda-python")
13+
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
14+
from cuda.bindings import runtime as cudart
15+
else:
16+
from cuda import cudart
17+
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
18+
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
19+
20+
21+
def internode_gather(M, local_world_size, block_M, threads):
22+
23+
@T.prim_func
24+
def main(
25+
dst: T.Tensor((M), "float32"),
26+
src: T.Tensor((M), "float32"),
27+
):
28+
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
29+
rank = T.alloc_local([1], "uint64")
30+
rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size) # 2 nodes
31+
T.putmem_nbi_block(
32+
T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), block_M * 4,
33+
rank[0])
34+
35+
return main
36+
37+
38+
def intranode_gather(M, world_size, block_M, threads):
39+
40+
@T.prim_func
41+
def main(
42+
dst: T.Tensor((M * world_size), "float32"),
43+
src: T.Tensor((M * 2), "float32"),
44+
):
45+
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
46+
rank = T.alloc_local([1], "uint64")
47+
num_rank = T.alloc_local([1], "uint64")
48+
rank[0] = T.get_rank()
49+
num_rank[0] = T.get_num_ranks()
50+
tid = T.get_thread_binding()
51+
if tid == 0:
52+
T.print(T.cast(rank[0], "int32"), msg="signal")
53+
T.print(T.cast(num_rank[0], "int32"), msg="signal")
54+
for k in T.serial(world_size // 2): # 2 node
55+
T.put_block(
56+
src=T.address_of(src[bx * block_M]),
57+
dst=T.address_of(dst[bx * block_M + rank[0] * M]),
58+
size=block_M,
59+
dst_pe=k,
60+
)
61+
T.put_block(
62+
src=T.address_of(src[bx * block_M + M]),
63+
dst=T.address_of(dst[bx * block_M + M * num_rank[0] + rank[0] * M]),
64+
size=block_M,
65+
dst_pe=k,
66+
)
67+
68+
return main
69+
70+
71+
if __name__ == '__main__':
72+
tilelang.disable_cache()
73+
74+
M = 2
75+
K = 12288
76+
#for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
77+
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(
78+
return_tp_group=True, return_lc_group=True)
79+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
80+
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
81+
82+
allocator = tilelang.get_allocator(
83+
size=2**25,
84+
device="cuda",
85+
is_distributed=True,
86+
local_rank=LOCAL_RANK,
87+
num_local_ranks=local_world_size,
88+
group=LC_GROUP)
89+
print(local_world_size, LOCAL_RANK)
90+
91+
# Each rank sends the local_tensor to ranks of other nodes with the same local_rank
92+
# Assuming there are 2 nodes, each with 4 workers
93+
# 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0])
94+
# 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1])
95+
# 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
96+
# 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
97+
interkernel = tilelang.compile(internode_gather(M, local_world_size, M, 128))
98+
if LOCAL_RANK == 0:
99+
print(interkernel.get_kernel_source())
100+
src = pynvshmem.nvshmem_create_tensor([M], torch.float32)
101+
dst = pynvshmem.nvshmem_create_tensor([M], torch.float32)
102+
input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK
103+
src.copy_(input_data)
104+
105+
pynvshmem.nvshmem_barrier_all()
106+
dist.barrier(TP_GROUP)
107+
interkernel(dst, src)
108+
pynvshmem.nvshmem_barrier_all()
109+
110+
# Each rank sends the local_tensor and the received internode tensors to intranode ranks.
111+
# 0-th and 4-th local tensors ([0]->[1,2,3])
112+
# 1-th and 5-th local tensors ([1]->[0,2,3])
113+
# 2-th and 6-th local tensors ([2]->[0,1,3])
114+
# 3-th and 7-th local tensors ([3]->[0,1,2])
115+
# 0-th and 4-th local tensors ([4]->[5,6,7])
116+
# 1-th and 5-th local tensors ([5]->[4,6,7])
117+
# 2-th and 6-th local tensors ([6]->[4,5,7])
118+
# 3-th and 7-th local tensors ([7]->[4,5,6])
119+
src_intra = tilelang.tensor((M * 2), torch.float32, allocator=allocator).normal_()
120+
dst_intra = tilelang.tensor((M * WORLD_SIZE), torch.float32, allocator=allocator)
121+
if RANK < WORLD_SIZE / 2:
122+
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
123+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
124+
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
125+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
126+
else:
127+
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
128+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
129+
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
130+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
131+
132+
env.USE_NVSHMEM = False
133+
intrakernel = tilelang.compile(
134+
intranode_gather(M, WORLD_SIZE, M, 128),
135+
pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True})
136+
intrakernel.initialize(allocator=allocator)
137+
if LOCAL_RANK == 0:
138+
print(intrakernel.get_kernel_source())
139+
torch.cuda.synchronize()
140+
torch.distributed.barrier(LC_GROUP)
141+
intrakernel(dst_intra, src_intra)
142+
torch.cuda.synchronize()
143+
torch.distributed.barrier(LC_GROUP)
144+
145+
print(dst_intra)

src/target/codegen_cuda.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ std::string CodeGenTileLangCUDA::Finish() {
270270
}
271271

272272
if (use_nvshmem_) {
273-
decl_stream << "#include <nvshmem.h>>\n";
274-
decl_stream << "#include <nvshmemx.h>>\n";
273+
decl_stream << "#include <nvshmem.h>\n";
274+
decl_stream << "#include <nvshmemx.h>\n";
275275
}
276276

277277
if (need_cooperative_groups_) {

tilelang/distributed/launch.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ nproc_per_node=${GPUS:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `GPUS`
2020
nnodes=${NODES:=1} # set env var. `NODES` to # of nodes
2121
node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node
2222

23-
master_addr="127.0.0.1"
24-
master_port="$(( RANDOM % 100 + 23400 ))" # randomly choose a port between 23400 and 23499
23+
master_addr=${ARNOLD_WORKER_0_HOST:="127.0.0.1"}
24+
if [ -z ${ARNOLD_WORKER_0_PORT} ]; then
25+
master_port="8361"
26+
else
27+
master_port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1)
28+
fi
2529
additional_args="--rdzv_endpoint=${master_addr}:${master_port}"
2630
IB_HCA=mlx5
2731

tilelang/distributed/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,14 @@ def init_dist(local_rank: int, num_local_ranks: int):
5858
list(range(num_local_ranks * num_nodes)))
5959

6060

61-
def init_distributed(return_tp_group=False, init_nvshmem=True):
61+
def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=False):
6262
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
6363
RANK = int(os.environ.get("RANK", 0))
6464
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
6565

6666
torch.distributed.init_process_group(
6767
backend="nccl",
68+
device_id=torch.device(f'cuda:{LOCAL_RANK}'),
6869
world_size=WORLD_SIZE,
6970
rank=RANK,
7071
timeout=datetime.timedelta(seconds=1800),
@@ -78,7 +79,14 @@ def init_distributed(return_tp_group=False, init_nvshmem=True):
7879
import pynvshmem
7980
pynvshmem.init_nvshmem_by_uniqueid(TP_GROUP)
8081

81-
if return_tp_group:
82+
if return_lc_group:
83+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
84+
base = (RANK // local_world_size) * local_world_size
85+
LC_GROUP = torch.distributed.new_group(
86+
list(range(base, base + local_world_size)), backend="nccl")
87+
88+
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
89+
elif return_tp_group:
8290
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP
8391
else:
8492
return WORLD_SIZE, RANK, LOCAL_RANK

tilelang/utils/allocator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _init_table(self):
148148
] * self._group.size()
149149
local_ipc_handle = _create_ipc_handle(self._base_ptr.value)
150150
dist.all_gather_object(ipc_handles, local_ipc_handle, self._group)
151-
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64)
151+
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64, device='cuda')
152152
_sync_ipc_handles(self._local_rank, device_ids,
153153
ctypes.c_void_p(buffer_ptrs.data_ptr()).value, ipc_handles, None)
154154
buffer_ptrs[self._local_rank] = self._base_ptr.value

0 commit comments

Comments
 (0)