From 3e6a29119e97866eeefdaec0553f1951b142b0a5 Mon Sep 17 00:00:00 2001 From: Max Hu Date: Sun, 22 Dec 2024 18:29:38 -0500 Subject: [PATCH] [Enhancement] Add option and functionality to set torch stream as the current stream (#629) Use torch c++ api to set the current stream to current torch stream. Implementation: - Build a hidet-torch shared library to wrap the original torch C++ API (The original API contains torch defined structure like `CUDAStream` and cannot be easily dlopened during runtime and accessed) - dlopen the newly added hidet-torch library and access torch's current stream - Add option "use_torch_stream" to hidet's option to dynamically set the stream to current torch stream or hidet's stream during runtime - When hidet's CUDA graph mode is on, hidet will still create a new hidet stream and capture the graph on that stream instead of using the torch stream. Benefits: - Removes the overhead of query and calling torch's current stream api from the python side - Could also reduce the overhead occured in Hexcute integration because `set_to_torch_stream` is called in the launch function. We can remove the stream query/switch on python side. Performance improvement (measured on L4 lock frequency@6250MHZ compute/1500MHZ memory): 1. For Hexcute kernel (without cudagraph), I manually disabled CUDA graph on DMWL (vLLM) side, prefill and decoding stage will both use the generic model and call Hexcute kernel directly. command: `python3 benchmark_latency.py --model hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --input-len 1024 --output-len 128 --batch-size 8 --num-iters-warmup 5 --num-iters 10 --max-model-len 32768 --quantization awq_hidet` Comparsion before and after removing stream query and stream switch before Hexcute kernel call (https://github.com/CentML/DMWL/pull/121) Before avg latency: 12.624572871897545 seconds After avg latency: 11.764245539499097 seconds 2. Profile small kermels in hidet and measure latency: - Enable CUDA graph `python bench_op_torch_api.py --params 16x16,16x16 --mode max-autotune matmul` Before: 0.27151119 second After: 0.25410826999999997 second - Disable CUDA graph `python bench_op_torch_api.py --params 16x16,16x16 --mode max-autotune-no-cudagraphs matmul` Before: 0.14555310999999999 second After: 0.11648335 second This is related to #563 --- CMakeLists.txt | 30 ++++++++++++- apps/compile_server/requirements.txt | 2 +- .../getting-started/build-from-source.rst | 4 ++ include/hidet/runtime/cuda/context.h | 13 ++++++ include/hidet/runtime/torch/stream.h | 15 +++++++ python/hidet/cuda/graph.py | 11 ++++- python/hidet/cuda/stream.py | 9 +++- python/hidet/ffi/ffi.py | 21 +++++++-- python/hidet/ffi/runtime_api.py | 16 +++++-- .../graph/frontend/torch/dynamo_backends.py | 2 - python/hidet/option.py | 28 ++++++++++++ requirements-dev.txt | 1 - requirements.txt | 3 ++ scripts/wheel/build_wheel.sh | 5 +++ src/hidet/runtime/cuda/context.cpp | 12 +++++ src/hidet/runtime/hidet_torch/cuda_stream.cpp | 6 +++ src/hidet/runtime/torch/stream.cpp | 37 ++++++++++++++++ tests/frontends/torch/test_torch_stream.py | 44 +++++++++++++++++++ 18 files changed, 245 insertions(+), 14 deletions(-) create mode 100644 include/hidet/runtime/torch/stream.h create mode 100644 src/hidet/runtime/hidet_torch/cuda_stream.cpp create mode 100644 src/hidet/runtime/torch/stream.cpp create mode 100644 tests/frontends/torch/test_torch_stream.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e982c11b..19d2d2beb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19) project(hidet C CXX) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # config hidet @@ -34,6 +34,7 @@ add_library(hidet_runtime SHARED src/hidet/runtime/llm/tokenizer/pretokenizers.cpp src/hidet/runtime/llm/tokenizer/tokenizer.cpp src/hidet/runtime/llm/tokenizer/utf8.cpp + src/hidet/runtime/torch/stream.cpp ) target_include_directories(hidet_runtime PRIVATE ${CMAKE_SOURCE_DIR}/include /usr/include) set_target_properties(hidet_runtime PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -44,4 +45,31 @@ add_library(hidet SHARED ) target_include_directories(hidet PRIVATE ${CMAKE_SOURCE_DIR}/include) set_target_properties(hidet PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +execute_process( + COMMAND python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))" + OUTPUT_VARIABLE TORCH_LIBRARY + OUTPUT_STRIP_TRAILING_WHITESPACE +) +set(TORCH_INCLUDE_DIR "${TORCH_LIBRARY}/include") +execute_process( + COMMAND python3 -c "import nvidia; import os; print(os.path.dirname(nvidia.__file__))" + OUTPUT_VARIABLE NVIDIA_LIBRARY + OUTPUT_STRIP_TRAILING_WHITESPACE +) +set(CUDARUNTIME_INCLUDE_DIR "${NVIDIA_LIBRARY}/cuda_runtime/include") +execute_process( + COMMAND python3 -c "import triton; import os; print(os.path.dirname(triton.__file__))" + OUTPUT_VARIABLE TRITON_LIBRARY + OUTPUT_STRIP_TRAILING_WHITESPACE +) +set(TRITON_INCLUDE_DIR "${TRITON_LIBRARY}/backends/nvidia/include/") + +add_library(hidet_torch_wrapper SHARED + src/hidet/runtime/hidet_torch/cuda_stream.cpp +) +target_include_directories(hidet_torch_wrapper PRIVATE ${CMAKE_SOURCE_DIR}/include ${TORCH_INCLUDE_DIR} ${CUDARUNTIME_INCLUDE_DIR} ${TRITON_INCLUDE_DIR}) +target_link_libraries(hidet_torch_wrapper ${TORCH_LIBRARY}/lib/libc10_cuda.so) +set_target_properties(hidet_torch_wrapper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + target_link_libraries(hidet "-Wl,--no-as-needed" hidet_runtime) diff --git a/apps/compile_server/requirements.txt b/apps/compile_server/requirements.txt index 917fcbc26..65f532f2d 100644 --- a/apps/compile_server/requirements.txt +++ b/apps/compile_server/requirements.txt @@ -52,4 +52,4 @@ lark scipy # for torch runtime api dependency -torch>=2.3.0 +torch>=2.3.0 \ No newline at end of file diff --git a/docs/source/getting-started/build-from-source.rst b/docs/source/getting-started/build-from-source.rst index 4b35652d8..2c90c1f87 100644 --- a/docs/source/getting-started/build-from-source.rst +++ b/docs/source/getting-started/build-from-source.rst @@ -16,6 +16,10 @@ First clone the repository to local: Build shared libraries ~~~~~~~~~~~~~~~~~~~~~~ +Before building the runtime library, make sure you have ``torch`` installed in your python environment: +.. code-block:: console + + $ pip install torch The runtime library is written in C++ and compiled into a shared library. To build the shared library, you need to have a C++ compiler installed (as well as build tools like ``cmake``, and ``make``). The following command will build the diff --git a/include/hidet/runtime/cuda/context.h b/include/hidet/runtime/cuda/context.h index a14a652e9..3f0a52813 100644 --- a/include/hidet/runtime/cuda/context.h +++ b/include/hidet/runtime/cuda/context.h @@ -19,6 +19,9 @@ struct CudaContext: BaseContext { /* The cuda stream the kernels will be launched on. */ void *stream = nullptr; + /* whether to use torch stream */ + bool use_torch_stream = true; + /* NCCL Comunicators*/ void **nccl_comms = nullptr; @@ -35,6 +38,16 @@ struct CudaContext: BaseContext { */ DLL void set_cuda_stream(void *stream); +/** + * Get the use torch stream flag + */ +DLL bool get_use_torch_cuda_stream(); + +/** + * set the flag of whether to use torch stream + */ +DLL void use_torch_cuda_stream(bool use); + /** * Get the cuda stream of cuda context. */ diff --git a/include/hidet/runtime/torch/stream.h b/include/hidet/runtime/torch/stream.h new file mode 100644 index 000000000..4e495dda2 --- /dev/null +++ b/include/hidet/runtime/torch/stream.h @@ -0,0 +1,15 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +DLL void *hidet_get_current_torch_stream(); diff --git a/python/hidet/cuda/graph.py b/python/hidet/cuda/graph.py index ea6b52ae2..4fb279c5f 100644 --- a/python/hidet/cuda/graph.py +++ b/python/hidet/cuda/graph.py @@ -13,6 +13,7 @@ from typing import List, Sequence, Optional, Any, Callable from cuda import cudart from cuda.cudart import cudaGraphExec_t +from hidet.option import use_torch_stream, is_use_torch_stream from hidet.graph.tensor import Tensor from hidet.runtime.storage import MemoryPool, CudaMemoryAPI, memory_pool from hidet.runtime.device import Device @@ -116,7 +117,6 @@ def __init__( self._inputs: List[Tensor] = [] self._outputs: List[Tensor] = [] self._ref_objs: List[Any] = ref_objs - with memory_pool(self._memory_pool): # create the input tensors self._inputs = f_create_inputs() @@ -126,11 +126,18 @@ def __init__( for _ in range(num_warmup): f_run(self._inputs) + # There are two scenarios: + # 1. if torch or hidet is using default stream we use hidet created new stream + # 2. If torch is using its own new stream we use hidet created new stream to avoid + # interfere with torch stream + # Both cases we switch back to use hidet stream + prev_flag = is_use_torch_stream() + use_torch_stream(False) # capture the cuda graph self._memory_api.freeze() with self._graph_capture: self._outputs = f_run(self._inputs) - + use_torch_stream(prev_flag) # instantiate the cuda graph self._graph_exec: cudaGraphExec_t = self._graph_capture.instantiate() diff --git a/python/hidet/cuda/stream.py b/python/hidet/cuda/stream.py index ce20b9308..c641d7f9f 100644 --- a/python/hidet/cuda/stream.py +++ b/python/hidet/cuda/stream.py @@ -210,8 +210,15 @@ def current_stream(device=None) -> Stream: stream: Stream The current stream. """ + from hidet.ffi import runtime_api + device_id = _get_device_id(device) - if device_id not in _current_streams: + c_stream = runtime_api.get_current_stream() + if c_stream is not None: + # we return the current stream no matter if it's hidet/torch stream + _current_streams[device_id] = ExternalStream(handle=c_stream, device_id=device_id) + else: + # if no current stream is set, we use the default stream _current_streams[device_id] = ExternalStream(handle=0, device_id=device_id) return _current_streams[_get_device_id(device)] diff --git a/python/hidet/ffi/ffi.py b/python/hidet/ffi/ffi.py index 6771b9157..51db7500d 100644 --- a/python/hidet/ffi/ffi.py +++ b/python/hidet/ffi/ffi.py @@ -9,33 +9,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Dict, Optional import os import os.path import ctypes +import torch from hidet.libinfo import get_library_search_dirs _LIB: Optional[ctypes.CDLL] = None _LIB_RUNTIME: Optional[ctypes.CDLL] = None +_LIB_HIDET_TORCH_WRAPPER: Optional[ctypes.CDLL] = None -library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None} +library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None, 'hidet_torch': None} def load_library(): - global _LIB, _LIB_RUNTIME + global _LIB, _LIB_RUNTIME, _LIB_HIDET_TORCH_WRAPPER if _LIB: return + libc10_path = os.path.join(os.path.dirname(torch.__file__), 'lib/libc10_cuda.so') + ctypes.cdll.LoadLibrary(libc10_path) + library_dirs = get_library_search_dirs() for library_dir in library_dirs: libhidet_path = os.path.join(library_dir, 'libhidet.so') libhidet_runtime_path = os.path.join(library_dir, 'libhidet_runtime.so') - if not os.path.exists(libhidet_path) or not os.path.exists(libhidet_runtime_path): + libhidet_torch_wrapper_path = os.path.join(library_dir, 'libhidet_torch_wrapper.so') + if ( + not os.path.exists(libhidet_path) + or not os.path.exists(libhidet_runtime_path) + or not os.path.exists(libhidet_torch_wrapper_path) + ): continue _LIB_RUNTIME = ctypes.cdll.LoadLibrary(libhidet_runtime_path) + _LIB_HIDET_TORCH_WRAPPER = ctypes.cdll.LoadLibrary(libhidet_torch_wrapper_path) _LIB = ctypes.cdll.LoadLibrary(libhidet_path) library_paths['hidet_runtime'] = libhidet_runtime_path library_paths['hidet'] = libhidet_path + library_paths['hidet_torch_wrapper'] = libhidet_torch_wrapper_path break if _LIB is None: raise OSError('Can not find library in the following directory: \n' + '\n'.join(library_dirs)) @@ -71,6 +84,8 @@ def get_func(func_name, arg_types: List, restype, lib=None): func = getattr(_LIB, func_name) elif func_exists(func_name, _LIB_RUNTIME): func = getattr(_LIB_RUNTIME, func_name) + elif func_exists(func_name, _LIB_HIDET_TORCH_WRAPPER): + func = getattr(_LIB_HIDET_TORCH_WRAPPER, func_name) elif func_exists(func_name, lib): func = getattr(lib, func_name) else: diff --git a/python/hidet/ffi/runtime_api.py b/python/hidet/ffi/runtime_api.py index c9ba7aecc..3bf1c94f6 100644 --- a/python/hidet/ffi/runtime_api.py +++ b/python/hidet/ffi/runtime_api.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Union -from ctypes import c_void_p, c_char_p, c_uint64, c_int32 +from ctypes import c_void_p, c_char_p, c_uint64, c_int32, c_bool from hidet.cuda import Stream from .ffi import get_func from .array import Array @@ -26,15 +26,17 @@ class RuntimeAPI: _get_symbol_value = get_func('get_symbol_value', [c_char_p], c_int32) _set_symbol_value = get_func('set_symbol_value', [c_char_p, c_int32], None) _set_nccl_comms = get_func('set_nccl_comms', [c_int32, c_void_p], None) + _get_use_torch_stream = get_func('get_use_torch_cuda_stream', [], c_bool) + _use_torch_cuda_stream = get_func('use_torch_cuda_stream', [c_bool], None) @staticmethod def set_current_stream(stream: Union[Stream, int]) -> None: RuntimeAPI._set_current_stream(c_void_p(int(stream))) @staticmethod - def get_current_stream() -> int: + def get_current_stream() -> Union[int, None]: p = RuntimeAPI._get_current_stream() - return p.value + return p if p else None @staticmethod def register_callback(name: str, cfunc): @@ -68,5 +70,13 @@ def set_nccl_comms(comms: Array) -> None: comms_array_t = c_void_p * comms.length RuntimeAPI._set_nccl_comms(comms.length, comms_array_t.from_buffer(comms.buffer)) + @staticmethod + def get_use_torch_cuda_stream() -> bool: + return RuntimeAPI._get_use_torch_stream() + + @staticmethod + def use_torch_cuda_stream(use: bool) -> None: + RuntimeAPI._use_torch_cuda_stream(use) + runtime_api = RuntimeAPI() diff --git a/python/hidet/graph/frontend/torch/dynamo_backends.py b/python/hidet/graph/frontend/torch/dynamo_backends.py index 906df720e..2d49ceb82 100644 --- a/python/hidet/graph/frontend/torch/dynamo_backends.py +++ b/python/hidet/graph/frontend/torch/dynamo_backends.py @@ -195,8 +195,6 @@ def __call__(self, *args): else: # ignore constant pass - # Inherited cuda stream from torch - runtime_api.set_current_stream(torch.cuda.current_stream().cuda_stream) # Prepare inputs tensor_args = preprocess_inputs(tensor_args) # Run graph/model diff --git a/python/hidet/option.py b/python/hidet/option.py index 67011d9fe..f69cbb80f 100644 --- a/python/hidet/option.py +++ b/python/hidet/option.py @@ -965,6 +965,34 @@ def is_option_exist(name: str) -> bool: return name in OptionRegistry.registered_options +def use_torch_stream(use: bool): + """Set the flag for whether to use torch steam + + Parameters + ---------- + use: bool + whether to set + """ + from hidet.ffi.runtime_api import runtime_api + + # set use torch stream flag according to the option + runtime_api.use_torch_cuda_stream(use) + + +def is_use_torch_stream() -> bool: + """Checking if currently is using torch stream + + + Returns + ------- + ret: bool + True if using torch stream, False otherwise. + """ + from hidet.ffi.runtime_api import runtime_api + + return runtime_api.get_use_torch_cuda_stream() + + class cuda: """ The CUDA related options. diff --git a/requirements-dev.txt b/requirements-dev.txt index 6f1199d63..d04d9b1a3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,7 +12,6 @@ black==22.10.0 pylint==2.13.9 # for models to test -torch>=2.3.0 torchvision datasets diffusers diff --git a/requirements.txt b/requirements.txt index b0f7bdf08..e6834b8c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,6 @@ tomlkit # for performance measurements scipy + +# for torch runtime api dependency +torch>=2.3.0 \ No newline at end of file diff --git a/scripts/wheel/build_wheel.sh b/scripts/wheel/build_wheel.sh index f567ec300..4201d880a 100644 --- a/scripts/wheel/build_wheel.sh +++ b/scripts/wheel/build_wheel.sh @@ -13,11 +13,16 @@ set -e # exit immediately if a command exits with a non-zero status. # This script builds a wheel for the current platform and Python version. ############################################################################### + # work in the same directory of this script CURRENT_SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) ROOT_DIR=$(cd -- "$CURRENT_SCRIPT_DIR/../.." &> /dev/null && pwd) cd $CURRENT_SCRIPT_DIR +pip3 install torch +pip3 install nvidia-cuda-runtime-cu12 +pip3 install triton + # create a new build directory rm -rf build; mkdir build; diff --git a/src/hidet/runtime/cuda/context.cpp b/src/hidet/runtime/cuda/context.cpp index e100f1b73..7dc448165 100644 --- a/src/hidet/runtime/cuda/context.cpp +++ b/src/hidet/runtime/cuda/context.cpp @@ -12,6 +12,7 @@ #include #include #include +#include CudaContext *CudaContext::global() { static CudaContext instance; @@ -36,7 +37,18 @@ DLL void set_cuda_stream(void *stream) { CudaContext::global()->stream = stream; } +DLL bool get_use_torch_cuda_stream() { + return CudaContext::global()->use_torch_stream; +} + +DLL void use_torch_cuda_stream(bool use) { + CudaContext::global()->use_torch_stream = use; +} + DLL void *get_cuda_stream() { + if (CudaContext::global()->use_torch_stream) { + return hidet_get_current_torch_stream(); + } return CudaContext::global()->stream; } diff --git a/src/hidet/runtime/hidet_torch/cuda_stream.cpp b/src/hidet/runtime/hidet_torch/cuda_stream.cpp new file mode 100644 index 000000000..e302a0fc7 --- /dev/null +++ b/src/hidet/runtime/hidet_torch/cuda_stream.cpp @@ -0,0 +1,6 @@ +#include +#include + +DLL void *hidet_get_current_torch_cuda_stream() { + return at::cuda::getCurrentCUDAStream().stream(); +} diff --git a/src/hidet/runtime/torch/stream.cpp b/src/hidet/runtime/torch/stream.cpp new file mode 100644 index 000000000..1f238b76d --- /dev/null +++ b/src/hidet/runtime/torch/stream.cpp @@ -0,0 +1,37 @@ +#include +#include "../cuda/utils.h" + +typedef void *(*hidetGetCurrentCUDAStream_t)(); +static hidetGetCurrentCUDAStream_t hidetGetCurrentTorchCUDAStream = nullptr; + +static std::string library_path; +static void *libhidet_torch_wrapper = nullptr; +// load torch runtime APIs +static inline void lazy_load_torch_runtime() { + if (libhidet_torch_wrapper == nullptr) { + const char *libpath; + if (library_path.empty()) { + libpath = "libhidet_torch_wrapper.so"; + } else { + libpath = library_path.c_str(); + } + libhidet_torch_wrapper = dlopen(libpath, RTLD_LAZY); + + if (libhidet_torch_wrapper == nullptr) { + LOG(FATAL) << "Failed to load libhidet_torch_wrapper.so: " << dlerror(); + } + + hidetGetCurrentTorchCUDAStream = + get_symbol(libhidet_torch_wrapper, "hidet_get_current_torch_cuda_stream"); + } +} + +DLL void *hidet_get_current_torch_stream() { + lazy_load_torch_runtime(); + return hidetGetCurrentTorchCUDAStream(); +} + +// Hidet exported APIs +DLL void hidet_torch_set_library_path(const char *path) { + library_path = path; +} \ No newline at end of file diff --git a/tests/frontends/torch/test_torch_stream.py b/tests/frontends/torch/test_torch_stream.py new file mode 100644 index 000000000..fa5914861 --- /dev/null +++ b/tests/frontends/torch/test_torch_stream.py @@ -0,0 +1,44 @@ +import torch +import numpy +import pytest + + +@pytest.mark.parametrize("size", [(32, 32)]) +@pytest.mark.parametrize("mode", ["max-autotune", "max-autotune-no-cudagraphs"]) +def test_default_stream(size, mode): + device = torch.device(0) + + x = torch.rand(size=size, dtype=torch.float32).to(device) + w = torch.rand(size=size, dtype=torch.float32).to(device) + + def matmul(x): + return x.matmul(w) + + matmul_opt = torch.compile(matmul, backend='hidet', mode=mode) + hidet_output = matmul_opt(x) + torch_output = matmul(x) + torch_output = torch_output.detach().cpu().numpy() + hidet_output = hidet_output.detach().cpu().numpy() + numpy.testing.assert_allclose(torch_output, hidet_output, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("size", [(32, 32)]) +@pytest.mark.parametrize("mode", ["max-autotune", "max-autotune-no-cudagraphs"]) +def test_new_torch_stream(size, mode): + device = torch.device(0) + + x = torch.rand(size=size, dtype=torch.float32).to(device) + w = torch.rand(size=size, dtype=torch.float32).to(device) + s = torch.cuda.Stream(device=device) + + def matmul(x): + return x.matmul(w) + + with torch.cuda.stream(s): + matmul_opt = torch.compile(matmul, backend='hidet', mode=mode) + hidet_output = matmul_opt(x) + torch_output = matmul(x) + s.synchronize() + torch_output = torch_output.detach().cpu().numpy() + hidet_output = hidet_output.detach().cpu().numpy() + numpy.testing.assert_allclose(torch_output, hidet_output, atol=1e-3, rtol=1e-3)