Skip to content

Commit

Permalink
[Distributed] add nccl primitives (#280)
Browse files Browse the repository at this point in the history
What we have now:

* automatically finding headers and libraries provided by the pypi
package `nvidia-nccl-cu11` or `nvidia-nccl-cu12`
* expose nccl's initialization apis to the python runtime
* a python launch script to start single-machine-multi-gpu tasks
(`./examples/distributed/test.py`)
* modified cudacontext to include a pool of nccl communicators
* primitive for all_reduce

---------

Co-authored-by: Yaoyao Ding <[email protected]>
  • Loading branch information
soodoshll and yaoyaoding committed Jun 19, 2023
1 parent a795526 commit bb1612e
Show file tree
Hide file tree
Showing 31 changed files with 560 additions and 56 deletions.
95 changes: 95 additions & 0 deletions examples/distributed/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Testing script for distributed components for hidet
To debug, set the environment variable NCCL_DEBUG=INFO
"""
import hidet
import multiprocessing
from multiprocessing import Process
import numpy
import argparse

import hidet.cuda.nccl
from hidet.cuda import nccl
from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename
from hidet.ffi import runtime_api
from hidet.lang import attrs
from hidet.ir.primitives.cuda.nccl import all_reduce
from hidet.ir.type import data_type
from hidet.utils import prod
from hidet.drivers import build_ir_module
from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs
from hidet.runtime import load_compiled_module

print("NCCL version:", nccl.nccl_version())

parser = argparse.ArgumentParser()
parser.add_argument("n_gpus", type=int)
parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg'])
args = parser.parse_args()

def run(world_size, rank, shared_id, barrier):
numpy.random.seed(rank)

# Initialize unique id
if rank == 0:
nccl.init_unique_id(shared_id)

barrier.wait()
hidet.cuda.set_device(rank)

print('initialize', rank)
# Create NcclCommunicator and set the cuda context
# this part should be moved into CompiledGraph in the future
comm = nccl.create_comm(world_size, shared_id, rank)
comms_array = nccl.comms_to_array([comm])
runtime_api.set_nccl_comms(comms_array)

# Initialize send and receive buffer
device = f"cuda:{rank}"
send = hidet.randn([2, 2], device=device)
recv = hidet.empty([2, 2], device=device)

print(rank, send)

dtype = data_type('float32')
shape = [2, 2]
nbytes = dtype.nbytes * prod(shape)

# Define IRModule
with hidet.script_module() as script_module:
@hidet.script
def launch(send: dtype[shape], recv: dtype[shape]):
attrs.func_kind = 'public'
all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op))

# Build
ir_module = script_module.ir_module()
ir_module.target = 'cuda'
ir_module.include_dirs.extend(get_nccl_include_dirs())
ir_module.linking_dirs.extend(get_nccl_library_search_dirs())
ir_module.include_headers.append(["nccl.h"])
ir_module.linking_libs.append(":" + nccl_library_filename())
out_dir = f'./.cache/all_reduce_{rank}'

build_ir_module(ir_module, out_dir, target='cuda')
compiled_module = load_compiled_module(out_dir)

compiled_module(send, recv)
s = hidet.cuda.current_stream()
s.synchronize()
print(rank, recv)

world_size = args.n_gpus

# Barrier to ensure unique id is created
barrier = multiprocessing.Barrier(world_size)

# Create a unique id object in shared memory
shared_id = multiprocessing.Value(NcclUniqueId, lock=False)

processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)]

for p in processes:
p.start()
for p in processes:
p.join()
1 change: 0 additions & 1 deletion gallery/developer-guides/hidet-script-dynamic-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def matmul_kernel(
# iterate over the k tiles
num_k_tiles = (k_size + block_k_size - 1) // block_k_size
for k_tile in range(num_k_tiles):

# load smem_a [block_m_size, block_k_size] from global memory
for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on(
threadIdx.x
Expand Down
14 changes: 14 additions & 0 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void* stream = nullptr;

/* NCCL Comunicators*/
void ** nccl_comms = nullptr;

int num_comms = 0;

/**
* Get the instance of cuda context.
*/
Expand All @@ -40,3 +45,12 @@ DLL void* get_cuda_stream();
*/
DLL void* request_cuda_workspace(size_t nbytes, bool require_clean);

/**
* Set required NCCL communicators of the context.
*/
DLL void set_nccl_comms(int num_comms, void** comm);

/**
* Get the NCCL communicator by the index
*/
DLL void* get_nccl_comm(int idx);
71 changes: 58 additions & 13 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ class SourceCompiler:
The base class of source compiler.
"""

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
raise NotImplementedError()

def run_compile_command(self, command: str, src_path, out_lib_path: str):
Expand Down Expand Up @@ -104,8 +112,16 @@ def _resolve_nvcc_path():
return path
raise FileNotFoundError('Can not find nvcc compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')

cc = hidet.cuda.compute_capability()
Expand All @@ -118,9 +134,10 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# the path to nvcc compiler
self.nvcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# optimize host side code via -O3
'-O3',
# host compiler options: enable openmp, avx2, unroll loops and fast math
Expand Down Expand Up @@ -153,7 +170,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'--shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -179,16 +196,25 @@ def _resolve_gcc_path():
return path
raise FileNotFoundError('Can not find g++ compiler.')

def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
def compile(
self,
src_path: str,
out_lib_path: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
if len(object_files) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')
command = [
# the path to nvcc compiler
self.gcc_path,
# the included directories.
*['-I{}'.format(include_dir) for include_dir in self.include_dirs],
*['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)],
# the library directories.
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
*['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)],
*['-l{}'.format(library) for library in linking_libraries],
# apply -O3 optimization.
'-O3',
# support avx intrinsics
Expand All @@ -204,7 +230,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st
# generate shared library (lib.so).
'-shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
' '.join(object_files),
# the source path.
src_path,
# the output library path.
Expand All @@ -216,7 +242,13 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st


def compile_source(
source_file: str, output_library_file: str, target: str, object_files: Optional[Sequence[str]]
source_file: str,
output_library_file: str,
target: str,
include_dirs: Sequence[str] = (),
linking_dirs: Sequence[str] = (),
linking_libraries: Sequence[str] = (),
object_files: Sequence[str] = (),
) -> None:
"""
Compile the source code in 'src_path' file and output the library to 'out_lib_path'.
Expand All @@ -229,6 +261,12 @@ def compile_source(
The path to output library.
target: str
The target platform. Currently only support 'cpu' and 'gpu'.
include_dirs: Optional[Sequence[str]]
The include directories.
linking_dirs: Optional[Sequence[str]]
The library directories.
linking_libraries:
The libraries to link to the output library.
object_files: Optional[Sequence[str]]
The path to object files. If not None, the object files will be linked to the output library.
"""
Expand All @@ -247,4 +285,11 @@ def compile_source(
raise ValueError('Unknown target platform: {}'.format(target))

object_files = object_files or []
compiler.compile(source_file, output_library_file, object_files)
compiler.compile(
source_file,
output_library_file,
include_dirs=include_dirs,
linking_dirs=linking_dirs,
linking_libraries=linking_libraries,
object_files=object_files,
)
5 changes: 5 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cuda/context.h>') + NewLine()
doc += Text("#include <hidet/runtime/logging.h>") + NewLine()

for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()

if self.require_tf32:
# nvcc use float to 'store' tfloat32 data
doc += Text('typedef float tfloat32_t;') + NewLine()
Expand Down Expand Up @@ -768,6 +771,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cpu/float16.h>') + NewLine()
if self.require_bf16:
doc += Text('#include <hidet/runtime/cpu/bfloat16.h>') + NewLine()
for header in self.ir_module.include_headers:
doc += Text('#include <{}>').format(header) + NewLine()
if self.require_tf32:
doc += Text('typedef float tfloat32_t;') + NewLine()
doc += NewLine()
Expand Down
13 changes: 13 additions & 0 deletions python/hidet/cuda/nccl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl
from .ffi import nccl_version, nccl_library_filename
Loading

0 comments on commit bb1612e

Please sign in to comment.