Skip to content

Commit

Permalink
[Enhancement] Add option and functionality to set torch stream as the…
Browse files Browse the repository at this point in the history
… current stream (#629)

Use torch c++ api to set the current stream to current torch stream.

Implementation:
- Build a hidet-torch shared library to wrap the original torch C++ API
(The original API contains torch defined structure like `CUDAStream` and
cannot be easily dlopened during runtime and accessed)
- dlopen the newly added hidet-torch library and access torch's current
stream
- Add option "use_torch_stream" to hidet's option to dynamically set the
stream to current torch stream or hidet's stream during runtime
- When hidet's CUDA graph mode is on, hidet will still create a new
hidet stream and capture the graph on that stream instead of using the
torch stream.

Benefits:
- Removes the overhead of query and calling torch's current stream api
from the python side
- Could also reduce the overhead occured in Hexcute integration because
`set_to_torch_stream` is called in the launch function. We can remove
the stream query/switch on python side.

Performance improvement (measured on L4 lock frequency@6250MHZ
compute/1500MHZ memory):
1. For Hexcute kernel (without cudagraph), I manually disabled CUDA
graph on DMWL (vLLM) side, prefill and decoding stage will both use the
generic model and call Hexcute kernel directly.
command: `python3 benchmark_latency.py --model
hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --input-len 1024
--output-len 128 --batch-size 8 --num-iters-warmup 5 --num-iters 10
--max-model-len 32768 --quantization awq_hidet`
Comparsion before and after removing stream query and stream switch
before Hexcute kernel call (CentML/DMWL#121)
Before avg latency: 12.624572871897545 seconds
After avg latency: 11.764245539499097 seconds

2. Profile small kermels in hidet and measure latency:
- Enable CUDA graph
`python bench_op_torch_api.py --params 16x16,16x16 --mode max-autotune
matmul`
Before: 0.27151119 second
After: 0.25410826999999997 second
- Disable CUDA graph
`python bench_op_torch_api.py --params 16x16,16x16 --mode
max-autotune-no-cudagraphs matmul`
Before: 0.14555310999999999 second
After: 0.11648335 second

This is related to #563
  • Loading branch information
maxyanghu authored and vadiklyutiy committed Dec 24, 2024
1 parent c583fb2 commit 3e6a291
Show file tree
Hide file tree
Showing 18 changed files with 245 additions and 14 deletions.
30 changes: 29 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19)

project(hidet C CXX)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# config hidet
Expand Down Expand Up @@ -34,6 +34,7 @@ add_library(hidet_runtime SHARED
src/hidet/runtime/llm/tokenizer/pretokenizers.cpp
src/hidet/runtime/llm/tokenizer/tokenizer.cpp
src/hidet/runtime/llm/tokenizer/utf8.cpp
src/hidet/runtime/torch/stream.cpp
)
target_include_directories(hidet_runtime PRIVATE ${CMAKE_SOURCE_DIR}/include /usr/include)
set_target_properties(hidet_runtime PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
Expand All @@ -44,4 +45,31 @@ add_library(hidet SHARED
)
target_include_directories(hidet PRIVATE ${CMAKE_SOURCE_DIR}/include)
set_target_properties(hidet PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)

execute_process(
COMMAND python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))"
OUTPUT_VARIABLE TORCH_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(TORCH_INCLUDE_DIR "${TORCH_LIBRARY}/include")
execute_process(
COMMAND python3 -c "import nvidia; import os; print(os.path.dirname(nvidia.__file__))"
OUTPUT_VARIABLE NVIDIA_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(CUDARUNTIME_INCLUDE_DIR "${NVIDIA_LIBRARY}/cuda_runtime/include")
execute_process(
COMMAND python3 -c "import triton; import os; print(os.path.dirname(triton.__file__))"
OUTPUT_VARIABLE TRITON_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(TRITON_INCLUDE_DIR "${TRITON_LIBRARY}/backends/nvidia/include/")

add_library(hidet_torch_wrapper SHARED
src/hidet/runtime/hidet_torch/cuda_stream.cpp
)
target_include_directories(hidet_torch_wrapper PRIVATE ${CMAKE_SOURCE_DIR}/include ${TORCH_INCLUDE_DIR} ${CUDARUNTIME_INCLUDE_DIR} ${TRITON_INCLUDE_DIR})
target_link_libraries(hidet_torch_wrapper ${TORCH_LIBRARY}/lib/libc10_cuda.so)
set_target_properties(hidet_torch_wrapper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)

target_link_libraries(hidet "-Wl,--no-as-needed" hidet_runtime)
2 changes: 1 addition & 1 deletion apps/compile_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ lark
scipy

# for torch runtime api dependency
torch>=2.3.0
torch>=2.3.0
4 changes: 4 additions & 0 deletions docs/source/getting-started/build-from-source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ First clone the repository to local:
Build shared libraries
~~~~~~~~~~~~~~~~~~~~~~
Before building the runtime library, make sure you have ``torch`` installed in your python environment:
.. code-block:: console
$ pip install torch
The runtime library is written in C++ and compiled into a shared library. To build the shared library, you need to have
a C++ compiler installed (as well as build tools like ``cmake``, and ``make``). The following command will build the
Expand Down
13 changes: 13 additions & 0 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void *stream = nullptr;

/* whether to use torch stream */
bool use_torch_stream = true;

/* NCCL Comunicators*/
void **nccl_comms = nullptr;

Expand All @@ -35,6 +38,16 @@ struct CudaContext: BaseContext {
*/
DLL void set_cuda_stream(void *stream);

/**
* Get the use torch stream flag
*/
DLL bool get_use_torch_cuda_stream();

/**
* set the flag of whether to use torch stream
*/
DLL void use_torch_cuda_stream(bool use);

/**
* Get the cuda stream of cuda context.
*/
Expand Down
15 changes: 15 additions & 0 deletions include/hidet/runtime/torch/stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hidet/runtime/common.h>

DLL void *hidet_get_current_torch_stream();
11 changes: 9 additions & 2 deletions python/hidet/cuda/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import List, Sequence, Optional, Any, Callable
from cuda import cudart
from cuda.cudart import cudaGraphExec_t
from hidet.option import use_torch_stream, is_use_torch_stream
from hidet.graph.tensor import Tensor
from hidet.runtime.storage import MemoryPool, CudaMemoryAPI, memory_pool
from hidet.runtime.device import Device
Expand Down Expand Up @@ -116,7 +117,6 @@ def __init__(
self._inputs: List[Tensor] = []
self._outputs: List[Tensor] = []
self._ref_objs: List[Any] = ref_objs

with memory_pool(self._memory_pool):
# create the input tensors
self._inputs = f_create_inputs()
Expand All @@ -126,11 +126,18 @@ def __init__(
for _ in range(num_warmup):
f_run(self._inputs)

# There are two scenarios:
# 1. if torch or hidet is using default stream we use hidet created new stream
# 2. If torch is using its own new stream we use hidet created new stream to avoid
# interfere with torch stream
# Both cases we switch back to use hidet stream
prev_flag = is_use_torch_stream()
use_torch_stream(False)
# capture the cuda graph
self._memory_api.freeze()
with self._graph_capture:
self._outputs = f_run(self._inputs)

use_torch_stream(prev_flag)
# instantiate the cuda graph
self._graph_exec: cudaGraphExec_t = self._graph_capture.instantiate()

Expand Down
9 changes: 8 additions & 1 deletion python/hidet/cuda/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,15 @@ def current_stream(device=None) -> Stream:
stream: Stream
The current stream.
"""
from hidet.ffi import runtime_api

device_id = _get_device_id(device)
if device_id not in _current_streams:
c_stream = runtime_api.get_current_stream()
if c_stream is not None:
# we return the current stream no matter if it's hidet/torch stream
_current_streams[device_id] = ExternalStream(handle=c_stream, device_id=device_id)
else:
# if no current stream is set, we use the default stream
_current_streams[device_id] = ExternalStream(handle=0, device_id=device_id)
return _current_streams[_get_device_id(device)]

Expand Down
21 changes: 18 additions & 3 deletions python/hidet/ffi/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict, Optional
import os
import os.path
import ctypes
import torch
from hidet.libinfo import get_library_search_dirs

_LIB: Optional[ctypes.CDLL] = None
_LIB_RUNTIME: Optional[ctypes.CDLL] = None
_LIB_HIDET_TORCH_WRAPPER: Optional[ctypes.CDLL] = None


library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None}
library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None, 'hidet_torch': None}


def load_library():
global _LIB, _LIB_RUNTIME
global _LIB, _LIB_RUNTIME, _LIB_HIDET_TORCH_WRAPPER
if _LIB:
return
libc10_path = os.path.join(os.path.dirname(torch.__file__), 'lib/libc10_cuda.so')
ctypes.cdll.LoadLibrary(libc10_path)

library_dirs = get_library_search_dirs()
for library_dir in library_dirs:
libhidet_path = os.path.join(library_dir, 'libhidet.so')
libhidet_runtime_path = os.path.join(library_dir, 'libhidet_runtime.so')
if not os.path.exists(libhidet_path) or not os.path.exists(libhidet_runtime_path):
libhidet_torch_wrapper_path = os.path.join(library_dir, 'libhidet_torch_wrapper.so')
if (
not os.path.exists(libhidet_path)
or not os.path.exists(libhidet_runtime_path)
or not os.path.exists(libhidet_torch_wrapper_path)
):
continue
_LIB_RUNTIME = ctypes.cdll.LoadLibrary(libhidet_runtime_path)
_LIB_HIDET_TORCH_WRAPPER = ctypes.cdll.LoadLibrary(libhidet_torch_wrapper_path)
_LIB = ctypes.cdll.LoadLibrary(libhidet_path)
library_paths['hidet_runtime'] = libhidet_runtime_path
library_paths['hidet'] = libhidet_path
library_paths['hidet_torch_wrapper'] = libhidet_torch_wrapper_path
break
if _LIB is None:
raise OSError('Can not find library in the following directory: \n' + '\n'.join(library_dirs))
Expand Down Expand Up @@ -71,6 +84,8 @@ def get_func(func_name, arg_types: List, restype, lib=None):
func = getattr(_LIB, func_name)
elif func_exists(func_name, _LIB_RUNTIME):
func = getattr(_LIB_RUNTIME, func_name)
elif func_exists(func_name, _LIB_HIDET_TORCH_WRAPPER):
func = getattr(_LIB_HIDET_TORCH_WRAPPER, func_name)
elif func_exists(func_name, lib):
func = getattr(lib, func_name)
else:
Expand Down
16 changes: 13 additions & 3 deletions python/hidet/ffi/runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ctypes import c_void_p, c_char_p, c_uint64, c_int32
from ctypes import c_void_p, c_char_p, c_uint64, c_int32, c_bool
from hidet.cuda import Stream
from .ffi import get_func
from .array import Array
Expand All @@ -26,15 +26,17 @@ class RuntimeAPI:
_get_symbol_value = get_func('get_symbol_value', [c_char_p], c_int32)
_set_symbol_value = get_func('set_symbol_value', [c_char_p, c_int32], None)
_set_nccl_comms = get_func('set_nccl_comms', [c_int32, c_void_p], None)
_get_use_torch_stream = get_func('get_use_torch_cuda_stream', [], c_bool)
_use_torch_cuda_stream = get_func('use_torch_cuda_stream', [c_bool], None)

@staticmethod
def set_current_stream(stream: Union[Stream, int]) -> None:
RuntimeAPI._set_current_stream(c_void_p(int(stream)))

@staticmethod
def get_current_stream() -> int:
def get_current_stream() -> Union[int, None]:
p = RuntimeAPI._get_current_stream()
return p.value
return p if p else None

@staticmethod
def register_callback(name: str, cfunc):
Expand Down Expand Up @@ -68,5 +70,13 @@ def set_nccl_comms(comms: Array) -> None:
comms_array_t = c_void_p * comms.length
RuntimeAPI._set_nccl_comms(comms.length, comms_array_t.from_buffer(comms.buffer))

@staticmethod
def get_use_torch_cuda_stream() -> bool:
return RuntimeAPI._get_use_torch_stream()

@staticmethod
def use_torch_cuda_stream(use: bool) -> None:
RuntimeAPI._use_torch_cuda_stream(use)


runtime_api = RuntimeAPI()
2 changes: 0 additions & 2 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ def __call__(self, *args):
else:
# ignore constant
pass
# Inherited cuda stream from torch
runtime_api.set_current_stream(torch.cuda.current_stream().cuda_stream)
# Prepare inputs
tensor_args = preprocess_inputs(tensor_args)
# Run graph/model
Expand Down
28 changes: 28 additions & 0 deletions python/hidet/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,34 @@ def is_option_exist(name: str) -> bool:
return name in OptionRegistry.registered_options


def use_torch_stream(use: bool):
"""Set the flag for whether to use torch steam
Parameters
----------
use: bool
whether to set
"""
from hidet.ffi.runtime_api import runtime_api

# set use torch stream flag according to the option
runtime_api.use_torch_cuda_stream(use)


def is_use_torch_stream() -> bool:
"""Checking if currently is using torch stream
Returns
-------
ret: bool
True if using torch stream, False otherwise.
"""
from hidet.ffi.runtime_api import runtime_api

return runtime_api.get_use_torch_cuda_stream()


class cuda:
"""
The CUDA related options.
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ black==22.10.0
pylint==2.13.9

# for models to test
torch>=2.3.0
torchvision
datasets
diffusers
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ tomlkit

# for performance measurements
scipy

# for torch runtime api dependency
torch>=2.3.0
5 changes: 5 additions & 0 deletions scripts/wheel/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ set -e # exit immediately if a command exits with a non-zero status.
# This script builds a wheel for the current platform and Python version.
###############################################################################


# work in the same directory of this script
CURRENT_SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
ROOT_DIR=$(cd -- "$CURRENT_SCRIPT_DIR/../.." &> /dev/null && pwd)
cd $CURRENT_SCRIPT_DIR

pip3 install torch
pip3 install nvidia-cuda-runtime-cu12
pip3 install triton

# create a new build directory
rm -rf build; mkdir build;

Expand Down
12 changes: 12 additions & 0 deletions src/hidet/runtime/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <hidet/runtime/callbacks.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>
#include <hidet/runtime/torch/stream.h>

CudaContext *CudaContext::global() {
static CudaContext instance;
Expand All @@ -36,7 +37,18 @@ DLL void set_cuda_stream(void *stream) {
CudaContext::global()->stream = stream;
}

DLL bool get_use_torch_cuda_stream() {
return CudaContext::global()->use_torch_stream;
}

DLL void use_torch_cuda_stream(bool use) {
CudaContext::global()->use_torch_stream = use;
}

DLL void *get_cuda_stream() {
if (CudaContext::global()->use_torch_stream) {
return hidet_get_current_torch_stream();
}
return CudaContext::global()->stream;
}

Expand Down
6 changes: 6 additions & 0 deletions src/hidet/runtime/hidet_torch/cuda_stream.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <c10/cuda/CUDAStream.h>
#include <hidet/runtime/common.h>

DLL void *hidet_get_current_torch_cuda_stream() {
return at::cuda::getCurrentCUDAStream().stream();
}
Loading

0 comments on commit 3e6a291

Please sign in to comment.