From bb1612e86436450731797f1082dc690200037fa5 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Mon, 19 Jun 2023 00:39:35 -0400 Subject: [PATCH] [Distributed] add nccl primitives (#280) What we have now: * automatically finding headers and libraries provided by the pypi package `nvidia-nccl-cu11` or `nvidia-nccl-cu12` * expose nccl's initialization apis to the python runtime * a python launch script to start single-machine-multi-gpu tasks (`./examples/distributed/test.py`) * modified cudacontext to include a pool of nccl communicators * primitive for all_reduce --------- Co-authored-by: Yaoyao Ding --- examples/distributed/test.py | 95 +++++++++++++++ .../hidet-script-dynamic-kernel.py | 1 - include/hidet/runtime/cuda/context.h | 14 +++ python/hidet/backend/build.py | 71 +++++++++--- python/hidet/backend/codegen.py | 5 + python/hidet/cuda/nccl/__init__.py | 13 +++ python/hidet/cuda/nccl/comm.py | 100 ++++++++++++++++ python/hidet/cuda/nccl/ffi.py | 108 ++++++++++++++++++ python/hidet/cuda/nccl/libinfo.py | 26 +++++ python/hidet/drivers/build_graph.py | 3 - python/hidet/drivers/build_module.py | 21 ++-- python/hidet/drivers/build_task.py | 11 +- python/hidet/ffi/ffi.py | 4 +- python/hidet/ffi/runtime_api.py | 7 ++ .../hidet/graph/ops/matmul/matmul_f32_x86.py | 2 - python/hidet/graph/ops/reduce/reduce_f16.py | 1 - python/hidet/ir/functors/compute_functor.py | 1 - python/hidet/ir/functors/module_functor.py | 2 +- python/hidet/ir/module.py | 44 ++++++- python/hidet/ir/primitives/cuda/nccl.py | 30 +++++ python/hidet/ir/primitives/runtime.py | 7 ++ python/hidet/ir/tools/printer.py | 9 ++ python/hidet/testing/models/llama.py | 6 +- python/hidet/transforms/base.py | 7 +- .../hidet/transforms/generate_launch_func.py | 1 - .../transforms/import_primitive_functions.py | 2 +- python/hidet/transforms/inline_function.py | 2 +- .../hidet/transforms/instantiate_symbols.py | 2 +- .../hidet/utils/model_translator/examples.py | 6 + src/hidet/runtime/cuda_context.cpp | 13 +++ tests/operators/test_operator.py | 2 + 31 files changed, 560 insertions(+), 56 deletions(-) create mode 100644 examples/distributed/test.py create mode 100644 python/hidet/cuda/nccl/__init__.py create mode 100644 python/hidet/cuda/nccl/comm.py create mode 100644 python/hidet/cuda/nccl/ffi.py create mode 100644 python/hidet/cuda/nccl/libinfo.py create mode 100644 python/hidet/ir/primitives/cuda/nccl.py diff --git a/examples/distributed/test.py b/examples/distributed/test.py new file mode 100644 index 000000000..bad43ef23 --- /dev/null +++ b/examples/distributed/test.py @@ -0,0 +1,95 @@ +""" +Testing script for distributed components for hidet +To debug, set the environment variable NCCL_DEBUG=INFO +""" +import hidet +import multiprocessing +from multiprocessing import Process +import numpy +import argparse + +import hidet.cuda.nccl +from hidet.cuda import nccl +from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename +from hidet.ffi import runtime_api +from hidet.lang import attrs +from hidet.ir.primitives.cuda.nccl import all_reduce +from hidet.ir.type import data_type +from hidet.utils import prod +from hidet.drivers import build_ir_module +from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs +from hidet.runtime import load_compiled_module + +print("NCCL version:", nccl.nccl_version()) + +parser = argparse.ArgumentParser() +parser.add_argument("n_gpus", type=int) +parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg']) +args = parser.parse_args() + +def run(world_size, rank, shared_id, barrier): + numpy.random.seed(rank) + + # Initialize unique id + if rank == 0: + nccl.init_unique_id(shared_id) + + barrier.wait() + hidet.cuda.set_device(rank) + + print('initialize', rank) + # Create NcclCommunicator and set the cuda context + # this part should be moved into CompiledGraph in the future + comm = nccl.create_comm(world_size, shared_id, rank) + comms_array = nccl.comms_to_array([comm]) + runtime_api.set_nccl_comms(comms_array) + + # Initialize send and receive buffer + device = f"cuda:{rank}" + send = hidet.randn([2, 2], device=device) + recv = hidet.empty([2, 2], device=device) + + print(rank, send) + + dtype = data_type('float32') + shape = [2, 2] + nbytes = dtype.nbytes * prod(shape) + + # Define IRModule + with hidet.script_module() as script_module: + @hidet.script + def launch(send: dtype[shape], recv: dtype[shape]): + attrs.func_kind = 'public' + all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op)) + + # Build + ir_module = script_module.ir_module() + ir_module.target = 'cuda' + ir_module.include_dirs.extend(get_nccl_include_dirs()) + ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) + ir_module.include_headers.append(["nccl.h"]) + ir_module.linking_libs.append(":" + nccl_library_filename()) + out_dir = f'./.cache/all_reduce_{rank}' + + build_ir_module(ir_module, out_dir, target='cuda') + compiled_module = load_compiled_module(out_dir) + + compiled_module(send, recv) + s = hidet.cuda.current_stream() + s.synchronize() + print(rank, recv) + +world_size = args.n_gpus + +# Barrier to ensure unique id is created +barrier = multiprocessing.Barrier(world_size) + +# Create a unique id object in shared memory +shared_id = multiprocessing.Value(NcclUniqueId, lock=False) + +processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)] + +for p in processes: + p.start() +for p in processes: + p.join() \ No newline at end of file diff --git a/gallery/developer-guides/hidet-script-dynamic-kernel.py b/gallery/developer-guides/hidet-script-dynamic-kernel.py index ddf71c649..97655469c 100644 --- a/gallery/developer-guides/hidet-script-dynamic-kernel.py +++ b/gallery/developer-guides/hidet-script-dynamic-kernel.py @@ -83,7 +83,6 @@ def matmul_kernel( # iterate over the k tiles num_k_tiles = (k_size + block_k_size - 1) // block_k_size for k_tile in range(num_k_tiles): - # load smem_a [block_m_size, block_k_size] from global memory for i, k in auto_map(block_m_size, block_k_size, workers=num_threads).on( threadIdx.x diff --git a/include/hidet/runtime/cuda/context.h b/include/hidet/runtime/cuda/context.h index dbcff8138..e1aa732c9 100644 --- a/include/hidet/runtime/cuda/context.h +++ b/include/hidet/runtime/cuda/context.h @@ -19,6 +19,11 @@ struct CudaContext: BaseContext { /* The cuda stream the kernels will be launched on. */ void* stream = nullptr; + /* NCCL Comunicators*/ + void ** nccl_comms = nullptr; + + int num_comms = 0; + /** * Get the instance of cuda context. */ @@ -40,3 +45,12 @@ DLL void* get_cuda_stream(); */ DLL void* request_cuda_workspace(size_t nbytes, bool require_clean); +/** + * Set required NCCL communicators of the context. + */ +DLL void set_nccl_comms(int num_comms, void** comm); + +/** + * Get the NCCL communicator by the index + */ +DLL void* get_nccl_comm(int idx); \ No newline at end of file diff --git a/python/hidet/backend/build.py b/python/hidet/backend/build.py index 497ca7adf..02dd71f8d 100644 --- a/python/hidet/backend/build.py +++ b/python/hidet/backend/build.py @@ -39,7 +39,15 @@ class SourceCompiler: The base class of source compiler. """ - def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None: + def compile( + self, + src_path: str, + out_lib_path: str, + include_dirs: Sequence[str] = (), + linking_dirs: Sequence[str] = (), + linking_libraries: Sequence[str] = (), + object_files: Sequence[str] = (), + ) -> None: raise NotImplementedError() def run_compile_command(self, command: str, src_path, out_lib_path: str): @@ -104,8 +112,16 @@ def _resolve_nvcc_path(): return path raise FileNotFoundError('Can not find nvcc compiler.') - def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None: - if len(linking_objects) > 0 and out_lib_path.endswith('.o'): + def compile( + self, + src_path: str, + out_lib_path: str, + include_dirs: Sequence[str] = (), + linking_dirs: Sequence[str] = (), + linking_libraries: Sequence[str] = (), + object_files: Sequence[str] = (), + ) -> None: + if len(object_files) > 0 and out_lib_path.endswith('.o'): raise ValueError('Can not compile multiple objects into a single object file.') cc = hidet.cuda.compute_capability() @@ -118,9 +134,10 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st # the path to nvcc compiler self.nvcc_path, # the included directories. - *['-I{}'.format(include_dir) for include_dir in self.include_dirs], + *['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)], # the library directories. - *['-L{}'.format(library_dir) for library_dir in self.library_dirs], + *['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)], + *['-l{}'.format(library) for library in linking_libraries], # optimize host side code via -O3 '-O3', # host compiler options: enable openmp, avx2, unroll loops and fast math @@ -153,7 +170,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st # generate shared library (lib.so). '--shared' if out_lib_path.endswith('.so') else '--compile', # the linking objects. - ' '.join(linking_objects), + ' '.join(object_files), # the source path. src_path, # the output library path. @@ -179,16 +196,25 @@ def _resolve_gcc_path(): return path raise FileNotFoundError('Can not find g++ compiler.') - def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None: - if len(linking_objects) > 0 and out_lib_path.endswith('.o'): + def compile( + self, + src_path: str, + out_lib_path: str, + include_dirs: Sequence[str] = (), + linking_dirs: Sequence[str] = (), + linking_libraries: Sequence[str] = (), + object_files: Sequence[str] = (), + ) -> None: + if len(object_files) > 0 and out_lib_path.endswith('.o'): raise ValueError('Can not compile multiple objects into a single object file.') command = [ # the path to nvcc compiler self.gcc_path, # the included directories. - *['-I{}'.format(include_dir) for include_dir in self.include_dirs], + *['-I{}'.format(include_dir) for include_dir in self.include_dirs + list(include_dirs)], # the library directories. - *['-L{}'.format(library_dir) for library_dir in self.library_dirs], + *['-L{}'.format(library_dir) for library_dir in self.library_dirs + list(linking_dirs)], + *['-l{}'.format(library) for library in linking_libraries], # apply -O3 optimization. '-O3', # support avx intrinsics @@ -204,7 +230,7 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st # generate shared library (lib.so). '-shared' if out_lib_path.endswith('.so') else '--compile', # the linking objects. - ' '.join(linking_objects), + ' '.join(object_files), # the source path. src_path, # the output library path. @@ -216,7 +242,13 @@ def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[st def compile_source( - source_file: str, output_library_file: str, target: str, object_files: Optional[Sequence[str]] + source_file: str, + output_library_file: str, + target: str, + include_dirs: Sequence[str] = (), + linking_dirs: Sequence[str] = (), + linking_libraries: Sequence[str] = (), + object_files: Sequence[str] = (), ) -> None: """ Compile the source code in 'src_path' file and output the library to 'out_lib_path'. @@ -229,6 +261,12 @@ def compile_source( The path to output library. target: str The target platform. Currently only support 'cpu' and 'gpu'. + include_dirs: Optional[Sequence[str]] + The include directories. + linking_dirs: Optional[Sequence[str]] + The library directories. + linking_libraries: + The libraries to link to the output library. object_files: Optional[Sequence[str]] The path to object files. If not None, the object files will be linked to the output library. """ @@ -247,4 +285,11 @@ def compile_source( raise ValueError('Unknown target platform: {}'.format(target)) object_files = object_files or [] - compiler.compile(source_file, output_library_file, object_files) + compiler.compile( + source_file, + output_library_file, + include_dirs=include_dirs, + linking_dirs=linking_dirs, + linking_libraries=linking_libraries, + object_files=object_files, + ) diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index fa15efca8..c094b8944 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -682,6 +682,9 @@ def require_headers(self) -> Doc: doc += Text('#include ') + NewLine() doc += Text("#include ") + NewLine() + for header in self.ir_module.include_headers: + doc += Text('#include <{}>').format(header) + NewLine() + if self.require_tf32: # nvcc use float to 'store' tfloat32 data doc += Text('typedef float tfloat32_t;') + NewLine() @@ -768,6 +771,8 @@ def require_headers(self) -> Doc: doc += Text('#include ') + NewLine() if self.require_bf16: doc += Text('#include ') + NewLine() + for header in self.ir_module.include_headers: + doc += Text('#include <{}>').format(header) + NewLine() if self.require_tf32: doc += Text('typedef float tfloat32_t;') + NewLine() doc += NewLine() diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py new file mode 100644 index 000000000..6ff476d71 --- /dev/null +++ b/python/hidet/cuda/nccl/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl +from .ffi import nccl_version, nccl_library_filename diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py new file mode 100644 index 000000000..241ffc18f --- /dev/null +++ b/python/hidet/cuda/nccl/comm.py @@ -0,0 +1,100 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import IntEnum +from typing import List +import struct + +from hidet.ffi.utils import Array +from hidet.ir.type import void_p, DataType +from .ffi import nccl_runtime_api, NcclUniqueId + + +class NcclDataType(IntEnum): + int8 = 0 + char = 0 + uint8 = 1 + int32 = 2 + int = 2 + uint32 = 3 + int64 = 4 + uint64 = 5 + float16 = 6 + half = 6 + float32 = 7 + float = 7 + float64 = 8 + double = 8 + bfloat = 9 + + +class NcclRedOp(IntEnum): + sum = 0 + prod = 1 + max = 2 + min = 3 + avg = 4 + + +class NcclCommunicator: + def __init__(self, handle: int): + """ + Users should not call this constructor directly. Because there are two ways of creating + a new communicator: 1) using unique_id and rank ; 2) using split. + """ + + self._handle = handle + + def __del__(self): + nccl_runtime_api.comm_destroy(self._handle) + + @property + def handle(self): + return self._handle + + def split(self): + raise NotImplementedError() + + +def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator: + handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank) + return NcclCommunicator(handle) + + +def comms_to_array(comms: List[NcclCommunicator]) -> Array: + handles = [comm.handle for comm in comms] + array = Array(void_p, len(comms)) + struct.pack_into(array.format, array.buffer, 0, *handles) + return array + + +def init_unique_id(unqie_id: NcclUniqueId) -> None: + nccl_runtime_api.get_unique_id(unqie_id) + + +def dtype_to_nccl(dtype: DataType) -> NcclDataType: + sname_dict = { + 'f64': NcclDataType.float64, + 'f32': NcclDataType.float32, + 'bf16': NcclDataType.bfloat, + 'f16': NcclDataType.float16, + 'i64': NcclDataType.int64, + 'i32': NcclDataType.int32, + 'i8': NcclDataType.int8, + 'u64': NcclDataType.uint64, + 'u32': NcclDataType.uint32, + 'u8': NcclDataType.uint8, + } + sname = dtype.short_name + nccl_type = sname_dict.get(sname, None) + if nccl_type is None: + raise RuntimeError(f"Data type {dtype.name} is not supported in NCCL") + return nccl_type diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py new file mode 100644 index 000000000..66a7dbc74 --- /dev/null +++ b/python/hidet/cuda/nccl/ffi.py @@ -0,0 +1,108 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import ctypes +from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER +import glob +import os + +from hidet.ffi.ffi import get_func +from .libinfo import get_nccl_library_search_dirs + +_LIB_NCCL: Optional[ctypes.CDLL] = None +nccl_library_path = None + + +class NcclUniqueId(Structure): + """ + Defined as in nccl.h + """ + + _fields_ = [("internal", c_byte * 128)] + + +def nccl_available(): + return _LIB_NCCL is not None + + +def nccl_version(): + return nccl_runtime_api.get_version() + + +def load_nccl_library(): + global _LIB_NCCL, nccl_library_path + library_dirs = get_nccl_library_search_dirs() + for library_dir in library_dirs: + lib_nccl_paths = glob.glob(os.path.join(library_dir, 'libnccl.so*')) + if len(lib_nccl_paths) == 0: + continue + _LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0]) + nccl_library_path = lib_nccl_paths[0] + break + if _LIB_NCCL is None: + raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs)) + + +load_nccl_library() + + +def nccl_library_filename(): + return os.path.basename(nccl_library_path) + + +if not nccl_available(): + raise RuntimeError("NCCL Library not found.") + + +class NCCLRuntimeAPI: + """ + Runtime APIs regarding NCCL + TODO: Exception handling + """ + + _get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL) + _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL) + _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL) + _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL) + + _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + + @staticmethod + def get_version() -> int: + version = c_int(0) + NCCLRuntimeAPI._get_version(pointer(version)) + return version.value + + @staticmethod + def get_unique_id(comm_id: NcclUniqueId) -> None: + """ + In-place initialization of the NcclUniqueId object + """ + ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) + assert ret == 0, ret + + @staticmethod + def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int: + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) + assert ret == 0, ret + return comm.value + + @staticmethod + def comm_destroy(comm_handle) -> None: + ret = NCCLRuntimeAPI._comm_destroy(comm_handle) + assert ret == 0 + + +nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/cuda/nccl/libinfo.py b/python/hidet/cuda/nccl/libinfo.py new file mode 100644 index 000000000..4e70df8c2 --- /dev/null +++ b/python/hidet/cuda/nccl/libinfo.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +def _get_nccl_dirs(): + import site + + return [os.path.join(root, 'nvidia', 'nccl') for root in site.getsitepackages()] + + +def get_nccl_include_dirs(): + return [os.path.join(root, 'include') for root in _get_nccl_dirs()] + + +def get_nccl_library_search_dirs(): + return [os.path.join(root, 'lib') for root in _get_nccl_dirs()] diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index b1bb84138..994d13d88 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -73,7 +73,6 @@ def create_graph_execution(graph: FlowGraph, weights: List[Tensor], node2kernel: instructions: List[GraphExecutionInstruction] = [] for node_idx, node in enumerate(graph.nodes): - inst_task_idx = node2kernel[node_idx] inst_inputs = [tensor_index[x] for x in node.inputs] @@ -155,7 +154,6 @@ def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kerne graph_nodes: List[Operator] = graph.nodes with hidet.script_module() as script_module: - cpu_workspace = script_module.define_global_var('cpu_workspace', byte_p) cuda_workspace = script_module.define_global_var('cuda_workspace', byte_p) weights = script_module.define_global_var('weights', void_p[len(graph_weights)]) @@ -301,7 +299,6 @@ def save_to_graph_cache(cgraph: CompiledGraph): def build_flow_graph(graph, *, space=0) -> CompiledGraph: - assert isinstance(graph, FlowGraph) # get the graph weights diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 0defe3c6d..05671bec8 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Dict +from typing import Sequence, Dict import logging import os import pickle @@ -31,14 +31,7 @@ logger.addHandler(logging.StreamHandler()) -def build_ir_module( - ir_module: IRModule, - output_dir: str, - *, - target: str, - output_kind: str = '.so', # '.so', '.o' - object_files: Optional[Sequence[str]] = None, -): +def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so'): # '.so', '.o' if target == 'cuda': src_path = os.path.join(output_dir, 'source.cu') elif target == 'cpu': @@ -66,7 +59,15 @@ def build_ir_module( codegen(ir_module, src_out_path=src_path, target=target) # compile source code - compile_source(src_path, output_library_file=lib_path, target=target, object_files=object_files) + compile_source( + src_path, + output_library_file=lib_path, + target=target, + include_dirs=ir_module.include_dirs, + linking_dirs=ir_module.linking_dirs, + linking_libraries=ir_module.linking_libs, + object_files=ir_module.object_files, + ) # write the function types if output_kind == '.so': diff --git a/python/hidet/drivers/build_task.py b/python/hidet/drivers/build_task.py index 572317014..584ffe9ab 100644 --- a/python/hidet/drivers/build_task.py +++ b/python/hidet/drivers/build_task.py @@ -76,8 +76,6 @@ def get_output_shape(idx: int32, dims: ~int32): launch_func.name = 'launch_0' task_ir_module.functions['launch_0'] = launch_func - - object_files = [] else: # otherwise, build each candidate to a .o file, and link them into the task's ir module for i, candidate in enumerate(candidates): @@ -94,7 +92,6 @@ def get_output_shape(idx: int32, dims: ~int32): param_types = [~t.type.dtype for t in task.params] with hidet.script_module() as script_module: - launch_candidates = [] for i in range(len(candidates)): launch_candidates.append( @@ -115,8 +112,10 @@ def launch(arg: meta.types(param_types)): ir_module = script_module.ir_module() ir_module.add_function(get_input_shape.name, get_input_shape) ir_module.add_function(get_output_shape.name, get_output_shape) + ir_module.object_files.extend( + [os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))] + ) task_ir_module = ir_module - object_files = [os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))] # add assertions to the launch function if len(task.assertions) > 0: @@ -131,9 +130,7 @@ def launch(arg: meta.types(param_types)): body.seq = assertions + body.seq # build task ir module - build_ir_module( - ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', object_files=object_files, target=target - ) + build_ir_module(ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', target=target) def generate_meta_data(task: Task, task_dir: str, build_target: str, num_candidates: int): diff --git a/python/hidet/ffi/ffi.py b/python/hidet/ffi/ffi.py index b8e968ae4..6771b9157 100644 --- a/python/hidet/ffi/ffi.py +++ b/python/hidet/ffi/ffi.py @@ -66,11 +66,13 @@ def func_exists(func_name: str, shared_lib: ctypes.CDLL) -> bool: return False -def get_func(func_name, arg_types: List, restype): +def get_func(func_name, arg_types: List, restype, lib=None): if func_exists(func_name, _LIB): func = getattr(_LIB, func_name) elif func_exists(func_name, _LIB_RUNTIME): func = getattr(_LIB_RUNTIME, func_name) + elif func_exists(func_name, lib): + func = getattr(lib, func_name) else: raise ValueError( 'Can not find function "{}" in hidet libraries:\n{}\n{}'.format( diff --git a/python/hidet/ffi/runtime_api.py b/python/hidet/ffi/runtime_api.py index 0880f45ea..3b2e3af24 100644 --- a/python/hidet/ffi/runtime_api.py +++ b/python/hidet/ffi/runtime_api.py @@ -13,6 +13,7 @@ from ctypes import c_void_p, c_char_p, c_uint64, c_int32 from hidet.cuda import Stream from .ffi import get_func +from .utils import Array class RuntimeAPI: @@ -24,6 +25,7 @@ class RuntimeAPI: _reset_symbol_table = get_func('reset_symbol_table', [], None) _get_symbol_value = get_func('get_symbol_value', [c_char_p], c_int32) _set_symbol_value = get_func('set_symbol_value', [c_char_p, c_int32], None) + _set_nccl_comms = get_func('set_nccl_comms', [c_int32, c_void_p], None) @staticmethod def set_current_stream(stream: Union[Stream, int]) -> None: @@ -61,5 +63,10 @@ def set_symbol_value(name: str, value: int) -> None: name = name.encode('utf-8') RuntimeAPI._set_symbol_value(name, value) + @staticmethod + def set_nccl_comms(comms: Array) -> None: + comms_array_t = c_void_p * comms.length + RuntimeAPI._set_nccl_comms(comms.length, comms_array_t.from_buffer(comms.buffer)) + runtime_api = RuntimeAPI() diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 85b0ff174..b94d8becf 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -186,7 +186,6 @@ def micro_kernel_6x16( def micro_kernel_4x8( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c1 = avx_f32x8_load(~c[1, 0]) @@ -213,7 +212,6 @@ def micro_kernel_4x8( def micro_kernel_8x8( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c1 = avx_f32x8_load(~c[1, 0]) diff --git a/python/hidet/graph/ops/reduce/reduce_f16.py b/python/hidet/graph/ops/reduce/reduce_f16.py index cc921780d..aede58ba8 100644 --- a/python/hidet/graph/ops/reduce/reduce_f16.py +++ b/python/hidet/graph/ops/reduce/reduce_f16.py @@ -26,7 +26,6 @@ class ReduceF16Task(Task): def __init__( self, x: TensorNode, dims: List[int], keep_dim: bool, reduce_type: ReduceType, accumulate_dtype: str = 'float32' ): - y_shape = [] for i in range(len(x.shape)): if i in dims: diff --git a/python/hidet/ir/functors/compute_functor.py b/python/hidet/ir/functors/compute_functor.py index 63edb7eaa..86420a905 100644 --- a/python/hidet/ir/functors/compute_functor.py +++ b/python/hidet/ir/functors/compute_functor.py @@ -131,7 +131,6 @@ def visit_ReduceCompute(self, node: ReduceCompute): ): return node else: - return ReduceCompute(node.name, shape, axes, value, node.reduce_operation, accumulate_dtype) def visit_ArgReduceCompute(self, node: ArgReduceCompute): diff --git a/python/hidet/ir/functors/module_functor.py b/python/hidet/ir/functors/module_functor.py index 2124ed5e9..5e51cc6fb 100644 --- a/python/hidet/ir/functors/module_functor.py +++ b/python/hidet/ir/functors/module_functor.py @@ -52,7 +52,7 @@ def visit_IRModule(self, module: IRModule): if same_list(global_vars, module.global_vars) and functions is module.functions: return module else: - return IRModule(functions, global_vars, module.namespace, module.extern_functions) + return module.copy().reset_funcs(functions, global_vars) def visit_Function(self, func: Function): params = self.visit(func.params) diff --git a/python/hidet/ir/module.py b/python/hidet/ir/module.py index 1aacefa2b..751c05a24 100644 --- a/python/hidet/ir/module.py +++ b/python/hidet/ir/module.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Dict +from typing import Dict, List from hidet.ir.node import Node from hidet.ir.type import FuncType from hidet.ir.expr import Var @@ -24,11 +24,33 @@ class IRModule(Node): An IRModule contains one or more functions. It is the basic compilation unit of hidet. """ - def __init__(self, functions=None, global_vars=None, namespace='', extern_functions: Dict[str, Var] = None): + def __init__( + self, + functions=None, + global_vars=None, + namespace='', + extern_functions: Dict[str, Var] = None, + include_headers: List[str] = None, + include_dirs: List[str] = None, + linking_dirs: List[str] = None, + linking_libs: List[str] = None, + object_files: List[str] = None, + ): + # the functions defined in this module self.functions: Dict[str, Function] = functions if functions else {} + # the global variables defined in this module self.global_vars: Dict[str, Var] = global_vars if global_vars else {} + # the namespace of the module, all the functions and the global variables will be defined in this namespace self.namespace: str = namespace + # the external functions that are used in this module, the Var must have a function type self.extern_functions: Dict[str, Var] = {} if extern_functions is None else extern_functions + # '#include ...' preprocessor directives + self.include_headers: List[str] = include_headers if include_headers else [] + # flags that will be passed to the underlying compiler, can be used to add 3rd-party libraries + self.include_dirs: List[str] = include_dirs if include_dirs else [] # -L flags + self.linking_dirs: List[str] = linking_dirs if linking_dirs else [] # -I flags + self.linking_libs: List[str] = linking_libs if linking_libs else [] # -l flags + self.object_files: List[str] = object_files if object_files else [] # .o files assert all(isinstance(func, Function) for func in self.functions.values()) and all( isinstance(var, Var) for var in self.global_vars.values() @@ -51,6 +73,24 @@ def add_function(self, name, func: Function): else: self.functions[name] = func + def copy(self): + return IRModule( + functions=self.functions.copy(), + global_vars=self.global_vars.copy(), + namespace=self.namespace, + extern_functions=self.extern_functions.copy(), + include_headers=self.include_headers.copy(), + include_dirs=self.include_dirs.copy(), + linking_dirs=self.linking_dirs.copy(), + linking_libs=self.linking_libs.copy(), + object_files=self.object_files.copy(), + ) + + def reset_funcs(self, functions: Dict[str, Function] = None, global_vars: Dict[str, Var] = None): + self.functions = functions if functions else {} + self.global_vars = global_vars if global_vars else {} + return self + def build(self): """ Build the module. diff --git a/python/hidet/ir/primitives/cuda/nccl.py b/python/hidet/ir/primitives/cuda/nccl.py new file mode 100644 index 000000000..96b2b2a2d --- /dev/null +++ b/python/hidet/ir/primitives/cuda/nccl.py @@ -0,0 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir.expr import Expr +from hidet.ir.stmt import BlackBoxStmt +from hidet.ir.type import DataType + +# TODO: we should not put nccl-related types here since hidet.cuda.nccl depends on +# the existence of nccl library? +from hidet.cuda.nccl import NcclRedOp, dtype_to_nccl + + +def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp): + from hidet.ir.primitives.runtime import get_cuda_stream, get_nccl_comm + + comm = get_nccl_comm(comm_id) + return BlackBoxStmt( + 'ncclAllReduce({}, {}, {}, (ncclDataType_t){}, (ncclRedOp_t){}, ' + '(ncclComm_t){}, (cudaStream_t){});'.format( + sendbuff, recvbuff, count, int(dtype_to_nccl(dtype)), int(op), comm, get_cuda_stream() + ) + ) diff --git a/python/hidet/ir/primitives/runtime.py b/python/hidet/ir/primitives/runtime.py index fd891fa43..e16e55490 100644 --- a/python/hidet/ir/primitives/runtime.py +++ b/python/hidet/ir/primitives/runtime.py @@ -53,6 +53,9 @@ def register_functions(): register_primitive_function( name='memory_planner_used', func_or_type=FuncType([int32], int64), codegen_name='memory_planner_used' ) + register_primitive_function( + name='get_nccl_comm', func_or_type=FuncType([int32], void_p), codegen_name='get_nccl_comm' + ) def get_cuda_stream() -> void_p: @@ -89,3 +92,7 @@ def memory_planner_free(idx: Union[int, Expr], ptr: Union[int, Expr]): def memory_planner_used(idx: Union[int, Expr]): return call_primitive_func('memory_planner_used', [idx]) + + +def get_nccl_comm(idx: int) -> void_p: + return call_primitive_func('get_nccl_comm', [idx]) diff --git a/python/hidet/ir/tools/printer.py b/python/hidet/ir/tools/printer.py index 1ae70a47d..455b1f734 100644 --- a/python/hidet/ir/tools/printer.py +++ b/python/hidet/ir/tools/printer.py @@ -90,6 +90,15 @@ def visit_IRModule(self, ir_module: IRModule): doc = Doc() self.ir_module = ir_module + for linking_lib in ir_module.linking_libs: + doc += Text('link lib: ') + linking_lib + NewLine() + for object_file in ir_module.object_files: + doc += Text('external object: ') + object_file + NewLine() + for header in ir_module.include_headers: + doc += Text('#include <{}>'.format(header)) + NewLine() + if len(ir_module.include_headers) + len(ir_module.linking_libs) + len(ir_module.object_files) > 0: + doc += NewLine() + for name, var in ir_module.global_vars.items(): if name in ir_module.functions: continue diff --git a/python/hidet/testing/models/llama.py b/python/hidet/testing/models/llama.py index c546fd943..6b4f2b7a0 100644 --- a/python/hidet/testing/models/llama.py +++ b/python/hidet/testing/models/llama.py @@ -161,7 +161,6 @@ def forward( position_ids: Optional[hidet.Tensor] = None, past_key_value: Optional[Tuple[hidet.Tensor]] = None, ) -> Tuple[hidet.Tensor, Tuple[hidet.Tensor, hidet.Tensor]]: - bsz, q_len, _ = hidden_states.shape query_states = self.q_proj(hidden_states).reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose(1, 2) @@ -370,12 +369,12 @@ def build_flow_graph(model, batch_size=1, device='cuda', dtype='float16'): y = model(input_ids, position_ids=position_ids, past_key_values=key_value_cache) inputs = [input_ids, position_ids] - for (q, k) in key_value_cache: + for q, k in key_value_cache: inputs.append(q) inputs.append(k) outputs = [y['new_ids']] - for (q, k) in y['past_key_values']: + for q, k in y['past_key_values']: outputs.append(q) outputs.append(k) @@ -437,7 +436,6 @@ def generate_torch(input_ids: str, tokenizer, torch_model, num_tokens, device='c def get_compiled_model(name='decapoda-research/llama-7b-hf', device='cuda', opt=False): - tok = LlamaTokenizer.from_pretrained(name) with torch.device("cuda"): # reduce the time to load the model diff --git a/python/hidet/transforms/base.py b/python/hidet/transforms/base.py index 2c10d546c..a8bf32edb 100644 --- a/python/hidet/transforms/base.py +++ b/python/hidet/transforms/base.py @@ -59,12 +59,7 @@ def process_module(self, ir_module: IRModule) -> IRModule: if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs): return ir_module else: - return IRModule( - functions=new_funcs, - global_vars=ir_module.global_vars, - namespace=ir_module.namespace, - extern_functions=ir_module.extern_functions, - ) + return ir_module.copy().reset_funcs(new_funcs, ir_module.global_vars) def process_func(self, func: Function) -> Function: return func diff --git a/python/hidet/transforms/generate_launch_func.py b/python/hidet/transforms/generate_launch_func.py index 355fb32df..3a7103f7d 100644 --- a/python/hidet/transforms/generate_launch_func.py +++ b/python/hidet/transforms/generate_launch_func.py @@ -40,7 +40,6 @@ def _rewrite_dim3(dim3: Tuple[Expr, Expr, Expr], param2arg: Dict[Expr, Expr]) -> def add_launch_func(ir_module: IRModule, kernel_func: Function): - with FunctionBuilder(name='launch', kind='public') as fb: params = [Var(param.hint, param.type) for param in kernel_func.params] param_remap = {a: b for a, b in zip(kernel_func.params, params)} diff --git a/python/hidet/transforms/import_primitive_functions.py b/python/hidet/transforms/import_primitive_functions.py index e444c7ef7..45c6c284b 100644 --- a/python/hidet/transforms/import_primitive_functions.py +++ b/python/hidet/transforms/import_primitive_functions.py @@ -37,7 +37,7 @@ def process_module(self, ir_module: IRModule) -> IRModule: if len(primitive_funcs) == 0: return ir_module else: - new_ir_module = IRModule(namespace=ir_module.namespace, extern_functions=ir_module.extern_functions) + new_ir_module = ir_module.copy().reset_funcs() for func_name, func in ir_module.functions.items(): new_ir_module.add_function(func_name, func) for func in primitive_funcs: diff --git a/python/hidet/transforms/inline_function.py b/python/hidet/transforms/inline_function.py index 4b361caef..faf26b307 100644 --- a/python/hidet/transforms/inline_function.py +++ b/python/hidet/transforms/inline_function.py @@ -145,7 +145,7 @@ def prune_unused_functions(ir_module: IRModule): class InlineFunctionPass(Pass): def process_module(self, ir_module: IRModule) -> IRModule: call_graph = CallGraph(ir_module, allow_missing=True) - updated_ir_module = IRModule(namespace=ir_module.namespace, extern_functions=ir_module.extern_functions) + updated_ir_module = ir_module.copy().reset_funcs() for node in call_graph.reversed_order: assert isinstance(node, CallGraphNode) func = inline_callees(node.func, updated_ir_module) diff --git a/python/hidet/transforms/instantiate_symbols.py b/python/hidet/transforms/instantiate_symbols.py index d4c0c5e24..ea8e7e9a3 100644 --- a/python/hidet/transforms/instantiate_symbols.py +++ b/python/hidet/transforms/instantiate_symbols.py @@ -35,7 +35,7 @@ def __init__(self): self.current_func: Optional[str] = None def visit_IRModule(self, module: IRModule): - updated_module = IRModule(namespace=module.namespace, extern_functions=module.extern_functions) + updated_module = module.copy().reset_funcs() call_graph = CallGraph(module, allow_missing=True) for node in call_graph.reversed_order: updated_module.functions[node.func.name] = self.visit(node.func) diff --git a/python/hidet/utils/model_translator/examples.py b/python/hidet/utils/model_translator/examples.py index 129181e7d..66f5744fd 100644 --- a/python/hidet/utils/model_translator/examples.py +++ b/python/hidet/utils/model_translator/examples.py @@ -18,6 +18,7 @@ import torch import numpy as np + # transpile a simple function def orig_func(a, b: torch.Tensor, c: float, dim: int = 1): """ @@ -40,6 +41,7 @@ def orig_func(a, b: torch.Tensor, c: float, dim: int = 1): print("final result:\n") print(transpiled_str(interpreter)) + # %% # conditionals with multiple branches may need to have all branches covered def conditional(a: torch.Tensor, b: torch.Tensor, c): @@ -63,6 +65,7 @@ def conditional(a: torch.Tensor, b: torch.Tensor, c): print(transpiled_str(interpreter)) + # %% def forloop(a: torch.Tensor, c, l): for i in range(l): @@ -79,6 +82,7 @@ def forloop(a: torch.Tensor, c, l): vis_interpreter(interpreter) print(transpiled_str(interpreter)) + # %% # calling a function that is not in the torch name space will trigger a recursive trace # asking for user input whether to trace or not @@ -100,6 +104,7 @@ def raw_function(a, b, c, d): vis_interpreter(interpreter) print(transpiled_str(interpreter)) + # %% # tracing classes works a bit differently class TestClass(torch.nn.Module): @@ -140,6 +145,7 @@ def test2(self): vis_interpreter(intp) print(transpiled_str(intp)) + # %% # inheritance would trigger more recursive traces class TestClass2(TestClass): diff --git a/src/hidet/runtime/cuda_context.cpp b/src/hidet/runtime/cuda_context.cpp index 2c9449be1..0a6273acf 100644 --- a/src/hidet/runtime/cuda_context.cpp +++ b/src/hidet/runtime/cuda_context.cpp @@ -55,3 +55,16 @@ DLL void* request_cuda_workspace(size_t nbytes, bool require_clean) { return nullptr; } } + +DLL void set_nccl_comms(int num_comms, void** comms) { + CudaContext::global()->num_comms = num_comms; + CudaContext::global()->nccl_comms = comms; +} + +DLL void* get_nccl_comm(int idx) { + const int num_comms = CudaContext::global()->num_comms; + if (idx >= num_comms) { + LOG(FATAL) << "Index of NCCL Communicator out of boundary. (" << idx << " vs " << num_comms << ")"; + } + return CudaContext::global()->nccl_comms[idx]; +} \ No newline at end of file diff --git a/tests/operators/test_operator.py b/tests/operators/test_operator.py index cb7d993d3..0fff63830 100644 --- a/tests/operators/test_operator.py +++ b/tests/operators/test_operator.py @@ -12,6 +12,8 @@ import hidet import pytest +hidet.option.save_lower_ir() + def test_profile_config(): a = hidet.randn([1, 10, 10], device='cuda')