Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm_ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Map
from .module import Module, system_lib, load_module
from .stream import StreamContext, use_raw_stream, use_torch_stream
from .stream import StreamContext, get_raw_stream, use_raw_stream, use_torch_stream
from . import serialization
from . import access_path
from . import testing
Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def _convert_to_ffi_error(error: BaseException) -> Error: ...
def _env_set_current_stream(
device_type: int, device_id: int, stream: int | c_void_p
) -> int | c_void_p: ...
def _env_get_current_stream(device_type: int, device_id: int) -> int: ...

class DataType:
"""Internal wrapper around ``DLDataType``.
Expand Down
6 changes: 6 additions & 0 deletions python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
return <uint64_t>prev_stream


def _env_get_current_stream(int device_type, int device_id):
cdef void* current_stream
current_stream = TVMFFIEnvGetStream(device_type, device_id)
return <uint64_t>current_stream


cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
Expand Down
17 changes: 17 additions & 0 deletions python/tvm_ffi/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,20 @@ def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]) -> StreamC
"try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph"
)
return StreamContext(device, stream)


def get_raw_stream(device: core.Device) -> int:
"""Get the current ffi stream of given device.

Parameters
----------
device : tvm_ffi.Device
The device to which the stream belongs.

Returns
-------
stream : int
The current ffi stream.

"""
return core._env_get_current_stream(device.dlpack_device_type(), device.index)
3 changes: 3 additions & 0 deletions tests/python/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def test_raw_stream() -> None:
stream_2 = 987654321
with tvm_ffi.use_raw_stream(device, stream_1):
mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
assert tvm_ffi.get_raw_stream(device) == stream_1

with tvm_ffi.use_raw_stream(device, stream_2):
mod.check_stream(device.dlpack_device_type(), device.index, stream_2)
assert tvm_ffi.get_raw_stream(device) == stream_2

mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
assert tvm_ffi.get_raw_stream(device) == stream_1


@pytest.mark.skipif(
Expand Down