Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/reference/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ Containers
Map


Stream Context
--------------
.. autosummary::
:toctree: generated/

StreamContext
use_torch_stream
use_raw_stream


Utility
-------

Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Map
from .module import Module, system_lib, load_module
from .stream import StreamContext, use_raw_stream, use_torch_stream
from . import serialization
from . import access_path
from . import testing
Expand Down
9 changes: 9 additions & 0 deletions python/tvm_ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream) nogil

def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
cdef TVMFFIStreamHandle prev_stream = NULL
CHECK_CALL(TVMFFIEnvSetStream(
device_type,
device_id,
<void*>stream,
&prev_stream))
return <uint64_t>prev_stream


cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context
Expand Down
148 changes: 148 additions & 0 deletions python/tvm_ffi/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Stream context."""
from ctypes import c_void_p
from typing import Any, Optional, Union

from . import core
from ._tensor import device


class StreamContext:
"""StreamContext represents a stream context in the ffi system.
StreamContext helps setup ffi environment stream by python `with` statement.
When entering `with` scope, it caches the current environment stream and
setup the given new stream.
When exiting `with` scope, it recovers the stream to the cached environment stream.

Parameters
----------
device : Device
The device to which the stream belongs.

stream : Union[int, c_void_p]
The stream handle.

See Also
--------
:py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream`
"""

def __init__(self, device: core.Device, stream: Union[int, c_void_p]):
self.device_type = device.dlpack_device_type()
self.device_id = device.index
self.stream = stream

def __enter__(self):
self.prev_stream = core._env_set_current_stream(
self.device_type, self.device_id, self.stream
)

def __exit__(self, *args):
self.prev_stream = core._env_set_current_stream(
self.device_type, self.device_id, self.prev_stream
)


try:
import torch

class TorchStreamContext:
def __init__(self, context: Optional[Any]):
self.torch_context = context

def __enter__(self):
if self.torch_context:
self.torch_context.__enter__()
current_stream = torch.cuda.current_stream()
self.ffi_context = StreamContext(
device(str(current_stream.device)), current_stream.cuda_stream
)
self.ffi_context.__enter__()

def __exit__(self, *args):
if self.torch_context:
self.torch_context.__exit__(*args)
self.ffi_context.__exit__(*args)

def use_torch_stream(context: Optional[Any] = None):
"""
Create a ffi stream context with given torch stream,
cuda graph or current stream if `None` provided.

Parameters
----------
context : Optional[Any]
The wrapped torch stream or cuda graph.

Returns
-------
context : tvm_ffi.TorchStreamContext
The ffi stream context wrapping torch stream context.

Examples
--------
.. code-block:: python

s = torch.cuda.Stream()
with tvm_ffi.use_torch_stream(torch.cuda.stream(s)):
...

g = torch.cuda.CUDAGraph()
with tvm_ffi.use_torch_stream(torch.cuda.graph(g)):
...

Note
----
When working with raw cudaStream_t handle, using :py:func:`tvm_ffi.use_raw_stream` instead.
"""
return TorchStreamContext(context)

except ImportError:

def use_torch_stream(context: Optional[Any] = None):
raise ImportError("Cannot import torch")


def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]):
"""
Create a ffi stream context with given device and stream handle.

Parameters
----------
device : tvm_ffi.Device
The device to which the stream belongs.

stream : Union[int, c_void_p]
The stream handle.

Returns
-------
context : tvm_ffi.StreamContext
The ffi stream context.

Note
----
When working with torch stram or cuda graph, using :py:func:`tvm_ffi.use_torch_stream` instead.
"""
if not isinstance(stream, (int, c_void_p)):
raise ValueError(
"use_raw_stream only accepts int or c_void_p as stram input, "
"try use_torch_stream when using torch.cuda.Stream or torch.cuda.graph"
)
return StreamContext(device, stream)
115 changes: 115 additions & 0 deletions tests/python/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm_ffi
import tvm_ffi.cpp

try:
import torch
except ImportError:
torch = None


def gen_check_stream_mod():
return tvm_ffi.cpp.load_inline(
name="check_stream",
cpp_sources="""
void check_stream(int device_type, int device_id, uint64_t stream) {
uint64_t cur_stream = reinterpret_cast<uint64_t>(TVMFFIEnvGetStream(device_type, device_id));
TVM_FFI_ICHECK_EQ(cur_stream, stream);
}
""",
functions=["check_stream"],
)


def test_raw_stream():
mod = gen_check_stream_mod()
device = tvm_ffi.device("cuda:0")
stream_1 = 123456789
stream_2 = 987654321
with tvm_ffi.use_raw_stream(device, stream_1):
mod.check_stream(device.dlpack_device_type(), device.index, stream_1)

with tvm_ffi.use_raw_stream(device, stream_2):
mod.check_stream(device.dlpack_device_type(), device.index, stream_2)

mod.check_stream(device.dlpack_device_type(), device.index, stream_1)


@pytest.mark.skipif(
torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA"
)
def test_torch_stream():
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
device_type = device.dlpack_device_type()
stream_1 = torch.cuda.Stream(device_id)
stream_2 = torch.cuda.Stream(device_id)
with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_1)):
assert torch.cuda.current_stream() == stream_1
mod.check_stream(device_type, device_id, stream_1.cuda_stream)

with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_2)):
assert torch.cuda.current_stream() == stream_2
mod.check_stream(device_type, device_id, stream_2.cuda_stream)

assert torch.cuda.current_stream() == stream_1
mod.check_stream(device_type, device_id, stream_1.cuda_stream)


@pytest.mark.skipif(
torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA"
)
def test_torch_current_stream():
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
device_type = device.dlpack_device_type()
stream_1 = torch.cuda.Stream(device_id)
stream_2 = torch.cuda.Stream(device_id)
with torch.cuda.stream(stream_1):
assert torch.cuda.current_stream() == stream_1
with tvm_ffi.use_torch_stream():
mod.check_stream(device_type, device_id, stream_1.cuda_stream)

with torch.cuda.stream(stream_2):
assert torch.cuda.current_stream() == stream_2
with tvm_ffi.use_torch_stream():
mod.check_stream(device_type, device_id, stream_2.cuda_stream)

assert torch.cuda.current_stream() == stream_1
with tvm_ffi.use_torch_stream():
mod.check_stream(device_type, device_id, stream_1.cuda_stream)


@pytest.mark.skipif(
torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA"
)
def test_torch_graph():
mod = gen_check_stream_mod()
device_id = torch.cuda.current_device()
device = tvm_ffi.device("cuda", device_id)
device_type = device.dlpack_device_type()
graph = torch.cuda.CUDAGraph()
stream = torch.cuda.Stream(device_id)
with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)):
assert torch.cuda.current_stream() == stream
mod.check_stream(device_type, device_id, stream.cuda_stream)
Loading