diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py index 720968d6..e46a7aa5 100644 --- a/python/tvm_ffi/__init__.py +++ b/python/tvm_ffi/__init__.py @@ -42,7 +42,7 @@ from ._tensor import from_dlpack, Tensor, Shape from .container import Array, Map from .module import Module, system_lib, load_module -from .stream import StreamContext, use_raw_stream, use_torch_stream +from .stream import StreamContext, get_raw_stream, use_raw_stream, use_torch_stream from . import serialization from . import access_path from . import testing diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index 54b44cf6..45a7d28c 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -259,6 +259,7 @@ def _convert_to_ffi_error(error: BaseException) -> Error: ... def _env_set_current_stream( device_type: int, device_id: int, stream: int | c_void_p ) -> int | c_void_p: ... +def _env_get_current_stream(device_type: int, device_id: int) -> int: ... class DataType: """Internal wrapper around ``DLDataType``. diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index 633ace4d..a8b4212e 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -261,6 +261,12 @@ def _env_set_current_stream(int device_type, int device_id, uint64_t stream): return prev_stream +def _env_get_current_stream(int device_type, int device_id): + cdef void* current_stream + current_stream = TVMFFIEnvGetStream(device_type, device_id) + return current_stream + + cdef extern from "tvm_ffi_python_helpers.h": # no need to expose fields of the call context setter data structure ctypedef int (*DLPackFromPyObject)( diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py index 7f2dde52..e00d4225 100644 --- a/python/tvm_ffi/stream.py +++ b/python/tvm_ffi/stream.py @@ -161,3 +161,20 @@ def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]) -> StreamC "try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph" ) return StreamContext(device, stream) + + +def get_raw_stream(device: core.Device) -> int: + """Get the current ffi stream of given device. + + Parameters + ---------- + device : tvm_ffi.Device + The device to which the stream belongs. + + Returns + ------- + stream : int + The current ffi stream. + + """ + return core._env_get_current_stream(device.dlpack_device_type(), device.index) diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py index 9280aabb..fbe1a0a2 100644 --- a/tests/python/test_stream.py +++ b/tests/python/test_stream.py @@ -50,11 +50,14 @@ def test_raw_stream() -> None: stream_2 = 987654321 with tvm_ffi.use_raw_stream(device, stream_1): mod.check_stream(device.dlpack_device_type(), device.index, stream_1) + assert tvm_ffi.get_raw_stream(device) == stream_1 with tvm_ffi.use_raw_stream(device, stream_2): mod.check_stream(device.dlpack_device_type(), device.index, stream_2) + assert tvm_ffi.get_raw_stream(device) == stream_2 mod.check_stream(device.dlpack_device_type(), device.index, stream_1) + assert tvm_ffi.get_raw_stream(device) == stream_1 @pytest.mark.skipif(