diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst index 482c19db..756af4c4 100644 --- a/docs/reference/python/index.rst +++ b/docs/reference/python/index.rst @@ -68,6 +68,16 @@ Containers Map +Stream Context +-------------- +.. autosummary:: + :toctree: generated/ + + StreamContext + use_torch_stream + use_raw_stream + + Utility ------- diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py index 16f035ed..9bafe2b7 100644 --- a/python/tvm_ffi/__init__.py +++ b/python/tvm_ffi/__init__.py @@ -37,6 +37,7 @@ from ._tensor import from_dlpack, Tensor, Shape from .container import Array, Map from .module import Module, system_lib, load_module +from .stream import StreamContext, use_raw_stream, use_torch_stream from . import serialization from . import access_path from . import testing diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index 08f7df2c..77c9c7e8 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -245,6 +245,15 @@ cdef extern from "tvm/ffi/extra/c_env_api.h": TVMFFIStreamHandle stream, TVMFFIStreamHandle* opt_out_original_stream) nogil +def _env_set_current_stream(int device_type, int device_id, uint64_t stream): + cdef TVMFFIStreamHandle prev_stream = NULL + CHECK_CALL(TVMFFIEnvSetStream( + device_type, + device_id, + stream, + &prev_stream)) + return prev_stream + cdef extern from "tvm_ffi_python_helpers.h": # no need to expose fields of the call context diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py new file mode 100644 index 00000000..598afcac --- /dev/null +++ b/python/tvm_ffi/stream.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Stream context.""" +from ctypes import c_void_p +from typing import Any, Optional, Union + +from . import core +from ._tensor import device + + +class StreamContext: + """StreamContext represents a stream context in the ffi system. + StreamContext helps setup ffi environment stream by python `with` statement. + When entering `with` scope, it caches the current environment stream and + setup the given new stream. + When exiting `with` scope, it recovers the stream to the cached environment stream. + + Parameters + ---------- + device : Device + The device to which the stream belongs. + + stream : Union[int, c_void_p] + The stream handle. + + See Also + -------- + :py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream` + """ + + def __init__(self, device: core.Device, stream: Union[int, c_void_p]): + self.device_type = device.dlpack_device_type() + self.device_id = device.index + self.stream = stream + + def __enter__(self): + self.prev_stream = core._env_set_current_stream( + self.device_type, self.device_id, self.stream + ) + + def __exit__(self, *args): + self.prev_stream = core._env_set_current_stream( + self.device_type, self.device_id, self.prev_stream + ) + + +try: + import torch + + class TorchStreamContext: + def __init__(self, context: Optional[Any]): + self.torch_context = context + + def __enter__(self): + if self.torch_context: + self.torch_context.__enter__() + current_stream = torch.cuda.current_stream() + self.ffi_context = StreamContext( + device(str(current_stream.device)), current_stream.cuda_stream + ) + self.ffi_context.__enter__() + + def __exit__(self, *args): + if self.torch_context: + self.torch_context.__exit__(*args) + self.ffi_context.__exit__(*args) + + def use_torch_stream(context: Optional[Any] = None): + """ + Create a ffi stream context with given torch stream, + cuda graph or current stream if `None` provided. + + Parameters + ---------- + context : Optional[Any] + The wrapped torch stream or cuda graph. + + Returns + ------- + context : tvm_ffi.TorchStreamContext + The ffi stream context wrapping torch stream context. + + Examples + -------- + .. code-block:: python + + s = torch.cuda.Stream() + with tvm_ffi.use_torch_stream(torch.cuda.stream(s)): + ... + + g = torch.cuda.CUDAGraph() + with tvm_ffi.use_torch_stream(torch.cuda.graph(g)): + ... + + Note + ---- + When working with raw cudaStream_t handle, using :py:func:`tvm_ffi.use_raw_stream` instead. + """ + return TorchStreamContext(context) + +except ImportError: + + def use_torch_stream(context: Optional[Any] = None): + raise ImportError("Cannot import torch") + + +def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]): + """ + Create a ffi stream context with given device and stream handle. + + Parameters + ---------- + device : tvm_ffi.Device + The device to which the stream belongs. + + stream : Union[int, c_void_p] + The stream handle. + + Returns + ------- + context : tvm_ffi.StreamContext + The ffi stream context. + + Note + ---- + When working with torch stram or cuda graph, using :py:func:`tvm_ffi.use_torch_stream` instead. + """ + if not isinstance(stream, (int, c_void_p)): + raise ValueError( + "use_raw_stream only accepts int or c_void_p as stram input, " + "try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph" + ) + return StreamContext(device, stream) diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py new file mode 100644 index 00000000..c7b81a82 --- /dev/null +++ b/tests/python/test_stream.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm_ffi +import tvm_ffi.cpp + +try: + import torch +except ImportError: + torch = None + + +def gen_check_stream_mod(): + return tvm_ffi.cpp.load_inline( + name="check_stream", + cpp_sources=""" + void check_stream(int device_type, int device_id, uint64_t stream) { + uint64_t cur_stream = reinterpret_cast(TVMFFIEnvGetStream(device_type, device_id)); + TVM_FFI_ICHECK_EQ(cur_stream, stream); + } + """, + functions=["check_stream"], + ) + + +def test_raw_stream(): + mod = gen_check_stream_mod() + device = tvm_ffi.device("cuda:0") + stream_1 = 123456789 + stream_2 = 987654321 + with tvm_ffi.use_raw_stream(device, stream_1): + mod.check_stream(device.dlpack_device_type(), device.index, stream_1) + + with tvm_ffi.use_raw_stream(device, stream_2): + mod.check_stream(device.dlpack_device_type(), device.index, stream_2) + + mod.check_stream(device.dlpack_device_type(), device.index, stream_1) + + +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" +) +def test_torch_stream(): + mod = gen_check_stream_mod() + device_id = torch.cuda.current_device() + device = tvm_ffi.device("cuda", device_id) + device_type = device.dlpack_device_type() + stream_1 = torch.cuda.Stream(device_id) + stream_2 = torch.cuda.Stream(device_id) + with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_1)): + assert torch.cuda.current_stream() == stream_1 + mod.check_stream(device_type, device_id, stream_1.cuda_stream) + + with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_2)): + assert torch.cuda.current_stream() == stream_2 + mod.check_stream(device_type, device_id, stream_2.cuda_stream) + + assert torch.cuda.current_stream() == stream_1 + mod.check_stream(device_type, device_id, stream_1.cuda_stream) + + +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" +) +def test_torch_current_stream(): + mod = gen_check_stream_mod() + device_id = torch.cuda.current_device() + device = tvm_ffi.device("cuda", device_id) + device_type = device.dlpack_device_type() + stream_1 = torch.cuda.Stream(device_id) + stream_2 = torch.cuda.Stream(device_id) + with torch.cuda.stream(stream_1): + assert torch.cuda.current_stream() == stream_1 + with tvm_ffi.use_torch_stream(): + mod.check_stream(device_type, device_id, stream_1.cuda_stream) + + with torch.cuda.stream(stream_2): + assert torch.cuda.current_stream() == stream_2 + with tvm_ffi.use_torch_stream(): + mod.check_stream(device_type, device_id, stream_2.cuda_stream) + + assert torch.cuda.current_stream() == stream_1 + with tvm_ffi.use_torch_stream(): + mod.check_stream(device_type, device_id, stream_1.cuda_stream) + + +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" +) +def test_torch_graph(): + mod = gen_check_stream_mod() + device_id = torch.cuda.current_device() + device = tvm_ffi.device("cuda", device_id) + device_type = device.dlpack_device_type() + graph = torch.cuda.CUDAGraph() + stream = torch.cuda.Stream(device_id) + with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)): + assert torch.cuda.current_stream() == stream + mod.check_stream(device_type, device_id, stream.cuda_stream)