Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm_ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Map
from .module import Module, system_lib, load_module
from .stream import StreamContext, use_raw_stream, use_torch_stream
from .stream import StreamContext, get_raw_stream, use_raw_stream, use_torch_stream
from . import serialization
from . import access_path
from . import testing
Expand Down
6 changes: 6 additions & 0 deletions python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
return <uint64_t>prev_stream


def _env_get_stream(int device_type, int device_id):
cdef void* current_stream
current_stream = TVMFFIEnvGetStream(device_type, device_id)
return <uint64_t>current_stream


cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
Expand Down
16 changes: 16 additions & 0 deletions python/tvm_ffi/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,19 @@ def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]) -> StreamC
"try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph"
)
return StreamContext(device, stream)


def get_raw_stream(device: core.Device) -> int:
"""Get the current ffi stream of given device.

Parameters
----------
device : tvm_ffi.Device
The device to which the stream belongs.

Returns
-------
stream : int
The current ffi stream.
"""
return core._env_get_stream(device.dlpack_device_type(), device.index)
3 changes: 3 additions & 0 deletions tests/python/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def test_raw_stream() -> None:
stream_2 = 987654321
with tvm_ffi.use_raw_stream(device, stream_1):
mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
assert tvm_ffi.get_raw_stream(device) == stream_1

with tvm_ffi.use_raw_stream(device, stream_2):
mod.check_stream(device.dlpack_device_type(), device.index, stream_2)
assert tvm_ffi.get_raw_stream(device) == stream_2

mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
assert tvm_ffi.get_raw_stream(device) == stream_1


@pytest.mark.skipif(
Expand Down
Loading